support recursive async and non-async functions
which heap allocate their own frames related: #1006
This commit is contained in:
parent
2148943fff
commit
6ab8b2aab4
|
@ -627,7 +627,7 @@ struct AstNodeParamDecl {
|
|||
AstNode *type;
|
||||
Token *var_token;
|
||||
bool is_noalias;
|
||||
bool is_inline;
|
||||
bool is_comptime;
|
||||
bool is_var_args;
|
||||
};
|
||||
|
||||
|
|
|
@ -1556,7 +1556,7 @@ static ZigType *analyze_fn_type(CodeGen *g, AstNode *proto_node, Scope *child_sc
|
|||
AstNode *param_node = fn_proto->params.at(fn_type_id.next_param_index);
|
||||
assert(param_node->type == NodeTypeParamDecl);
|
||||
|
||||
bool param_is_comptime = param_node->data.param_decl.is_inline;
|
||||
bool param_is_comptime = param_node->data.param_decl.is_comptime;
|
||||
bool param_is_var_args = param_node->data.param_decl.is_var_args;
|
||||
|
||||
if (param_is_comptime) {
|
||||
|
@ -8234,6 +8234,10 @@ static void resolve_llvm_types_anyerror(CodeGen *g) {
|
|||
}
|
||||
|
||||
static void resolve_llvm_types_async_frame(CodeGen *g, ZigType *frame_type, ResolveStatus wanted_resolve_status) {
|
||||
Error err;
|
||||
if ((err = type_resolve(g, frame_type, ResolveStatusSizeKnown)))
|
||||
zig_unreachable();
|
||||
|
||||
ZigType *passed_frame_type = fn_is_async(frame_type->data.frame.fn) ? frame_type : nullptr;
|
||||
resolve_llvm_types_struct(g, frame_type->data.frame.locals_struct, wanted_resolve_status, passed_frame_type);
|
||||
frame_type->llvm_type = frame_type->data.frame.locals_struct->llvm_type;
|
||||
|
@ -8375,7 +8379,6 @@ static void resolve_llvm_types_any_frame(CodeGen *g, ZigType *any_frame_type, Re
|
|||
}
|
||||
|
||||
static void resolve_llvm_types(CodeGen *g, ZigType *type, ResolveStatus wanted_resolve_status) {
|
||||
assert(type->id == ZigTypeIdOpaque || type_is_resolved(type, ResolveStatusSizeKnown));
|
||||
assert(wanted_resolve_status > ResolveStatusSizeKnown);
|
||||
switch (type->id) {
|
||||
case ZigTypeIdInvalid:
|
||||
|
|
|
@ -448,7 +448,7 @@ static void render_node_extra(AstRender *ar, AstNode *node, bool grouped) {
|
|||
assert(param_decl->type == NodeTypeParamDecl);
|
||||
if (param_decl->data.param_decl.name != nullptr) {
|
||||
const char *noalias_str = param_decl->data.param_decl.is_noalias ? "noalias " : "";
|
||||
const char *inline_str = param_decl->data.param_decl.is_inline ? "inline " : "";
|
||||
const char *inline_str = param_decl->data.param_decl.is_comptime ? "comptime " : "";
|
||||
fprintf(ar->f, "%s%s", noalias_str, inline_str);
|
||||
print_symbol(ar, param_decl->data.param_decl.name);
|
||||
fprintf(ar->f, ": ");
|
||||
|
|
|
@ -6340,9 +6340,12 @@ static LLVMValueRef gen_const_val(CodeGen *g, ConstExprValue *const_val, const c
|
|||
ZigType *type_entry = const_val->type;
|
||||
assert(type_has_bits(type_entry));
|
||||
|
||||
switch (const_val->special) {
|
||||
check: switch (const_val->special) {
|
||||
case ConstValSpecialLazy:
|
||||
zig_unreachable();
|
||||
if ((err = ir_resolve_lazy(g, nullptr, const_val))) {
|
||||
report_errors_and_exit(g);
|
||||
}
|
||||
goto check;
|
||||
case ConstValSpecialRuntime:
|
||||
zig_unreachable();
|
||||
case ConstValSpecialUndef:
|
||||
|
|
83
src/ir.cpp
83
src/ir.cpp
|
@ -9012,7 +9012,42 @@ static bool ir_num_lit_fits_in_other_type(IrAnalyze *ira, IrInstruction *instruc
|
|||
return false;
|
||||
}
|
||||
|
||||
ConstExprValue *const_val = ir_resolve_const(ira, instruction, UndefBad);
|
||||
ConstExprValue *const_val = ir_resolve_const(ira, instruction, LazyOkNoUndef);
|
||||
if (const_val == nullptr)
|
||||
return false;
|
||||
|
||||
if (const_val->special == ConstValSpecialLazy) {
|
||||
switch (const_val->data.x_lazy->id) {
|
||||
case LazyValueIdAlignOf: {
|
||||
// This is guaranteed to fit into a u29
|
||||
if (other_type->id == ZigTypeIdComptimeInt)
|
||||
return true;
|
||||
size_t align_bits = get_align_amt_type(ira->codegen)->data.integral.bit_count;
|
||||
if (other_type->id == ZigTypeIdInt && !other_type->data.integral.is_signed &&
|
||||
other_type->data.integral.bit_count >= align_bits)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case LazyValueIdSizeOf: {
|
||||
// This is guaranteed to fit into a usize
|
||||
if (other_type->id == ZigTypeIdComptimeInt)
|
||||
return true;
|
||||
size_t usize_bits = ira->codegen->builtin_types.entry_usize->data.integral.bit_count;
|
||||
if (other_type->id == ZigTypeIdInt && !other_type->data.integral.is_signed &&
|
||||
other_type->data.integral.bit_count >= usize_bits)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
const_val = ir_resolve_const(ira, instruction, UndefBad);
|
||||
if (const_val == nullptr)
|
||||
return false;
|
||||
|
||||
|
@ -10262,7 +10297,7 @@ static void copy_const_val(ConstExprValue *dest, ConstExprValue *src, bool same_
|
|||
memcpy(dest, src, sizeof(ConstExprValue));
|
||||
if (!same_global_refs) {
|
||||
dest->global_refs = global_refs;
|
||||
if (src->special == ConstValSpecialUndef)
|
||||
if (src->special != ConstValSpecialStatic)
|
||||
return;
|
||||
if (dest->type->id == ZigTypeIdStruct) {
|
||||
dest->data.x_struct.fields = create_const_vals(dest->type->data.structure.src_field_count);
|
||||
|
@ -11213,7 +11248,7 @@ static IrInstruction *ir_get_ref(IrAnalyze *ira, IrInstruction *source_instructi
|
|||
return ira->codegen->invalid_instruction;
|
||||
|
||||
if (instr_is_comptime(value)) {
|
||||
ConstExprValue *val = ir_resolve_const(ira, value, UndefOk);
|
||||
ConstExprValue *val = ir_resolve_const(ira, value, LazyOk);
|
||||
if (!val)
|
||||
return ira->codegen->invalid_instruction;
|
||||
return ir_get_const_ptr(ira, source_instruction, val, value->value.type,
|
||||
|
@ -12125,7 +12160,8 @@ static IrInstruction *ir_analyze_cast(IrAnalyze *ira, IrInstruction *source_inst
|
|||
if (wanted_type->id == ZigTypeIdComptimeInt || wanted_type->id == ZigTypeIdInt) {
|
||||
IrInstruction *result = ir_const(ira, source_instr, wanted_type);
|
||||
if (actual_type->id == ZigTypeIdComptimeInt || actual_type->id == ZigTypeIdInt) {
|
||||
bigint_init_bigint(&result->value.data.x_bigint, &value->value.data.x_bigint);
|
||||
copy_const_val(&result->value, &value->value, false);
|
||||
result->value.type = wanted_type;
|
||||
} else {
|
||||
float_init_bigint(&result->value.data.x_bigint, &value->value);
|
||||
}
|
||||
|
@ -15301,7 +15337,7 @@ static bool ir_analyze_fn_call_generic_arg(IrAnalyze *ira, AstNode *fn_proto_nod
|
|||
}
|
||||
}
|
||||
|
||||
bool comptime_arg = param_decl_node->data.param_decl.is_inline ||
|
||||
bool comptime_arg = param_decl_node->data.param_decl.is_comptime ||
|
||||
casted_arg->value.type->id == ZigTypeIdComptimeInt || casted_arg->value.type->id == ZigTypeIdComptimeFloat;
|
||||
|
||||
ConstExprValue *arg_val;
|
||||
|
@ -17594,6 +17630,11 @@ static IrInstruction *ir_analyze_instruction_field_ptr(IrAnalyze *ira, IrInstruc
|
|||
ConstExprValue *child_val = const_ptr_pointee(ira, ira->codegen, container_ptr_val, source_node);
|
||||
if (child_val == nullptr)
|
||||
return ira->codegen->invalid_instruction;
|
||||
if ((err = ir_resolve_const_val(ira->codegen, ira->new_irb.exec,
|
||||
field_ptr_instruction->base.source_node, child_val, UndefBad)))
|
||||
{
|
||||
return ira->codegen->invalid_instruction;
|
||||
}
|
||||
ZigType *child_type = child_val->data.x_type;
|
||||
|
||||
if (type_is_invalid(child_type)) {
|
||||
|
@ -21293,8 +21334,10 @@ static IrInstruction *ir_analyze_instruction_from_bytes(IrAnalyze *ira, IrInstru
|
|||
src_ptr_align = get_abi_alignment(ira->codegen, target->value.type);
|
||||
}
|
||||
|
||||
if ((err = type_resolve(ira->codegen, dest_child_type, ResolveStatusSizeKnown)))
|
||||
return ira->codegen->invalid_instruction;
|
||||
if (src_ptr_align != 0) {
|
||||
if ((err = type_resolve(ira->codegen, dest_child_type, ResolveStatusAlignmentKnown)))
|
||||
return ira->codegen->invalid_instruction;
|
||||
}
|
||||
|
||||
ZigType *dest_ptr_type = get_pointer_to_type_extra(ira->codegen, dest_child_type,
|
||||
src_ptr_const, src_ptr_volatile, PtrLenUnknown,
|
||||
|
@ -21337,6 +21380,8 @@ static IrInstruction *ir_analyze_instruction_from_bytes(IrAnalyze *ira, IrInstru
|
|||
}
|
||||
|
||||
if (have_known_len) {
|
||||
if ((err = type_resolve(ira->codegen, dest_child_type, ResolveStatusSizeKnown)))
|
||||
return ira->codegen->invalid_instruction;
|
||||
uint64_t child_type_size = type_size(ira->codegen, dest_child_type);
|
||||
uint64_t remainder = known_len % child_type_size;
|
||||
if (remainder != 0) {
|
||||
|
@ -23963,15 +24008,23 @@ static IrInstruction *ir_analyze_instruction_ptr_type(IrAnalyze *ira, IrInstruct
|
|||
}
|
||||
|
||||
static IrInstruction *ir_analyze_instruction_align_cast(IrAnalyze *ira, IrInstructionAlignCast *instruction) {
|
||||
uint32_t align_bytes;
|
||||
IrInstruction *align_bytes_inst = instruction->align_bytes->child;
|
||||
if (!ir_resolve_align(ira, align_bytes_inst, nullptr, &align_bytes))
|
||||
return ira->codegen->invalid_instruction;
|
||||
|
||||
IrInstruction *target = instruction->target->child;
|
||||
if (type_is_invalid(target->value.type))
|
||||
return ira->codegen->invalid_instruction;
|
||||
|
||||
ZigType *elem_type = nullptr;
|
||||
if (is_slice(target->value.type)) {
|
||||
ZigType *slice_ptr_type = target->value.type->data.structure.fields[slice_ptr_index].type_entry;
|
||||
elem_type = slice_ptr_type->data.pointer.child_type;
|
||||
} else if (target->value.type->id == ZigTypeIdPointer) {
|
||||
elem_type = target->value.type->data.pointer.child_type;
|
||||
}
|
||||
|
||||
uint32_t align_bytes;
|
||||
IrInstruction *align_bytes_inst = instruction->align_bytes->child;
|
||||
if (!ir_resolve_align(ira, align_bytes_inst, elem_type, &align_bytes))
|
||||
return ira->codegen->invalid_instruction;
|
||||
|
||||
IrInstruction *result = ir_align_cast(ira, target, align_bytes, true);
|
||||
if (type_is_invalid(result->value.type))
|
||||
return ira->codegen->invalid_instruction;
|
||||
|
@ -25644,7 +25697,7 @@ static Error ir_resolve_lazy_raw(AstNode *source_node, ConstExprValue *val) {
|
|||
}
|
||||
|
||||
val->special = ConstValSpecialStatic;
|
||||
assert(val->type->id == ZigTypeIdComptimeInt);
|
||||
assert(val->type->id == ZigTypeIdComptimeInt || val->type->id == ZigTypeIdInt);
|
||||
bigint_init_unsigned(&val->data.x_bigint, align_in_bytes);
|
||||
return ErrorNone;
|
||||
}
|
||||
|
@ -25699,7 +25752,7 @@ static Error ir_resolve_lazy_raw(AstNode *source_node, ConstExprValue *val) {
|
|||
}
|
||||
|
||||
val->special = ConstValSpecialStatic;
|
||||
assert(val->type->id == ZigTypeIdComptimeInt);
|
||||
assert(val->type->id == ZigTypeIdComptimeInt || val->type->id == ZigTypeIdInt);
|
||||
bigint_init_unsigned(&val->data.x_bigint, abi_size);
|
||||
return ErrorNone;
|
||||
}
|
||||
|
@ -25885,7 +25938,7 @@ static Error ir_resolve_lazy_raw(AstNode *source_node, ConstExprValue *val) {
|
|||
Error ir_resolve_lazy(CodeGen *codegen, AstNode *source_node, ConstExprValue *val) {
|
||||
Error err;
|
||||
if ((err = ir_resolve_lazy_raw(source_node, val))) {
|
||||
if (codegen->trace_err != nullptr && !source_node->already_traced_this_node) {
|
||||
if (codegen->trace_err != nullptr && source_node != nullptr && !source_node->already_traced_this_node) {
|
||||
source_node->already_traced_this_node = true;
|
||||
codegen->trace_err = add_error_note(codegen, codegen->trace_err, source_node,
|
||||
buf_create_from_str("referenced here"));
|
||||
|
|
|
@ -2075,7 +2075,7 @@ static AstNode *ast_parse_param_decl(ParseContext *pc) {
|
|||
res->column = first->start_column;
|
||||
res->data.param_decl.name = token_buf(name);
|
||||
res->data.param_decl.is_noalias = first->id == TokenIdKeywordNoAlias;
|
||||
res->data.param_decl.is_inline = first->id == TokenIdKeywordCompTime;
|
||||
res->data.param_decl.is_comptime = first->id == TokenIdKeywordCompTime;
|
||||
return res;
|
||||
}
|
||||
|
||||
|
|
10
std/mem.zig
10
std/mem.zig
|
@ -117,7 +117,15 @@ pub const Allocator = struct {
|
|||
const byte_slice = try self.reallocFn(self, ([*]u8)(undefined)[0..0], undefined, byte_count, a);
|
||||
assert(byte_slice.len == byte_count);
|
||||
@memset(byte_slice.ptr, undefined, byte_slice.len);
|
||||
return @bytesToSlice(T, @alignCast(a, byte_slice));
|
||||
if (alignment == null) {
|
||||
// TODO This is a workaround for zig not being able to successfully do
|
||||
// @bytesToSlice(T, @alignCast(a, byte_slice)) without resolving alignment of T,
|
||||
// which causes a circular dependency in async functions which try to heap-allocate
|
||||
// their own frame with @Frame(func).
|
||||
return @intToPtr([*]T, @ptrToInt(byte_slice.ptr))[0..n];
|
||||
} else {
|
||||
return @bytesToSlice(T, @alignCast(a, byte_slice));
|
||||
}
|
||||
}
|
||||
|
||||
/// This function requests a new byte size for an existing allocation,
|
||||
|
|
|
@ -1051,6 +1051,7 @@ pub fn addCases(cases: *tests.CompileErrorContext) void {
|
|||
\\const Foo = struct {};
|
||||
\\export fn a() void {
|
||||
\\ const T = [*c]Foo;
|
||||
\\ var t: T = undefined;
|
||||
\\}
|
||||
,
|
||||
"tmp.zig:3:19: error: C pointers cannot point to non-C-ABI-compatible type 'Foo'",
|
||||
|
@ -2290,6 +2291,7 @@ pub fn addCases(cases: *tests.CompileErrorContext) void {
|
|||
"error union operator with non error set LHS",
|
||||
\\comptime {
|
||||
\\ const z = i32!i32;
|
||||
\\ var x: z = undefined;
|
||||
\\}
|
||||
,
|
||||
"tmp.zig:2:15: error: expected error set type, found type 'i32'",
|
||||
|
|
|
@ -854,3 +854,68 @@ test "await does not force async if callee is blocking" {
|
|||
var x = async S.simple();
|
||||
expect(await x == 1234);
|
||||
}
|
||||
|
||||
test "recursive async function" {
|
||||
expect(recursiveAsyncFunctionTest(false).doTheTest() == 55);
|
||||
expect(recursiveAsyncFunctionTest(true).doTheTest() == 55);
|
||||
}
|
||||
|
||||
fn recursiveAsyncFunctionTest(comptime suspending_implementation: bool) type {
|
||||
return struct {
|
||||
fn fib(allocator: *std.mem.Allocator, x: u32) error{OutOfMemory}!u32 {
|
||||
if (x <= 1) return x;
|
||||
|
||||
if (suspending_implementation) {
|
||||
suspend {
|
||||
resume @frame();
|
||||
}
|
||||
}
|
||||
|
||||
const f1 = try allocator.create(@Frame(fib));
|
||||
defer allocator.destroy(f1);
|
||||
|
||||
const f2 = try allocator.create(@Frame(fib));
|
||||
defer allocator.destroy(f2);
|
||||
|
||||
f1.* = async fib(allocator, x - 1);
|
||||
var f1_awaited = false;
|
||||
errdefer if (!f1_awaited) {
|
||||
_ = await f1;
|
||||
};
|
||||
|
||||
f2.* = async fib(allocator, x - 2);
|
||||
var f2_awaited = false;
|
||||
errdefer if (!f2_awaited) {
|
||||
_ = await f2;
|
||||
};
|
||||
|
||||
var sum: u32 = 0;
|
||||
|
||||
f1_awaited = true;
|
||||
const result_f1 = await f1; // TODO https://github.com/ziglang/zig/issues/3077
|
||||
sum += try result_f1;
|
||||
|
||||
f2_awaited = true;
|
||||
const result_f2 = await f2; // TODO https://github.com/ziglang/zig/issues/3077
|
||||
sum += try result_f2;
|
||||
|
||||
return sum;
|
||||
}
|
||||
|
||||
fn doTheTest() u32 {
|
||||
if (suspending_implementation) {
|
||||
var result: u32 = undefined;
|
||||
_ = async amain(&result);
|
||||
return result;
|
||||
} else {
|
||||
return fib(std.heap.direct_allocator, 10) catch unreachable;
|
||||
}
|
||||
}
|
||||
|
||||
fn amain(result: *u32) void {
|
||||
var x = async fib(std.heap.direct_allocator, 10);
|
||||
const res = await x; // TODO https://github.com/ziglang/zig/issues/3077
|
||||
result.* = res catch unreachable;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user