implement async function parameters
This commit is contained in:
parent
11bd50f2b2
commit
59bf9ca58c
|
@ -3638,6 +3638,7 @@ static const size_t err_union_err_index = 0;
|
|||
static const size_t err_union_payload_index = 1;
|
||||
|
||||
static const size_t coro_resume_index_index = 0;
|
||||
static const size_t coro_arg_start = 1;
|
||||
|
||||
// TODO call graph analysis to find out what this number needs to be for every function
|
||||
// MUST BE A POWER OF TWO.
|
||||
|
|
|
@ -1891,6 +1891,21 @@ static Error resolve_coro_frame(CodeGen *g, ZigType *frame_type) {
|
|||
field_names.append("resume_index");
|
||||
field_types.append(g->builtin_types.entry_usize);
|
||||
|
||||
for (size_t arg_i = 0; arg_i < fn->type_entry->data.fn.fn_type_id.param_count; arg_i += 1) {
|
||||
FnTypeParamInfo *param_info = &fn->type_entry->data.fn.fn_type_id.param_info[arg_i];
|
||||
AstNode *param_decl_node = get_param_decl_node(fn, arg_i);
|
||||
Buf *param_name;
|
||||
bool is_var_args = param_decl_node && param_decl_node->data.param_decl.is_var_args;
|
||||
if (param_decl_node && !is_var_args) {
|
||||
param_name = param_decl_node->data.param_decl.name;
|
||||
} else {
|
||||
param_name = buf_sprintf("arg%" ZIG_PRI_usize "", arg_i);
|
||||
}
|
||||
ZigType *param_type = param_info[arg_i].type;
|
||||
field_names.append(buf_ptr(param_name));
|
||||
field_types.append(param_type);
|
||||
}
|
||||
|
||||
assert(field_names.length == field_types.length);
|
||||
frame_type->data.frame.locals_struct = get_struct_type(g, buf_ptr(&frame_type->name),
|
||||
field_names.items, field_types.items, field_names.length);
|
||||
|
@ -7058,19 +7073,22 @@ void resolve_llvm_types_fn(CodeGen *g, ZigType *fn_type, ZigFn *fn) {
|
|||
// +1 for maybe first argument the error return trace
|
||||
// +2 for maybe arguments async allocator and error code pointer
|
||||
ZigList<ZigLLVMDIType *> param_di_types = {};
|
||||
param_di_types.append(get_llvm_di_type(g, fn_type_id->return_type));
|
||||
ZigType *gen_return_type;
|
||||
if (is_async) {
|
||||
gen_return_type = g->builtin_types.entry_usize;
|
||||
param_di_types.append(get_llvm_di_type(g, gen_return_type));
|
||||
} else if (!type_has_bits(fn_type_id->return_type)) {
|
||||
gen_return_type = g->builtin_types.entry_void;
|
||||
param_di_types.append(get_llvm_di_type(g, gen_return_type));
|
||||
} else if (first_arg_return) {
|
||||
gen_return_type = g->builtin_types.entry_void;
|
||||
param_di_types.append(get_llvm_di_type(g, gen_return_type));
|
||||
ZigType *gen_type = get_pointer_to_type(g, fn_type_id->return_type, false);
|
||||
gen_param_types.append(get_llvm_type(g, gen_type));
|
||||
param_di_types.append(get_llvm_di_type(g, gen_type));
|
||||
gen_return_type = g->builtin_types.entry_void;
|
||||
} else {
|
||||
gen_return_type = fn_type_id->return_type;
|
||||
param_di_types.append(get_llvm_di_type(g, gen_return_type));
|
||||
}
|
||||
fn_type->data.fn.gen_return_type = gen_return_type;
|
||||
|
||||
|
@ -7080,36 +7098,43 @@ void resolve_llvm_types_fn(CodeGen *g, ZigType *fn_type, ZigFn *fn) {
|
|||
param_di_types.append(get_llvm_di_type(g, gen_type));
|
||||
}
|
||||
if (is_async) {
|
||||
fn_type->data.fn.gen_param_info = allocate<FnGenParamInfo>(1);
|
||||
|
||||
ZigType *frame_type = (fn == nullptr) ? g->builtin_types.entry_frame_header : get_coro_frame_type(g, fn);
|
||||
ZigType *ptr_type = get_pointer_to_type(g, frame_type, false);
|
||||
gen_param_types.append(get_llvm_type(g, ptr_type));
|
||||
param_di_types.append(get_llvm_di_type(g, ptr_type));
|
||||
}
|
||||
|
||||
fn_type->data.fn.gen_param_info = allocate<FnGenParamInfo>(fn_type_id->param_count);
|
||||
for (size_t i = 0; i < fn_type_id->param_count; i += 1) {
|
||||
FnTypeParamInfo *src_param_info = &fn_type->data.fn.fn_type_id.param_info[i];
|
||||
ZigType *type_entry = src_param_info->type;
|
||||
FnGenParamInfo *gen_param_info = &fn_type->data.fn.gen_param_info[i];
|
||||
fn_type->data.fn.gen_param_info[0].src_index = 0;
|
||||
fn_type->data.fn.gen_param_info[0].gen_index = 0;
|
||||
fn_type->data.fn.gen_param_info[0].type = ptr_type;
|
||||
|
||||
gen_param_info->src_index = i;
|
||||
gen_param_info->gen_index = SIZE_MAX;
|
||||
} else {
|
||||
fn_type->data.fn.gen_param_info = allocate<FnGenParamInfo>(fn_type_id->param_count);
|
||||
for (size_t i = 0; i < fn_type_id->param_count; i += 1) {
|
||||
FnTypeParamInfo *src_param_info = &fn_type->data.fn.fn_type_id.param_info[i];
|
||||
ZigType *type_entry = src_param_info->type;
|
||||
FnGenParamInfo *gen_param_info = &fn_type->data.fn.gen_param_info[i];
|
||||
|
||||
if (is_c_abi || !type_has_bits(type_entry))
|
||||
continue;
|
||||
gen_param_info->src_index = i;
|
||||
gen_param_info->gen_index = SIZE_MAX;
|
||||
|
||||
ZigType *gen_type;
|
||||
if (handle_is_ptr(type_entry)) {
|
||||
gen_type = get_pointer_to_type(g, type_entry, true);
|
||||
gen_param_info->is_byval = true;
|
||||
} else {
|
||||
gen_type = type_entry;
|
||||
if (is_c_abi || !type_has_bits(type_entry))
|
||||
continue;
|
||||
|
||||
ZigType *gen_type;
|
||||
if (handle_is_ptr(type_entry)) {
|
||||
gen_type = get_pointer_to_type(g, type_entry, true);
|
||||
gen_param_info->is_byval = true;
|
||||
} else {
|
||||
gen_type = type_entry;
|
||||
}
|
||||
gen_param_info->gen_index = gen_param_types.length;
|
||||
gen_param_info->type = gen_type;
|
||||
gen_param_types.append(get_llvm_type(g, gen_type));
|
||||
|
||||
param_di_types.append(get_llvm_di_type(g, gen_type));
|
||||
}
|
||||
gen_param_info->gen_index = gen_param_types.length;
|
||||
gen_param_info->type = gen_type;
|
||||
gen_param_types.append(get_llvm_type(g, gen_type));
|
||||
|
||||
param_di_types.append(get_llvm_di_type(g, gen_type));
|
||||
}
|
||||
|
||||
if (is_c_abi) {
|
||||
|
|
|
@ -1965,10 +1965,12 @@ static bool iter_function_params_c_abi(CodeGen *g, ZigType *fn_type, FnWalk *fn_
|
|||
}
|
||||
case FnWalkIdInits: {
|
||||
clear_debug_source_node(g);
|
||||
LLVMValueRef arg = LLVMGetParam(llvm_fn, fn_walk->data.inits.gen_i);
|
||||
LLVMTypeRef ptr_to_int_type_ref = LLVMPointerType(LLVMIntType((unsigned)ty_size * 8), 0);
|
||||
LLVMValueRef bitcasted = LLVMBuildBitCast(g->builder, var->value_ref, ptr_to_int_type_ref, "");
|
||||
gen_store_untyped(g, arg, bitcasted, var->align_bytes, false);
|
||||
if (fn_walk->data.inits.fn->resume_blocks.length == 0) {
|
||||
LLVMValueRef arg = LLVMGetParam(llvm_fn, fn_walk->data.inits.gen_i);
|
||||
LLVMTypeRef ptr_to_int_type_ref = LLVMPointerType(LLVMIntType((unsigned)ty_size * 8), 0);
|
||||
LLVMValueRef bitcasted = LLVMBuildBitCast(g->builder, var->value_ref, ptr_to_int_type_ref, "");
|
||||
gen_store_untyped(g, arg, bitcasted, var->align_bytes, false);
|
||||
}
|
||||
if (var->decl_node) {
|
||||
gen_var_debug_decl(g, var);
|
||||
}
|
||||
|
@ -2061,7 +2063,7 @@ void walk_function_params(CodeGen *g, ZigType *fn_type, FnWalk *fn_walk) {
|
|||
assert(variable);
|
||||
assert(variable->value_ref);
|
||||
|
||||
if (!handle_is_ptr(variable->var_type)) {
|
||||
if (!handle_is_ptr(variable->var_type) && fn_walk->data.inits.fn->resume_blocks.length == 0) {
|
||||
clear_debug_source_node(g);
|
||||
ZigType *fn_type = fn_table_entry->type_entry;
|
||||
unsigned gen_arg_index = fn_type->data.fn.gen_param_info[variable->src_arg_index].gen_index;
|
||||
|
@ -3471,8 +3473,6 @@ static LLVMValueRef ir_render_call(CodeGen *g, IrExecutable *executable, IrInstr
|
|||
if (prefix_arg_err_ret_stack) {
|
||||
zig_panic("TODO");
|
||||
}
|
||||
|
||||
gen_param_values.append(result_loc);
|
||||
} else {
|
||||
if (first_arg_ret) {
|
||||
gen_param_values.append(result_loc);
|
||||
|
@ -3504,6 +3504,15 @@ static LLVMValueRef ir_render_call(CodeGen *g, IrExecutable *executable, IrInstr
|
|||
LLVMCallConv llvm_cc = get_llvm_cc(g, cc);
|
||||
LLVMValueRef result;
|
||||
|
||||
if (instruction->is_async) {
|
||||
for (size_t arg_i = 0; arg_i < gen_param_values.length; arg_i += 1) {
|
||||
LLVMValueRef arg_ptr = LLVMBuildStructGEP(g->builder, result_loc, coro_arg_start + arg_i, "");
|
||||
LLVMBuildStore(g->builder, gen_param_values.at(arg_i), arg_ptr);
|
||||
}
|
||||
ZigLLVMBuildCall(g->builder, fn_val, &result_loc, 1, llvm_cc, fn_inline, "");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (instruction->new_stack == nullptr) {
|
||||
result = ZigLLVMBuildCall(g->builder, fn_val,
|
||||
gen_param_values.items, (unsigned)gen_param_values.length, llvm_cc, fn_inline, "");
|
||||
|
@ -3519,11 +3528,6 @@ static LLVMValueRef ir_render_call(CodeGen *g, IrExecutable *executable, IrInstr
|
|||
LLVMBuildCall(g->builder, stackrestore_fn_val, &old_stack_ref, 1, "");
|
||||
}
|
||||
|
||||
|
||||
if (instruction->is_async) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (src_return_type->id == ZigTypeIdUnreachable) {
|
||||
return LLVMBuildUnreachable(g->builder);
|
||||
} else if (!ret_has_bits) {
|
||||
|
@ -6285,7 +6289,9 @@ static void do_code_gen(CodeGen *g) {
|
|||
build_all_basic_blocks(g, fn_table_entry);
|
||||
clear_debug_source_node(g);
|
||||
|
||||
if (want_sret || fn_table_entry->resume_blocks.length != 0) {
|
||||
bool is_async = cc == CallingConventionAsync || fn_table_entry->resume_blocks.length != 0;
|
||||
|
||||
if (want_sret || is_async) {
|
||||
g->cur_ret_ptr = LLVMGetParam(fn, 0);
|
||||
} else if (handle_is_ptr(fn_type_id->return_type)) {
|
||||
g->cur_ret_ptr = build_alloca(g, fn_type_id->return_type, "result", 0);
|
||||
|
@ -6303,7 +6309,6 @@ static void do_code_gen(CodeGen *g) {
|
|||
}
|
||||
|
||||
// error return tracing setup
|
||||
bool is_async = cc == CallingConventionAsync;
|
||||
bool have_err_ret_trace_stack = g->have_err_ret_tracing && fn_table_entry->calls_or_awaits_errorable_fn && !is_async && !have_err_ret_trace_arg;
|
||||
LLVMValueRef err_ret_array_val = nullptr;
|
||||
if (have_err_ret_trace_stack) {
|
||||
|
@ -6378,7 +6383,9 @@ static void do_code_gen(CodeGen *g) {
|
|||
FnGenParamInfo *gen_info = &fn_table_entry->type_entry->data.fn.gen_param_info[var->src_arg_index];
|
||||
assert(gen_info->gen_index != SIZE_MAX);
|
||||
|
||||
if (handle_is_ptr(var->var_type)) {
|
||||
if (is_async) {
|
||||
var->value_ref = LLVMBuildStructGEP(g->builder, g->cur_ret_ptr, coro_arg_start + var_i, "");
|
||||
} else if (handle_is_ptr(var->var_type)) {
|
||||
if (gen_info->is_byval) {
|
||||
gen_type = var->var_type;
|
||||
} else {
|
||||
|
|
|
@ -2,33 +2,34 @@ const std = @import("std");
|
|||
const builtin = @import("builtin");
|
||||
const expect = std.testing.expect;
|
||||
|
||||
var x: i32 = 1;
|
||||
var global_x: i32 = 1;
|
||||
|
||||
test "simple coroutine suspend and resume" {
|
||||
const p = async simpleAsyncFn();
|
||||
expect(x == 2);
|
||||
expect(global_x == 2);
|
||||
resume p;
|
||||
expect(x == 3);
|
||||
expect(global_x == 3);
|
||||
}
|
||||
fn simpleAsyncFn() void {
|
||||
x += 1;
|
||||
global_x += 1;
|
||||
suspend;
|
||||
x += 1;
|
||||
global_x += 1;
|
||||
}
|
||||
|
||||
//test "create a coroutine and cancel it" {
|
||||
// const p = try async<allocator> simpleAsyncFn();
|
||||
// comptime expect(@typeOf(p) == promise->void);
|
||||
// cancel p;
|
||||
// expect(x == 2);
|
||||
//}
|
||||
//async fn simpleAsyncFn() void {
|
||||
// x += 1;
|
||||
// suspend;
|
||||
// x += 1;
|
||||
//}
|
||||
//
|
||||
//test "coroutine suspend, resume, cancel" {
|
||||
var global_y: i32 = 1;
|
||||
|
||||
test "pass parameter to coroutine" {
|
||||
const p = async simpleAsyncFnWithArg(2);
|
||||
expect(global_y == 3);
|
||||
resume p;
|
||||
expect(global_y == 5);
|
||||
}
|
||||
fn simpleAsyncFnWithArg(delta: i32) void {
|
||||
global_y += delta;
|
||||
suspend;
|
||||
global_y += delta;
|
||||
}
|
||||
//test "coroutine suspend, resume" {
|
||||
// seq('a');
|
||||
// const p = try async<allocator> testAsyncSeq();
|
||||
// seq('c');
|
||||
|
|
Loading…
Reference in New Issue
Block a user