From bd77bc749a1b38ef754fdd5b7e14a8228cb6f72c Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Sat, 12 Dec 2015 22:55:29 -0700 Subject: [PATCH] structs are working --- example/structs/structs.zig | 2 - src/analyze.cpp | 82 +++++++++++++------------- src/analyze.hpp | 4 ++ src/codegen.cpp | 111 +++++++++++++++++++----------------- test/run_tests.cpp | 21 +++++++ 5 files changed, 127 insertions(+), 93 deletions(-) diff --git a/example/structs/structs.zig b/example/structs/structs.zig index b274184b8..e3b364b21 100644 --- a/example/structs/structs.zig +++ b/example/structs/structs.zig @@ -2,8 +2,6 @@ export executable "structs"; use "std.zig"; -// Note: this example is not working because codegen is confused about -// how byvalue structs which are in memory on the stack work export fn main(argc : isize, argv : *mut *mut u8, env : *mut *mut u8) -> i32 { let mut foo : Foo; diff --git a/src/analyze.cpp b/src/analyze.cpp index 7b21a4360..3d46dfcdf 100644 --- a/src/analyze.cpp +++ b/src/analyze.cpp @@ -758,6 +758,47 @@ static bool is_op_allowed(TypeTableEntry *type, BinOpType op) { zig_unreachable(); } +static TypeTableEntry *analyze_cast_expr(CodeGen *g, ImportTableEntry *import, BlockContext *context, + TypeTableEntry *expected_type, AstNode *node) +{ + TypeTableEntry *wanted_type = resolve_type(g, node->data.cast_expr.type); + TypeTableEntry *actual_type = analyze_expression(g, import, context, nullptr, node->data.cast_expr.expr); + + if (wanted_type->id == TypeTableEntryIdInvalid || + actual_type->id == TypeTableEntryIdInvalid) + { + return g->builtin_types.entry_invalid; + } + + CastNode *cast_node = &node->codegen_node->data.cast_node; + + // special casing this for now, TODO think about casting and do a general solution + if (wanted_type == g->builtin_types.entry_isize && + actual_type->id == TypeTableEntryIdPointer) + { + cast_node->op = CastOpPtrToInt; + return wanted_type; + } else if (wanted_type->id == TypeTableEntryIdInt && + actual_type->id == TypeTableEntryIdInt) + { + cast_node->op = CastOpIntWidenOrShorten; + return wanted_type; + } else if (wanted_type == g->builtin_types.entry_string && + actual_type->id == TypeTableEntryIdArray && + actual_type->data.array.child_type == g->builtin_types.entry_u8) + { + cast_node->op = CastOpArrayToString; + context->cast_expr_alloca_list.append(node); + return wanted_type; + } else { + add_node_error(g, node, + buf_sprintf("invalid cast from type '%s' to '%s'", + buf_ptr(&actual_type->name), + buf_ptr(&wanted_type->name))); + return g->builtin_types.entry_invalid; + } +} + static TypeTableEntry * analyze_expression(CodeGen *g, ImportTableEntry *import, BlockContext *context, TypeTableEntry *expected_type, AstNode *node) { @@ -1100,45 +1141,8 @@ static TypeTableEntry * analyze_expression(CodeGen *g, ImportTableEntry *import, break; } case NodeTypeCastExpr: - { - TypeTableEntry *wanted_type = resolve_type(g, node->data.cast_expr.type); - TypeTableEntry *actual_type = analyze_expression(g, import, context, nullptr, - node->data.cast_expr.expr); - - if (wanted_type->id == TypeTableEntryIdInvalid || - actual_type->id == TypeTableEntryIdInvalid) - { - return_type = g->builtin_types.entry_invalid; - break; - } - - CastNode *cast_node = &node->codegen_node->data.cast_node; - - // special casing this for now, TODO think about casting and do a general solution - if (wanted_type == g->builtin_types.entry_isize && - actual_type->id == TypeTableEntryIdPointer) - { - cast_node->op = CastOpPtrToInt; - return_type = wanted_type; - } else if (wanted_type->id == TypeTableEntryIdInt && - actual_type->id == TypeTableEntryIdInt) - { - cast_node->op = CastOpIntWidenOrShorten; - return_type = wanted_type; - } else if (wanted_type == g->builtin_types.entry_string && - actual_type->id == TypeTableEntryIdArray && - actual_type->data.array.child_type == g->builtin_types.entry_u8) - { - cast_node->op = CastOpArrayToString; - return_type = wanted_type; - } else { - add_node_error(g, node, - buf_sprintf("TODO handle cast from '%s' to '%s'", - buf_ptr(&actual_type->name), buf_ptr(&wanted_type->name))); - return_type = g->builtin_types.entry_invalid; - } - break; - } + return_type = analyze_cast_expr(g, import, context, expected_type, node); + break; case NodeTypePrefixOpExpr: switch (node->data.prefix_op_expr.prefix_op) { case PrefixOpBoolNot: diff --git a/src/analyze.hpp b/src/analyze.hpp index f7ccbf18d..a7b6ee79c 100644 --- a/src/analyze.hpp +++ b/src/analyze.hpp @@ -193,6 +193,7 @@ struct BlockContext { BlockContext *root; // always points to the BlockContext with the NodeTypeFnDef BlockContext *parent; // nullptr when this is the root HashMap variable_table; + ZigList cast_expr_alloca_list; LLVMZigDIScope *di_scope; }; @@ -244,6 +245,9 @@ enum CastOp { struct CastNode { CastOp op; + // if op is CastOpArrayToString, this will be a pointer to + // the string struct on the stack + LLVMValueRef ptr; }; struct CodeGenNode { diff --git a/src/codegen.cpp b/src/codegen.cpp index 8314e98b9..c38b1bd76 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -59,22 +59,7 @@ void codegen_set_out_name(CodeGen *g, Buf *out_name) { } static LLVMValueRef gen_expr(CodeGen *g, AstNode *expr_node); - -static LLVMTypeRef to_llvm_type(AstNode *type_node) { - assert(type_node->type == NodeTypeType); - assert(type_node->codegen_node); - assert(type_node->codegen_node->data.type_node.entry); - - return type_node->codegen_node->data.type_node.entry->type_ref; -} - -static LLVMZigDIType *to_llvm_debug_type(AstNode *type_node) { - assert(type_node->type == NodeTypeType); - assert(type_node->codegen_node); - assert(type_node->codegen_node->data.type_node.entry); - - return type_node->codegen_node->data.type_node.entry->di_type; -} + static TypeTableEntry *get_type_for_type_node(CodeGen *g, AstNode *type_node) { assert(type_node->type == NodeTypeType); @@ -83,6 +68,22 @@ static TypeTableEntry *get_type_for_type_node(CodeGen *g, AstNode *type_node) { return type_node->codegen_node->data.type_node.entry; } +static LLVMTypeRef fn_proto_type_from_type_node(CodeGen *g, AstNode *type_node) { + TypeTableEntry *type_entry = get_type_for_type_node(g, type_node); + + if (type_entry->id == TypeTableEntryIdStruct || type_entry->id == TypeTableEntryIdArray) { + return get_pointer_to_type(g, type_entry, true)->type_ref; + } else { + return type_entry->type_ref; + } +} + +static LLVMZigDIType *to_llvm_debug_type(CodeGen *g, AstNode *type_node) { + TypeTableEntry *type_entry = get_type_for_type_node(g, type_node); + return type_entry->di_type; +} + + static bool type_is_unreachable(CodeGen *g, AstNode *type_node) { return get_type_for_type_node(g, type_node) == g->builtin_types.entry_unreachable; } @@ -198,20 +199,6 @@ static LLVMValueRef gen_array_ptr(CodeGen *g, AstNode *node) { return LLVMBuildInBoundsGEP(g->builder, array_ref_value, indices, 2, ""); } -static LLVMValueRef gen_field_val(CodeGen *g, AstNode *node) { - assert(node->type == NodeTypeFieldAccessExpr); - - LLVMValueRef struct_val = gen_expr(g, node->data.field_access_expr.struct_expr); - assert(struct_val); - - FieldAccessNode *codegen_field_access = &node->codegen_node->data.field_access_node; - assert(codegen_field_access->field_index >= 0); - - add_debug_source_node(g, node); - return LLVMBuildExtractValue(g->builder, struct_val, codegen_field_access->field_index, ""); -} - -/* static LLVMValueRef gen_field_ptr(CodeGen *g, AstNode *node) { assert(node->type == NodeTypeFieldAccessExpr); @@ -223,9 +210,9 @@ static LLVMValueRef gen_field_ptr(CodeGen *g, AstNode *node) { assert(codegen_field_access->field_index >= 0); + add_debug_source_node(g, node); return LLVMBuildStructGEP(g->builder, struct_ptr, codegen_field_access->field_index, ""); } -*/ static LLVMValueRef gen_array_access_expr(CodeGen *g, AstNode *node) { assert(node->type == NodeTypeArrayAccessExpr); @@ -249,11 +236,8 @@ static LLVMValueRef gen_field_access_expr(CodeGen *g, AstNode *node) { zig_panic("gen_field_access_expr bad array field"); } } else if (struct_type->id == TypeTableEntryIdStruct) { - /* LLVMValueRef ptr = gen_field_ptr(g, node); return LLVMBuildLoad(g->builder, ptr, ""); - */ - return gen_field_val(g, node); } else { zig_panic("gen_field_access_expr bad struct type"); } @@ -311,14 +295,19 @@ static LLVMValueRef gen_cast_expr(CodeGen *g, AstNode *node) { } case CastOpArrayToString: { - LLVMValueRef struct_vals[] = { - expr_val, - LLVMConstInt(g->builtin_types.entry_usize->type_ref, actual_type->data.array.len, false) - }; - unsigned field_count = g->builtin_types.entry_string->data.structure.field_count; - assert(field_count == 2); - return LLVMConstNamedStruct(g->builtin_types.entry_string->type_ref, - struct_vals, field_count); + assert(cast_node->ptr); + + add_debug_source_node(g, node); + + LLVMValueRef ptr_ptr = LLVMBuildStructGEP(g->builder, cast_node->ptr, 0, ""); + LLVMBuildStore(g->builder, expr_val, ptr_ptr); + + LLVMValueRef len_ptr = LLVMBuildStructGEP(g->builder, cast_node->ptr, 1, ""); + LLVMValueRef len_val = LLVMConstInt(g->builtin_types.entry_usize->type_ref, + actual_type->data.array.len, false); + LLVMBuildStore(g->builder, len_val, len_ptr); + + return cast_node->ptr; } } zig_unreachable(); @@ -580,6 +569,8 @@ static LLVMValueRef gen_assign_expr(CodeGen *g, AstNode *node) { assert(array_type->id == TypeTableEntryIdArray); op1_type = array_type->data.array.child_type; target_ref = gen_array_ptr(g, lhs_node); + } else if (lhs_node->type == NodeTypeFieldAccessExpr) { + target_ref = gen_field_ptr(g, lhs_node); } else { zig_panic("bad assign target"); } @@ -717,6 +708,7 @@ static LLVMValueRef gen_if_expr(CodeGen *g, AstNode *node) { static LLVMValueRef gen_block(CodeGen *g, AstNode *block_node, TypeTableEntry *implicit_return_type) { assert(block_node->type == NodeTypeBlock); + BlockContext *old_block_context = g->cur_block_context; g->cur_block_context = block_node->codegen_node->data.block_node.block_context; LLVMValueRef return_value; @@ -726,6 +718,7 @@ static LLVMValueRef gen_block(CodeGen *g, AstNode *block_node, TypeTableEntry *i } if (implicit_return_type) { + add_debug_source_node(g, block_node); if (implicit_return_type == g->builtin_types.entry_void) { LLVMBuildRetVoid(g->builder); } else if (implicit_return_type != g->builtin_types.entry_unreachable) { @@ -733,6 +726,8 @@ static LLVMValueRef gen_block(CodeGen *g, AstNode *block_node, TypeTableEntry *i } } + g->cur_block_context = old_block_context; + return return_value; } @@ -934,6 +929,8 @@ static LLVMValueRef gen_expr(CodeGen *g, AstNode *node) { } else if (variable->is_ptr) { if (variable->type->id == TypeTableEntryIdArray) { return variable->value_ref; + } else if (variable->type->id == TypeTableEntryIdStruct) { + return variable->value_ref; } else { add_debug_source_node(g, node); return LLVMBuildLoad(g->builder, variable->value_ref, ""); @@ -994,12 +991,12 @@ static LLVMZigDISubroutineType *create_di_function_type(CodeGen *g, AstNodeFnPro LLVMZigDIFile *di_file) { LLVMZigDIType **types = allocate(1 + fn_proto->params.length); - types[0] = to_llvm_debug_type(fn_proto->return_type); + types[0] = to_llvm_debug_type(g, fn_proto->return_type); int types_len = fn_proto->params.length + 1; for (int i = 0; i < fn_proto->params.length; i += 1) { AstNode *param_node = fn_proto->params.at(i); assert(param_node->type == NodeTypeParamDecl); - LLVMZigDIType *param_type = to_llvm_debug_type(param_node->data.param_decl.type); + LLVMZigDIType *param_type = to_llvm_debug_type(g, param_node->data.param_decl.type); types[i + 1] = param_type; } return LLVMZigCreateSubroutineType(g->dbuilder, di_file, types, types_len, 0); @@ -1026,7 +1023,7 @@ static void do_code_gen(CodeGen *g) { assert(proto_node->type == NodeTypeFnProto); AstNodeFnProto *fn_proto = &proto_node->data.fn_proto; - LLVMTypeRef ret_type = to_llvm_type(fn_proto->return_type); + LLVMTypeRef ret_type = fn_proto_type_from_type_node(g, fn_proto->return_type); int param_count = count_non_void_params(g, &fn_proto->params); LLVMTypeRef *param_types = allocate(param_count); int gen_param_index = 0; @@ -1036,7 +1033,7 @@ static void do_code_gen(CodeGen *g) { if (is_param_decl_type_void(g, param_node)) continue; AstNode *type_node = param_node->data.param_decl.type; - param_types[gen_param_index] = to_llvm_type(type_node); + param_types[gen_param_index] = fn_proto_type_from_type_node(g, type_node); gen_param_index += 1; } LLVMTypeRef function_type = LLVMFunctionType(ret_type, param_types, param_count, fn_proto->is_var_args); @@ -1061,8 +1058,8 @@ static void do_code_gen(CodeGen *g) { } // Generate function definitions. - for (int i = 0; i < g->fn_defs.length; i += 1) { - FnTableEntry *fn_table_entry = g->fn_defs.at(i); + for (int fn_i = 0; fn_i < g->fn_defs.length; fn_i += 1) { + FnTableEntry *fn_table_entry = g->fn_defs.at(fn_i); ImportTableEntry *import = fn_table_entry->import_entry; AstNode *fn_def_node = fn_table_entry->fn_def_node; LLVMValueRef fn = fn_table_entry->fn_value; @@ -1101,8 +1098,8 @@ static void do_code_gen(CodeGen *g) { LLVMGetParams(fn, params); int non_void_index = 0; - for (int i = 0; i < fn_proto->params.length; i += 1) { - AstNode *param_decl = fn_proto->params.at(i); + for (int param_i = 0; param_i < fn_proto->params.length; param_i += 1) { + AstNode *param_decl = fn_proto->params.at(param_i); assert(param_decl->type == NodeTypeParamDecl); if (is_param_decl_type_void(g, param_decl)) continue; @@ -1115,8 +1112,8 @@ static void do_code_gen(CodeGen *g) { // Set up debug info for blocks and variables and // allocate all local variables - for (int i = 0; i < codegen_fn_def->all_block_contexts.length; i += 1) { - BlockContext *block_context = codegen_fn_def->all_block_contexts.at(i); + for (int bc_i = 0; bc_i < codegen_fn_def->all_block_contexts.length; bc_i += 1) { + BlockContext *block_context = codegen_fn_def->all_block_contexts.at(bc_i); if (block_context->parent) { LLVMZigDILexicalBlock *di_block = LLVMZigCreateLexicalBlock(g->dbuilder, @@ -1157,6 +1154,16 @@ static void do_code_gen(CodeGen *g) { import->di_file, var->decl_node->line + 1, var->type->di_type, !g->strip_debug_symbols, 0, arg_no); } + + // allocate structs which are the result of casts + for (int cea_i = 0; cea_i < block_context->cast_expr_alloca_list.length; cea_i += 1) { + AstNode *cast_expr_node = block_context->cast_expr_alloca_list.at(cea_i); + assert(cast_expr_node->type == NodeTypeCastExpr); + CastNode *cast_codegen = &cast_expr_node->codegen_node->data.cast_node; + TypeTableEntry *type_entry = get_type_for_type_node(g, cast_expr_node->data.cast_expr.type); + add_debug_source_node(g, cast_expr_node); + cast_codegen->ptr = LLVMBuildAlloca(g->builder, type_entry->type_ref, ""); + } } TypeTableEntry *implicit_return_type = codegen_fn_def->implicit_return_type; diff --git a/test/run_tests.cpp b/test/run_tests.cpp index 1e3b9ed0e..a90bdd0ea 100644 --- a/test/run_tests.cpp +++ b/test/run_tests.cpp @@ -477,6 +477,27 @@ export fn main(argc : isize, argv : *mut *mut u8, env : *mut *mut u8) -> i32 { } )SOURCE", "OK\n"); + add_simple_case("structs", R"SOURCE( +use "std.zig"; + +export fn main(argc : isize, argv : *mut *mut u8, env : *mut *mut u8) -> i32 { + let mut foo : Foo; + foo.a = foo.a + 1; + foo.b = foo.a == 1; + test_foo(foo); + return 0; +} +struct Foo { + a : i32, + b : bool, + c : f32, +} +fn test_foo(foo : Foo) { + if foo.b { + print_str("OK\n" as string); + } +} + )SOURCE", "OK\n"); } static void add_compile_failure_test_cases(void) {