diff --git a/src/all_types.hpp b/src/all_types.hpp index de99f7855..ecd34a10b 100644 --- a/src/all_types.hpp +++ b/src/all_types.hpp @@ -1104,6 +1104,7 @@ enum BuiltinFnId { BuiltinFnIdAddWithOverflow, BuiltinFnIdSubWithOverflow, BuiltinFnIdMulWithOverflow, + BuiltinFnIdShlWithOverflow, BuiltinFnIdCInclude, BuiltinFnIdCDefine, BuiltinFnIdCUndef, diff --git a/src/analyze.cpp b/src/analyze.cpp index 7f02cc9af..62938ad0f 100644 --- a/src/analyze.cpp +++ b/src/analyze.cpp @@ -4546,6 +4546,7 @@ static TypeTableEntry *analyze_builtin_fn_call_expr(CodeGen *g, ImportTableEntry case BuiltinFnIdAddWithOverflow: case BuiltinFnIdSubWithOverflow: case BuiltinFnIdMulWithOverflow: + case BuiltinFnIdShlWithOverflow: { AstNode *type_node = node->data.fn_call_expr.params.at(0); TypeTableEntry *int_type = analyze_type_expr(g, import, context, type_node); diff --git a/src/codegen.cpp b/src/codegen.cpp index b3046e747..9e2f2bc91 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -459,6 +459,34 @@ static LLVMValueRef gen_fence(CodeGen *g, AstNode *node) { return nullptr; } +static LLVMValueRef gen_shl_with_overflow(CodeGen *g, AstNode *node) { + assert(node->type == NodeTypeFnCallExpr); + + int fn_call_param_count = node->data.fn_call_expr.params.length; + assert(fn_call_param_count == 4); + + TypeTableEntry *int_type = get_type_for_type_node(node->data.fn_call_expr.params.at(0)); + assert(int_type->id == TypeTableEntryIdInt); + + LLVMValueRef val1 = gen_expr(g, node->data.fn_call_expr.params.at(1)); + LLVMValueRef val2 = gen_expr(g, node->data.fn_call_expr.params.at(2)); + LLVMValueRef ptr_result = gen_expr(g, node->data.fn_call_expr.params.at(3)); + + set_debug_source_node(g, node); + LLVMValueRef result = LLVMBuildShl(g->builder, val1, val2, ""); + LLVMValueRef orig_val; + if (int_type->data.integral.is_signed) { + orig_val = LLVMBuildAShr(g->builder, result, val2, ""); + } else { + orig_val = LLVMBuildLShr(g->builder, result, val2, ""); + } + LLVMValueRef overflow_bit = LLVMBuildICmp(g->builder, LLVMIntNE, val1, orig_val, ""); + + LLVMBuildStore(g->builder, result, ptr_result); + + return overflow_bit; +} + static LLVMValueRef gen_builtin_fn_call_expr(CodeGen *g, AstNode *node) { assert(node->type == NodeTypeFnCallExpr); AstNode *fn_ref_expr = node->data.fn_call_expr.fn_ref_expr; @@ -527,6 +555,8 @@ static LLVMValueRef gen_builtin_fn_call_expr(CodeGen *g, AstNode *node) { return overflow_bit; } + case BuiltinFnIdShlWithOverflow: + return gen_shl_with_overflow(g, node); case BuiltinFnIdMemcpy: { int fn_call_param_count = node->data.fn_call_expr.params.length; @@ -4357,6 +4387,7 @@ static void define_builtin_fns(CodeGen *g) { create_builtin_fn_with_arg_count(g, BuiltinFnIdAddWithOverflow, "add_with_overflow", 4); create_builtin_fn_with_arg_count(g, BuiltinFnIdSubWithOverflow, "sub_with_overflow", 4); create_builtin_fn_with_arg_count(g, BuiltinFnIdMulWithOverflow, "mul_with_overflow", 4); + create_builtin_fn_with_arg_count(g, BuiltinFnIdShlWithOverflow, "shl_with_overflow", 4); create_builtin_fn_with_arg_count(g, BuiltinFnIdCInclude, "c_include", 1); create_builtin_fn_with_arg_count(g, BuiltinFnIdCDefine, "c_define", 2); create_builtin_fn_with_arg_count(g, BuiltinFnIdCUndef, "c_undef", 1); diff --git a/src/eval.cpp b/src/eval.cpp index 1253125c5..a2bea04ab 100644 --- a/src/eval.cpp +++ b/src/eval.cpp @@ -707,6 +707,7 @@ static bool eval_fn_call_builtin(EvalFn *ef, AstNode *node, ConstExprValue *out_ case BuiltinFnIdErrName: case BuiltinFnIdEmbedFile: case BuiltinFnIdCmpExchange: + case BuiltinFnIdShlWithOverflow: zig_panic("TODO"); case BuiltinFnIdBreakpoint: case BuiltinFnIdInvalid: diff --git a/test/self_hosted.zig b/test/self_hosted.zig index f7a19a2d3..2e6cc6a3e 100644 --- a/test/self_hosted.zig +++ b/test/self_hosted.zig @@ -1522,3 +1522,11 @@ fn test_shl_wrapping_noeval(x: u16w) { x_u16 <<= 1; assert(x_u16 == 65534); } + +#attribute("test") +fn shl_with_overflow() { + var result: u16 = undefined; + assert(@shl_with_overflow(u16, 0b0010111111111111, 3, &result)); + assert(!@shl_with_overflow(u16, 0b0010111111111111, 2, &result)); + assert(result == 0b1011111111111100); +}