diff --git a/src/all_types.hpp b/src/all_types.hpp index 82d2e2cdd..aa7ff06ce 100644 --- a/src/all_types.hpp +++ b/src/all_types.hpp @@ -3638,6 +3638,7 @@ static const size_t err_union_err_index = 0; static const size_t err_union_payload_index = 1; static const size_t coro_resume_index_index = 0; +static const size_t coro_arg_start = 1; // TODO call graph analysis to find out what this number needs to be for every function // MUST BE A POWER OF TWO. diff --git a/src/analyze.cpp b/src/analyze.cpp index 91a100795..aff11e017 100644 --- a/src/analyze.cpp +++ b/src/analyze.cpp @@ -1891,6 +1891,21 @@ static Error resolve_coro_frame(CodeGen *g, ZigType *frame_type) { field_names.append("resume_index"); field_types.append(g->builtin_types.entry_usize); + for (size_t arg_i = 0; arg_i < fn->type_entry->data.fn.fn_type_id.param_count; arg_i += 1) { + FnTypeParamInfo *param_info = &fn->type_entry->data.fn.fn_type_id.param_info[arg_i]; + AstNode *param_decl_node = get_param_decl_node(fn, arg_i); + Buf *param_name; + bool is_var_args = param_decl_node && param_decl_node->data.param_decl.is_var_args; + if (param_decl_node && !is_var_args) { + param_name = param_decl_node->data.param_decl.name; + } else { + param_name = buf_sprintf("arg%" ZIG_PRI_usize "", arg_i); + } + ZigType *param_type = param_info[arg_i].type; + field_names.append(buf_ptr(param_name)); + field_types.append(param_type); + } + assert(field_names.length == field_types.length); frame_type->data.frame.locals_struct = get_struct_type(g, buf_ptr(&frame_type->name), field_names.items, field_types.items, field_names.length); @@ -7058,19 +7073,22 @@ void resolve_llvm_types_fn(CodeGen *g, ZigType *fn_type, ZigFn *fn) { // +1 for maybe first argument the error return trace // +2 for maybe arguments async allocator and error code pointer ZigList param_di_types = {}; - param_di_types.append(get_llvm_di_type(g, fn_type_id->return_type)); ZigType *gen_return_type; if (is_async) { gen_return_type = g->builtin_types.entry_usize; + param_di_types.append(get_llvm_di_type(g, gen_return_type)); } else if (!type_has_bits(fn_type_id->return_type)) { gen_return_type = g->builtin_types.entry_void; + param_di_types.append(get_llvm_di_type(g, gen_return_type)); } else if (first_arg_return) { + gen_return_type = g->builtin_types.entry_void; + param_di_types.append(get_llvm_di_type(g, gen_return_type)); ZigType *gen_type = get_pointer_to_type(g, fn_type_id->return_type, false); gen_param_types.append(get_llvm_type(g, gen_type)); param_di_types.append(get_llvm_di_type(g, gen_type)); - gen_return_type = g->builtin_types.entry_void; } else { gen_return_type = fn_type_id->return_type; + param_di_types.append(get_llvm_di_type(g, gen_return_type)); } fn_type->data.fn.gen_return_type = gen_return_type; @@ -7080,36 +7098,43 @@ void resolve_llvm_types_fn(CodeGen *g, ZigType *fn_type, ZigFn *fn) { param_di_types.append(get_llvm_di_type(g, gen_type)); } if (is_async) { + fn_type->data.fn.gen_param_info = allocate(1); + ZigType *frame_type = (fn == nullptr) ? g->builtin_types.entry_frame_header : get_coro_frame_type(g, fn); ZigType *ptr_type = get_pointer_to_type(g, frame_type, false); gen_param_types.append(get_llvm_type(g, ptr_type)); param_di_types.append(get_llvm_di_type(g, ptr_type)); - } - fn_type->data.fn.gen_param_info = allocate(fn_type_id->param_count); - for (size_t i = 0; i < fn_type_id->param_count; i += 1) { - FnTypeParamInfo *src_param_info = &fn_type->data.fn.fn_type_id.param_info[i]; - ZigType *type_entry = src_param_info->type; - FnGenParamInfo *gen_param_info = &fn_type->data.fn.gen_param_info[i]; + fn_type->data.fn.gen_param_info[0].src_index = 0; + fn_type->data.fn.gen_param_info[0].gen_index = 0; + fn_type->data.fn.gen_param_info[0].type = ptr_type; - gen_param_info->src_index = i; - gen_param_info->gen_index = SIZE_MAX; + } else { + fn_type->data.fn.gen_param_info = allocate(fn_type_id->param_count); + for (size_t i = 0; i < fn_type_id->param_count; i += 1) { + FnTypeParamInfo *src_param_info = &fn_type->data.fn.fn_type_id.param_info[i]; + ZigType *type_entry = src_param_info->type; + FnGenParamInfo *gen_param_info = &fn_type->data.fn.gen_param_info[i]; - if (is_c_abi || !type_has_bits(type_entry)) - continue; + gen_param_info->src_index = i; + gen_param_info->gen_index = SIZE_MAX; - ZigType *gen_type; - if (handle_is_ptr(type_entry)) { - gen_type = get_pointer_to_type(g, type_entry, true); - gen_param_info->is_byval = true; - } else { - gen_type = type_entry; + if (is_c_abi || !type_has_bits(type_entry)) + continue; + + ZigType *gen_type; + if (handle_is_ptr(type_entry)) { + gen_type = get_pointer_to_type(g, type_entry, true); + gen_param_info->is_byval = true; + } else { + gen_type = type_entry; + } + gen_param_info->gen_index = gen_param_types.length; + gen_param_info->type = gen_type; + gen_param_types.append(get_llvm_type(g, gen_type)); + + param_di_types.append(get_llvm_di_type(g, gen_type)); } - gen_param_info->gen_index = gen_param_types.length; - gen_param_info->type = gen_type; - gen_param_types.append(get_llvm_type(g, gen_type)); - - param_di_types.append(get_llvm_di_type(g, gen_type)); } if (is_c_abi) { diff --git a/src/codegen.cpp b/src/codegen.cpp index fa5f3ef8e..9184bd7c8 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -1965,10 +1965,12 @@ static bool iter_function_params_c_abi(CodeGen *g, ZigType *fn_type, FnWalk *fn_ } case FnWalkIdInits: { clear_debug_source_node(g); - LLVMValueRef arg = LLVMGetParam(llvm_fn, fn_walk->data.inits.gen_i); - LLVMTypeRef ptr_to_int_type_ref = LLVMPointerType(LLVMIntType((unsigned)ty_size * 8), 0); - LLVMValueRef bitcasted = LLVMBuildBitCast(g->builder, var->value_ref, ptr_to_int_type_ref, ""); - gen_store_untyped(g, arg, bitcasted, var->align_bytes, false); + if (fn_walk->data.inits.fn->resume_blocks.length == 0) { + LLVMValueRef arg = LLVMGetParam(llvm_fn, fn_walk->data.inits.gen_i); + LLVMTypeRef ptr_to_int_type_ref = LLVMPointerType(LLVMIntType((unsigned)ty_size * 8), 0); + LLVMValueRef bitcasted = LLVMBuildBitCast(g->builder, var->value_ref, ptr_to_int_type_ref, ""); + gen_store_untyped(g, arg, bitcasted, var->align_bytes, false); + } if (var->decl_node) { gen_var_debug_decl(g, var); } @@ -2061,7 +2063,7 @@ void walk_function_params(CodeGen *g, ZigType *fn_type, FnWalk *fn_walk) { assert(variable); assert(variable->value_ref); - if (!handle_is_ptr(variable->var_type)) { + if (!handle_is_ptr(variable->var_type) && fn_walk->data.inits.fn->resume_blocks.length == 0) { clear_debug_source_node(g); ZigType *fn_type = fn_table_entry->type_entry; unsigned gen_arg_index = fn_type->data.fn.gen_param_info[variable->src_arg_index].gen_index; @@ -3471,8 +3473,6 @@ static LLVMValueRef ir_render_call(CodeGen *g, IrExecutable *executable, IrInstr if (prefix_arg_err_ret_stack) { zig_panic("TODO"); } - - gen_param_values.append(result_loc); } else { if (first_arg_ret) { gen_param_values.append(result_loc); @@ -3504,6 +3504,15 @@ static LLVMValueRef ir_render_call(CodeGen *g, IrExecutable *executable, IrInstr LLVMCallConv llvm_cc = get_llvm_cc(g, cc); LLVMValueRef result; + if (instruction->is_async) { + for (size_t arg_i = 0; arg_i < gen_param_values.length; arg_i += 1) { + LLVMValueRef arg_ptr = LLVMBuildStructGEP(g->builder, result_loc, coro_arg_start + arg_i, ""); + LLVMBuildStore(g->builder, gen_param_values.at(arg_i), arg_ptr); + } + ZigLLVMBuildCall(g->builder, fn_val, &result_loc, 1, llvm_cc, fn_inline, ""); + return nullptr; + } + if (instruction->new_stack == nullptr) { result = ZigLLVMBuildCall(g->builder, fn_val, gen_param_values.items, (unsigned)gen_param_values.length, llvm_cc, fn_inline, ""); @@ -3519,11 +3528,6 @@ static LLVMValueRef ir_render_call(CodeGen *g, IrExecutable *executable, IrInstr LLVMBuildCall(g->builder, stackrestore_fn_val, &old_stack_ref, 1, ""); } - - if (instruction->is_async) { - return nullptr; - } - if (src_return_type->id == ZigTypeIdUnreachable) { return LLVMBuildUnreachable(g->builder); } else if (!ret_has_bits) { @@ -6285,7 +6289,9 @@ static void do_code_gen(CodeGen *g) { build_all_basic_blocks(g, fn_table_entry); clear_debug_source_node(g); - if (want_sret || fn_table_entry->resume_blocks.length != 0) { + bool is_async = cc == CallingConventionAsync || fn_table_entry->resume_blocks.length != 0; + + if (want_sret || is_async) { g->cur_ret_ptr = LLVMGetParam(fn, 0); } else if (handle_is_ptr(fn_type_id->return_type)) { g->cur_ret_ptr = build_alloca(g, fn_type_id->return_type, "result", 0); @@ -6303,7 +6309,6 @@ static void do_code_gen(CodeGen *g) { } // error return tracing setup - bool is_async = cc == CallingConventionAsync; bool have_err_ret_trace_stack = g->have_err_ret_tracing && fn_table_entry->calls_or_awaits_errorable_fn && !is_async && !have_err_ret_trace_arg; LLVMValueRef err_ret_array_val = nullptr; if (have_err_ret_trace_stack) { @@ -6378,7 +6383,9 @@ static void do_code_gen(CodeGen *g) { FnGenParamInfo *gen_info = &fn_table_entry->type_entry->data.fn.gen_param_info[var->src_arg_index]; assert(gen_info->gen_index != SIZE_MAX); - if (handle_is_ptr(var->var_type)) { + if (is_async) { + var->value_ref = LLVMBuildStructGEP(g->builder, g->cur_ret_ptr, coro_arg_start + var_i, ""); + } else if (handle_is_ptr(var->var_type)) { if (gen_info->is_byval) { gen_type = var->var_type; } else { diff --git a/test/stage1/behavior/coroutines.zig b/test/stage1/behavior/coroutines.zig index fd07790e7..3a5465702 100644 --- a/test/stage1/behavior/coroutines.zig +++ b/test/stage1/behavior/coroutines.zig @@ -2,33 +2,34 @@ const std = @import("std"); const builtin = @import("builtin"); const expect = std.testing.expect; -var x: i32 = 1; +var global_x: i32 = 1; test "simple coroutine suspend and resume" { const p = async simpleAsyncFn(); - expect(x == 2); + expect(global_x == 2); resume p; - expect(x == 3); + expect(global_x == 3); } fn simpleAsyncFn() void { - x += 1; + global_x += 1; suspend; - x += 1; + global_x += 1; } -//test "create a coroutine and cancel it" { -// const p = try async simpleAsyncFn(); -// comptime expect(@typeOf(p) == promise->void); -// cancel p; -// expect(x == 2); -//} -//async fn simpleAsyncFn() void { -// x += 1; -// suspend; -// x += 1; -//} -// -//test "coroutine suspend, resume, cancel" { +var global_y: i32 = 1; + +test "pass parameter to coroutine" { + const p = async simpleAsyncFnWithArg(2); + expect(global_y == 3); + resume p; + expect(global_y == 5); +} +fn simpleAsyncFnWithArg(delta: i32) void { + global_y += delta; + suspend; + global_y += delta; +} +//test "coroutine suspend, resume" { // seq('a'); // const p = try async testAsyncSeq(); // seq('c');