diff --git a/src/codegen.cpp b/src/codegen.cpp index eb56d26ca..680f5a9e3 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -3414,17 +3414,34 @@ static LLVMValueRef ir_render_struct_init(CodeGen *g, IrExecutable *executable, static LLVMValueRef ir_render_union_init(CodeGen *g, IrExecutable *executable, IrInstructionUnionInit *instruction) { TypeUnionField *type_union_field = instruction->field; - assert(type_has_bits(type_union_field->type_entry)); - - LLVMValueRef field_ptr = LLVMBuildStructGEP(g->builder, instruction->tmp_ptr, (unsigned)0, ""); - LLVMValueRef value = ir_llvm_value(g, instruction->init_value); + if (!type_has_bits(type_union_field->type_entry)) + return nullptr; uint32_t field_align_bytes = get_abi_alignment(g, type_union_field->type_entry); - TypeTableEntry *ptr_type = get_pointer_to_type_extra(g, type_union_field->type_entry, false, false, field_align_bytes, 0, 0); + LLVMValueRef uncasted_union_ptr; + // Even if safety is off in this block, if the union type has the safety field, we have to populate it + // correctly. Otherwise safety code somewhere other than here could fail. + TypeTableEntry *union_type = instruction->union_type; + if (union_type->data.unionation.gen_tag_index != SIZE_MAX) { + LLVMValueRef tag_field_ptr = LLVMBuildStructGEP(g->builder, instruction->tmp_ptr, + union_type->data.unionation.gen_tag_index, ""); + LLVMValueRef tag_value = LLVMConstInt(union_type->data.unionation.tag_type->type_ref, + type_union_field->value, false); + gen_store_untyped(g, tag_value, tag_field_ptr, 0, false); + + uncasted_union_ptr = LLVMBuildStructGEP(g->builder, instruction->tmp_ptr, + (unsigned)union_type->data.unionation.gen_union_index, ""); + } else { + uncasted_union_ptr = LLVMBuildStructGEP(g->builder, instruction->tmp_ptr, (unsigned)0, ""); + } + + LLVMValueRef field_ptr = LLVMBuildBitCast(g->builder, uncasted_union_ptr, ptr_type->type_ref, ""); + LLVMValueRef value = ir_llvm_value(g, instruction->init_value); + gen_assign_raw(g, field_ptr, ptr_type, value); return instruction->tmp_ptr; diff --git a/test/cases/union.zig b/test/cases/union.zig index 1abebb3b3..404472158 100644 --- a/test/cases/union.zig +++ b/test/cases/union.zig @@ -45,6 +45,23 @@ test "basic unions" { assert(foo.float == 12.34); } +test "init union with runtime value" { + var foo: Foo = undefined; + + setFloat(&foo, 12.34); + assert(foo.float == 12.34); + + setInt(&foo, 42); + assert(foo.int == 42); +} + +fn setFloat(foo: &Foo, x: f64) { + *foo = Foo { .float = x }; +} + +fn setInt(foo: &Foo, x: i32) { + *foo = Foo { .int = x }; +} const FooExtern = extern union { float: f64, @@ -57,3 +74,4 @@ test "basic extern unions" { foo.float = 12.34; assert(foo.float == 12.34); } +