From ead2d32be871411685f846e604ec7e4253b9f25a Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Thu, 25 Jul 2019 00:03:06 -0400 Subject: [PATCH] calling an inferred async function --- src/all_types.hpp | 17 ++- src/analyze.cpp | 220 +++++++++++++++++----------- src/codegen.cpp | 72 +++++++-- src/ir.cpp | 26 ++-- src/zig_llvm.cpp | 4 + src/zig_llvm.h | 1 + test/stage1/behavior/coroutines.zig | 16 ++ 7 files changed, 238 insertions(+), 118 deletions(-) diff --git a/src/all_types.hpp b/src/all_types.hpp index a68f19a87..d67356b17 100644 --- a/src/all_types.hpp +++ b/src/all_types.hpp @@ -35,6 +35,7 @@ struct ConstExprValue; struct IrInstruction; struct IrInstructionCast; struct IrInstructionAllocaGen; +struct IrInstructionCallGen; struct IrBasicBlock; struct ScopeDecls; struct ZigWindowsSDK; @@ -1336,11 +1337,6 @@ struct GlobalExport { GlobalLinkageId linkage; }; -struct FnCall { - AstNode *source_node; - ZigFn *callee; -}; - struct ZigFn { LLVMValueRef llvm_value; const char *llvm_name; @@ -1387,7 +1383,7 @@ struct ZigFn { ZigFn *inferred_async_fn; ZigList export_list; - ZigList call_list; + ZigList call_list; LLVMValueRef valgrind_client_request_array; LLVMBasicBlockRef preamble_llvm_block; @@ -2585,6 +2581,8 @@ struct IrInstructionCallGen { size_t arg_count; IrInstruction **args; IrInstruction *result_loc; + IrInstruction *frame_result_loc; + IrBasicBlock *resume_block; IrInstruction *new_stack; FnInline fn_inline; @@ -3645,7 +3643,12 @@ static const size_t err_union_err_index = 0; static const size_t err_union_payload_index = 1; static const size_t coro_resume_index_index = 0; -static const size_t coro_arg_start = 1; +static const size_t coro_fn_ptr_index = 1; +static const size_t coro_awaiter_index = 2; +static const size_t coro_arg_start = 3; + +// one for the GetSize block, one for the Entry block, resume blocks are indexed after that. +static const size_t coro_extra_resume_block_count = 2; // TODO call graph analysis to find out what this number needs to be for every function // MUST BE A POWER OF TWO. diff --git a/src/analyze.cpp b/src/analyze.cpp index 957e61b19..c8e02a477 100644 --- a/src/analyze.cpp +++ b/src/analyze.cpp @@ -1869,80 +1869,6 @@ static Error resolve_union_type(CodeGen *g, ZigType *union_type) { return ErrorNone; } -static Error resolve_coro_frame(CodeGen *g, ZigType *frame_type) { - if (frame_type->data.frame.locals_struct != nullptr) - return ErrorNone; - - ZigFn *fn = frame_type->data.frame.fn; - switch (fn->anal_state) { - case FnAnalStateInvalid: - return ErrorSemanticAnalyzeFail; - case FnAnalStateComplete: - break; - case FnAnalStateReady: - analyze_fn_body(g, fn); - if (fn->anal_state == FnAnalStateInvalid) - return ErrorSemanticAnalyzeFail; - break; - case FnAnalStateProbing: - add_node_error(g, fn->proto_node, - buf_sprintf("cannot resolve '%s': function not fully analyzed yet", - buf_ptr(&frame_type->name))); - return ErrorSemanticAnalyzeFail; - } - // TODO iterate over fn->alloca_gen_list - ZigList field_types = {}; - ZigList field_names = {}; - - field_names.append("resume_index"); - field_types.append(g->builtin_types.entry_usize); - - FnTypeId *fn_type_id = &fn->type_entry->data.fn.fn_type_id; - field_names.append("result"); - field_types.append(fn_type_id->return_type); - - for (size_t arg_i = 0; arg_i < fn_type_id->param_count; arg_i += 1) { - FnTypeParamInfo *param_info = &fn_type_id->param_info[arg_i]; - AstNode *param_decl_node = get_param_decl_node(fn, arg_i); - Buf *param_name; - bool is_var_args = param_decl_node && param_decl_node->data.param_decl.is_var_args; - if (param_decl_node && !is_var_args) { - param_name = param_decl_node->data.param_decl.name; - } else { - param_name = buf_sprintf("arg%" ZIG_PRI_usize "", arg_i); - } - ZigType *param_type = param_info->type; - field_names.append(buf_ptr(param_name)); - field_types.append(param_type); - } - - for (size_t alloca_i = 0; alloca_i < fn->alloca_gen_list.length; alloca_i += 1) { - IrInstructionAllocaGen *instruction = fn->alloca_gen_list.at(alloca_i); - ZigType *ptr_type = instruction->base.value.type; - assert(ptr_type->id == ZigTypeIdPointer); - ZigType *child_type = ptr_type->data.pointer.child_type; - if (!type_has_bits(child_type)) - continue; - if (instruction->base.ref_count == 0) - continue; - if (instruction->base.value.special != ConstValSpecialRuntime) { - if (const_ptr_pointee(nullptr, g, &instruction->base.value, nullptr)->special != - ConstValSpecialRuntime) - { - continue; - } - } - field_names.append(instruction->name_hint); - field_types.append(child_type); - } - - - assert(field_names.length == field_types.length); - frame_type->data.frame.locals_struct = get_struct_type(g, buf_ptr(&frame_type->name), - field_names.items, field_types.items, field_names.length); - return ErrorNone; -} - static bool type_is_valid_extern_enum_tag(CodeGen *g, ZigType *ty) { // Only integer types are allowed by the C ABI if(ty->id != ZigTypeIdInt) @@ -3861,18 +3787,24 @@ static void analyze_fn_async(CodeGen *g, ZigFn *fn) { } for (size_t i = 0; i < fn->call_list.length; i += 1) { - FnCall *call = &fn->call_list.at(i); - if (call->callee->type_entry->data.fn.fn_type_id.cc != CallingConventionUnspecified) + IrInstructionCallGen *call = fn->call_list.at(i); + ZigFn *callee = call->fn_entry; + if (callee == nullptr) { + // TODO function pointer call here, could be anything continue; - assert(call->callee->anal_state == FnAnalStateComplete); - analyze_fn_async(g, call->callee); - if (call->callee->anal_state == FnAnalStateInvalid) { + } + + if (callee->type_entry->data.fn.fn_type_id.cc != CallingConventionUnspecified) + continue; + assert(callee->anal_state == FnAnalStateComplete); + analyze_fn_async(g, callee); + if (callee->anal_state == FnAnalStateInvalid) { fn->anal_state = FnAnalStateInvalid; return; } - if (fn_is_async(call->callee)) { - fn->inferred_async_node = call->source_node; - fn->inferred_async_fn = call->callee; + if (fn_is_async(callee)) { + fn->inferred_async_node = call->base.source_node; + fn->inferred_async_fn = callee; if (must_not_be_async) { ErrorMsg *msg = add_node_error(g, fn->proto_node, buf_sprintf("function with calling convention '%s' cannot be async", @@ -5147,6 +5079,127 @@ Error ensure_complete_type(CodeGen *g, ZigType *type_entry) { return type_resolve(g, type_entry, ResolveStatusSizeKnown); } +static Error resolve_coro_frame(CodeGen *g, ZigType *frame_type) { + if (frame_type->data.frame.locals_struct != nullptr) + return ErrorNone; + + ZigFn *fn = frame_type->data.frame.fn; + switch (fn->anal_state) { + case FnAnalStateInvalid: + return ErrorSemanticAnalyzeFail; + case FnAnalStateComplete: + break; + case FnAnalStateReady: + analyze_fn_body(g, fn); + if (fn->anal_state == FnAnalStateInvalid) + return ErrorSemanticAnalyzeFail; + break; + case FnAnalStateProbing: + add_node_error(g, fn->proto_node, + buf_sprintf("cannot resolve '%s': function not fully analyzed yet", + buf_ptr(&frame_type->name))); + return ErrorSemanticAnalyzeFail; + } + + for (size_t i = 0; i < fn->call_list.length; i += 1) { + IrInstructionCallGen *call = fn->call_list.at(i); + ZigFn *callee = call->fn_entry; + assert(callee != nullptr); + + analyze_fn_body(g, callee); + if (callee->anal_state == FnAnalStateInvalid) { + frame_type->data.frame.locals_struct = g->builtin_types.entry_invalid; + return ErrorSemanticAnalyzeFail; + } + analyze_fn_async(g, callee); + if (!fn_is_async(callee)) + continue; + + IrBasicBlock *new_resume_block = allocate(1); + new_resume_block->name_hint = "CallResume"; + new_resume_block->resume_index = fn->resume_blocks.length + coro_extra_resume_block_count; + fn->resume_blocks.append(new_resume_block); + call->resume_block = new_resume_block; + fn->analyzed_executable.basic_block_list.append(new_resume_block); + + ZigType *callee_frame_type = get_coro_frame_type(g, callee); + + IrInstructionAllocaGen *alloca_gen = allocate(1); + alloca_gen->base.id = IrInstructionIdAllocaGen; + alloca_gen->base.source_node = call->base.source_node; + alloca_gen->base.scope = call->base.scope; + alloca_gen->base.value.type = get_pointer_to_type(g, callee_frame_type, false); + alloca_gen->base.ref_count = 1; + alloca_gen->name_hint = ""; + fn->alloca_gen_list.append(alloca_gen); + call->frame_result_loc = &alloca_gen->base; + } + + ZigList field_types = {}; + ZigList field_names = {}; + + field_names.append("resume_index"); + field_types.append(g->builtin_types.entry_usize); + + field_names.append("fn_ptr"); + field_types.append(fn->type_entry); + + field_names.append("awaiter"); + field_types.append(g->builtin_types.entry_usize); + + FnTypeId *fn_type_id = &fn->type_entry->data.fn.fn_type_id; + ZigType *ptr_return_type = get_pointer_to_type(g, fn_type_id->return_type, false); + field_names.append("result_ptr"); + field_types.append(ptr_return_type); + + field_names.append("result"); + field_types.append(fn_type_id->return_type); + + for (size_t arg_i = 0; arg_i < fn_type_id->param_count; arg_i += 1) { + FnTypeParamInfo *param_info = &fn_type_id->param_info[arg_i]; + AstNode *param_decl_node = get_param_decl_node(fn, arg_i); + Buf *param_name; + bool is_var_args = param_decl_node && param_decl_node->data.param_decl.is_var_args; + if (param_decl_node && !is_var_args) { + param_name = param_decl_node->data.param_decl.name; + } else { + param_name = buf_sprintf("arg%" ZIG_PRI_usize "", arg_i); + } + ZigType *param_type = param_info->type; + field_names.append(buf_ptr(param_name)); + field_types.append(param_type); + } + + for (size_t alloca_i = 0; alloca_i < fn->alloca_gen_list.length; alloca_i += 1) { + IrInstructionAllocaGen *instruction = fn->alloca_gen_list.at(alloca_i); + ZigType *ptr_type = instruction->base.value.type; + assert(ptr_type->id == ZigTypeIdPointer); + ZigType *child_type = ptr_type->data.pointer.child_type; + if (!type_has_bits(child_type)) + continue; + if (instruction->base.ref_count == 0) + continue; + if (instruction->base.value.special != ConstValSpecialRuntime) { + if (const_ptr_pointee(nullptr, g, &instruction->base.value, nullptr)->special != + ConstValSpecialRuntime) + { + continue; + } + } + field_names.append(instruction->name_hint); + field_types.append(child_type); + } + + + assert(field_names.length == field_types.length); + frame_type->data.frame.locals_struct = get_struct_type(g, buf_ptr(&frame_type->name), + field_names.items, field_types.items, field_names.length); + frame_type->abi_size = frame_type->data.frame.locals_struct->abi_size; + frame_type->abi_align = frame_type->data.frame.locals_struct->abi_align; + frame_type->size_in_bits = frame_type->data.frame.locals_struct->size_in_bits; + return ErrorNone; +} + Error type_resolve(CodeGen *g, ZigType *ty, ResolveStatus status) { if (type_is_invalid(ty)) return ErrorSemanticAnalyzeFail; @@ -7343,9 +7396,6 @@ static void resolve_llvm_types_coro_frame(CodeGen *g, ZigType *frame_type, Resol resolve_llvm_types_struct(g, frame_type->data.frame.locals_struct, wanted_resolve_status); frame_type->llvm_type = frame_type->data.frame.locals_struct->llvm_type; frame_type->llvm_di_type = frame_type->data.frame.locals_struct->llvm_di_type; - frame_type->abi_size = frame_type->data.frame.locals_struct->abi_size; - frame_type->abi_align = frame_type->data.frame.locals_struct->abi_align; - frame_type->size_in_bits = frame_type->data.frame.locals_struct->size_in_bits; } static void resolve_llvm_types(CodeGen *g, ZigType *type, ResolveStatus wanted_resolve_status) { diff --git a/src/codegen.cpp b/src/codegen.cpp index 534b97232..34f4aa1cc 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -3324,13 +3324,16 @@ static void set_call_instr_sret(CodeGen *g, LLVMValueRef call_instr) { static LLVMValueRef ir_render_call(CodeGen *g, IrExecutable *executable, IrInstructionCallGen *instruction) { LLVMValueRef fn_val; ZigType *fn_type; + bool callee_is_async; if (instruction->fn_entry) { fn_val = fn_llvm_value(g, instruction->fn_entry); fn_type = instruction->fn_entry->type_entry; + callee_is_async = fn_is_async(instruction->fn_entry); } else { assert(instruction->fn_ref); fn_val = ir_llvm_value(g, instruction->fn_ref); fn_type = instruction->fn_ref->value.type; + callee_is_async = fn_type->data.fn.fn_type_id.cc == CallingConventionAsync; } FnTypeId *fn_type_id = &fn_type->data.fn.fn_type_id; @@ -3345,17 +3348,47 @@ static LLVMValueRef ir_render_call(CodeGen *g, IrExecutable *executable, IrInstr bool is_var_args = fn_type_id->is_var_args; ZigList gen_param_values = {}; LLVMValueRef result_loc = instruction->result_loc ? ir_llvm_value(g, instruction->result_loc) : nullptr; + LLVMValueRef zero = LLVMConstNull(g->builtin_types.entry_usize->llvm_type); + LLVMValueRef frame_result_loc; + LLVMValueRef awaiter_init_val; + LLVMValueRef ret_ptr; if (instruction->is_async) { - assert(result_loc != nullptr); + frame_result_loc = result_loc; + awaiter_init_val = zero; + if (ret_has_bits) { + ret_ptr = LLVMBuildStructGEP(g->builder, frame_result_loc, coro_arg_start + 1, ""); + } + } else if (callee_is_async) { + frame_result_loc = ir_llvm_value(g, instruction->frame_result_loc); + awaiter_init_val = LLVMBuildPtrToInt(g->builder, g->cur_ret_ptr, + g->builtin_types.entry_usize->llvm_type, ""); // caller's own frame pointer + if (ret_has_bits) { + ret_ptr = result_loc; + } + } + if (instruction->is_async || callee_is_async) { + assert(frame_result_loc != nullptr); assert(instruction->fn_entry != nullptr); - LLVMValueRef resume_index_ptr = LLVMBuildStructGEP(g->builder, result_loc, coro_resume_index_index, ""); - LLVMValueRef zero = LLVMConstNull(g->builtin_types.entry_usize->llvm_type); + LLVMValueRef resume_index_ptr = LLVMBuildStructGEP(g->builder, frame_result_loc, coro_resume_index_index, ""); LLVMBuildStore(g->builder, zero, resume_index_ptr); + LLVMValueRef fn_ptr_ptr = LLVMBuildStructGEP(g->builder, frame_result_loc, coro_fn_ptr_index, ""); + LLVMValueRef bitcasted_fn_val = LLVMBuildBitCast(g->builder, fn_val, + LLVMGetElementType(LLVMTypeOf(fn_ptr_ptr)), ""); + LLVMBuildStore(g->builder, bitcasted_fn_val, fn_ptr_ptr); if (prefix_arg_err_ret_stack) { zig_panic("TODO"); } - } else { + + LLVMValueRef awaiter_ptr = LLVMBuildStructGEP(g->builder, frame_result_loc, coro_awaiter_index, ""); + LLVMBuildStore(g->builder, awaiter_init_val, awaiter_ptr); + + if (ret_has_bits) { + LLVMValueRef ret_ptr_ptr = LLVMBuildStructGEP(g->builder, frame_result_loc, coro_arg_start, ""); + LLVMBuildStore(g->builder, ret_ptr, ret_ptr_ptr); + } + } + if (!instruction->is_async && !callee_is_async) { if (first_arg_ret) { gen_param_values.append(result_loc); } @@ -3386,14 +3419,28 @@ static LLVMValueRef ir_render_call(CodeGen *g, IrExecutable *executable, IrInstr LLVMCallConv llvm_cc = get_llvm_cc(g, cc); LLVMValueRef result; - if (instruction->is_async) { - size_t ret_1_or_0 = type_has_bits(fn_type->data.fn.fn_type_id.return_type) ? 1 : 0; + if (instruction->is_async || callee_is_async) { + size_t ret_2_or_0 = type_has_bits(fn_type->data.fn.fn_type_id.return_type) ? 2 : 0; for (size_t arg_i = 0; arg_i < gen_param_values.length; arg_i += 1) { - LLVMValueRef arg_ptr = LLVMBuildStructGEP(g->builder, result_loc, - coro_arg_start + ret_1_or_0 + arg_i, ""); + LLVMValueRef arg_ptr = LLVMBuildStructGEP(g->builder, frame_result_loc, + coro_arg_start + ret_2_or_0 + arg_i, ""); LLVMBuildStore(g->builder, gen_param_values.at(arg_i), arg_ptr); } - ZigLLVMBuildCall(g->builder, fn_val, &result_loc, 1, llvm_cc, fn_inline, ""); + } + if (instruction->is_async) { + ZigLLVMBuildCall(g->builder, fn_val, &frame_result_loc, 1, llvm_cc, fn_inline, ""); + return nullptr; + } else if (callee_is_async) { + LLVMValueRef resume_index_ptr = LLVMBuildStructGEP(g->builder, g->cur_ret_ptr, coro_resume_index_index, ""); + LLVMValueRef new_resume_index = LLVMConstInt(g->builtin_types.entry_usize->llvm_type, + instruction->resume_block->resume_index, false); + LLVMBuildStore(g->builder, new_resume_index, resume_index_ptr); + + LLVMValueRef call_inst = ZigLLVMBuildCall(g->builder, fn_val, &frame_result_loc, 1, llvm_cc, fn_inline, ""); + ZigLLVMSetTailCall(call_inst); + LLVMBuildRet(g->builder, call_inst); + + LLVMPositionBuilderAtEnd(g->builder, instruction->resume_block->llvm_block); return nullptr; } @@ -6174,7 +6221,7 @@ static void do_code_gen(CodeGen *g) { clear_debug_source_node(g); bool is_async = fn_is_async(fn_table_entry); - size_t async_var_index = coro_arg_start + (type_has_bits(fn_type_id->return_type) ? 1 : 0); + size_t async_var_index = coro_arg_start + (type_has_bits(fn_type_id->return_type) ? 2 : 0); if (want_sret || is_async) { g->cur_ret_ptr = LLVMGetParam(fn, 0); @@ -6385,8 +6432,9 @@ static void do_code_gen(CodeGen *g) { LLVMAddCase(switch_instr, one, get_size_block); for (size_t resume_i = 0; resume_i < fn_table_entry->resume_blocks.length; resume_i += 1) { - LLVMValueRef case_value = LLVMConstInt(usize_type_ref, resume_i + 2, false); - LLVMAddCase(switch_instr, case_value, fn_table_entry->resume_blocks.at(resume_i)->llvm_block); + IrBasicBlock *resume_block = fn_table_entry->resume_blocks.at(resume_i); + LLVMValueRef case_value = LLVMConstInt(usize_type_ref, resume_block->resume_index, false); + LLVMAddCase(switch_instr, case_value, resume_block->llvm_block); } } else { // create debug variable declarations for parameters diff --git a/src/ir.cpp b/src/ir.cpp index 0cc68eaa5..cb4a90c31 100644 --- a/src/ir.cpp +++ b/src/ir.cpp @@ -1385,7 +1385,7 @@ static IrInstruction *ir_build_call_src(IrBuilder *irb, Scope *scope, AstNode *s return &call_instruction->base; } -static IrInstruction *ir_build_call_gen(IrAnalyze *ira, IrInstruction *source_instruction, +static IrInstructionCallGen *ir_build_call_gen(IrAnalyze *ira, IrInstruction *source_instruction, ZigFn *fn_entry, IrInstruction *fn_ref, size_t arg_count, IrInstruction **args, FnInline fn_inline, bool is_async, IrInstruction *new_stack, IrInstruction *result_loc, ZigType *return_type) @@ -1408,7 +1408,7 @@ static IrInstruction *ir_build_call_gen(IrAnalyze *ira, IrInstruction *source_in if (new_stack != nullptr) ir_ref_instruction(new_stack, ira->new_irb.current_basic_block); if (result_loc != nullptr) ir_ref_instruction(result_loc, ira->new_irb.current_basic_block); - return &call_instruction->base; + return call_instruction; } static IrInstruction *ir_build_phi(IrBuilder *irb, Scope *scope, AstNode *source_node, @@ -14650,8 +14650,8 @@ static IrInstruction *ir_analyze_async_call(IrAnalyze *ira, IrInstructionCallSrc if (result_loc != nullptr && (type_is_invalid(result_loc->value.type) || instr_is_unreachable(result_loc))) { return result_loc; } - return ir_build_call_gen(ira, &call_instruction->base, fn_entry, fn_ref, arg_count, - casted_args, FnInlineAuto, true, nullptr, result_loc, frame_type); + return &ir_build_call_gen(ira, &call_instruction->base, fn_entry, fn_ref, arg_count, + casted_args, FnInlineAuto, true, nullptr, result_loc, frame_type)->base; } static bool ir_analyze_fn_call_inline_arg(IrAnalyze *ira, AstNode *fn_proto_node, IrInstruction *arg, Scope **exec_scope, size_t *next_proto_i) @@ -15387,15 +15387,16 @@ static IrInstruction *ir_analyze_fn_call(IrAnalyze *ira, IrInstructionCallSrc *c if (impl_fn_type_id->cc == CallingConventionAsync && parent_fn_entry->inferred_async_node == nullptr) { parent_fn_entry->inferred_async_node = fn_ref->source_node; } - parent_fn_entry->call_list.append({call_instruction->base.source_node, impl_fn}); } - IrInstruction *new_call_instruction = ir_build_call_gen(ira, &call_instruction->base, + IrInstructionCallGen *new_call_instruction = ir_build_call_gen(ira, &call_instruction->base, impl_fn, nullptr, impl_param_count, casted_args, fn_inline, call_instruction->is_async, casted_new_stack, result_loc, impl_fn_type_id->return_type); - return ir_finish_anal(ira, new_call_instruction); + parent_fn_entry->call_list.append(new_call_instruction); + + return ir_finish_anal(ira, &new_call_instruction->base); } ZigFn *parent_fn_entry = exec_fn_entry(ira->new_irb.exec); @@ -15469,9 +15470,6 @@ static IrInstruction *ir_analyze_fn_call(IrAnalyze *ira, IrInstructionCallSrc *c if (fn_type_id->cc == CallingConventionAsync && parent_fn_entry->inferred_async_node == nullptr) { parent_fn_entry->inferred_async_node = fn_ref->source_node; } - if (fn_entry != nullptr) { - parent_fn_entry->call_list.append({call_instruction->base.source_node, fn_entry}); - } } if (call_instruction->is_async) { @@ -15491,10 +15489,11 @@ static IrInstruction *ir_analyze_fn_call(IrAnalyze *ira, IrInstructionCallSrc *c result_loc = nullptr; } - IrInstruction *new_call_instruction = ir_build_call_gen(ira, &call_instruction->base, fn_entry, fn_ref, + IrInstructionCallGen *new_call_instruction = ir_build_call_gen(ira, &call_instruction->base, fn_entry, fn_ref, call_param_count, casted_args, fn_inline, false, casted_new_stack, result_loc, return_type); - return ir_finish_anal(ira, new_call_instruction); + parent_fn_entry->call_list.append(new_call_instruction); + return ir_finish_anal(ira, &new_call_instruction->base); } static IrInstruction *ir_analyze_instruction_call(IrAnalyze *ira, IrInstructionCallSrc *call_instruction) { @@ -24154,8 +24153,7 @@ static IrInstruction *ir_analyze_instruction_suspend_br(IrAnalyze *ira, IrInstru ZigFn *fn_entry = exec_fn_entry(ira->new_irb.exec); ir_assert(fn_entry != nullptr, &instruction->base); - // +2 - one for the GetSize block, one for the Entry block, resume blocks are indexed after that. - new_bb->resume_index = fn_entry->resume_blocks.length + 2; + new_bb->resume_index = fn_entry->resume_blocks.length + coro_extra_resume_block_count; fn_entry->resume_blocks.append(new_bb); if (fn_entry->inferred_async_node == nullptr) { diff --git a/src/zig_llvm.cpp b/src/zig_llvm.cpp index c51c9e1a5..b52edabe6 100644 --- a/src/zig_llvm.cpp +++ b/src/zig_llvm.cpp @@ -898,6 +898,10 @@ LLVMValueRef ZigLLVMBuildAShrExact(LLVMBuilderRef builder, LLVMValueRef LHS, LLV return wrap(unwrap(builder)->CreateAShr(unwrap(LHS), unwrap(RHS), name, true)); } +void ZigLLVMSetTailCall(LLVMValueRef Call) { + unwrap(Call)->setTailCallKind(CallInst::TCK_MustTail); +} + class MyOStream: public raw_ostream { public: diff --git a/src/zig_llvm.h b/src/zig_llvm.h index 8b7b0775f..2a2ab567a 100644 --- a/src/zig_llvm.h +++ b/src/zig_llvm.h @@ -211,6 +211,7 @@ ZIG_EXTERN_C LLVMValueRef ZigLLVMInsertDeclare(struct ZigLLVMDIBuilder *dibuilde ZIG_EXTERN_C struct ZigLLVMDILocation *ZigLLVMGetDebugLoc(unsigned line, unsigned col, struct ZigLLVMDIScope *scope); ZIG_EXTERN_C void ZigLLVMSetFastMath(LLVMBuilderRef builder_wrapped, bool on_state); +ZIG_EXTERN_C void ZigLLVMSetTailCall(LLVMValueRef Call); ZIG_EXTERN_C void ZigLLVMAddFunctionAttr(LLVMValueRef fn, const char *attr_name, const char *attr_value); ZIG_EXTERN_C void ZigLLVMAddFunctionAttrCold(LLVMValueRef fn); diff --git a/test/stage1/behavior/coroutines.zig b/test/stage1/behavior/coroutines.zig index 4f1cc8406..7188e7af8 100644 --- a/test/stage1/behavior/coroutines.zig +++ b/test/stage1/behavior/coroutines.zig @@ -77,6 +77,22 @@ test "local variable in async function" { S.doTheTest(); } +test "calling an inferred async function" { + const S = struct { + fn doTheTest() void { + const p = async first(); + } + + fn first() void { + other(); + } + fn other() void { + suspend; + } + }; + S.doTheTest(); +} + //test "coroutine suspend, resume" { // seq('a'); // const p = try async testAsyncSeq();