fix try in an async function with error union and non-zero-bit payload

This commit is contained in:
Andrew Kelley 2019-08-10 15:20:08 -04:00
parent b9d1d45dfd
commit 22428a7546
No known key found for this signature in database
GPG Key ID: 7C5F548F728501A9
6 changed files with 187 additions and 57 deletions

View File

@ -74,6 +74,7 @@ struct IrExecutable {
bool invalid;
bool is_inline;
bool is_generic_instantiation;
bool need_err_code_spill;
};
enum OutType {
@ -1384,6 +1385,7 @@ struct ZigFn {
size_t prealloc_backward_branch_quota;
AstNode **param_source_nodes;
Buf **param_names;
IrInstruction *err_code_spill;
AstNode *fn_no_inline_set_node;
AstNode *fn_static_eval_set_node;
@ -2366,7 +2368,8 @@ enum IrInstructionId {
IrInstructionIdAwaitGen,
IrInstructionIdCoroResume,
IrInstructionIdTestCancelRequested,
IrInstructionIdSpill,
IrInstructionIdSpillBegin,
IrInstructionIdSpillEnd,
};
struct IrInstruction {
@ -3649,13 +3652,19 @@ enum SpillId {
SpillIdRetErrCode,
};
struct IrInstructionSpill {
struct IrInstructionSpillBegin {
IrInstruction base;
SpillId spill_id;
IrInstruction *operand;
};
struct IrInstructionSpillEnd {
IrInstruction base;
IrInstructionSpillBegin *begin;
};
enum ResultLocId {
ResultLocIdInvalid,
ResultLocIdNone,

View File

@ -5190,6 +5190,18 @@ static Error resolve_coro_frame(CodeGen *g, ZigType *frame_type) {
}
ZigType *fn_type = get_async_fn_type(g, fn->type_entry);
if (fn->analyzed_executable.need_err_code_spill) {
IrInstructionAllocaGen *alloca_gen = allocate<IrInstructionAllocaGen>(1);
alloca_gen->base.id = IrInstructionIdAllocaGen;
alloca_gen->base.source_node = fn->proto_node;
alloca_gen->base.scope = fn->child_scope;
alloca_gen->base.value.type = get_pointer_to_type(g, g->builtin_types.entry_global_error_set, false);
alloca_gen->base.ref_count = 1;
alloca_gen->name_hint = "";
fn->alloca_gen_list.append(alloca_gen);
fn->err_code_spill = &alloca_gen->base;
}
for (size_t i = 0; i < fn->call_list.length; i += 1) {
IrInstructionCallGen *call = fn->call_list.at(i);
ZigFn *callee = call->fn_entry;

View File

@ -2274,16 +2274,16 @@ static LLVMValueRef gen_maybe_atomic_op(CodeGen *g, LLVMAtomicRMWBinOp op, LLVMV
static LLVMValueRef ir_render_return_begin(CodeGen *g, IrExecutable *executable,
IrInstructionReturnBegin *instruction)
{
bool ret_type_has_bits = instruction->operand != nullptr &&
type_has_bits(instruction->operand->value.type);
ZigType *operand_type = (instruction->operand != nullptr) ? instruction->operand->value.type : nullptr;
bool operand_has_bits = (operand_type != nullptr) && type_has_bits(operand_type);
if (!fn_is_async(g->cur_fn)) {
return ret_type_has_bits ? ir_llvm_value(g, instruction->operand) : nullptr;
return operand_has_bits ? ir_llvm_value(g, instruction->operand) : nullptr;
}
ZigType *ret_type = g->cur_fn->type_entry->data.fn.fn_type_id.return_type;
bool ret_type_has_bits = type_has_bits(ret_type);
LLVMTypeRef usize_type_ref = g->builtin_types.entry_usize->llvm_type;
ZigType *ret_type = ret_type_has_bits ? instruction->operand->value.type : nullptr;
if (ret_type_has_bits && !handle_is_ptr(ret_type)) {
// It's a scalar, so it didn't get written to the result ptr. Do that before the atomic rmw.
LLVMBuildStore(g->builder, ir_llvm_value(g, instruction->operand), g->cur_ret_ptr);
@ -2333,11 +2333,11 @@ static LLVMValueRef ir_render_return_begin(CodeGen *g, IrExecutable *executable,
g->cur_is_after_return = true;
LLVMBuildStore(g->builder, g->cur_async_prev_val, g->cur_async_prev_val_field_ptr);
if (!ret_type_has_bits) {
if (!operand_has_bits) {
return nullptr;
}
return get_handle_value(g, g->cur_ret_ptr, ret_type, get_pointer_to_type(g, ret_type, true));
return get_handle_value(g, g->cur_ret_ptr, operand_type, get_pointer_to_type(g, operand_type, true));
}
static LLVMValueRef ir_render_return(CodeGen *g, IrExecutable *executable, IrInstructionReturn *instruction) {
@ -5113,18 +5113,6 @@ static LLVMValueRef ir_render_test_err(CodeGen *g, IrExecutable *executable, IrI
return LLVMBuildICmp(g->builder, LLVMIntNE, err_val, zero, "");
}
static LLVMValueRef gen_unwrap_err_code(CodeGen *g, LLVMValueRef err_union_ptr, ZigType *ptr_type) {
ZigType *err_union_type = ptr_type->data.pointer.child_type;
ZigType *payload_type = err_union_type->data.error_union.payload_type;
if (!type_has_bits(payload_type)) {
return err_union_ptr;
} else {
// TODO assign undef to the payload
LLVMValueRef err_union_handle = get_handle_value(g, err_union_ptr, err_union_type, ptr_type);
return LLVMBuildStructGEP(g->builder, err_union_handle, err_union_err_index, "");
}
}
static LLVMValueRef ir_render_unwrap_err_code(CodeGen *g, IrExecutable *executable,
IrInstructionUnwrapErrCode *instruction)
{
@ -5133,8 +5121,16 @@ static LLVMValueRef ir_render_unwrap_err_code(CodeGen *g, IrExecutable *executab
ZigType *ptr_type = instruction->err_union_ptr->value.type;
assert(ptr_type->id == ZigTypeIdPointer);
ZigType *err_union_type = ptr_type->data.pointer.child_type;
ZigType *payload_type = err_union_type->data.error_union.payload_type;
LLVMValueRef err_union_ptr = ir_llvm_value(g, instruction->err_union_ptr);
return gen_unwrap_err_code(g, err_union_ptr, ptr_type);
if (!type_has_bits(payload_type)) {
return err_union_ptr;
} else {
// TODO assign undef to the payload
LLVMValueRef err_union_handle = get_handle_value(g, err_union_ptr, err_union_type, ptr_type);
return LLVMBuildStructGEP(g->builder, err_union_handle, err_union_err_index, "");
}
}
static LLVMValueRef ir_render_unwrap_err_payload(CodeGen *g, IrExecutable *executable,
@ -5615,21 +5611,36 @@ static LLVMValueRef ir_render_test_cancel_requested(CodeGen *g, IrExecutable *ex
}
}
static LLVMValueRef ir_render_spill(CodeGen *g, IrExecutable *executable, IrInstructionSpill *instruction) {
static LLVMValueRef ir_render_spill_begin(CodeGen *g, IrExecutable *executable,
IrInstructionSpillBegin *instruction)
{
if (!fn_is_async(g->cur_fn))
return ir_llvm_value(g, instruction->operand);
return nullptr;
switch (instruction->spill_id) {
case SpillIdInvalid:
zig_unreachable();
case SpillIdRetErrCode: {
LLVMValueRef ret_ptr = LLVMBuildLoad(g->builder, g->cur_ret_ptr, "");
ZigType *ret_type = g->cur_fn->type_entry->data.fn.fn_type_id.return_type;
if (ret_type->id == ZigTypeIdErrorUnion) {
return gen_unwrap_err_code(g, ret_ptr, get_pointer_to_type(g, ret_type, true));
} else {
zig_unreachable();
}
LLVMValueRef operand = ir_llvm_value(g, instruction->operand);
LLVMValueRef ptr = ir_llvm_value(g, g->cur_fn->err_code_spill);
LLVMBuildStore(g->builder, operand, ptr);
return nullptr;
}
}
zig_unreachable();
}
static LLVMValueRef ir_render_spill_end(CodeGen *g, IrExecutable *executable, IrInstructionSpillEnd *instruction) {
if (!fn_is_async(g->cur_fn))
return ir_llvm_value(g, instruction->begin->operand);
switch (instruction->begin->spill_id) {
case SpillIdInvalid:
zig_unreachable();
case SpillIdRetErrCode: {
LLVMValueRef ptr = ir_llvm_value(g, g->cur_fn->err_code_spill);
return LLVMBuildLoad(g->builder, ptr, "");
}
}
@ -5891,8 +5902,10 @@ static LLVMValueRef ir_render_instruction(CodeGen *g, IrExecutable *executable,
return ir_render_await(g, executable, (IrInstructionAwaitGen *)instruction);
case IrInstructionIdTestCancelRequested:
return ir_render_test_cancel_requested(g, executable, (IrInstructionTestCancelRequested *)instruction);
case IrInstructionIdSpill:
return ir_render_spill(g, executable, (IrInstructionSpill *)instruction);
case IrInstructionIdSpillBegin:
return ir_render_spill_begin(g, executable, (IrInstructionSpillBegin *)instruction);
case IrInstructionIdSpillEnd:
return ir_render_spill_end(g, executable, (IrInstructionSpillEnd *)instruction);
}
zig_unreachable();
}

View File

@ -1066,8 +1066,12 @@ static constexpr IrInstructionId ir_instruction_id(IrInstructionTestCancelReques
return IrInstructionIdTestCancelRequested;
}
static constexpr IrInstructionId ir_instruction_id(IrInstructionSpill *) {
return IrInstructionIdSpill;
static constexpr IrInstructionId ir_instruction_id(IrInstructionSpillBegin *) {
return IrInstructionIdSpillBegin;
}
static constexpr IrInstructionId ir_instruction_id(IrInstructionSpillEnd *) {
return IrInstructionIdSpillEnd;
}
template<typename T>
@ -3336,15 +3340,28 @@ static IrInstruction *ir_build_test_cancel_requested(IrBuilder *irb, Scope *scop
return &instruction->base;
}
static IrInstruction *ir_build_spill(IrBuilder *irb, Scope *scope, AstNode *source_node,
static IrInstructionSpillBegin *ir_build_spill_begin(IrBuilder *irb, Scope *scope, AstNode *source_node,
IrInstruction *operand, SpillId spill_id)
{
IrInstructionSpill *instruction = ir_build_instruction<IrInstructionSpill>(irb, scope, source_node);
IrInstructionSpillBegin *instruction = ir_build_instruction<IrInstructionSpillBegin>(irb, scope, source_node);
instruction->base.value.special = ConstValSpecialStatic;
instruction->base.value.type = irb->codegen->builtin_types.entry_void;
instruction->operand = operand;
instruction->spill_id = spill_id;
ir_ref_instruction(operand, irb->current_basic_block);
return instruction;
}
static IrInstruction *ir_build_spill_end(IrBuilder *irb, Scope *scope, AstNode *source_node,
IrInstructionSpillBegin *begin)
{
IrInstructionSpillEnd *instruction = ir_build_instruction<IrInstructionSpillEnd>(irb, scope, source_node);
instruction->begin = begin;
ir_ref_instruction(&begin->base, irb->current_basic_block);
return &instruction->base;
}
@ -3602,14 +3619,15 @@ static IrInstruction *ir_gen_return(IrBuilder *irb, Scope *scope, AstNode *node,
IrInstruction *err_val_ptr = ir_build_unwrap_err_code(irb, scope, node, err_union_ptr);
IrInstruction *err_val = ir_build_load_ptr(irb, scope, node, err_val_ptr);
ir_mark_gen(ir_build_add_implicit_return_type(irb, scope, node, err_val));
err_val = ir_build_return_begin(irb, scope, node, err_val);
IrInstructionSpillBegin *spill_begin = ir_build_spill_begin(irb, scope, node, err_val,
SpillIdRetErrCode);
ir_build_return_begin(irb, scope, node, err_val);
err_val = ir_build_spill_end(irb, scope, node, spill_begin);
ResultLocReturn *result_loc_ret = allocate<ResultLocReturn>(1);
result_loc_ret->base.id = ResultLocIdReturn;
ir_build_reset_result(irb, scope, node, &result_loc_ret->base);
ir_build_end_expr(irb, scope, node, err_val, &result_loc_ret->base);
if (!ir_gen_defers_for_block(irb, scope, outer_scope, true)) {
ResultLocReturn *result_loc_ret = allocate<ResultLocReturn>(1);
result_loc_ret->base.id = ResultLocIdReturn;
ir_build_reset_result(irb, scope, node, &result_loc_ret->base);
err_val = ir_build_spill(irb, scope, node, err_val, SpillIdRetErrCode);
ir_build_end_expr(irb, scope, node, err_val, &result_loc_ret->base);
if (irb->codegen->have_err_ret_tracing && !should_inline) {
ir_build_save_err_ret_addr(irb, scope, node);
}
@ -12778,8 +12796,21 @@ static IrInstruction *ir_analyze_instruction_return(IrAnalyze *ira, IrInstructio
return ir_finish_anal(ira, result);
}
// This cast might have been already done from IrInstructionReturnBegin but it also
// might not have, in the case of `try`.
IrInstruction *casted_operand = ir_implicit_cast(ira, operand, ira->explicit_return_type);
if (type_is_invalid(casted_operand->value.type)) {
AstNode *source_node = ira->explicit_return_type_source_node;
if (source_node != nullptr) {
ErrorMsg *msg = ira->codegen->errors.last();
add_error_note(ira->codegen, msg, source_node,
buf_sprintf("return type declared here"));
}
return ir_unreach_error(ira);
}
IrInstruction *result = ir_build_return(&ira->new_irb, instruction->base.scope,
instruction->base.source_node, operand);
instruction->base.source_node, casted_operand);
result->value.type = ira->codegen->builtin_types.entry_unreachable;
return ir_finish_anal(ira, result);
}
@ -24742,15 +24773,38 @@ static IrInstruction *ir_analyze_instruction_test_cancel_requested(IrAnalyze *ir
return ir_build_test_cancel_requested(&ira->new_irb, instruction->base.scope, instruction->base.source_node);
}
static IrInstruction *ir_analyze_instruction_spill(IrAnalyze *ira, IrInstructionSpill *instruction) {
static IrInstruction *ir_analyze_instruction_spill_begin(IrAnalyze *ira, IrInstructionSpillBegin *instruction) {
if (ir_should_inline(ira->new_irb.exec, instruction->base.scope))
return ir_const_void(ira, &instruction->base);
IrInstruction *operand = instruction->operand->child;
if (type_is_invalid(operand->value.type))
return ira->codegen->invalid_instruction;
if (ir_should_inline(ira->new_irb.exec, instruction->base.scope)) {
if (!type_has_bits(operand->value.type))
return ir_const_void(ira, &instruction->base);
ir_assert(instruction->spill_id == SpillIdRetErrCode, &instruction->base);
ira->new_irb.exec->need_err_code_spill = true;
IrInstructionSpillBegin *result = ir_build_spill_begin(&ira->new_irb, instruction->base.scope,
instruction->base.source_node, operand, instruction->spill_id);
return &result->base;
}
static IrInstruction *ir_analyze_instruction_spill_end(IrAnalyze *ira, IrInstructionSpillEnd *instruction) {
IrInstruction *operand = instruction->begin->operand->child;
if (type_is_invalid(operand->value.type))
return ira->codegen->invalid_instruction;
if (ir_should_inline(ira->new_irb.exec, instruction->base.scope) || !type_has_bits(operand->value.type))
return operand;
}
IrInstruction *result = ir_build_spill(&ira->new_irb, instruction->base.scope, instruction->base.source_node,
operand, instruction->spill_id);
ir_assert(instruction->begin->base.child->id == IrInstructionIdSpillBegin, &instruction->base);
IrInstructionSpillBegin *begin = reinterpret_cast<IrInstructionSpillBegin *>(instruction->begin->base.child);
IrInstruction *result = ir_build_spill_end(&ira->new_irb, instruction->base.scope,
instruction->base.source_node, begin);
result->value.type = operand->value.type;
return result;
}
@ -25054,8 +25108,10 @@ static IrInstruction *ir_analyze_instruction_base(IrAnalyze *ira, IrInstruction
return ir_analyze_instruction_await(ira, (IrInstructionAwaitSrc *)instruction);
case IrInstructionIdTestCancelRequested:
return ir_analyze_instruction_test_cancel_requested(ira, (IrInstructionTestCancelRequested *)instruction);
case IrInstructionIdSpill:
return ir_analyze_instruction_spill(ira, (IrInstructionSpill *)instruction);
case IrInstructionIdSpillBegin:
return ir_analyze_instruction_spill_begin(ira, (IrInstructionSpillBegin *)instruction);
case IrInstructionIdSpillEnd:
return ir_analyze_instruction_spill_end(ira, (IrInstructionSpillEnd *)instruction);
}
zig_unreachable();
}
@ -25193,6 +25249,7 @@ bool ir_has_side_effects(IrInstruction *instruction) {
case IrInstructionIdCoroResume:
case IrInstructionIdAwaitSrc:
case IrInstructionIdAwaitGen:
case IrInstructionIdSpillBegin:
return true;
case IrInstructionIdPhi:
@ -25291,7 +25348,7 @@ bool ir_has_side_effects(IrInstruction *instruction) {
case IrInstructionIdAllocaSrc:
case IrInstructionIdAllocaGen:
case IrInstructionIdTestCancelRequested:
case IrInstructionIdSpill:
case IrInstructionIdSpillEnd:
return false;
case IrInstructionIdAsm:

View File

@ -1554,12 +1554,18 @@ static void ir_print_test_cancel_requested(IrPrint *irp, IrInstructionTestCancel
fprintf(irp->f, "@testCancelRequested()");
}
static void ir_print_spill(IrPrint *irp, IrInstructionSpill *instruction) {
fprintf(irp->f, "@spill(");
static void ir_print_spill_begin(IrPrint *irp, IrInstructionSpillBegin *instruction) {
fprintf(irp->f, "@spillBegin(");
ir_print_other_instruction(irp, instruction->operand);
fprintf(irp->f, ")");
}
static void ir_print_spill_end(IrPrint *irp, IrInstructionSpillEnd *instruction) {
fprintf(irp->f, "@spillEnd(");
ir_print_other_instruction(irp, &instruction->begin->base);
fprintf(irp->f, ")");
}
static void ir_print_instruction(IrPrint *irp, IrInstruction *instruction) {
ir_print_prefix(irp, instruction);
switch (instruction->id) {
@ -2045,8 +2051,11 @@ static void ir_print_instruction(IrPrint *irp, IrInstruction *instruction) {
case IrInstructionIdTestCancelRequested:
ir_print_test_cancel_requested(irp, (IrInstructionTestCancelRequested *)instruction);
break;
case IrInstructionIdSpill:
ir_print_spill(irp, (IrInstructionSpill *)instruction);
case IrInstructionIdSpillBegin:
ir_print_spill_begin(irp, (IrInstructionSpillBegin *)instruction);
break;
case IrInstructionIdSpillEnd:
ir_print_spill_end(irp, (IrInstructionSpillEnd *)instruction);
break;
}
fprintf(irp->f, "\n");

View File

@ -642,3 +642,33 @@ test "combining try with errdefer cancel" {
};
S.doTheTest();
}
test "try in an async function with error union and non-zero-bit payload" {
const S = struct {
var frame: anyframe = undefined;
var ok = false;
fn doTheTest() void {
_ = async amain();
resume frame;
expect(ok);
}
fn amain() void {
std.testing.expectError(error.Bad, theProblem());
ok = true;
}
fn theProblem() ![]u8 {
frame = @frame();
suspend;
const result = try other();
return result;
}
fn other() ![]u8 {
return error.Bad;
}
};
S.doTheTest();
}