From 22428a75462e01877181501801dce4c090a87e9c Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Sat, 10 Aug 2019 15:20:08 -0400 Subject: [PATCH] fix try in an async function with error union and non-zero-bit payload --- src/all_types.hpp | 13 +++- src/analyze.cpp | 12 ++++ src/codegen.cpp | 75 +++++++++++++--------- src/ir.cpp | 97 +++++++++++++++++++++++------ src/ir_print.cpp | 17 +++-- test/stage1/behavior/coroutines.zig | 30 +++++++++ 6 files changed, 187 insertions(+), 57 deletions(-) diff --git a/src/all_types.hpp b/src/all_types.hpp index 45182f3db..8b4d1e6d7 100644 --- a/src/all_types.hpp +++ b/src/all_types.hpp @@ -74,6 +74,7 @@ struct IrExecutable { bool invalid; bool is_inline; bool is_generic_instantiation; + bool need_err_code_spill; }; enum OutType { @@ -1384,6 +1385,7 @@ struct ZigFn { size_t prealloc_backward_branch_quota; AstNode **param_source_nodes; Buf **param_names; + IrInstruction *err_code_spill; AstNode *fn_no_inline_set_node; AstNode *fn_static_eval_set_node; @@ -2366,7 +2368,8 @@ enum IrInstructionId { IrInstructionIdAwaitGen, IrInstructionIdCoroResume, IrInstructionIdTestCancelRequested, - IrInstructionIdSpill, + IrInstructionIdSpillBegin, + IrInstructionIdSpillEnd, }; struct IrInstruction { @@ -3649,13 +3652,19 @@ enum SpillId { SpillIdRetErrCode, }; -struct IrInstructionSpill { +struct IrInstructionSpillBegin { IrInstruction base; SpillId spill_id; IrInstruction *operand; }; +struct IrInstructionSpillEnd { + IrInstruction base; + + IrInstructionSpillBegin *begin; +}; + enum ResultLocId { ResultLocIdInvalid, ResultLocIdNone, diff --git a/src/analyze.cpp b/src/analyze.cpp index a09ba582c..7482ba92b 100644 --- a/src/analyze.cpp +++ b/src/analyze.cpp @@ -5190,6 +5190,18 @@ static Error resolve_coro_frame(CodeGen *g, ZigType *frame_type) { } ZigType *fn_type = get_async_fn_type(g, fn->type_entry); + if (fn->analyzed_executable.need_err_code_spill) { + IrInstructionAllocaGen *alloca_gen = allocate(1); + alloca_gen->base.id = IrInstructionIdAllocaGen; + alloca_gen->base.source_node = fn->proto_node; + alloca_gen->base.scope = fn->child_scope; + alloca_gen->base.value.type = get_pointer_to_type(g, g->builtin_types.entry_global_error_set, false); + alloca_gen->base.ref_count = 1; + alloca_gen->name_hint = ""; + fn->alloca_gen_list.append(alloca_gen); + fn->err_code_spill = &alloca_gen->base; + } + for (size_t i = 0; i < fn->call_list.length; i += 1) { IrInstructionCallGen *call = fn->call_list.at(i); ZigFn *callee = call->fn_entry; diff --git a/src/codegen.cpp b/src/codegen.cpp index 976ee4181..2f07fcd71 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -2274,16 +2274,16 @@ static LLVMValueRef gen_maybe_atomic_op(CodeGen *g, LLVMAtomicRMWBinOp op, LLVMV static LLVMValueRef ir_render_return_begin(CodeGen *g, IrExecutable *executable, IrInstructionReturnBegin *instruction) { - bool ret_type_has_bits = instruction->operand != nullptr && - type_has_bits(instruction->operand->value.type); - + ZigType *operand_type = (instruction->operand != nullptr) ? instruction->operand->value.type : nullptr; + bool operand_has_bits = (operand_type != nullptr) && type_has_bits(operand_type); if (!fn_is_async(g->cur_fn)) { - return ret_type_has_bits ? ir_llvm_value(g, instruction->operand) : nullptr; + return operand_has_bits ? ir_llvm_value(g, instruction->operand) : nullptr; } + ZigType *ret_type = g->cur_fn->type_entry->data.fn.fn_type_id.return_type; + bool ret_type_has_bits = type_has_bits(ret_type); LLVMTypeRef usize_type_ref = g->builtin_types.entry_usize->llvm_type; - ZigType *ret_type = ret_type_has_bits ? instruction->operand->value.type : nullptr; if (ret_type_has_bits && !handle_is_ptr(ret_type)) { // It's a scalar, so it didn't get written to the result ptr. Do that before the atomic rmw. LLVMBuildStore(g->builder, ir_llvm_value(g, instruction->operand), g->cur_ret_ptr); @@ -2333,11 +2333,11 @@ static LLVMValueRef ir_render_return_begin(CodeGen *g, IrExecutable *executable, g->cur_is_after_return = true; LLVMBuildStore(g->builder, g->cur_async_prev_val, g->cur_async_prev_val_field_ptr); - if (!ret_type_has_bits) { + if (!operand_has_bits) { return nullptr; } - return get_handle_value(g, g->cur_ret_ptr, ret_type, get_pointer_to_type(g, ret_type, true)); + return get_handle_value(g, g->cur_ret_ptr, operand_type, get_pointer_to_type(g, operand_type, true)); } static LLVMValueRef ir_render_return(CodeGen *g, IrExecutable *executable, IrInstructionReturn *instruction) { @@ -5113,18 +5113,6 @@ static LLVMValueRef ir_render_test_err(CodeGen *g, IrExecutable *executable, IrI return LLVMBuildICmp(g->builder, LLVMIntNE, err_val, zero, ""); } -static LLVMValueRef gen_unwrap_err_code(CodeGen *g, LLVMValueRef err_union_ptr, ZigType *ptr_type) { - ZigType *err_union_type = ptr_type->data.pointer.child_type; - ZigType *payload_type = err_union_type->data.error_union.payload_type; - if (!type_has_bits(payload_type)) { - return err_union_ptr; - } else { - // TODO assign undef to the payload - LLVMValueRef err_union_handle = get_handle_value(g, err_union_ptr, err_union_type, ptr_type); - return LLVMBuildStructGEP(g->builder, err_union_handle, err_union_err_index, ""); - } -} - static LLVMValueRef ir_render_unwrap_err_code(CodeGen *g, IrExecutable *executable, IrInstructionUnwrapErrCode *instruction) { @@ -5133,8 +5121,16 @@ static LLVMValueRef ir_render_unwrap_err_code(CodeGen *g, IrExecutable *executab ZigType *ptr_type = instruction->err_union_ptr->value.type; assert(ptr_type->id == ZigTypeIdPointer); + ZigType *err_union_type = ptr_type->data.pointer.child_type; + ZigType *payload_type = err_union_type->data.error_union.payload_type; LLVMValueRef err_union_ptr = ir_llvm_value(g, instruction->err_union_ptr); - return gen_unwrap_err_code(g, err_union_ptr, ptr_type); + if (!type_has_bits(payload_type)) { + return err_union_ptr; + } else { + // TODO assign undef to the payload + LLVMValueRef err_union_handle = get_handle_value(g, err_union_ptr, err_union_type, ptr_type); + return LLVMBuildStructGEP(g->builder, err_union_handle, err_union_err_index, ""); + } } static LLVMValueRef ir_render_unwrap_err_payload(CodeGen *g, IrExecutable *executable, @@ -5615,21 +5611,36 @@ static LLVMValueRef ir_render_test_cancel_requested(CodeGen *g, IrExecutable *ex } } -static LLVMValueRef ir_render_spill(CodeGen *g, IrExecutable *executable, IrInstructionSpill *instruction) { +static LLVMValueRef ir_render_spill_begin(CodeGen *g, IrExecutable *executable, + IrInstructionSpillBegin *instruction) +{ if (!fn_is_async(g->cur_fn)) - return ir_llvm_value(g, instruction->operand); + return nullptr; switch (instruction->spill_id) { case SpillIdInvalid: zig_unreachable(); case SpillIdRetErrCode: { - LLVMValueRef ret_ptr = LLVMBuildLoad(g->builder, g->cur_ret_ptr, ""); - ZigType *ret_type = g->cur_fn->type_entry->data.fn.fn_type_id.return_type; - if (ret_type->id == ZigTypeIdErrorUnion) { - return gen_unwrap_err_code(g, ret_ptr, get_pointer_to_type(g, ret_type, true)); - } else { - zig_unreachable(); - } + LLVMValueRef operand = ir_llvm_value(g, instruction->operand); + LLVMValueRef ptr = ir_llvm_value(g, g->cur_fn->err_code_spill); + LLVMBuildStore(g->builder, operand, ptr); + return nullptr; + } + + } + zig_unreachable(); +} + +static LLVMValueRef ir_render_spill_end(CodeGen *g, IrExecutable *executable, IrInstructionSpillEnd *instruction) { + if (!fn_is_async(g->cur_fn)) + return ir_llvm_value(g, instruction->begin->operand); + + switch (instruction->begin->spill_id) { + case SpillIdInvalid: + zig_unreachable(); + case SpillIdRetErrCode: { + LLVMValueRef ptr = ir_llvm_value(g, g->cur_fn->err_code_spill); + return LLVMBuildLoad(g->builder, ptr, ""); } } @@ -5891,8 +5902,10 @@ static LLVMValueRef ir_render_instruction(CodeGen *g, IrExecutable *executable, return ir_render_await(g, executable, (IrInstructionAwaitGen *)instruction); case IrInstructionIdTestCancelRequested: return ir_render_test_cancel_requested(g, executable, (IrInstructionTestCancelRequested *)instruction); - case IrInstructionIdSpill: - return ir_render_spill(g, executable, (IrInstructionSpill *)instruction); + case IrInstructionIdSpillBegin: + return ir_render_spill_begin(g, executable, (IrInstructionSpillBegin *)instruction); + case IrInstructionIdSpillEnd: + return ir_render_spill_end(g, executable, (IrInstructionSpillEnd *)instruction); } zig_unreachable(); } diff --git a/src/ir.cpp b/src/ir.cpp index 845ee0375..97971efd5 100644 --- a/src/ir.cpp +++ b/src/ir.cpp @@ -1066,8 +1066,12 @@ static constexpr IrInstructionId ir_instruction_id(IrInstructionTestCancelReques return IrInstructionIdTestCancelRequested; } -static constexpr IrInstructionId ir_instruction_id(IrInstructionSpill *) { - return IrInstructionIdSpill; +static constexpr IrInstructionId ir_instruction_id(IrInstructionSpillBegin *) { + return IrInstructionIdSpillBegin; +} + +static constexpr IrInstructionId ir_instruction_id(IrInstructionSpillEnd *) { + return IrInstructionIdSpillEnd; } template @@ -3336,15 +3340,28 @@ static IrInstruction *ir_build_test_cancel_requested(IrBuilder *irb, Scope *scop return &instruction->base; } -static IrInstruction *ir_build_spill(IrBuilder *irb, Scope *scope, AstNode *source_node, +static IrInstructionSpillBegin *ir_build_spill_begin(IrBuilder *irb, Scope *scope, AstNode *source_node, IrInstruction *operand, SpillId spill_id) { - IrInstructionSpill *instruction = ir_build_instruction(irb, scope, source_node); + IrInstructionSpillBegin *instruction = ir_build_instruction(irb, scope, source_node); + instruction->base.value.special = ConstValSpecialStatic; + instruction->base.value.type = irb->codegen->builtin_types.entry_void; instruction->operand = operand; instruction->spill_id = spill_id; ir_ref_instruction(operand, irb->current_basic_block); + return instruction; +} + +static IrInstruction *ir_build_spill_end(IrBuilder *irb, Scope *scope, AstNode *source_node, + IrInstructionSpillBegin *begin) +{ + IrInstructionSpillEnd *instruction = ir_build_instruction(irb, scope, source_node); + instruction->begin = begin; + + ir_ref_instruction(&begin->base, irb->current_basic_block); + return &instruction->base; } @@ -3602,14 +3619,15 @@ static IrInstruction *ir_gen_return(IrBuilder *irb, Scope *scope, AstNode *node, IrInstruction *err_val_ptr = ir_build_unwrap_err_code(irb, scope, node, err_union_ptr); IrInstruction *err_val = ir_build_load_ptr(irb, scope, node, err_val_ptr); ir_mark_gen(ir_build_add_implicit_return_type(irb, scope, node, err_val)); - err_val = ir_build_return_begin(irb, scope, node, err_val); + IrInstructionSpillBegin *spill_begin = ir_build_spill_begin(irb, scope, node, err_val, + SpillIdRetErrCode); + ir_build_return_begin(irb, scope, node, err_val); + err_val = ir_build_spill_end(irb, scope, node, spill_begin); + ResultLocReturn *result_loc_ret = allocate(1); + result_loc_ret->base.id = ResultLocIdReturn; + ir_build_reset_result(irb, scope, node, &result_loc_ret->base); + ir_build_end_expr(irb, scope, node, err_val, &result_loc_ret->base); if (!ir_gen_defers_for_block(irb, scope, outer_scope, true)) { - ResultLocReturn *result_loc_ret = allocate(1); - result_loc_ret->base.id = ResultLocIdReturn; - ir_build_reset_result(irb, scope, node, &result_loc_ret->base); - err_val = ir_build_spill(irb, scope, node, err_val, SpillIdRetErrCode); - ir_build_end_expr(irb, scope, node, err_val, &result_loc_ret->base); - if (irb->codegen->have_err_ret_tracing && !should_inline) { ir_build_save_err_ret_addr(irb, scope, node); } @@ -12778,8 +12796,21 @@ static IrInstruction *ir_analyze_instruction_return(IrAnalyze *ira, IrInstructio return ir_finish_anal(ira, result); } + // This cast might have been already done from IrInstructionReturnBegin but it also + // might not have, in the case of `try`. + IrInstruction *casted_operand = ir_implicit_cast(ira, operand, ira->explicit_return_type); + if (type_is_invalid(casted_operand->value.type)) { + AstNode *source_node = ira->explicit_return_type_source_node; + if (source_node != nullptr) { + ErrorMsg *msg = ira->codegen->errors.last(); + add_error_note(ira->codegen, msg, source_node, + buf_sprintf("return type declared here")); + } + return ir_unreach_error(ira); + } + IrInstruction *result = ir_build_return(&ira->new_irb, instruction->base.scope, - instruction->base.source_node, operand); + instruction->base.source_node, casted_operand); result->value.type = ira->codegen->builtin_types.entry_unreachable; return ir_finish_anal(ira, result); } @@ -24742,15 +24773,38 @@ static IrInstruction *ir_analyze_instruction_test_cancel_requested(IrAnalyze *ir return ir_build_test_cancel_requested(&ira->new_irb, instruction->base.scope, instruction->base.source_node); } -static IrInstruction *ir_analyze_instruction_spill(IrAnalyze *ira, IrInstructionSpill *instruction) { +static IrInstruction *ir_analyze_instruction_spill_begin(IrAnalyze *ira, IrInstructionSpillBegin *instruction) { + if (ir_should_inline(ira->new_irb.exec, instruction->base.scope)) + return ir_const_void(ira, &instruction->base); + IrInstruction *operand = instruction->operand->child; if (type_is_invalid(operand->value.type)) return ira->codegen->invalid_instruction; - if (ir_should_inline(ira->new_irb.exec, instruction->base.scope)) { + + if (!type_has_bits(operand->value.type)) + return ir_const_void(ira, &instruction->base); + + ir_assert(instruction->spill_id == SpillIdRetErrCode, &instruction->base); + ira->new_irb.exec->need_err_code_spill = true; + + IrInstructionSpillBegin *result = ir_build_spill_begin(&ira->new_irb, instruction->base.scope, + instruction->base.source_node, operand, instruction->spill_id); + return &result->base; +} + +static IrInstruction *ir_analyze_instruction_spill_end(IrAnalyze *ira, IrInstructionSpillEnd *instruction) { + IrInstruction *operand = instruction->begin->operand->child; + if (type_is_invalid(operand->value.type)) + return ira->codegen->invalid_instruction; + + if (ir_should_inline(ira->new_irb.exec, instruction->base.scope) || !type_has_bits(operand->value.type)) return operand; - } - IrInstruction *result = ir_build_spill(&ira->new_irb, instruction->base.scope, instruction->base.source_node, - operand, instruction->spill_id); + + ir_assert(instruction->begin->base.child->id == IrInstructionIdSpillBegin, &instruction->base); + IrInstructionSpillBegin *begin = reinterpret_cast(instruction->begin->base.child); + + IrInstruction *result = ir_build_spill_end(&ira->new_irb, instruction->base.scope, + instruction->base.source_node, begin); result->value.type = operand->value.type; return result; } @@ -25054,8 +25108,10 @@ static IrInstruction *ir_analyze_instruction_base(IrAnalyze *ira, IrInstruction return ir_analyze_instruction_await(ira, (IrInstructionAwaitSrc *)instruction); case IrInstructionIdTestCancelRequested: return ir_analyze_instruction_test_cancel_requested(ira, (IrInstructionTestCancelRequested *)instruction); - case IrInstructionIdSpill: - return ir_analyze_instruction_spill(ira, (IrInstructionSpill *)instruction); + case IrInstructionIdSpillBegin: + return ir_analyze_instruction_spill_begin(ira, (IrInstructionSpillBegin *)instruction); + case IrInstructionIdSpillEnd: + return ir_analyze_instruction_spill_end(ira, (IrInstructionSpillEnd *)instruction); } zig_unreachable(); } @@ -25193,6 +25249,7 @@ bool ir_has_side_effects(IrInstruction *instruction) { case IrInstructionIdCoroResume: case IrInstructionIdAwaitSrc: case IrInstructionIdAwaitGen: + case IrInstructionIdSpillBegin: return true; case IrInstructionIdPhi: @@ -25291,7 +25348,7 @@ bool ir_has_side_effects(IrInstruction *instruction) { case IrInstructionIdAllocaSrc: case IrInstructionIdAllocaGen: case IrInstructionIdTestCancelRequested: - case IrInstructionIdSpill: + case IrInstructionIdSpillEnd: return false; case IrInstructionIdAsm: diff --git a/src/ir_print.cpp b/src/ir_print.cpp index 39e781e4f..9d4570d79 100644 --- a/src/ir_print.cpp +++ b/src/ir_print.cpp @@ -1554,12 +1554,18 @@ static void ir_print_test_cancel_requested(IrPrint *irp, IrInstructionTestCancel fprintf(irp->f, "@testCancelRequested()"); } -static void ir_print_spill(IrPrint *irp, IrInstructionSpill *instruction) { - fprintf(irp->f, "@spill("); +static void ir_print_spill_begin(IrPrint *irp, IrInstructionSpillBegin *instruction) { + fprintf(irp->f, "@spillBegin("); ir_print_other_instruction(irp, instruction->operand); fprintf(irp->f, ")"); } +static void ir_print_spill_end(IrPrint *irp, IrInstructionSpillEnd *instruction) { + fprintf(irp->f, "@spillEnd("); + ir_print_other_instruction(irp, &instruction->begin->base); + fprintf(irp->f, ")"); +} + static void ir_print_instruction(IrPrint *irp, IrInstruction *instruction) { ir_print_prefix(irp, instruction); switch (instruction->id) { @@ -2045,8 +2051,11 @@ static void ir_print_instruction(IrPrint *irp, IrInstruction *instruction) { case IrInstructionIdTestCancelRequested: ir_print_test_cancel_requested(irp, (IrInstructionTestCancelRequested *)instruction); break; - case IrInstructionIdSpill: - ir_print_spill(irp, (IrInstructionSpill *)instruction); + case IrInstructionIdSpillBegin: + ir_print_spill_begin(irp, (IrInstructionSpillBegin *)instruction); + break; + case IrInstructionIdSpillEnd: + ir_print_spill_end(irp, (IrInstructionSpillEnd *)instruction); break; } fprintf(irp->f, "\n"); diff --git a/test/stage1/behavior/coroutines.zig b/test/stage1/behavior/coroutines.zig index c92cca957..a1828a662 100644 --- a/test/stage1/behavior/coroutines.zig +++ b/test/stage1/behavior/coroutines.zig @@ -642,3 +642,33 @@ test "combining try with errdefer cancel" { }; S.doTheTest(); } + +test "try in an async function with error union and non-zero-bit payload" { + const S = struct { + var frame: anyframe = undefined; + var ok = false; + + fn doTheTest() void { + _ = async amain(); + resume frame; + expect(ok); + } + + fn amain() void { + std.testing.expectError(error.Bad, theProblem()); + ok = true; + } + + fn theProblem() ![]u8 { + frame = @frame(); + suspend; + const result = try other(); + return result; + } + + fn other() ![]u8 { + return error.Bad; + } + }; + S.doTheTest(); +}