diff --git a/BRANCH_TODO b/BRANCH_TODO index f5db81a08..6d6ae4252 100644 --- a/BRANCH_TODO +++ b/BRANCH_TODO @@ -1,4 +1,5 @@ - * await + * compile error for error: expected anyframe->T, found 'anyframe' + * compile error for error: expected anyframe->T, found 'i32' * await of a non async function * await in single-threaded mode * async call on a non async function @@ -13,3 +14,6 @@ * @typeInfo for @Frame(func) * peer type resolution of *@Frame(func) and anyframe * peer type resolution of *@Frame(func) and anyframe->T when the return type matches + * returning a value from within a suspend block + * struct types as the return type of an async function. make sure it works with return result locations. + * make resuming inside a suspend block, with nothing after it, a must-tail call. diff --git a/src/all_types.hpp b/src/all_types.hpp index e66c9aebf..9ab90b278 100644 --- a/src/all_types.hpp +++ b/src/all_types.hpp @@ -1550,6 +1550,8 @@ enum PanicMsgId { PanicMsgIdFloatToInt, PanicMsgIdPtrCastNull, PanicMsgIdBadResume, + PanicMsgIdBadAwait, + PanicMsgIdBadReturn, PanicMsgIdCount, }; @@ -1795,7 +1797,6 @@ struct CodeGen { ZigType *entry_arg_tuple; ZigType *entry_enum_literal; ZigType *entry_any_frame; - ZigType *entry_async_fn; } builtin_types; ZigType *align_amt_type; @@ -2348,6 +2349,7 @@ enum IrInstructionId { IrInstructionIdUnionInitNamedField, IrInstructionIdSuspendBegin, IrInstructionIdSuspendBr, + IrInstructionIdAwait, IrInstructionIdCoroResume, }; @@ -3600,6 +3602,12 @@ struct IrInstructionSuspendBr { IrBasicBlock *resume_block; }; +struct IrInstructionAwait { + IrInstruction base; + + IrInstruction *frame; +}; + struct IrInstructionCoroResume { IrInstruction base; diff --git a/src/analyze.cpp b/src/analyze.cpp index 99caf9688..5af9698dd 100644 --- a/src/analyze.cpp +++ b/src/analyze.cpp @@ -3807,6 +3807,9 @@ static void add_async_error_notes(CodeGen *g, ErrorMsg *msg, ZigFn *fn) { } else if (fn->inferred_async_node->type == NodeTypeSuspend) { add_error_note(g, msg, fn->inferred_async_node, buf_sprintf("suspends here")); + } else if (fn->inferred_async_node->type == NodeTypeAwaitExpr) { + add_error_note(g, msg, fn->inferred_async_node, + buf_sprintf("await is a suspend point")); } else { zig_unreachable(); } @@ -7361,7 +7364,7 @@ static void resolve_llvm_types_fn_type(CodeGen *g, ZigType *fn_type) { param_di_types.append(get_llvm_di_type(g, gen_type)); } if (is_async) { - fn_type->data.fn.gen_param_info = allocate(1); + fn_type->data.fn.gen_param_info = allocate(2); ZigType *frame_type = get_any_frame_type(g, fn_type_id->return_type); gen_param_types.append(get_llvm_type(g, frame_type)); @@ -7370,6 +7373,13 @@ static void resolve_llvm_types_fn_type(CodeGen *g, ZigType *fn_type) { fn_type->data.fn.gen_param_info[0].src_index = 0; fn_type->data.fn.gen_param_info[0].gen_index = 0; fn_type->data.fn.gen_param_info[0].type = frame_type; + + gen_param_types.append(get_llvm_type(g, g->builtin_types.entry_usize)); + param_di_types.append(get_llvm_di_type(g, g->builtin_types.entry_usize)); + + fn_type->data.fn.gen_param_info[1].src_index = 1; + fn_type->data.fn.gen_param_info[1].gen_index = 1; + fn_type->data.fn.gen_param_info[1].type = g->builtin_types.entry_usize; } else { fn_type->data.fn.gen_param_info = allocate(fn_type_id->param_count); for (size_t i = 0; i < fn_type_id->param_count; i += 1) { @@ -7434,15 +7444,21 @@ void resolve_llvm_types_fn(CodeGen *g, ZigFn *fn) { ZigType *gen_return_type = g->builtin_types.entry_void; ZigList param_di_types = {}; + ZigList gen_param_types = {}; // first "parameter" is return value param_di_types.append(get_llvm_di_type(g, gen_return_type)); ZigType *frame_type = get_coro_frame_type(g, fn); ZigType *ptr_type = get_pointer_to_type(g, frame_type, false); - LLVMTypeRef gen_param_type = get_llvm_type(g, ptr_type); + gen_param_types.append(get_llvm_type(g, ptr_type)); param_di_types.append(get_llvm_di_type(g, ptr_type)); - fn->raw_type_ref = LLVMFunctionType(get_llvm_type(g, gen_return_type), &gen_param_type, 1, false); + // this parameter is used to pass the result pointer when await completes + gen_param_types.append(get_llvm_type(g, g->builtin_types.entry_usize)); + param_di_types.append(get_llvm_di_type(g, g->builtin_types.entry_usize)); + + fn->raw_type_ref = LLVMFunctionType(get_llvm_type(g, gen_return_type), + gen_param_types.items, gen_param_types.length, false); fn->raw_di_type = ZigLLVMCreateSubroutineType(g->dbuilder, param_di_types.items, (int)param_di_types.length, 0); } @@ -7493,7 +7509,8 @@ static void resolve_llvm_types_any_frame(CodeGen *g, ZigType *any_frame_type, Re 8*g->pointer_size_bytes, 8*g->builtin_types.entry_usize->abi_align, buf_ptr(&any_frame_type->name)); LLVMTypeRef llvm_void = LLVMVoidType(); - LLVMTypeRef fn_type = LLVMFunctionType(llvm_void, &any_frame_type->llvm_type, 1, false); + LLVMTypeRef arg_types[] = {any_frame_type->llvm_type, g->builtin_types.entry_usize->llvm_type}; + LLVMTypeRef fn_type = LLVMFunctionType(llvm_void, arg_types, 2, false); LLVMTypeRef usize_type_ref = get_llvm_type(g, g->builtin_types.entry_usize); ZigLLVMDIType *usize_di_type = get_llvm_di_type(g, g->builtin_types.entry_usize); ZigLLVMDIScope *compile_unit_scope = ZigLLVMCompileUnitToScope(g->compile_unit); diff --git a/src/ast_render.cpp b/src/ast_render.cpp index 4d6bae311..98e11e24c 100644 --- a/src/ast_render.cpp +++ b/src/ast_render.cpp @@ -1149,9 +1149,11 @@ static void render_node_extra(AstRender *ar, AstNode *node, bool grouped) { } case NodeTypeSuspend: { - fprintf(ar->f, "suspend"); if (node->data.suspend.block != nullptr) { + fprintf(ar->f, "suspend "); render_node_grouped(ar, node->data.suspend.block); + } else { + fprintf(ar->f, "suspend\n"); } break; } diff --git a/src/codegen.cpp b/src/codegen.cpp index d0aadaabe..6fe46acbb 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -873,6 +873,10 @@ static Buf *panic_msg_buf(PanicMsgId msg_id) { return buf_create_from_str("cast causes pointer to be null"); case PanicMsgIdBadResume: return buf_create_from_str("invalid resume of async function"); + case PanicMsgIdBadAwait: + return buf_create_from_str("async function awaited twice"); + case PanicMsgIdBadReturn: + return buf_create_from_str("async function returned twice"); } zig_unreachable(); } @@ -1991,14 +1995,66 @@ static LLVMValueRef ir_render_save_err_ret_addr(CodeGen *g, IrExecutable *execut static LLVMValueRef ir_render_return(CodeGen *g, IrExecutable *executable, IrInstructionReturn *return_instruction) { if (fn_is_async(g->cur_fn)) { + LLVMTypeRef usize_type_ref = g->builtin_types.entry_usize->llvm_type; + LLVMValueRef locals_ptr = g->cur_ret_ptr; + bool ret_type_has_bits = return_instruction->value != nullptr && + type_has_bits(return_instruction->value->value.type); + ZigType *ret_type = ret_type_has_bits ? return_instruction->value->value.type : nullptr; + if (ir_want_runtime_safety(g, &return_instruction->base)) { - LLVMValueRef locals_ptr = g->cur_ret_ptr; LLVMValueRef resume_index_ptr = LLVMBuildStructGEP(g->builder, locals_ptr, coro_fn_ptr_index, ""); LLVMValueRef new_resume_fn = g->cur_fn->resume_blocks.last()->split_llvm_fn; LLVMBuildStore(g->builder, new_resume_fn, resume_index_ptr); } + LLVMValueRef awaiter_ptr = LLVMBuildStructGEP(g->builder, locals_ptr, coro_awaiter_index, ""); + LLVMValueRef result_ptr_as_usize; + if (ret_type_has_bits) { + LLVMValueRef result_ptr_ptr = LLVMBuildStructGEP(g->builder, locals_ptr, coro_arg_start, ""); + LLVMValueRef result_ptr = LLVMBuildLoad(g->builder, result_ptr_ptr, ""); + if (!handle_is_ptr(ret_type)) { + // It's a scalar, so it didn't get written to the result ptr. Do that now. + LLVMBuildStore(g->builder, ir_llvm_value(g, return_instruction->value), result_ptr); + } + result_ptr_as_usize = LLVMBuildPtrToInt(g->builder, result_ptr, usize_type_ref, ""); + } else { + result_ptr_as_usize = LLVMGetUndef(usize_type_ref); + } + LLVMValueRef zero = LLVMConstNull(usize_type_ref); + LLVMValueRef all_ones = LLVMConstAllOnes(usize_type_ref); + LLVMValueRef prev_val = LLVMBuildAtomicRMW(g->builder, LLVMAtomicRMWBinOpXchg, awaiter_ptr, + all_ones, LLVMAtomicOrderingSequentiallyConsistent, g->is_single_threaded); + + LLVMBasicBlockRef bad_return_block = LLVMAppendBasicBlock(g->cur_fn_val, "BadReturn"); + LLVMBasicBlockRef early_return_block = LLVMAppendBasicBlock(g->cur_fn_val, "EarlyReturn"); + LLVMBasicBlockRef resume_them_block = LLVMAppendBasicBlock(g->cur_fn_val, "ResumeThem"); + + LLVMValueRef switch_instr = LLVMBuildSwitch(g->builder, prev_val, resume_them_block, 2); + + LLVMAddCase(switch_instr, zero, early_return_block); + LLVMAddCase(switch_instr, all_ones, bad_return_block); + + // Something has gone horribly wrong, and this is an invalid second return. + LLVMPositionBuilderAtEnd(g->builder, bad_return_block); + gen_assertion(g, PanicMsgIdBadReturn, &return_instruction->base); + + // The caller will deal with fetching the result - we're done. + LLVMPositionBuilderAtEnd(g->builder, early_return_block); LLVMBuildRetVoid(g->builder); + + // We need to resume the caller by tail calling them. + LLVMPositionBuilderAtEnd(g->builder, resume_them_block); + ZigType *any_frame_type = get_any_frame_type(g, ret_type); + LLVMValueRef their_frame_ptr = LLVMBuildIntToPtr(g->builder, prev_val, + get_llvm_type(g, any_frame_type), ""); + LLVMValueRef fn_ptr_ptr = LLVMBuildStructGEP(g->builder, their_frame_ptr, coro_fn_ptr_index, ""); + LLVMValueRef awaiter_fn = LLVMBuildLoad(g->builder, fn_ptr_ptr, ""); + LLVMValueRef args[] = {their_frame_ptr, result_ptr_as_usize}; + LLVMValueRef call_inst = ZigLLVMBuildCall(g->builder, awaiter_fn, args, 2, LLVMFastCallConv, + ZigLLVM_FnInlineAuto, ""); + ZigLLVMSetTailCall(call_inst); + LLVMBuildRetVoid(g->builder); + return nullptr; } if (want_first_arg_sret(g, &g->cur_fn->type_entry->data.fn.fn_type_id)) { @@ -3514,14 +3570,17 @@ static LLVMValueRef ir_render_call(CodeGen *g, IrExecutable *executable, IrInstr } } if (instruction->is_async) { - ZigLLVMBuildCall(g->builder, fn_val, &frame_result_loc, 1, llvm_cc, fn_inline, ""); + LLVMValueRef args[] = {frame_result_loc, LLVMGetUndef(g->builtin_types.entry_usize->llvm_type)}; + ZigLLVMBuildCall(g->builder, fn_val, args, 2, llvm_cc, fn_inline, ""); return nullptr; } else if (callee_is_async) { + ZigType *ptr_result_type = get_pointer_to_type(g, src_return_type, true); LLVMValueRef split_llvm_fn = make_fn_llvm_value(g, g->cur_fn); LLVMValueRef fn_ptr_ptr = LLVMBuildStructGEP(g->builder, g->cur_ret_ptr, coro_fn_ptr_index, ""); LLVMBuildStore(g->builder, split_llvm_fn, fn_ptr_ptr); - LLVMValueRef call_inst = ZigLLVMBuildCall(g->builder, fn_val, &frame_result_loc, 1, llvm_cc, fn_inline, ""); + LLVMValueRef args[] = {frame_result_loc, LLVMGetUndef(g->builtin_types.entry_usize->llvm_type)}; + LLVMValueRef call_inst = ZigLLVMBuildCall(g->builder, fn_val, args, 2, llvm_cc, fn_inline, ""); ZigLLVMSetTailCall(call_inst); LLVMBuildRetVoid(g->builder); @@ -3530,7 +3589,15 @@ static LLVMValueRef ir_render_call(CodeGen *g, IrExecutable *executable, IrInstr LLVMBasicBlockRef call_bb = LLVMAppendBasicBlock(split_llvm_fn, "CallResume"); LLVMPositionBuilderAtEnd(g->builder, call_bb); render_async_var_decls(g, instruction->base.scope); - return nullptr; + + if (type_has_bits(src_return_type)) { + LLVMValueRef spilled_result_ptr = LLVMGetParam(g->cur_fn_val, 1); + LLVMValueRef casted_spilled_result_ptr = LLVMBuildIntToPtr(g->builder, spilled_result_ptr, + get_llvm_type(g, ptr_result_type), ""); + return get_handle_value(g, casted_spilled_result_ptr, src_return_type, ptr_result_type); + } else { + return nullptr; + } } if (instruction->new_stack == nullptr) { @@ -4829,7 +4896,7 @@ static LLVMValueRef ir_render_atomic_rmw(CodeGen *g, IrExecutable *executable, LLVMValueRef operand = ir_llvm_value(g, instruction->operand); if (get_codegen_ptr_type(operand_type) == nullptr) { - return LLVMBuildAtomicRMW(g->builder, op, ptr, operand, ordering, false); + return LLVMBuildAtomicRMW(g->builder, op, ptr, operand, ordering, g->is_single_threaded); } // it's a pointer but we need to treat it as an int @@ -4990,14 +5057,89 @@ static LLVMValueRef ir_render_suspend_br(CodeGen *g, IrExecutable *executable, return nullptr; } +static LLVMValueRef ir_render_await(CodeGen *g, IrExecutable *executable, IrInstructionAwait *instruction) { + LLVMValueRef target_frame_ptr = ir_llvm_value(g, instruction->frame); + ZigType *result_type = instruction->base.value.type; + ZigType *ptr_result_type = get_pointer_to_type(g, result_type, true); + + // Prepare to be suspended + LLVMValueRef split_llvm_fn = make_fn_llvm_value(g, g->cur_fn); + LLVMValueRef fn_ptr_ptr = LLVMBuildStructGEP(g->builder, g->cur_ret_ptr, coro_fn_ptr_index, ""); + LLVMBuildStore(g->builder, split_llvm_fn, fn_ptr_ptr); + + // At this point resuming the function will do the correct thing. + // This code is as if it is running inside the suspend block. + + LLVMTypeRef usize_type_ref = g->builtin_types.entry_usize->llvm_type; + // caller's own frame pointer + LLVMValueRef awaiter_init_val = LLVMBuildPtrToInt(g->builder, g->cur_ret_ptr, usize_type_ref, ""); + LLVMValueRef awaiter_ptr = LLVMBuildStructGEP(g->builder, target_frame_ptr, coro_awaiter_index, ""); + LLVMValueRef result_ptr_as_usize; + if (type_has_bits(result_type)) { + LLVMValueRef result_ptr_ptr = LLVMBuildStructGEP(g->builder, target_frame_ptr, coro_arg_start, ""); + LLVMValueRef result_ptr = LLVMBuildLoad(g->builder, result_ptr_ptr, ""); + result_ptr_as_usize = LLVMBuildPtrToInt(g->builder, result_ptr, usize_type_ref, ""); + } else { + result_ptr_as_usize = LLVMGetUndef(usize_type_ref); + } + LLVMValueRef prev_val = LLVMBuildAtomicRMW(g->builder, LLVMAtomicRMWBinOpXchg, awaiter_ptr, awaiter_init_val, + LLVMAtomicOrderingSequentiallyConsistent, g->is_single_threaded); + + LLVMBasicBlockRef bad_await_block = LLVMAppendBasicBlock(g->cur_fn_val, "BadAwait"); + LLVMBasicBlockRef complete_suspend_block = LLVMAppendBasicBlock(g->cur_fn_val, "CompleteSuspend"); + LLVMBasicBlockRef early_return_block = LLVMAppendBasicBlock(g->cur_fn_val, "EarlyReturn"); + + LLVMValueRef zero = LLVMConstNull(usize_type_ref); + LLVMValueRef all_ones = LLVMConstAllOnes(usize_type_ref); + LLVMValueRef switch_instr = LLVMBuildSwitch(g->builder, prev_val, bad_await_block, 2); + + LLVMAddCase(switch_instr, zero, complete_suspend_block); + LLVMAddCase(switch_instr, all_ones, early_return_block); + + // We discovered that another awaiter was already here. + LLVMPositionBuilderAtEnd(g->builder, bad_await_block); + gen_assertion(g, PanicMsgIdBadAwait, &instruction->base); + + // Rely on the target to resume us from suspension. + LLVMPositionBuilderAtEnd(g->builder, complete_suspend_block); + LLVMBuildRetVoid(g->builder); + + // The async function has already completed. So we use a tail call to resume ourselves. + LLVMPositionBuilderAtEnd(g->builder, early_return_block); + LLVMValueRef args[] = {g->cur_ret_ptr, result_ptr_as_usize}; + LLVMValueRef call_inst = ZigLLVMBuildCall(g->builder, split_llvm_fn, args, 2, LLVMFastCallConv, + ZigLLVM_FnInlineAuto, ""); + ZigLLVMSetTailCall(call_inst); + LLVMBuildRetVoid(g->builder); + + g->cur_fn_val = split_llvm_fn; + g->cur_ret_ptr = LLVMGetParam(split_llvm_fn, 0); + LLVMBasicBlockRef call_bb = LLVMAppendBasicBlock(split_llvm_fn, "AwaitResume"); + LLVMPositionBuilderAtEnd(g->builder, call_bb); + render_async_var_decls(g, instruction->base.scope); + + if (type_has_bits(result_type)) { + LLVMValueRef spilled_result_ptr = LLVMGetParam(g->cur_fn_val, 1); + LLVMValueRef casted_spilled_result_ptr = LLVMBuildIntToPtr(g->builder, spilled_result_ptr, + get_llvm_type(g, ptr_result_type), ""); + return get_handle_value(g, casted_spilled_result_ptr, result_type, ptr_result_type); + } else { + return nullptr; + } +} + static LLVMTypeRef anyframe_fn_type(CodeGen *g) { if (g->anyframe_fn_type != nullptr) return g->anyframe_fn_type; + LLVMTypeRef usize_type_ref = g->builtin_types.entry_usize->llvm_type; ZigType *anyframe_type = get_any_frame_type(g, nullptr); - LLVMTypeRef param_type = get_llvm_type(g, anyframe_type); LLVMTypeRef return_type = LLVMVoidType(); - LLVMTypeRef fn_type = LLVMFunctionType(return_type, ¶m_type, 1, false); + LLVMTypeRef param_types[] = { + get_llvm_type(g, anyframe_type), + usize_type_ref, + }; + LLVMTypeRef fn_type = LLVMFunctionType(return_type, param_types, 2, false); g->anyframe_fn_type = LLVMPointerType(fn_type, 0); return g->anyframe_fn_type; @@ -5006,13 +5148,15 @@ static LLVMTypeRef anyframe_fn_type(CodeGen *g) { static LLVMValueRef ir_render_coro_resume(CodeGen *g, IrExecutable *executable, IrInstructionCoroResume *instruction) { + LLVMTypeRef usize_type_ref = g->builtin_types.entry_usize->llvm_type; LLVMValueRef frame = ir_llvm_value(g, instruction->frame); ZigType *frame_type = instruction->frame->value.type; assert(frame_type->id == ZigTypeIdAnyFrame); LLVMValueRef fn_ptr_ptr = LLVMBuildStructGEP(g->builder, frame, coro_fn_ptr_index, ""); LLVMValueRef uncasted_fn_val = LLVMBuildLoad(g->builder, fn_ptr_ptr, ""); LLVMValueRef fn_val = LLVMBuildIntToPtr(g->builder, uncasted_fn_val, anyframe_fn_type(g), ""); - ZigLLVMBuildCall(g->builder, fn_val, &frame, 1, LLVMFastCallConv, ZigLLVM_FnInlineAuto, ""); + LLVMValueRef args[] = {frame, LLVMGetUndef(usize_type_ref)}; + ZigLLVMBuildCall(g->builder, fn_val, args, 2, LLVMFastCallConv, ZigLLVM_FnInlineAuto, ""); return nullptr; } @@ -5279,6 +5423,8 @@ static LLVMValueRef ir_render_instruction(CodeGen *g, IrExecutable *executable, return ir_render_coro_resume(g, executable, (IrInstructionCoroResume *)instruction); case IrInstructionIdFrameSizeGen: return ir_render_frame_size(g, executable, (IrInstructionFrameSizeGen *)instruction); + case IrInstructionIdAwait: + return ir_render_await(g, executable, (IrInstructionAwait *)instruction); } zig_unreachable(); } diff --git a/src/ir.cpp b/src/ir.cpp index 3a6853b03..ecd2cd6f8 100644 --- a/src/ir.cpp +++ b/src/ir.cpp @@ -1052,6 +1052,10 @@ static constexpr IrInstructionId ir_instruction_id(IrInstructionSuspendBr *) { return IrInstructionIdSuspendBr; } +static constexpr IrInstructionId ir_instruction_id(IrInstructionAwait *) { + return IrInstructionIdAwait; +} + static constexpr IrInstructionId ir_instruction_id(IrInstructionCoroResume *) { return IrInstructionIdCoroResume; } @@ -3274,6 +3278,17 @@ static IrInstruction *ir_build_suspend_br(IrBuilder *irb, Scope *scope, AstNode return &instruction->base; } +static IrInstruction *ir_build_await(IrBuilder *irb, Scope *scope, AstNode *source_node, + IrInstruction *frame) +{ + IrInstructionAwait *instruction = ir_build_instruction(irb, scope, source_node); + instruction->frame = frame; + + ir_ref_instruction(frame, irb->current_basic_block); + + return &instruction->base; +} + static IrInstruction *ir_build_coro_resume(IrBuilder *irb, Scope *scope, AstNode *source_node, IrInstruction *frame) { @@ -7774,11 +7789,26 @@ static IrInstruction *ir_gen_resume(IrBuilder *irb, Scope *scope, AstNode *node) static IrInstruction *ir_gen_await_expr(IrBuilder *irb, Scope *scope, AstNode *node) { assert(node->type == NodeTypeAwaitExpr); - IrInstruction *target_inst = ir_gen_node(irb, node->data.await_expr.expr, scope); + ZigFn *fn_entry = exec_fn_entry(irb->exec); + if (!fn_entry) { + add_node_error(irb->codegen, node, buf_sprintf("await outside function definition")); + return irb->codegen->invalid_instruction; + } + ScopeSuspend *existing_suspend_scope = get_scope_suspend(scope); + if (existing_suspend_scope) { + if (!existing_suspend_scope->reported_err) { + ErrorMsg *msg = add_node_error(irb->codegen, node, buf_sprintf("cannot await inside suspend block")); + add_error_note(irb->codegen, msg, existing_suspend_scope->base.source_node, buf_sprintf("suspend block here")); + existing_suspend_scope->reported_err = true; + } + return irb->codegen->invalid_instruction; + } + + IrInstruction *target_inst = ir_gen_node_extra(irb, node->data.await_expr.expr, scope, LValPtr, nullptr); if (target_inst == irb->codegen->invalid_instruction) return irb->codegen->invalid_instruction; - zig_panic("TODO ir_gen_await_expr"); + return ir_build_await(irb, scope, node, target_inst); } static IrInstruction *ir_gen_suspend(IrBuilder *irb, Scope *parent_scope, AstNode *node) { @@ -7789,15 +7819,6 @@ static IrInstruction *ir_gen_suspend(IrBuilder *irb, Scope *parent_scope, AstNod add_node_error(irb->codegen, node, buf_sprintf("suspend outside function definition")); return irb->codegen->invalid_instruction; } - ScopeDeferExpr *scope_defer_expr = get_scope_defer_expr(parent_scope); - if (scope_defer_expr) { - if (!scope_defer_expr->reported_err) { - ErrorMsg *msg = add_node_error(irb->codegen, node, buf_sprintf("cannot suspend inside defer expression")); - add_error_note(irb->codegen, msg, scope_defer_expr->base.source_node, buf_sprintf("defer here")); - scope_defer_expr->reported_err = true; - } - return irb->codegen->invalid_instruction; - } ScopeSuspend *existing_suspend_scope = get_scope_suspend(parent_scope); if (existing_suspend_scope) { if (!existing_suspend_scope->reported_err) { @@ -7808,7 +7829,7 @@ static IrInstruction *ir_gen_suspend(IrBuilder *irb, Scope *parent_scope, AstNod return irb->codegen->invalid_instruction; } - IrBasicBlock *resume_block = ir_create_basic_block(irb, parent_scope, "Resume"); + IrBasicBlock *resume_block = ir_create_basic_block(irb, parent_scope, "SuspendResume"); ir_build_suspend_begin(irb, parent_scope, node, resume_block); if (node->data.suspend.block != nullptr) { @@ -24372,6 +24393,49 @@ static IrInstruction *ir_analyze_instruction_suspend_br(IrAnalyze *ira, IrInstru return ir_finish_anal(ira, result); } +static IrInstruction *ir_analyze_instruction_await(IrAnalyze *ira, IrInstructionAwait *instruction) { + IrInstruction *frame_ptr = instruction->frame->child; + if (type_is_invalid(frame_ptr->value.type)) + return ira->codegen->invalid_instruction; + + ZigType *result_type; + IrInstruction *frame; + if (frame_ptr->value.type->id == ZigTypeIdPointer && + frame_ptr->value.type->data.pointer.ptr_len == PtrLenSingle && + frame_ptr->value.type->data.pointer.child_type->id == ZigTypeIdCoroFrame) + { + result_type = frame_ptr->value.type->data.pointer.child_type->data.frame.fn->type_entry->data.fn.fn_type_id.return_type; + frame = frame_ptr; + } else { + frame = ir_get_deref(ira, &instruction->base, frame_ptr, nullptr); + if (frame->value.type->id != ZigTypeIdAnyFrame || + frame->value.type->data.any_frame.result_type == nullptr) + { + ir_add_error(ira, &instruction->base, + buf_sprintf("expected anyframe->T, found '%s'", buf_ptr(&frame->value.type->name))); + return ira->codegen->invalid_instruction; + } + result_type = frame->value.type->data.any_frame.result_type; + } + + ZigType *any_frame_type = get_any_frame_type(ira->codegen, result_type); + IrInstruction *casted_frame = ir_implicit_cast(ira, frame, any_frame_type); + if (type_is_invalid(casted_frame->value.type)) + return ira->codegen->invalid_instruction; + + ZigFn *fn_entry = exec_fn_entry(ira->new_irb.exec); + ir_assert(fn_entry != nullptr, &instruction->base); + + if (fn_entry->inferred_async_node == nullptr) { + fn_entry->inferred_async_node = instruction->base.source_node; + } + + IrInstruction *result = ir_build_await(&ira->new_irb, + instruction->base.scope, instruction->base.source_node, frame); + result->value.type = result_type; + return ir_finish_anal(ira, result); +} + static IrInstruction *ir_analyze_instruction_coro_resume(IrAnalyze *ira, IrInstructionCoroResume *instruction) { IrInstruction *frame_ptr = instruction->frame->child; if (type_is_invalid(frame_ptr->value.type)) @@ -24380,11 +24444,11 @@ static IrInstruction *ir_analyze_instruction_coro_resume(IrAnalyze *ira, IrInstr IrInstruction *frame; if (frame_ptr->value.type->id == ZigTypeIdPointer && frame_ptr->value.type->data.pointer.ptr_len == PtrLenSingle && - frame_ptr->value.type->data.pointer.child_type->id == ZigTypeIdAnyFrame) + frame_ptr->value.type->data.pointer.child_type->id == ZigTypeIdCoroFrame) { - frame = ir_get_deref(ira, &instruction->base, frame_ptr, nullptr); - } else { frame = frame_ptr; + } else { + frame = ir_get_deref(ira, &instruction->base, frame_ptr, nullptr); } ZigType *any_frame_type = get_any_frame_type(ira->codegen, nullptr); @@ -24691,6 +24755,8 @@ static IrInstruction *ir_analyze_instruction_base(IrAnalyze *ira, IrInstruction return ir_analyze_instruction_suspend_br(ira, (IrInstructionSuspendBr *)instruction); case IrInstructionIdCoroResume: return ir_analyze_instruction_coro_resume(ira, (IrInstructionCoroResume *)instruction); + case IrInstructionIdAwait: + return ir_analyze_instruction_await(ira, (IrInstructionAwait *)instruction); } zig_unreachable(); } @@ -24826,6 +24892,7 @@ bool ir_has_side_effects(IrInstruction *instruction) { case IrInstructionIdSuspendBegin: case IrInstructionIdSuspendBr: case IrInstructionIdCoroResume: + case IrInstructionIdAwait: return true; case IrInstructionIdPhi: diff --git a/src/ir_print.cpp b/src/ir_print.cpp index 284ebed2f..46d2906d3 100644 --- a/src/ir_print.cpp +++ b/src/ir_print.cpp @@ -1546,6 +1546,12 @@ static void ir_print_coro_resume(IrPrint *irp, IrInstructionCoroResume *instruct fprintf(irp->f, ")"); } +static void ir_print_await(IrPrint *irp, IrInstructionAwait *instruction) { + fprintf(irp->f, "@await("); + ir_print_other_instruction(irp, instruction->frame); + fprintf(irp->f, ")"); +} + static void ir_print_instruction(IrPrint *irp, IrInstruction *instruction) { ir_print_prefix(irp, instruction); switch (instruction->id) { @@ -2025,6 +2031,9 @@ static void ir_print_instruction(IrPrint *irp, IrInstruction *instruction) { case IrInstructionIdCoroResume: ir_print_coro_resume(irp, (IrInstructionCoroResume *)instruction); break; + case IrInstructionIdAwait: + ir_print_await(irp, (IrInstructionAwait *)instruction); + break; } fprintf(irp->f, "\n"); } diff --git a/test/stage1/behavior/coroutine_await_struct.zig b/test/stage1/behavior/coroutine_await_struct.zig index 66ff8bb49..a649b0a39 100644 --- a/test/stage1/behavior/coroutine_await_struct.zig +++ b/test/stage1/behavior/coroutine_await_struct.zig @@ -6,12 +6,12 @@ const Foo = struct { x: i32, }; -var await_a_promise: promise = undefined; +var await_a_promise: anyframe = undefined; var await_final_result = Foo{ .x = 0 }; test "coroutine await struct" { await_seq('a'); - const p = async await_amain() catch unreachable; + const p = async await_amain(); await_seq('f'); resume await_a_promise; await_seq('i'); @@ -20,7 +20,7 @@ test "coroutine await struct" { } async fn await_amain() void { await_seq('b'); - const p = async await_another() catch unreachable; + const p = async await_another(); await_seq('e'); await_final_result = await p; await_seq('h'); @@ -29,7 +29,7 @@ async fn await_another() Foo { await_seq('c'); suspend { await_seq('d'); - await_a_promise = @handle(); + await_a_promise = @frame(); } await_seq('g'); return Foo{ .x = 1234 }; diff --git a/test/stage1/behavior/coroutines.zig b/test/stage1/behavior/coroutines.zig index 01237ed1c..28dd834bf 100644 --- a/test/stage1/behavior/coroutines.zig +++ b/test/stage1/behavior/coroutines.zig @@ -180,97 +180,85 @@ async fn testSuspendBlock() void { result = true; } -//var await_a_promise: anyframe = undefined; -//var await_final_result: i32 = 0; -// -//test "coroutine await" { -// await_seq('a'); -// const p = async await_amain() catch unreachable; -// await_seq('f'); -// resume await_a_promise; -// await_seq('i'); -// expect(await_final_result == 1234); -// expect(std.mem.eql(u8, await_points, "abcdefghi")); -//} -//async fn await_amain() void { -// await_seq('b'); -// const p = async await_another() catch unreachable; -// await_seq('e'); -// await_final_result = await p; -// await_seq('h'); -//} -//async fn await_another() i32 { -// await_seq('c'); -// suspend { -// await_seq('d'); -// await_a_promise = @frame(); -// } -// await_seq('g'); -// return 1234; -//} -// -//var await_points = [_]u8{0} ** "abcdefghi".len; -//var await_seq_index: usize = 0; -// -//fn await_seq(c: u8) void { -// await_points[await_seq_index] = c; -// await_seq_index += 1; -//} -// -//var early_final_result: i32 = 0; -// -//test "coroutine await early return" { -// early_seq('a'); -// const p = async early_amain() catch @panic("out of memory"); -// early_seq('f'); -// expect(early_final_result == 1234); -// expect(std.mem.eql(u8, early_points, "abcdef")); -//} -//async fn early_amain() void { -// early_seq('b'); -// const p = async early_another() catch @panic("out of memory"); -// early_seq('d'); -// early_final_result = await p; -// early_seq('e'); -//} -//async fn early_another() i32 { -// early_seq('c'); -// return 1234; -//} -// -//var early_points = [_]u8{0} ** "abcdef".len; -//var early_seq_index: usize = 0; -// -//fn early_seq(c: u8) void { -// early_points[early_seq_index] = c; -// early_seq_index += 1; -//} -// -//test "coro allocation failure" { -// var failing_allocator = std.debug.FailingAllocator.init(std.debug.global_allocator, 0); -// if (async<&failing_allocator.allocator> asyncFuncThatNeverGetsRun()) { -// @panic("expected allocation failure"); -// } else |err| switch (err) { -// error.OutOfMemory => {}, -// } -//} -//async fn asyncFuncThatNeverGetsRun() void { -// @panic("coro frame allocation should fail"); -//} -// -//test "async function with dot syntax" { -// const S = struct { -// var y: i32 = 1; -// async fn foo() void { -// y += 1; -// suspend; -// } -// }; -// const p = try async S.foo(); -// cancel p; -// expect(S.y == 2); -//} -// +var await_a_promise: anyframe = undefined; +var await_final_result: i32 = 0; + +test "coroutine await" { + await_seq('a'); + const p = async await_amain(); + await_seq('f'); + resume await_a_promise; + await_seq('i'); + expect(await_final_result == 1234); + expect(std.mem.eql(u8, await_points, "abcdefghi")); +} +async fn await_amain() void { + await_seq('b'); + const p = async await_another(); + await_seq('e'); + await_final_result = await p; + await_seq('h'); +} +async fn await_another() i32 { + await_seq('c'); + suspend { + await_seq('d'); + await_a_promise = @frame(); + } + await_seq('g'); + return 1234; +} + +var await_points = [_]u8{0} ** "abcdefghi".len; +var await_seq_index: usize = 0; + +fn await_seq(c: u8) void { + await_points[await_seq_index] = c; + await_seq_index += 1; +} + +var early_final_result: i32 = 0; + +test "coroutine await early return" { + early_seq('a'); + const p = async early_amain(); + early_seq('f'); + expect(early_final_result == 1234); + expect(std.mem.eql(u8, early_points, "abcdef")); +} +async fn early_amain() void { + early_seq('b'); + const p = async early_another(); + early_seq('d'); + early_final_result = await p; + early_seq('e'); +} +async fn early_another() i32 { + early_seq('c'); + return 1234; +} + +var early_points = [_]u8{0} ** "abcdef".len; +var early_seq_index: usize = 0; + +fn early_seq(c: u8) void { + early_points[early_seq_index] = c; + early_seq_index += 1; +} + +test "async function with dot syntax" { + const S = struct { + var y: i32 = 1; + async fn foo() void { + y += 1; + suspend; + } + }; + const p = async S.foo(); + // can't cancel in tests because they are non-async functions + expect(S.y == 2); +} + //test "async fn pointer in a struct field" { // var data: i32 = 1; // const Foo = struct { @@ -287,18 +275,17 @@ async fn testSuspendBlock() void { // y.* += 1; // suspend; //} -// + //test "async fn with inferred error set" { -// const p = (async failing()) catch unreachable; +// const p = async failing(); // resume p; -// cancel p; //} // //async fn failing() !void { // suspend; // return error.Fail; //} -// + //test "error return trace across suspend points - early return" { // const p = nonFailing(); // resume p; @@ -331,20 +318,18 @@ async fn testSuspendBlock() void { // } // }; //} -// -//test "break from suspend" { -// var buf: [500]u8 = undefined; -// var a = &std.heap.FixedBufferAllocator.init(buf[0..]).allocator; -// var my_result: i32 = 1; -// const p = try async testBreakFromSuspend(&my_result); -// cancel p; -// std.testing.expect(my_result == 2); -//} -//async fn testBreakFromSuspend(my_result: *i32) void { -// suspend { -// resume @frame(); -// } -// my_result.* += 1; -// suspend; -// my_result.* += 1; -//} + +test "break from suspend" { + var my_result: i32 = 1; + const p = async testBreakFromSuspend(&my_result); + // can't cancel here + std.testing.expect(my_result == 2); +} +async fn testBreakFromSuspend(my_result: *i32) void { + suspend { + resume @frame(); + } + my_result.* += 1; + suspend; + my_result.* += 1; +}