improve support for anyframe and anyframe->T

* add implicit cast from `*@Frame(func)` to `anyframe->T` or `anyframe`.
 * add implicit cast from `anyframe->T` to `anyframe`.
 * `resume` works on `anyframe->T` and `anyframe` types.
This commit is contained in:
Andrew Kelley 2019-07-29 19:32:49 -04:00
parent ee64a22045
commit dbdc4d62d0
No known key found for this signature in database
GPG Key ID: 7C5F548F728501A9
5 changed files with 97 additions and 29 deletions

View File

@ -1726,6 +1726,7 @@ struct CodeGen {
LLVMValueRef err_name_table;
LLVMValueRef safety_crash_err_fn;
LLVMValueRef return_err_fn;
LLVMTypeRef async_fn_llvm_type;
// reminder: hash tables must be initialized before use
HashMap<Buf *, ZigType *, buf_hash, buf_eql_buf> import_table;
@ -1793,7 +1794,6 @@ struct CodeGen {
ZigType *entry_global_error_set;
ZigType *entry_arg_tuple;
ZigType *entry_enum_literal;
ZigType *entry_frame_header;
ZigType *entry_any_frame;
} builtin_types;
ZigType *align_amt_type;

View File

@ -7348,19 +7348,13 @@ static void resolve_llvm_types_fn_type(CodeGen *g, ZigType *fn_type) {
if (is_async) {
fn_type->data.fn.gen_param_info = allocate<FnGenParamInfo>(1);
ZigType *frame_type = g->builtin_types.entry_frame_header;
Error err;
if ((err = type_resolve(g, frame_type, ResolveStatusSizeKnown))) {
zig_unreachable();
}
ZigType *ptr_type = get_pointer_to_type(g, frame_type, false);
gen_param_types.append(get_llvm_type(g, ptr_type));
param_di_types.append(get_llvm_di_type(g, ptr_type));
ZigType *frame_type = get_any_frame_type(g, fn_type_id->return_type);
gen_param_types.append(get_llvm_type(g, frame_type));
param_di_types.append(get_llvm_di_type(g, frame_type));
fn_type->data.fn.gen_param_info[0].src_index = 0;
fn_type->data.fn.gen_param_info[0].gen_index = 0;
fn_type->data.fn.gen_param_info[0].type = ptr_type;
fn_type->data.fn.gen_param_info[0].type = frame_type;
} else {
fn_type->data.fn.gen_param_info = allocate<FnGenParamInfo>(fn_type_id->param_count);
for (size_t i = 0; i < fn_type_id->param_count; i += 1) {

View File

@ -4902,14 +4902,28 @@ static LLVMValueRef ir_render_suspend_br(CodeGen *g, IrExecutable *executable,
return nullptr;
}
static LLVMTypeRef async_fn_llvm_type(CodeGen *g) {
if (g->async_fn_llvm_type != nullptr)
return g->async_fn_llvm_type;
ZigType *anyframe_type = get_any_frame_type(g, nullptr);
LLVMTypeRef param_type = get_llvm_type(g, anyframe_type);
LLVMTypeRef return_type = LLVMVoidType();
LLVMTypeRef fn_type = LLVMFunctionType(return_type, &param_type, 1, false);
g->async_fn_llvm_type = LLVMPointerType(fn_type, 0);
return g->async_fn_llvm_type;
}
static LLVMValueRef ir_render_coro_resume(CodeGen *g, IrExecutable *executable,
IrInstructionCoroResume *instruction)
{
LLVMValueRef frame = ir_llvm_value(g, instruction->frame);
ZigType *frame_type = instruction->frame->value.type;
assert(frame_type->id == ZigTypeIdCoroFrame);
ZigFn *fn = frame_type->data.frame.fn;
LLVMValueRef fn_val = fn_llvm_value(g, fn);
assert(frame_type->id == ZigTypeIdAnyFrame);
LLVMValueRef fn_ptr_ptr = LLVMBuildStructGEP(g->builder, frame, coro_fn_ptr_index, "");
LLVMValueRef uncasted_fn_val = LLVMBuildLoad(g->builder, fn_ptr_ptr, "");
LLVMValueRef fn_val = LLVMBuildIntToPtr(g->builder, uncasted_fn_val, async_fn_llvm_type(g), "");
ZigLLVMBuildCall(g->builder, fn_val, &frame, 1, LLVMFastCallConv, ZigLLVM_FnInlineAuto, "");
return nullptr;
}
@ -6746,11 +6760,6 @@ static void define_builtin_types(CodeGen *g) {
g->primitive_type_table.put(&entry->name, entry);
}
{
const char *field_names[] = {"resume_index"};
ZigType *field_types[] = {g->builtin_types.entry_usize};
g->builtin_types.entry_frame_header = get_struct_type(g, "(frame header)", field_names, field_types, 1);
}
}
static BuiltinFnEntry *create_builtin_fn(CodeGen *g, BuiltinFnId id, const char *name, size_t count) {

View File

@ -7764,7 +7764,7 @@ static IrInstruction *ir_gen_cancel(IrBuilder *irb, Scope *scope, AstNode *node)
static IrInstruction *ir_gen_resume(IrBuilder *irb, Scope *scope, AstNode *node) {
assert(node->type == NodeTypeResume);
IrInstruction *target_inst = ir_gen_node(irb, node->data.resume_expr.expr, scope);
IrInstruction *target_inst = ir_gen_node_extra(irb, node->data.resume_expr.expr, scope, LValPtr, nullptr);
if (target_inst == irb->codegen->invalid_instruction)
return irb->codegen->invalid_instruction;
@ -10882,6 +10882,33 @@ static IrInstruction *ir_analyze_err_set_cast(IrAnalyze *ira, IrInstruction *sou
return result;
}
static IrInstruction *ir_analyze_frame_ptr_to_anyframe(IrAnalyze *ira, IrInstruction *source_instr,
IrInstruction *value, ZigType *wanted_type)
{
if (instr_is_comptime(value)) {
zig_panic("TODO comptime frame pointer");
}
IrInstruction *result = ir_build_cast(&ira->new_irb, source_instr->scope, source_instr->source_node,
wanted_type, value, CastOpBitCast);
result->value.type = wanted_type;
return result;
}
static IrInstruction *ir_analyze_anyframe_to_anyframe(IrAnalyze *ira, IrInstruction *source_instr,
IrInstruction *value, ZigType *wanted_type)
{
if (instr_is_comptime(value)) {
zig_panic("TODO comptime anyframe->T to anyframe");
}
IrInstruction *result = ir_build_cast(&ira->new_irb, source_instr->scope, source_instr->source_node,
wanted_type, value, CastOpBitCast);
result->value.type = wanted_type;
return result;
}
static IrInstruction *ir_analyze_err_wrap_code(IrAnalyze *ira, IrInstruction *source_instr, IrInstruction *value,
ZigType *wanted_type, ResultLoc *result_loc)
{
@ -11978,6 +12005,29 @@ static IrInstruction *ir_analyze_cast(IrAnalyze *ira, IrInstruction *source_inst
}
}
// *@Frame(func) to anyframe->T or anyframe
if (actual_type->id == ZigTypeIdPointer && actual_type->data.pointer.ptr_len == PtrLenSingle &&
actual_type->data.pointer.child_type->id == ZigTypeIdCoroFrame && wanted_type->id == ZigTypeIdAnyFrame)
{
bool ok = true;
if (wanted_type->data.any_frame.result_type != nullptr) {
ZigFn *fn = actual_type->data.pointer.child_type->data.frame.fn;
ZigType *fn_return_type = fn->type_entry->data.fn.fn_type_id.return_type;
if (wanted_type->data.any_frame.result_type != fn_return_type) {
ok = false;
}
}
if (ok) {
return ir_analyze_frame_ptr_to_anyframe(ira, source_instr, value, wanted_type);
}
}
// anyframe->T to anyframe
if (actual_type->id == ZigTypeIdAnyFrame && actual_type->data.any_frame.result_type != nullptr &&
wanted_type->id == ZigTypeIdAnyFrame && wanted_type->data.any_frame.result_type == nullptr)
{
return ir_analyze_anyframe_to_anyframe(ira, source_instr, value, wanted_type);
}
// cast from null literal to maybe type
if (wanted_type->id == ZigTypeIdOptional &&
@ -24323,17 +24373,27 @@ static IrInstruction *ir_analyze_instruction_suspend_br(IrAnalyze *ira, IrInstru
}
static IrInstruction *ir_analyze_instruction_coro_resume(IrAnalyze *ira, IrInstructionCoroResume *instruction) {
IrInstruction *frame = instruction->frame->child;
if (type_is_invalid(frame->value.type))
IrInstruction *frame_ptr = instruction->frame->child;
if (type_is_invalid(frame_ptr->value.type))
return ira->codegen->invalid_instruction;
if (frame->value.type->id != ZigTypeIdCoroFrame) {
ir_add_error(ira, instruction->frame,
buf_sprintf("expected frame, found '%s'", buf_ptr(&frame->value.type->name)));
return ira->codegen->invalid_instruction;
IrInstruction *frame;
if (frame_ptr->value.type->id == ZigTypeIdPointer &&
frame_ptr->value.type->data.pointer.ptr_len == PtrLenSingle &&
frame_ptr->value.type->data.pointer.is_const &&
frame_ptr->value.type->data.pointer.child_type->id == ZigTypeIdAnyFrame)
{
frame = ir_get_deref(ira, &instruction->base, frame_ptr, nullptr);
} else {
frame = frame_ptr;
}
return ir_build_coro_resume(&ira->new_irb, instruction->base.scope, instruction->base.source_node, frame);
ZigType *any_frame_type = get_any_frame_type(ira->codegen, nullptr);
IrInstruction *casted_frame = ir_implicit_cast(ira, frame, any_frame_type);
if (type_is_invalid(casted_frame->value.type))
return ira->codegen->invalid_instruction;
return ir_build_coro_resume(&ira->new_irb, instruction->base.scope, instruction->base.source_node, casted_frame);
}
static IrInstruction *ir_analyze_instruction_base(IrAnalyze *ira, IrInstruction *instruction) {

View File

@ -5,15 +5,20 @@ const expect = std.testing.expect;
var global_x: i32 = 1;
test "simple coroutine suspend and resume" {
const p = async simpleAsyncFn();
const frame = async simpleAsyncFn();
expect(global_x == 2);
resume p;
resume frame;
expect(global_x == 3);
const af: anyframe->void = &frame;
resume frame;
expect(global_x == 4);
}
fn simpleAsyncFn() void {
global_x += 1;
suspend;
global_x += 1;
suspend;
global_x += 1;
}
var global_y: i32 = 1;