diff --git a/src/all_types.hpp b/src/all_types.hpp index bc6594857..1dad546a7 100644 --- a/src/all_types.hpp +++ b/src/all_types.hpp @@ -1718,7 +1718,7 @@ struct CodeGen { LLVMTargetMachineRef target_machine; ZigLLVMDIFile *dummy_di_file; LLVMValueRef cur_ret_ptr; - LLVMValueRef cur_ret_ptr_ptr; + LLVMValueRef cur_frame_ptr; LLVMValueRef cur_fn_val; LLVMValueRef cur_async_switch_instr; LLVMValueRef cur_async_resume_index_ptr; diff --git a/src/analyze.cpp b/src/analyze.cpp index 36eeaeac9..764b28ed4 100644 --- a/src/analyze.cpp +++ b/src/analyze.cpp @@ -5160,6 +5160,8 @@ static ZigType *get_async_fn_type(CodeGen *g, ZigType *orig_fn_type) { } static Error resolve_coro_frame(CodeGen *g, ZigType *frame_type) { + Error err; + if (frame_type->data.frame.locals_struct != nullptr) return ErrorNone; @@ -5286,6 +5288,9 @@ static Error resolve_coro_frame(CodeGen *g, ZigType *frame_type) { continue; } } + if ((err = type_resolve(g, child_type, ResolveStatusSizeKnown))) { + return err; + } const char *name; if (*instruction->name_hint == 0) { name = buf_ptr(buf_sprintf("@local%" ZIG_PRI_usize, alloca_i)); diff --git a/src/codegen.cpp b/src/codegen.cpp index cf846d99e..d1b5ebeda 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -2088,17 +2088,19 @@ static LLVMValueRef gen_resume(CodeGen *g, LLVMValueRef fn_val, LLVMValueRef tar static LLVMValueRef ir_render_return_begin(CodeGen *g, IrExecutable *executable, IrInstructionReturnBegin *instruction) { - if (!fn_is_async(g->cur_fn)) return nullptr; + bool ret_type_has_bits = instruction->operand != nullptr && + type_has_bits(instruction->operand->value.type); + + if (!fn_is_async(g->cur_fn)) { + return ret_type_has_bits ? ir_llvm_value(g, instruction->operand) : nullptr; + } LLVMTypeRef usize_type_ref = g->builtin_types.entry_usize->llvm_type; - bool ret_type_has_bits = instruction->operand != nullptr && - type_has_bits(instruction->operand->value.type); ZigType *ret_type = ret_type_has_bits ? instruction->operand->value.type : nullptr; if (ret_type_has_bits && !handle_is_ptr(ret_type)) { // It's a scalar, so it didn't get written to the result ptr. Do that before the atomic rmw. - LLVMValueRef result_ptr = LLVMBuildLoad(g->builder, g->cur_ret_ptr_ptr, ""); - LLVMBuildStore(g->builder, ir_llvm_value(g, instruction->operand), result_ptr); + LLVMBuildStore(g->builder, ir_llvm_value(g, instruction->operand), g->cur_ret_ptr); } // Prepare to be suspended. We might end up not having to suspend though. @@ -2147,7 +2149,11 @@ static LLVMValueRef ir_render_return_begin(CodeGen *g, IrExecutable *executable, LLVMBasicBlockRef incoming_blocks[] = { after_resume_block, switch_bb }; LLVMAddIncoming(g->cur_async_prev_val, incoming_values, incoming_blocks, 2); - return nullptr; + if (!ret_type_has_bits) { + return nullptr; + } + + return get_handle_value(g, g->cur_ret_ptr, ret_type, get_pointer_to_type(g, ret_type, true)); } static LLVMValueRef ir_render_return(CodeGen *g, IrExecutable *executable, IrInstructionReturn *instruction) { @@ -2166,17 +2172,16 @@ static LLVMValueRef ir_render_return(CodeGen *g, IrExecutable *executable, IrIns // If the awaiter result pointer is non-null, we need to copy the result to there. LLVMBasicBlockRef copy_block = LLVMAppendBasicBlock(g->cur_fn_val, "CopyResult"); LLVMBasicBlockRef copy_end_block = LLVMAppendBasicBlock(g->cur_fn_val, "CopyResultEnd"); - LLVMValueRef awaiter_ret_ptr_ptr = LLVMBuildStructGEP(g->builder, g->cur_ret_ptr, coro_ret_start + 1, ""); + LLVMValueRef awaiter_ret_ptr_ptr = LLVMBuildStructGEP(g->builder, g->cur_frame_ptr, coro_ret_start + 1, ""); LLVMValueRef awaiter_ret_ptr = LLVMBuildLoad(g->builder, awaiter_ret_ptr_ptr, ""); LLVMValueRef zero_ptr = LLVMConstNull(LLVMTypeOf(awaiter_ret_ptr)); LLVMValueRef need_copy_bit = LLVMBuildICmp(g->builder, LLVMIntNE, awaiter_ret_ptr, zero_ptr, ""); LLVMBuildCondBr(g->builder, need_copy_bit, copy_block, copy_end_block); LLVMPositionBuilderAtEnd(g->builder, copy_block); - LLVMValueRef ret_ptr = LLVMBuildLoad(g->builder, g->cur_ret_ptr_ptr, ""); LLVMTypeRef ptr_u8 = LLVMPointerType(LLVMInt8Type(), 0); LLVMValueRef dest_ptr_casted = LLVMBuildBitCast(g->builder, awaiter_ret_ptr, ptr_u8, ""); - LLVMValueRef src_ptr_casted = LLVMBuildBitCast(g->builder, ret_ptr, ptr_u8, ""); + LLVMValueRef src_ptr_casted = LLVMBuildBitCast(g->builder, g->cur_ret_ptr, ptr_u8, ""); bool is_volatile = false; uint32_t abi_align = get_abi_alignment(g, ret_type); LLVMValueRef byte_count_val = LLVMConstInt(usize_type_ref, type_size(g, ret_type), false); @@ -3385,10 +3390,6 @@ static LLVMValueRef ir_render_return_ptr(CodeGen *g, IrExecutable *executable, if (!type_has_bits(instruction->base.value.type)) return nullptr; src_assert(g->cur_ret_ptr != nullptr, instruction->base.source_node); - if (fn_is_async(g->cur_fn)) { - LLVMValueRef ptr_ptr = LLVMBuildStructGEP(g->builder, g->cur_ret_ptr, coro_ret_start, ""); - return LLVMBuildLoad(g->builder, ptr_ptr, ""); - } return g->cur_ret_ptr; } @@ -3547,7 +3548,7 @@ static void render_async_spills(CodeGen *g) { continue; } - var->value_ref = LLVMBuildStructGEP(g->builder, g->cur_ret_ptr, async_var_index, + var->value_ref = LLVMBuildStructGEP(g->builder, g->cur_frame_ptr, async_var_index, buf_ptr(&var->name)); async_var_index += 1; if (var->decl_node) { @@ -3578,7 +3579,7 @@ static void render_async_spills(CodeGen *g) { continue; } } - instruction->base.llvm_value = LLVMBuildStructGEP(g->builder, g->cur_ret_ptr, async_var_index, + instruction->base.llvm_value = LLVMBuildStructGEP(g->builder, g->cur_frame_ptr, async_var_index, instruction->name_hint); async_var_index += 1; } @@ -3697,7 +3698,7 @@ static LLVMValueRef ir_render_call(CodeGen *g, IrExecutable *executable, IrInstr // initialization. } else if (callee_is_async) { frame_result_loc = ir_llvm_value(g, instruction->frame_result_loc); - awaiter_init_val = LLVMBuildPtrToInt(g->builder, g->cur_ret_ptr, usize_type_ref, ""); // caller's own frame pointer + awaiter_init_val = LLVMBuildPtrToInt(g->builder, g->cur_frame_ptr, usize_type_ref, ""); // caller's own frame pointer if (ret_has_bits) { if (result_loc == nullptr) { // return type is a scalar, but we still need a pointer to it. Use the async fn frame. @@ -4850,7 +4851,7 @@ static LLVMValueRef ir_render_frame_address(CodeGen *g, IrExecutable *executable } static LLVMValueRef ir_render_handle(CodeGen *g, IrExecutable *executable, IrInstructionFrameHandle *instruction) { - return g->cur_ret_ptr; + return g->cur_frame_ptr; } static LLVMValueRef render_shl_with_overflow(CodeGen *g, IrInstructionOverflowOp *instruction) { @@ -5335,7 +5336,7 @@ static LLVMValueRef ir_render_await(CodeGen *g, IrExecutable *executable, IrInst } // caller's own frame pointer - LLVMValueRef awaiter_init_val = LLVMBuildPtrToInt(g->builder, g->cur_ret_ptr, usize_type_ref, ""); + LLVMValueRef awaiter_init_val = LLVMBuildPtrToInt(g->builder, g->cur_frame_ptr, usize_type_ref, ""); LLVMValueRef awaiter_ptr = LLVMBuildStructGEP(g->builder, target_frame_ptr, coro_awaiter_index, ""); LLVMValueRef prev_val = LLVMBuildAtomicRMW(g->builder, LLVMAtomicRMWBinOpXchg, awaiter_ptr, awaiter_init_val, LLVMAtomicOrderingRelease, g->is_single_threaded); @@ -6710,13 +6711,17 @@ static void do_code_gen(CodeGen *g) { bool is_async = fn_is_async(fn_table_entry); - if (want_sret || is_async) { - g->cur_ret_ptr = LLVMGetParam(fn, 0); - } else if (handle_is_ptr(fn_type_id->return_type)) { - g->cur_ret_ptr = build_alloca(g, fn_type_id->return_type, "result", 0); - // TODO add debug info variable for this + if (is_async) { + g->cur_frame_ptr = LLVMGetParam(fn, 0); } else { - g->cur_ret_ptr = nullptr; + if (want_sret) { + g->cur_ret_ptr = LLVMGetParam(fn, 0); + } else if (handle_is_ptr(fn_type_id->return_type)) { + g->cur_ret_ptr = build_alloca(g, fn_type_id->return_type, "result", 0); + // TODO add debug info variable for this + } else { + g->cur_ret_ptr = nullptr; + } } uint32_t err_ret_trace_arg_index = get_err_ret_trace_arg_index(g, fn_table_entry); @@ -6870,21 +6875,22 @@ static void do_code_gen(CodeGen *g) { LLVMPositionBuilderAtEnd(g->builder, g->cur_preamble_llvm_block); render_async_spills(g); - g->cur_async_awaiter_ptr = LLVMBuildStructGEP(g->builder, g->cur_ret_ptr, coro_awaiter_index, ""); - LLVMValueRef resume_index_ptr = LLVMBuildStructGEP(g->builder, g->cur_ret_ptr, coro_resume_index, ""); + g->cur_async_awaiter_ptr = LLVMBuildStructGEP(g->builder, g->cur_frame_ptr, coro_awaiter_index, ""); + LLVMValueRef resume_index_ptr = LLVMBuildStructGEP(g->builder, g->cur_frame_ptr, coro_resume_index, ""); g->cur_async_resume_index_ptr = resume_index_ptr; if (type_has_bits(fn_type_id->return_type)) { - g->cur_ret_ptr_ptr = LLVMBuildStructGEP(g->builder, g->cur_ret_ptr, coro_ret_start, ""); + LLVMValueRef cur_ret_ptr_ptr = LLVMBuildStructGEP(g->builder, g->cur_frame_ptr, coro_ret_start, ""); + g->cur_ret_ptr = LLVMBuildLoad(g->builder, cur_ret_ptr_ptr, ""); } if (codegen_fn_has_err_ret_tracing_arg(g, fn_type_id->return_type)) { uint32_t trace_field_index = frame_index_trace_arg(g, fn_type_id->return_type); - g->cur_err_ret_trace_val_arg = LLVMBuildStructGEP(g->builder, g->cur_ret_ptr, trace_field_index, ""); + g->cur_err_ret_trace_val_arg = LLVMBuildStructGEP(g->builder, g->cur_frame_ptr, trace_field_index, ""); } uint32_t trace_field_index_stack = UINT32_MAX; if (codegen_fn_has_err_ret_tracing_stack(g, fn_table_entry, true)) { trace_field_index_stack = frame_index_trace_stack(g, fn_type_id); - g->cur_err_ret_trace_val_stack = LLVMBuildStructGEP(g->builder, g->cur_ret_ptr, + g->cur_err_ret_trace_val_stack = LLVMBuildStructGEP(g->builder, g->cur_frame_ptr, trace_field_index_stack, ""); } @@ -6898,9 +6904,9 @@ static void do_code_gen(CodeGen *g) { g->cur_resume_block_count += 1; LLVMPositionBuilderAtEnd(g->builder, entry_block->llvm_block); if (trace_field_index_stack != UINT32_MAX) { - LLVMValueRef trace_field_ptr = LLVMBuildStructGEP(g->builder, g->cur_ret_ptr, + LLVMValueRef trace_field_ptr = LLVMBuildStructGEP(g->builder, g->cur_frame_ptr, trace_field_index_stack, ""); - LLVMValueRef trace_field_addrs = LLVMBuildStructGEP(g->builder, g->cur_ret_ptr, + LLVMValueRef trace_field_addrs = LLVMBuildStructGEP(g->builder, g->cur_frame_ptr, trace_field_index_stack + 1, ""); LLVMValueRef index_ptr = LLVMBuildStructGEP(g->builder, trace_field_ptr, 0, ""); diff --git a/src/ir.cpp b/src/ir.cpp index 64e5e31a1..7cb868cab 100644 --- a/src/ir.cpp +++ b/src/ir.cpp @@ -1129,8 +1129,6 @@ static IrInstruction *ir_build_return_begin(IrBuilder *irb, Scope *scope, AstNod IrInstruction *operand) { IrInstructionReturnBegin *return_instruction = ir_build_instruction(irb, scope, source_node); - return_instruction->base.value.type = irb->codegen->builtin_types.entry_void; - return_instruction->base.value.special = ConstValSpecialStatic; return_instruction->operand = operand; ir_ref_instruction(operand, irb->current_basic_block); @@ -3480,7 +3478,8 @@ static IrInstruction *ir_gen_return(IrBuilder *irb, Scope *scope, AstNode *node, return_value = ir_build_const_void(irb, scope, node); } - ir_build_return_begin(irb, scope, node, return_value); + ir_mark_gen(ir_build_add_implicit_return_type(irb, scope, node, return_value)); + return_value = ir_build_return_begin(irb, scope, node, return_value); size_t defer_counts[2]; ir_count_defers(irb, scope, outer_scope, defer_counts); @@ -3514,14 +3513,12 @@ static IrInstruction *ir_gen_return(IrBuilder *irb, Scope *scope, AstNode *node, ir_build_br(irb, scope, node, ret_stmt_block, is_comptime); ir_set_cursor_at_end_and_append_block(irb, ret_stmt_block); - ir_mark_gen(ir_build_add_implicit_return_type(irb, scope, node, return_value)); IrInstruction *result = ir_build_return(irb, scope, node, return_value); result_loc_ret->base.source_instruction = result; return result; } else { // generate unconditional defers ir_gen_defers_for_block(irb, scope, outer_scope, false); - ir_mark_gen(ir_build_add_implicit_return_type(irb, scope, node, return_value)); IrInstruction *result = ir_build_return(irb, scope, node, return_value); result_loc_ret->base.source_instruction = result; return result; @@ -3549,7 +3546,8 @@ static IrInstruction *ir_gen_return(IrBuilder *irb, Scope *scope, AstNode *node, ir_set_cursor_at_end_and_append_block(irb, return_block); IrInstruction *err_val_ptr = ir_build_unwrap_err_code(irb, scope, node, err_union_ptr); IrInstruction *err_val = ir_build_load_ptr(irb, scope, node, err_val_ptr); - ir_build_return_begin(irb, scope, node, err_val); + ir_mark_gen(ir_build_add_implicit_return_type(irb, scope, node, err_val)); + err_val = ir_build_return_begin(irb, scope, node, err_val); if (!ir_gen_defers_for_block(irb, scope, outer_scope, true)) { ResultLocReturn *result_loc_ret = allocate(1); result_loc_ret->base.id = ResultLocIdReturn; @@ -3559,7 +3557,6 @@ static IrInstruction *ir_gen_return(IrBuilder *irb, Scope *scope, AstNode *node, if (irb->codegen->have_err_ret_tracing && !should_inline) { ir_build_save_err_ret_addr(irb, scope, node); } - ir_mark_gen(ir_build_add_implicit_return_type(irb, scope, node, err_val)); IrInstruction *ret_inst = ir_build_return(irb, scope, node, err_val); result_loc_ret->base.source_instruction = ret_inst; } @@ -4972,7 +4969,7 @@ static IrInstruction *ir_gen_builtin_fn_call(IrBuilder *irb, Scope *scope, AstNo return ir_lval_wrap(irb, scope, ir_build_frame_address(irb, scope, node), lval, result_loc); case BuiltinFnIdFrameHandle: if (!irb->exec->fn_entry) { - add_node_error(irb->codegen, node, buf_sprintf("@handle() called outside of function definition")); + add_node_error(irb->codegen, node, buf_sprintf("@frame() called outside of function definition")); return irb->codegen->invalid_instruction; } return ir_lval_wrap(irb, scope, ir_build_handle(irb, scope, node), lval, result_loc); @@ -8101,9 +8098,9 @@ bool ir_gen(CodeGen *codegen, AstNode *node, Scope *scope, IrExecutable *ir_exec return false; if (!instr_is_unreachable(result)) { - ir_mark_gen(ir_build_return_begin(irb, scope, node, result)); - // no need for save_err_ret_addr because this cannot return error ir_mark_gen(ir_build_add_implicit_return_type(irb, scope, result->source_node, result)); + result = ir_mark_gen(ir_build_return_begin(irb, scope, node, result)); + // no need for save_err_ret_addr because this cannot return error ir_mark_gen(ir_build_return(irb, scope, result->source_node, result)); } @@ -9789,6 +9786,8 @@ static ZigType *ir_resolve_peer_types(IrAnalyze *ira, AstNode *source_node, ZigT ZigType *prev_err_set_type = (err_set_type == nullptr) ? prev_type->data.error_union.err_set_type : err_set_type; ZigType *cur_err_set_type = cur_type->data.error_union.err_set_type; + if (prev_err_set_type == cur_err_set_type) + continue; if (!resolve_inferred_error_set(ira->codegen, prev_err_set_type, cur_inst->source_node)) { return ira->codegen->builtin_types.entry_invalid; @@ -12614,6 +12613,14 @@ static IrInstruction *ir_analyze_instruction_return_begin(IrAnalyze *ira, IrInst if (type_is_invalid(operand->value.type)) return ira->codegen->invalid_instruction; + if (!instr_is_comptime(operand) && handle_is_ptr(ira->explicit_return_type)) { + // result location mechanism took care of it. + IrInstruction *result = ir_build_return_begin(&ira->new_irb, instruction->base.scope, + instruction->base.source_node, operand); + copy_const_val(&result->value, &operand->value, true); + return result; + } + IrInstruction *casted_operand = ir_implicit_cast(ira, operand, ira->explicit_return_type); if (type_is_invalid(casted_operand->value.type)) { AstNode *source_node = ira->explicit_return_type_source_node; @@ -12625,8 +12632,18 @@ static IrInstruction *ir_analyze_instruction_return_begin(IrAnalyze *ira, IrInst return ir_unreach_error(ira); } - return ir_build_return_begin(&ira->new_irb, instruction->base.scope, instruction->base.source_node, - casted_operand); + if (casted_operand->value.special == ConstValSpecialRuntime && + casted_operand->value.type->id == ZigTypeIdPointer && + casted_operand->value.data.rh_ptr == RuntimeHintPtrStack) + { + ir_add_error(ira, casted_operand, buf_sprintf("function returns address of local variable")); + return ir_unreach_error(ira); + } + + IrInstruction *result = ir_build_return_begin(&ira->new_irb, instruction->base.scope, + instruction->base.source_node, casted_operand); + copy_const_val(&result->value, &casted_operand->value, true); + return result; } static IrInstruction *ir_analyze_instruction_return(IrAnalyze *ira, IrInstructionReturn *instruction) { @@ -12642,21 +12659,8 @@ static IrInstruction *ir_analyze_instruction_return(IrAnalyze *ira, IrInstructio return ir_finish_anal(ira, result); } - IrInstruction *casted_operand = ir_implicit_cast(ira, operand, ira->explicit_return_type); - if (type_is_invalid(casted_operand->value.type)) { - // error already reported by IrInstructionReturnBegin - return ir_unreach_error(ira); - } - - if (casted_operand->value.special == ConstValSpecialRuntime && - casted_operand->value.type->id == ZigTypeIdPointer && - casted_operand->value.data.rh_ptr == RuntimeHintPtrStack) - { - ir_add_error(ira, casted_operand, buf_sprintf("function returns address of local variable")); - return ir_unreach_error(ira); - } IrInstruction *result = ir_build_return(&ira->new_irb, instruction->base.scope, - instruction->base.source_node, casted_operand); + instruction->base.source_node, operand); result->value.type = ira->codegen->builtin_types.entry_unreachable; return ir_finish_anal(ira, result); } @@ -14612,8 +14616,12 @@ static IrInstruction *ir_resolve_result_raw(IrAnalyze *ira, IrInstruction *suspe if ((err = type_resolve(ira->codegen, ira->explicit_return_type, ResolveStatusZeroBitsKnown))) { return ira->codegen->invalid_instruction; } - if (!type_has_bits(ira->explicit_return_type) || !handle_is_ptr(ira->explicit_return_type)) - return nullptr; + if (!type_has_bits(ira->explicit_return_type) || !handle_is_ptr(ira->explicit_return_type)) { + ZigFn *fn_entry = exec_fn_entry(ira->new_irb.exec); + if (fn_entry == nullptr || fn_entry->inferred_async_node == nullptr) { + return nullptr; + } + } ZigType *ptr_return_type = get_pointer_to_type(ira->codegen, ira->explicit_return_type, false); result_loc->written = true; @@ -24510,7 +24518,7 @@ static IrInstruction *ir_analyze_instruction_await(IrAnalyze *ira, IrInstruction IrInstruction *result_loc; if (type_has_bits(result_type)) { result_loc = ir_resolve_result(ira, &instruction->base, instruction->result_loc, - result_type, nullptr, true, false, true); + result_type, nullptr, true, true, true); if (result_loc != nullptr && (type_is_invalid(result_loc->value.type) || instr_is_unreachable(result_loc))) return result_loc; } else { diff --git a/test/stage1/behavior/coroutines.zig b/test/stage1/behavior/coroutines.zig index 4cea8d150..7a8edd793 100644 --- a/test/stage1/behavior/coroutines.zig +++ b/test/stage1/behavior/coroutines.zig @@ -334,40 +334,40 @@ test "async fn with inferred error set" { S.doTheTest(); } -//test "error return trace across suspend points - early return" { -// const p = nonFailing(); -// resume p; -// const p2 = async printTrace(p); -//} -// -//test "error return trace across suspend points - async return" { -// const p = nonFailing(); -// const p2 = async printTrace(p); -// resume p; -//} -// -//fn nonFailing() (anyframe->anyerror!void) { -// const Static = struct { -// var frame: @Frame(suspendThenFail) = undefined; -// }; -// Static.frame = async suspendThenFail(); -// return &Static.frame; -//} -//async fn suspendThenFail() anyerror!void { -// suspend; -// return error.Fail; -//} -//async fn printTrace(p: anyframe->(anyerror!void)) void { -// (await p) catch |e| { -// std.testing.expect(e == error.Fail); -// if (@errorReturnTrace()) |trace| { -// expect(trace.index == 1); -// } else switch (builtin.mode) { -// .Debug, .ReleaseSafe => @panic("expected return trace"), -// .ReleaseFast, .ReleaseSmall => {}, -// } -// }; -//} +test "error return trace across suspend points - early return" { + const p = nonFailing(); + resume p; + const p2 = async printTrace(p); +} + +test "error return trace across suspend points - async return" { + const p = nonFailing(); + const p2 = async printTrace(p); + resume p; +} + +fn nonFailing() (anyframe->anyerror!void) { + const Static = struct { + var frame: @Frame(suspendThenFail) = undefined; + }; + Static.frame = async suspendThenFail(); + return &Static.frame; +} +async fn suspendThenFail() anyerror!void { + suspend; + return error.Fail; +} +async fn printTrace(p: anyframe->(anyerror!void)) void { + (await p) catch |e| { + std.testing.expect(e == error.Fail); + if (@errorReturnTrace()) |trace| { + expect(trace.index == 1); + } else switch (builtin.mode) { + .Debug, .ReleaseSafe => @panic("expected return trace"), + .ReleaseFast, .ReleaseSmall => {}, + } + }; +} test "break from suspend" { var my_result: i32 = 1;