zig/src/codegen.cpp

441 lines
16 KiB
C++
Raw Normal View History

/*
* Copyright (c) 2015 Andrew Kelley
*
* This file is part of zig, which is MIT licensed.
* See http://opensource.org/licenses/MIT
*/
#include "codegen.hpp"
#include "hash_map.hpp"
#include "zig_llvm.hpp"
2015-11-25 04:51:36 +08:00
#include "os.hpp"
#include <stdio.h>
2015-11-24 17:43:45 +08:00
struct FnTableEntry {
LLVMValueRef fn_value;
AstNode *proto_node;
};
struct CodeGen {
2015-11-24 17:43:45 +08:00
LLVMModuleRef mod;
AstNode *root;
2015-11-24 17:43:45 +08:00
HashMap<Buf *, AstNode *, buf_hash, buf_eql_buf> fn_defs;
ZigList<ErrorMsg> errors;
2015-11-24 15:35:23 +08:00
LLVMBuilderRef builder;
2015-11-24 17:43:45 +08:00
HashMap<Buf *, FnTableEntry *, buf_hash, buf_eql_buf> fn_table;
HashMap<Buf *, LLVMValueRef, buf_hash, buf_eql_buf> str_table;
2015-11-24 15:35:23 +08:00
};
struct TypeNode {
LLVMTypeRef type_ref;
bool is_unreachable;
};
2015-11-24 15:35:23 +08:00
struct CodeGenNode {
union {
TypeNode type_node; // for NodeTypeType
2015-11-24 15:35:23 +08:00
} data;
};
CodeGen *create_codegen(AstNode *root) {
CodeGen *g = allocate<CodeGen>(1);
g->root = root;
2015-11-24 17:43:45 +08:00
g->fn_defs.init(32);
g->fn_table.init(32);
g->str_table.init(32);
return g;
}
static void add_node_error(CodeGen *g, AstNode *node, Buf *msg) {
g->errors.add_one();
ErrorMsg *last_msg = &g->errors.last();
last_msg->line_start = node->line;
last_msg->column_start = node->column;
last_msg->line_end = -1;
last_msg->column_end = -1;
last_msg->msg = msg;
}
2015-11-24 17:43:45 +08:00
static LLVMTypeRef to_llvm_type(AstNode *type_node) {
assert(type_node->type == NodeTypeType);
assert(type_node->codegen_node);
assert(type_node->codegen_node->data.type_node.type_ref);
2015-11-24 17:43:45 +08:00
return type_node->codegen_node->data.type_node.type_ref;
}
static bool type_is_unreachable(AstNode *type_node) {
assert(type_node->type == NodeTypeType);
return type_node->data.type.type == AstNodeTypeTypePrimitive &&
buf_eql_str(&type_node->data.type.primitive_name, "unreachable");
2015-11-24 17:43:45 +08:00
}
static void analyze_node(CodeGen *g, AstNode *node) {
switch (node->type) {
case NodeTypeRoot:
2015-11-24 17:43:45 +08:00
for (int i = 0; i < node->data.root.top_level_decls.length; i += 1) {
AstNode *child = node->data.root.top_level_decls.at(i);
analyze_node(g, child);
}
break;
2015-11-24 17:43:45 +08:00
case NodeTypeExternBlock:
for (int fn_decl_i = 0; fn_decl_i < node->data.extern_block.fn_decls.length; fn_decl_i += 1) {
AstNode *fn_decl = node->data.extern_block.fn_decls.at(fn_decl_i);
analyze_node(g, fn_decl);
AstNode *fn_proto = fn_decl->data.fn_decl.fn_proto;
Buf *name = &fn_proto->data.fn_proto.name;
ZigList<AstNode *> *params = &fn_proto->data.fn_proto.params;
LLVMTypeRef *fn_param_values = allocate<LLVMTypeRef>(params->length);
for (int param_i = 0; param_i < params->length; param_i += 1) {
AstNode *param_node = params->at(param_i);
assert(param_node->type == NodeTypeParamDecl);
AstNode *param_type = param_node->data.param_decl.type;
fn_param_values[param_i] = to_llvm_type(param_type);
}
AstNode *return_type_node = fn_proto->data.fn_proto.return_type;
LLVMTypeRef return_type = to_llvm_type(return_type_node);
2015-11-24 17:43:45 +08:00
LLVMTypeRef fn_type = LLVMFunctionType(return_type, fn_param_values, params->length, 0);
LLVMValueRef fn_val = LLVMAddFunction(g->mod, buf_ptr(name), fn_type);
LLVMSetLinkage(fn_val, LLVMExternalLinkage);
LLVMSetFunctionCallConv(fn_val, LLVMCCallConv);
if (type_is_unreachable(return_type_node)) {
LLVMAddFunctionAttr(fn_val, LLVMNoReturnAttribute);
}
2015-11-24 17:43:45 +08:00
FnTableEntry *fn_table_entry = allocate<FnTableEntry>(1);
fn_table_entry->fn_value = fn_val;
fn_table_entry->proto_node = fn_proto;
g->fn_table.put(name, fn_table_entry);
}
break;
case NodeTypeFnDef:
{
2015-11-24 17:43:45 +08:00
AstNode *proto_node = node->data.fn_def.fn_proto;
assert(proto_node->type = NodeTypeFnProto);
Buf *proto_name = &proto_node->data.fn_proto.name;
auto entry = g->fn_defs.maybe_get(proto_name);
if (entry) {
add_node_error(g, node,
2015-11-24 17:43:45 +08:00
buf_sprintf("redefinition of '%s'", buf_ptr(proto_name)));
} else {
2015-11-24 17:43:45 +08:00
g->fn_defs.put(proto_name, node);
analyze_node(g, proto_node);
}
break;
}
2015-11-24 17:43:45 +08:00
case NodeTypeFnDecl:
{
AstNode *proto_node = node->data.fn_decl.fn_proto;
assert(proto_node->type == NodeTypeFnProto);
analyze_node(g, proto_node);
break;
}
case NodeTypeFnProto:
{
for (int i = 0; i < node->data.fn_proto.params.length; i += 1) {
AstNode *child = node->data.fn_proto.params.at(i);
analyze_node(g, child);
}
analyze_node(g, node->data.fn_proto.return_type);
break;
}
case NodeTypeParamDecl:
analyze_node(g, node->data.param_decl.type);
break;
case NodeTypeType:
{
node->codegen_node = allocate<CodeGenNode>(1);
TypeNode *type_node = &node->codegen_node->data.type_node;
switch (node->data.type.type) {
case AstNodeTypeTypePrimitive:
{
Buf *name = &node->data.type.primitive_name;
if (buf_eql_str(name, "u8")) {
type_node->type_ref = LLVMInt8Type();
} else if (buf_eql_str(name, "i32")) {
type_node->type_ref = LLVMInt32Type();
} else if (buf_eql_str(name, "void")) {
type_node->type_ref = LLVMVoidType();
} else if (buf_eql_str(name, "unreachable")) {
type_node->type_ref = LLVMVoidType();
type_node->is_unreachable = true;
} else {
add_node_error(g, node,
buf_sprintf("invalid type name: '%s'", buf_ptr(name)));
type_node->type_ref = LLVMVoidType();
}
break;
2015-11-24 15:35:23 +08:00
}
case AstNodeTypeTypePointer:
{
analyze_node(g, node->data.type.child_type);
TypeNode *child_type_node = &node->data.type.child_type->codegen_node->data.type_node;
if (child_type_node->is_unreachable) {
add_node_error(g, node,
buf_create_from_str("pointer to unreachable not allowed"));
}
type_node->type_ref = LLVMPointerType(child_type_node->type_ref, 0);
break;
}
}
break;
2015-11-24 15:35:23 +08:00
}
case NodeTypeBlock:
2015-11-24 15:35:23 +08:00
for (int i = 0; i < node->data.block.statements.length; i += 1) {
AstNode *child = node->data.block.statements.at(i);
analyze_node(g, child);
}
break;
case NodeTypeStatement:
2015-11-24 15:35:23 +08:00
switch (node->data.statement.type) {
case AstNodeStatementTypeExpression:
analyze_node(g, node->data.statement.data.expr.expression);
break;
case AstNodeStatementTypeReturn:
analyze_node(g, node->data.statement.data.retrn.expression);
break;
}
break;
case NodeTypeExpression:
2015-11-24 15:35:23 +08:00
switch (node->data.expression.type) {
case AstNodeExpressionTypeNumber:
break;
case AstNodeExpressionTypeString:
break;
case AstNodeExpressionTypeFnCall:
analyze_node(g, node->data.expression.data.fn_call);
break;
case AstNodeExpressionTypeUnreachable:
break;
2015-11-24 15:35:23 +08:00
}
break;
case NodeTypeFnCall:
2015-11-24 15:35:23 +08:00
for (int i = 0; i < node->data.fn_call.params.length; i += 1) {
AstNode *child = node->data.fn_call.params.at(i);
analyze_node(g, child);
}
break;
}
}
2015-11-24 15:35:23 +08:00
void semantic_analyze(CodeGen *g) {
2015-11-24 17:43:45 +08:00
g->mod = LLVMModuleCreateWithName("ZigModule");
// Pass 1.
analyze_node(g, g->root);
}
2015-11-24 17:43:45 +08:00
static LLVMValueRef gen_expr(CodeGen *g, AstNode *expr_node);
2015-11-24 15:35:23 +08:00
static LLVMValueRef gen_fn_call(CodeGen *g, AstNode *fn_call_node) {
assert(fn_call_node->type == NodeTypeFnCall);
2015-11-24 17:43:45 +08:00
Buf *name = &fn_call_node->data.fn_call.name;
2015-11-24 15:35:23 +08:00
2015-11-24 17:43:45 +08:00
auto entry = g->fn_table.maybe_get(name);
if (!entry) {
add_node_error(g, fn_call_node,
buf_sprintf("undefined function: '%s'", buf_ptr(name)));
return LLVMConstNull(LLVMInt32Type());
}
FnTableEntry *fn_table_entry = entry->value;
assert(fn_table_entry->proto_node->type == NodeTypeFnProto);
int expected_param_count = fn_table_entry->proto_node->data.fn_proto.params.length;
int actual_param_count = fn_call_node->data.fn_call.params.length;
if (expected_param_count != actual_param_count) {
add_node_error(g, fn_call_node,
buf_sprintf("wrong number of arguments. Expected %d, got %d.",
expected_param_count, actual_param_count));
return LLVMConstNull(LLVMInt32Type());
}
2015-11-24 15:35:23 +08:00
2015-11-24 17:43:45 +08:00
LLVMValueRef *param_values = allocate<LLVMValueRef>(actual_param_count);
for (int i = 0; i < actual_param_count; i += 1) {
AstNode *expr_node = fn_call_node->data.fn_call.params.at(i);
param_values[i] = gen_expr(g, expr_node);
}
2015-11-24 15:35:23 +08:00
2015-11-24 17:43:45 +08:00
LLVMValueRef result = LLVMBuildCall(g->builder, fn_table_entry->fn_value,
param_values, actual_param_count, "");
if (type_is_unreachable(fn_table_entry->proto_node->data.fn_proto.return_type)) {
return LLVMBuildUnreachable(g->builder);
} else {
return result;
}
2015-11-24 17:43:45 +08:00
}
static LLVMValueRef find_or_create_string(CodeGen *g, Buf *str) {
auto entry = g->str_table.maybe_get(str);
if (entry) {
return entry->value;
}
LLVMValueRef text = LLVMConstString(buf_ptr(str), buf_len(str), false);
LLVMValueRef global_value = LLVMAddGlobal(g->mod, LLVMTypeOf(text), "");
LLVMSetLinkage(global_value, LLVMPrivateLinkage);
2015-11-24 17:43:45 +08:00
LLVMSetInitializer(global_value, text);
LLVMSetGlobalConstant(global_value, true);
LLVMSetUnnamedAddr(global_value, true);
2015-11-24 17:43:45 +08:00
g->str_table.put(str, global_value);
return global_value;
2015-11-24 15:35:23 +08:00
}
static LLVMValueRef gen_expr(CodeGen *g, AstNode *expr_node) {
assert(expr_node->type == NodeTypeExpression);
switch (expr_node->data.expression.type) {
case AstNodeExpressionTypeNumber:
2015-11-24 17:43:45 +08:00
{
Buf *number_str = &expr_node->data.expression.data.number;
LLVMTypeRef number_type = LLVMInt32Type();
LLVMValueRef number_val = LLVMConstIntOfStringAndSize(number_type,
buf_ptr(number_str), buf_len(number_str), 10);
return number_val;
}
2015-11-24 15:35:23 +08:00
case AstNodeExpressionTypeString:
2015-11-24 17:43:45 +08:00
{
Buf *str = &expr_node->data.expression.data.string;
LLVMValueRef str_val = find_or_create_string(g, str);
LLVMValueRef indices[] = {
LLVMConstInt(LLVMInt32Type(), 0, false),
LLVMConstInt(LLVMInt32Type(), 0, false)
};
LLVMValueRef ptr_val = LLVMBuildInBoundsGEP(g->builder, str_val,
indices, 2, "");
return ptr_val;
2015-11-24 17:43:45 +08:00
}
2015-11-24 15:35:23 +08:00
case AstNodeExpressionTypeFnCall:
return gen_fn_call(g, expr_node->data.expression.data.fn_call);
case AstNodeExpressionTypeUnreachable:
return LLVMBuildUnreachable(g->builder);
2015-11-24 15:35:23 +08:00
}
zig_unreachable();
}
static void gen_block(CodeGen *g, AstNode *block_node) {
assert(block_node->type == NodeTypeBlock);
for (int i = 0; i < block_node->data.block.statements.length; i += 1) {
AstNode *statement_node = block_node->data.block.statements.at(i);
assert(statement_node->type == NodeTypeStatement);
switch (statement_node->data.statement.type) {
case AstNodeStatementTypeReturn:
{
AstNode *expr_node = statement_node->data.statement.data.retrn.expression;
LLVMValueRef value = gen_expr(g, expr_node);
LLVMBuildRet(g->builder, value);
break;
}
case AstNodeStatementTypeExpression:
{
AstNode *expr_node = statement_node->data.statement.data.expr.expression;
gen_expr(g, expr_node);
break;
}
}
}
}
void code_gen(CodeGen *g) {
2015-11-24 15:35:23 +08:00
g->builder = LLVMCreateBuilder();
2015-11-24 17:43:45 +08:00
auto it = g->fn_defs.entry_iterator();
for (;;) {
auto *entry = it.next();
if (!entry)
break;
2015-11-24 15:35:23 +08:00
2015-11-24 17:43:45 +08:00
AstNode *fn_def_node = entry->value;
AstNodeFnDef *fn_def = &fn_def_node->data.fn_def;
assert(fn_def->fn_proto->type == NodeTypeFnProto);
AstNodeFnProto *fn_proto = &fn_def->fn_proto->data.fn_proto;
2015-11-24 15:35:23 +08:00
2015-11-24 17:43:45 +08:00
LLVMTypeRef ret_type = to_llvm_type(fn_proto->return_type);
LLVMTypeRef *param_types = allocate<LLVMTypeRef>(fn_proto->params.length);
for (int param_decl_i = 0; param_decl_i < fn_proto->params.length; param_decl_i += 1) {
AstNode *param_node = fn_proto->params.at(param_decl_i);
2015-11-24 15:35:23 +08:00
assert(param_node->type == NodeTypeParamDecl);
AstNode *type_node = param_node->data.param_decl.type;
param_types[param_decl_i] = to_llvm_type(type_node);
}
2015-11-24 17:43:45 +08:00
LLVMTypeRef function_type = LLVMFunctionType(ret_type, param_types, fn_proto->params.length, 0);
LLVMValueRef fn = LLVMAddFunction(g->mod, buf_ptr(&fn_proto->name), function_type);
2015-11-24 15:35:23 +08:00
if (type_is_unreachable(fn_proto->return_type)) {
LLVMAddFunctionAttr(fn, LLVMNoReturnAttribute);
}
2015-11-24 17:43:45 +08:00
LLVMBasicBlockRef entry_block = LLVMAppendBasicBlock(fn, "entry");
LLVMPositionBuilderAtEnd(g->builder, entry_block);
2015-11-24 15:35:23 +08:00
2015-11-24 17:43:45 +08:00
gen_block(g, fn_def->body);
2015-11-24 15:35:23 +08:00
}
2015-11-24 17:43:45 +08:00
LLVMDumpModule(g->mod);
char *error = nullptr;
LLVMVerifyModule(g->mod, LLVMAbortProcessAction, &error);
}
ZigList<ErrorMsg> *codegen_error_messages(CodeGen *g) {
return &g->errors;
}
void code_gen_link(CodeGen *g, bool is_static, const char *out_file) {
LLVMInitializeAllTargets();
LLVMInitializeAllTargetMCs();
LLVMInitializeAllAsmPrinters();
LLVMInitializeAllAsmParsers();
LLVMInitializeNativeTarget();
LLVMPassRegistryRef registry = LLVMGetGlobalPassRegistry();
LLVMInitializeCore(registry);
LLVMInitializeCodeGen(registry);
LLVMZigInitializeLoopStrengthReducePass(registry);
LLVMZigInitializeLowerIntrinsicsPass(registry);
LLVMZigInitializeUnreachableBlockElimPass(registry);
char *native_triple = LLVMGetDefaultTargetTriple();
LLVMTargetRef target_ref;
char *err_msg = nullptr;
if (LLVMGetTargetFromTriple(native_triple, &target_ref, &err_msg)) {
zig_panic("unable to get target from triple: %s", err_msg);
}
char *native_cpu = LLVMZigGetHostCPUName();
char *native_features = LLVMZigGetNativeFeatures();
LLVMCodeGenOptLevel opt_level = LLVMCodeGenLevelNone;
LLVMRelocMode reloc_mode = is_static ? LLVMRelocStatic : LLVMRelocPIC;
LLVMTargetMachineRef target_machine = LLVMCreateTargetMachine(target_ref, native_triple,
native_cpu, native_features, opt_level, reloc_mode, LLVMCodeModelDefault);
2015-11-25 04:51:36 +08:00
Buf out_file_o = {0};
buf_init_from_str(&out_file_o, out_file);
buf_append_str(&out_file_o, ".o");
if (LLVMTargetMachineEmitToFile(target_machine, g->mod, buf_ptr(&out_file_o), LLVMObjectFile, &err_msg)) {
zig_panic("unable to write object file: %s", err_msg);
}
2015-11-25 04:51:36 +08:00
ZigList<const char *> args = {0};
args.append("-o");
args.append(out_file);
args.append((const char *)buf_ptr(&out_file_o));
args.append("-lc");
os_spawn_process("ld", args, false);
}