stage1: add support for @mulAdd fused-multiply-add for floats and vectors of floats
Not all of the softfloat library is being built.... Vector support is very buggy at the moment, but should work when the bugs are fixed. (as I had the same code working with another vector function, that hasn't been merged yet).
This commit is contained in:
parent
bbfb53d524
commit
fce2d2d18b
|
@ -389,6 +389,8 @@ set(EMBEDDED_SOFTFLOAT_SOURCES
|
|||
"${CMAKE_SOURCE_DIR}/deps/SoftFloat-3e/source/s_subMagsF32.c"
|
||||
"${CMAKE_SOURCE_DIR}/deps/SoftFloat-3e/source/s_subMagsF64.c"
|
||||
"${CMAKE_SOURCE_DIR}/deps/SoftFloat-3e/source/s_tryPropagateNaNF128M.c"
|
||||
"${CMAKE_SOURCE_DIR}/deps/SoftFloat-3e/source/f16_mulAdd.c"
|
||||
"${CMAKE_SOURCE_DIR}/deps/SoftFloat-3e/source/f128M_mulAdd.c"
|
||||
"${CMAKE_SOURCE_DIR}/deps/SoftFloat-3e/source/softfloat_state.c"
|
||||
"${CMAKE_SOURCE_DIR}/deps/SoftFloat-3e/source/ui32_to_f128M.c"
|
||||
"${CMAKE_SOURCE_DIR}/deps/SoftFloat-3e/source/ui64_to_f128M.c"
|
||||
|
|
|
@ -6259,6 +6259,13 @@ comptime {
|
|||
This function is only valid within function scope.
|
||||
</p>
|
||||
|
||||
{#header_close#}
|
||||
{#header_open|@mulAdd#}
|
||||
<pre>{#syntax#}@mulAdd(comptime T: type, a: T, b: T, c: T) T{#endsyntax#}</pre>
|
||||
<p>
|
||||
Fused multiply add (for floats), similar to {#syntax#}(a * b) + c{#endsyntax#}, except
|
||||
only rounds once, and is thus more accurate.
|
||||
</p>
|
||||
{#header_close#}
|
||||
|
||||
{#header_open|@byteSwap#}
|
||||
|
|
|
@ -1406,6 +1406,7 @@ enum BuiltinFnId {
|
|||
BuiltinFnIdSubWithOverflow,
|
||||
BuiltinFnIdMulWithOverflow,
|
||||
BuiltinFnIdShlWithOverflow,
|
||||
BuiltinFnIdMulAdd,
|
||||
BuiltinFnIdCInclude,
|
||||
BuiltinFnIdCDefine,
|
||||
BuiltinFnIdCUndef,
|
||||
|
@ -1554,6 +1555,7 @@ enum ZigLLVMFnId {
|
|||
ZigLLVMFnIdClz,
|
||||
ZigLLVMFnIdPopCount,
|
||||
ZigLLVMFnIdOverflowArithmetic,
|
||||
ZigLLVMFnIdFMA,
|
||||
ZigLLVMFnIdFloor,
|
||||
ZigLLVMFnIdCeil,
|
||||
ZigLLVMFnIdSqrt,
|
||||
|
@ -1584,6 +1586,7 @@ struct ZigLLVMFnKey {
|
|||
} pop_count;
|
||||
struct {
|
||||
uint32_t bit_count;
|
||||
uint32_t vector_len; // 0 means not a vector
|
||||
} floating;
|
||||
struct {
|
||||
AddSubMul add_sub_mul;
|
||||
|
@ -2235,6 +2238,7 @@ enum IrInstructionId {
|
|||
IrInstructionIdHandle,
|
||||
IrInstructionIdAlignOf,
|
||||
IrInstructionIdOverflowOp,
|
||||
IrInstructionIdMulAdd,
|
||||
IrInstructionIdTestErr,
|
||||
IrInstructionIdUnwrapErrCode,
|
||||
IrInstructionIdUnwrapErrPayload,
|
||||
|
@ -3038,6 +3042,15 @@ struct IrInstructionOverflowOp {
|
|||
ZigType *result_ptr_type;
|
||||
};
|
||||
|
||||
struct IrInstructionMulAdd {
|
||||
IrInstruction base;
|
||||
|
||||
IrInstruction *type_value;
|
||||
IrInstruction *op1;
|
||||
IrInstruction *op2;
|
||||
IrInstruction *op3;
|
||||
};
|
||||
|
||||
struct IrInstructionAlignOf {
|
||||
IrInstruction base;
|
||||
|
||||
|
|
|
@ -5737,11 +5737,11 @@ uint32_t zig_llvm_fn_key_hash(ZigLLVMFnKey x) {
|
|||
case ZigLLVMFnIdPopCount:
|
||||
return (uint32_t)(x.data.clz.bit_count) * (uint32_t)101195049;
|
||||
case ZigLLVMFnIdFloor:
|
||||
return (uint32_t)(x.data.floating.bit_count) * (uint32_t)1899859168;
|
||||
case ZigLLVMFnIdCeil:
|
||||
return (uint32_t)(x.data.floating.bit_count) * (uint32_t)1953839089;
|
||||
case ZigLLVMFnIdSqrt:
|
||||
return (uint32_t)(x.data.floating.bit_count) * (uint32_t)2225366385;
|
||||
case ZigLLVMFnIdFMA:
|
||||
return (uint32_t)(x.data.floating.bit_count) * ((uint32_t)x.id + 1025) +
|
||||
(uint32_t)(x.data.floating.vector_len) * (((uint32_t)x.id << 5) + 1025);
|
||||
case ZigLLVMFnIdBswap:
|
||||
return (uint32_t)(x.data.bswap.bit_count) * (uint32_t)3661994335;
|
||||
case ZigLLVMFnIdBitReverse:
|
||||
|
@ -5772,6 +5772,7 @@ bool zig_llvm_fn_key_eql(ZigLLVMFnKey a, ZigLLVMFnKey b) {
|
|||
case ZigLLVMFnIdFloor:
|
||||
case ZigLLVMFnIdCeil:
|
||||
case ZigLLVMFnIdSqrt:
|
||||
case ZigLLVMFnIdFMA:
|
||||
return a.data.floating.bit_count == b.data.floating.bit_count;
|
||||
case ZigLLVMFnIdOverflowArithmetic:
|
||||
return (a.data.overflow_arithmetic.bit_count == b.data.overflow_arithmetic.bit_count) &&
|
||||
|
|
|
@ -807,31 +807,51 @@ static LLVMValueRef get_int_overflow_fn(CodeGen *g, ZigType *operand_type, AddSu
|
|||
}
|
||||
|
||||
static LLVMValueRef get_float_fn(CodeGen *g, ZigType *type_entry, ZigLLVMFnId fn_id) {
|
||||
assert(type_entry->id == ZigTypeIdFloat);
|
||||
assert(type_entry->id == ZigTypeIdFloat ||
|
||||
type_entry->id == ZigTypeIdVector);
|
||||
|
||||
bool is_vector = (type_entry->id == ZigTypeIdVector);
|
||||
ZigType *float_type = is_vector ? type_entry->data.vector.elem_type : type_entry;
|
||||
|
||||
ZigLLVMFnKey key = {};
|
||||
key.id = fn_id;
|
||||
key.data.floating.bit_count = (uint32_t)type_entry->data.floating.bit_count;
|
||||
key.data.floating.bit_count = (uint32_t)float_type->data.floating.bit_count;
|
||||
key.data.floating.vector_len = is_vector ? (uint32_t)type_entry->data.vector.len : 0;
|
||||
|
||||
auto existing_entry = g->llvm_fn_table.maybe_get(key);
|
||||
if (existing_entry)
|
||||
return existing_entry->value;
|
||||
|
||||
const char *name;
|
||||
uint32_t num_args;
|
||||
if (fn_id == ZigLLVMFnIdFloor) {
|
||||
name = "floor";
|
||||
num_args = 1;
|
||||
} else if (fn_id == ZigLLVMFnIdCeil) {
|
||||
name = "ceil";
|
||||
num_args = 1;
|
||||
} else if (fn_id == ZigLLVMFnIdSqrt) {
|
||||
name = "sqrt";
|
||||
num_args = 1;
|
||||
} else if (fn_id == ZigLLVMFnIdFMA) {
|
||||
name = "fma";
|
||||
num_args = 3;
|
||||
} else {
|
||||
zig_unreachable();
|
||||
}
|
||||
|
||||
char fn_name[64];
|
||||
sprintf(fn_name, "llvm.%s.f%" ZIG_PRI_usize "", name, type_entry->data.floating.bit_count);
|
||||
if (is_vector)
|
||||
sprintf(fn_name, "llvm.%s.v%" PRIu32 "f%" PRIu32, name, key.data.floating.vector_len, key.data.floating.bit_count);
|
||||
else
|
||||
sprintf(fn_name, "llvm.%s.f%" PRIu32, name, key.data.floating.bit_count);
|
||||
LLVMTypeRef float_type_ref = get_llvm_type(g, type_entry);
|
||||
LLVMTypeRef fn_type = LLVMFunctionType(float_type_ref, &float_type_ref, 1, false);
|
||||
LLVMTypeRef return_elem_types[3] = {
|
||||
float_type_ref,
|
||||
float_type_ref,
|
||||
float_type_ref,
|
||||
};
|
||||
LLVMTypeRef fn_type = LLVMFunctionType(float_type_ref, return_elem_types, num_args, false);
|
||||
LLVMValueRef fn_val = LLVMAddFunction(g->module, fn_name, fn_type);
|
||||
assert(LLVMGetIntrinsicID(fn_val));
|
||||
|
||||
|
@ -5437,6 +5457,21 @@ static LLVMValueRef ir_render_sqrt(CodeGen *g, IrExecutable *executable, IrInstr
|
|||
return LLVMBuildCall(g->builder, fn_val, &op, 1, "");
|
||||
}
|
||||
|
||||
static LLVMValueRef ir_render_mul_add(CodeGen *g, IrExecutable *executable, IrInstructionMulAdd *instruction) {
|
||||
LLVMValueRef op1 = ir_llvm_value(g, instruction->op1);
|
||||
LLVMValueRef op2 = ir_llvm_value(g, instruction->op2);
|
||||
LLVMValueRef op3 = ir_llvm_value(g, instruction->op3);
|
||||
assert(instruction->base.value.type->id == ZigTypeIdFloat ||
|
||||
instruction->base.value.type->id == ZigTypeIdVector);
|
||||
LLVMValueRef fn_val = get_float_fn(g, instruction->base.value.type, ZigLLVMFnIdFMA);
|
||||
LLVMValueRef args[3] = {
|
||||
op1,
|
||||
op2,
|
||||
op3,
|
||||
};
|
||||
return LLVMBuildCall(g->builder, fn_val, args, 3, "");
|
||||
}
|
||||
|
||||
static LLVMValueRef ir_render_bswap(CodeGen *g, IrExecutable *executable, IrInstructionBswap *instruction) {
|
||||
LLVMValueRef op = ir_llvm_value(g, instruction->op);
|
||||
ZigType *int_type = instruction->base.value.type;
|
||||
|
@ -5781,6 +5816,8 @@ static LLVMValueRef ir_render_instruction(CodeGen *g, IrExecutable *executable,
|
|||
return ir_render_mark_err_ret_trace_ptr(g, executable, (IrInstructionMarkErrRetTracePtr *)instruction);
|
||||
case IrInstructionIdSqrt:
|
||||
return ir_render_sqrt(g, executable, (IrInstructionSqrt *)instruction);
|
||||
case IrInstructionIdMulAdd:
|
||||
return ir_render_mul_add(g, executable, (IrInstructionMulAdd *)instruction);
|
||||
case IrInstructionIdArrayToVector:
|
||||
return ir_render_array_to_vector(g, executable, (IrInstructionArrayToVector *)instruction);
|
||||
case IrInstructionIdVectorToArray:
|
||||
|
@ -7398,6 +7435,7 @@ static void define_builtin_fns(CodeGen *g) {
|
|||
create_builtin_fn(g, BuiltinFnIdRem, "rem", 2);
|
||||
create_builtin_fn(g, BuiltinFnIdMod, "mod", 2);
|
||||
create_builtin_fn(g, BuiltinFnIdSqrt, "sqrt", 2);
|
||||
create_builtin_fn(g, BuiltinFnIdMulAdd, "mulAdd", 4);
|
||||
create_builtin_fn(g, BuiltinFnIdInlineCall, "inlineCall", SIZE_MAX);
|
||||
create_builtin_fn(g, BuiltinFnIdNoInlineCall, "noInlineCall", SIZE_MAX);
|
||||
create_builtin_fn(g, BuiltinFnIdNewStackCall, "newStackCall", SIZE_MAX);
|
||||
|
|
171
src/ir.cpp
171
src/ir.cpp
|
@ -747,6 +747,10 @@ static constexpr IrInstructionId ir_instruction_id(IrInstructionTestErr *) {
|
|||
return IrInstructionIdTestErr;
|
||||
}
|
||||
|
||||
static constexpr IrInstructionId ir_instruction_id(IrInstructionMulAdd *) {
|
||||
return IrInstructionIdMulAdd;
|
||||
}
|
||||
|
||||
static constexpr IrInstructionId ir_instruction_id(IrInstructionUnwrapErrCode *) {
|
||||
return IrInstructionIdUnwrapErrCode;
|
||||
}
|
||||
|
@ -2308,6 +2312,22 @@ static IrInstruction *ir_build_overflow_op(IrBuilder *irb, Scope *scope, AstNode
|
|||
return &instruction->base;
|
||||
}
|
||||
|
||||
static IrInstruction *ir_build_mul_add(IrBuilder *irb, Scope *scope, AstNode *source_node,
|
||||
IrInstruction *type_value, IrInstruction *op1, IrInstruction *op2, IrInstruction *op3) {
|
||||
IrInstructionMulAdd *instruction = ir_build_instruction<IrInstructionMulAdd>(irb, scope, source_node);
|
||||
instruction->type_value = type_value;
|
||||
instruction->op1 = op1;
|
||||
instruction->op2 = op2;
|
||||
instruction->op3 = op3;
|
||||
|
||||
ir_ref_instruction(type_value, irb->current_basic_block);
|
||||
ir_ref_instruction(op1, irb->current_basic_block);
|
||||
ir_ref_instruction(op2, irb->current_basic_block);
|
||||
ir_ref_instruction(op3, irb->current_basic_block);
|
||||
|
||||
return &instruction->base;
|
||||
}
|
||||
|
||||
static IrInstruction *ir_build_align_of(IrBuilder *irb, Scope *scope, AstNode *source_node, IrInstruction *type_value) {
|
||||
IrInstructionAlignOf *instruction = ir_build_instruction<IrInstructionAlignOf>(irb, scope, source_node);
|
||||
instruction->type_value = type_value;
|
||||
|
@ -4028,6 +4048,33 @@ static IrInstruction *ir_gen_overflow_op(IrBuilder *irb, Scope *scope, AstNode *
|
|||
return ir_build_overflow_op(irb, scope, node, op, type_value, op1, op2, result_ptr, nullptr);
|
||||
}
|
||||
|
||||
static IrInstruction *ir_gen_mul_add(IrBuilder *irb, Scope *scope, AstNode *node) {
|
||||
assert(node->type == NodeTypeFnCallExpr);
|
||||
|
||||
AstNode *type_node = node->data.fn_call_expr.params.at(0);
|
||||
AstNode *op1_node = node->data.fn_call_expr.params.at(1);
|
||||
AstNode *op2_node = node->data.fn_call_expr.params.at(2);
|
||||
AstNode *op3_node = node->data.fn_call_expr.params.at(3);
|
||||
|
||||
IrInstruction *type_value = ir_gen_node(irb, type_node, scope);
|
||||
if (type_value == irb->codegen->invalid_instruction)
|
||||
return irb->codegen->invalid_instruction;
|
||||
|
||||
IrInstruction *op1 = ir_gen_node(irb, op1_node, scope);
|
||||
if (op1 == irb->codegen->invalid_instruction)
|
||||
return irb->codegen->invalid_instruction;
|
||||
|
||||
IrInstruction *op2 = ir_gen_node(irb, op2_node, scope);
|
||||
if (op2 == irb->codegen->invalid_instruction)
|
||||
return irb->codegen->invalid_instruction;
|
||||
|
||||
IrInstruction *op3 = ir_gen_node(irb, op3_node, scope);
|
||||
if (op3 == irb->codegen->invalid_instruction)
|
||||
return irb->codegen->invalid_instruction;
|
||||
|
||||
return ir_build_mul_add(irb, scope, node, type_value, op1, op2, op3);
|
||||
}
|
||||
|
||||
static IrInstruction *ir_gen_this(IrBuilder *irb, Scope *orig_scope, AstNode *node) {
|
||||
for (Scope *it_scope = orig_scope; it_scope != nullptr; it_scope = it_scope->parent) {
|
||||
if (it_scope->id == ScopeIdDecls) {
|
||||
|
@ -4687,6 +4734,8 @@ static IrInstruction *ir_gen_builtin_fn_call(IrBuilder *irb, Scope *scope, AstNo
|
|||
return ir_lval_wrap(irb, scope, ir_gen_overflow_op(irb, scope, node, IrOverflowOpMul), lval);
|
||||
case BuiltinFnIdShlWithOverflow:
|
||||
return ir_lval_wrap(irb, scope, ir_gen_overflow_op(irb, scope, node, IrOverflowOpShl), lval);
|
||||
case BuiltinFnIdMulAdd:
|
||||
return ir_lval_wrap(irb, scope, ir_gen_mul_add(irb, scope, node), lval);
|
||||
case BuiltinFnIdTypeName:
|
||||
{
|
||||
AstNode *arg0_node = node->data.fn_call_expr.params.at(0);
|
||||
|
@ -21185,6 +21234,125 @@ static IrInstruction *ir_analyze_instruction_overflow_op(IrAnalyze *ira, IrInstr
|
|||
return result;
|
||||
}
|
||||
|
||||
static void ir_eval_mul_add(IrAnalyze *ira, IrInstructionMulAdd *source_instr, ZigType *float_type,
|
||||
ConstExprValue *op1, ConstExprValue *op2, ConstExprValue *op3, ConstExprValue *out_val) {
|
||||
if (float_type->id == ZigTypeIdComptimeFloat) {
|
||||
f128M_mulAdd(&out_val->data.x_bigfloat.value, &op1->data.x_bigfloat.value, &op2->data.x_bigfloat.value,
|
||||
&op3->data.x_bigfloat.value);
|
||||
} else if (float_type->id == ZigTypeIdFloat) {
|
||||
switch (float_type->data.floating.bit_count) {
|
||||
case 16:
|
||||
out_val->data.x_f16 = f16_mulAdd(op1->data.x_f16, op2->data.x_f16, op3->data.x_f16);
|
||||
break;
|
||||
case 32:
|
||||
out_val->data.x_f32 = fmaf(op1->data.x_f32, op2->data.x_f32, op3->data.x_f32);
|
||||
break;
|
||||
case 64:
|
||||
out_val->data.x_f64 = fma(op1->data.x_f64, op2->data.x_f64, op3->data.x_f64);
|
||||
break;
|
||||
case 128:
|
||||
f128M_mulAdd(&op1->data.x_f128, &op2->data.x_f128, &op3->data.x_f128, &out_val->data.x_f128);
|
||||
break;
|
||||
default:
|
||||
zig_unreachable();
|
||||
}
|
||||
} else {
|
||||
zig_unreachable();
|
||||
}
|
||||
}
|
||||
|
||||
static IrInstruction *ir_analyze_instruction_mul_add(IrAnalyze *ira, IrInstructionMulAdd *instruction) {
|
||||
IrInstruction *type_value = instruction->type_value->child;
|
||||
if (type_is_invalid(type_value->value.type))
|
||||
return ira->codegen->invalid_instruction;
|
||||
|
||||
ZigType *expr_type = ir_resolve_type(ira, type_value);
|
||||
if (type_is_invalid(expr_type))
|
||||
return ira->codegen->invalid_instruction;
|
||||
|
||||
// Only allow float types, and vectors of floats.
|
||||
ZigType *float_type = (expr_type->id == ZigTypeIdVector) ? expr_type->data.vector.elem_type : expr_type;
|
||||
if (float_type->id != ZigTypeIdFloat) {
|
||||
ir_add_error(ira, type_value,
|
||||
buf_sprintf("expected float or vector of float type, found '%s'", buf_ptr(&float_type->name)));
|
||||
return ira->codegen->invalid_instruction;
|
||||
}
|
||||
|
||||
IrInstruction *op1 = instruction->op1->child;
|
||||
if (type_is_invalid(op1->value.type))
|
||||
return ira->codegen->invalid_instruction;
|
||||
|
||||
IrInstruction *casted_op1 = ir_implicit_cast(ira, op1, expr_type);
|
||||
if (type_is_invalid(casted_op1->value.type))
|
||||
return ira->codegen->invalid_instruction;
|
||||
|
||||
IrInstruction *op2 = instruction->op2->child;
|
||||
if (type_is_invalid(op2->value.type))
|
||||
return ira->codegen->invalid_instruction;
|
||||
|
||||
IrInstruction *casted_op2 = ir_implicit_cast(ira, op2, expr_type);
|
||||
if (type_is_invalid(casted_op2->value.type))
|
||||
return ira->codegen->invalid_instruction;
|
||||
|
||||
IrInstruction *op3 = instruction->op3->child;
|
||||
if (type_is_invalid(op3->value.type))
|
||||
return ira->codegen->invalid_instruction;
|
||||
|
||||
IrInstruction *casted_op3 = ir_implicit_cast(ira, op3, expr_type);
|
||||
if (type_is_invalid(casted_op3->value.type))
|
||||
return ira->codegen->invalid_instruction;
|
||||
|
||||
if (instr_is_comptime(casted_op1) &&
|
||||
instr_is_comptime(casted_op2) &&
|
||||
instr_is_comptime(casted_op3)) {
|
||||
ConstExprValue *op1_const = ir_resolve_const(ira, casted_op1, UndefBad);
|
||||
if (!op1_const)
|
||||
return ira->codegen->invalid_instruction;
|
||||
ConstExprValue *op2_const = ir_resolve_const(ira, casted_op2, UndefBad);
|
||||
if (!op2_const)
|
||||
return ira->codegen->invalid_instruction;
|
||||
ConstExprValue *op3_const = ir_resolve_const(ira, casted_op3, UndefBad);
|
||||
if (!op3_const)
|
||||
return ira->codegen->invalid_instruction;
|
||||
|
||||
IrInstruction *result = ir_const(ira, &instruction->base, expr_type);
|
||||
ConstExprValue *out_val = &result->value;
|
||||
|
||||
if (expr_type->id == ZigTypeIdVector) {
|
||||
expand_undef_array(ira->codegen, op1_const);
|
||||
expand_undef_array(ira->codegen, op2_const);
|
||||
expand_undef_array(ira->codegen, op3_const);
|
||||
out_val->special = ConstValSpecialUndef;
|
||||
expand_undef_array(ira->codegen, out_val);
|
||||
size_t len = expr_type->data.vector.len;
|
||||
for (size_t i = 0; i < len; i += 1) {
|
||||
ConstExprValue *float_operand_op1 = &op1_const->data.x_array.data.s_none.elements[i];
|
||||
ConstExprValue *float_operand_op2 = &op2_const->data.x_array.data.s_none.elements[i];
|
||||
ConstExprValue *float_operand_op3 = &op3_const->data.x_array.data.s_none.elements[i];
|
||||
ConstExprValue *float_out_val = &out_val->data.x_array.data.s_none.elements[i];
|
||||
assert(float_operand_op1->type == float_type);
|
||||
assert(float_operand_op2->type == float_type);
|
||||
assert(float_operand_op3->type == float_type);
|
||||
assert(float_out_val->type == float_type);
|
||||
ir_eval_mul_add(ira, instruction, float_type,
|
||||
op1_const, op2_const, op3_const, float_out_val);
|
||||
float_out_val->type = float_type;
|
||||
}
|
||||
out_val->type = expr_type;
|
||||
out_val->special = ConstValSpecialStatic;
|
||||
} else {
|
||||
ir_eval_mul_add(ira, instruction, float_type, op1_const, op2_const, op3_const, out_val);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
IrInstruction *result = ir_build_mul_add(&ira->new_irb,
|
||||
instruction->base.scope, instruction->base.source_node,
|
||||
type_value, casted_op1, casted_op2, casted_op3);
|
||||
result->value.type = expr_type;
|
||||
return result;
|
||||
}
|
||||
|
||||
static IrInstruction *ir_analyze_instruction_test_err(IrAnalyze *ira, IrInstructionTestErr *instruction) {
|
||||
IrInstruction *value = instruction->value->child;
|
||||
if (type_is_invalid(value->value.type))
|
||||
|
@ -23596,6 +23764,8 @@ static IrInstruction *ir_analyze_instruction_nocast(IrAnalyze *ira, IrInstructio
|
|||
return ir_analyze_instruction_mark_err_ret_trace_ptr(ira, (IrInstructionMarkErrRetTracePtr *)instruction);
|
||||
case IrInstructionIdSqrt:
|
||||
return ir_analyze_instruction_sqrt(ira, (IrInstructionSqrt *)instruction);
|
||||
case IrInstructionIdMulAdd:
|
||||
return ir_analyze_instruction_mul_add(ira, (IrInstructionMulAdd *)instruction);
|
||||
case IrInstructionIdIntToErr:
|
||||
return ir_analyze_instruction_int_to_err(ira, (IrInstructionIntToErr *)instruction);
|
||||
case IrInstructionIdErrToInt:
|
||||
|
@ -23835,6 +24005,7 @@ bool ir_has_side_effects(IrInstruction *instruction) {
|
|||
case IrInstructionIdCoroPromise:
|
||||
case IrInstructionIdPromiseResultType:
|
||||
case IrInstructionIdSqrt:
|
||||
case IrInstructionIdMulAdd:
|
||||
case IrInstructionIdAtomicLoad:
|
||||
case IrInstructionIdIntCast:
|
||||
case IrInstructionIdFloatCast:
|
||||
|
|
|
@ -1439,6 +1439,22 @@ static void ir_print_sqrt(IrPrint *irp, IrInstructionSqrt *instruction) {
|
|||
fprintf(irp->f, ")");
|
||||
}
|
||||
|
||||
static void ir_print_mul_add(IrPrint *irp, IrInstructionMulAdd *instruction) {
|
||||
fprintf(irp->f, "@mulAdd(");
|
||||
if (instruction->type_value != nullptr) {
|
||||
ir_print_other_instruction(irp, instruction->type_value);
|
||||
} else {
|
||||
fprintf(irp->f, "null");
|
||||
}
|
||||
fprintf(irp->f, ",");
|
||||
ir_print_other_instruction(irp, instruction->op1);
|
||||
fprintf(irp->f, ",");
|
||||
ir_print_other_instruction(irp, instruction->op2);
|
||||
fprintf(irp->f, ",");
|
||||
ir_print_other_instruction(irp, instruction->op3);
|
||||
fprintf(irp->f, ")");
|
||||
}
|
||||
|
||||
static void ir_print_decl_var_gen(IrPrint *irp, IrInstructionDeclVarGen *decl_var_instruction) {
|
||||
ZigVar *var = decl_var_instruction->var;
|
||||
const char *var_or_const = decl_var_instruction->var->gen_is_const ? "const" : "var";
|
||||
|
@ -1905,6 +1921,9 @@ static void ir_print_instruction(IrPrint *irp, IrInstruction *instruction) {
|
|||
case IrInstructionIdSqrt:
|
||||
ir_print_sqrt(irp, (IrInstructionSqrt *)instruction);
|
||||
break;
|
||||
case IrInstructionIdMulAdd:
|
||||
ir_print_mul_add(irp, (IrInstructionMulAdd *)instruction);
|
||||
break;
|
||||
case IrInstructionIdAtomicLoad:
|
||||
ir_print_atomic_load(irp, (IrInstructionAtomicLoad *)instruction);
|
||||
break;
|
||||
|
|
34
test/stage1/behavior/muladd.zig
Normal file
34
test/stage1/behavior/muladd.zig
Normal file
|
@ -0,0 +1,34 @@
|
|||
const expect = @import("std").testing.expect;
|
||||
|
||||
test "@mulAdd" {
|
||||
comptime testMulAdd();
|
||||
testMulAdd();
|
||||
}
|
||||
|
||||
fn testMulAdd() void {
|
||||
{
|
||||
var a: f16 = 5.5;
|
||||
var b: f16 = 2.5;
|
||||
var c: f16 = 6.25;
|
||||
expect(@mulAdd(f16, a, b, c) == 20);
|
||||
}
|
||||
{
|
||||
var a: f32 = 5.5;
|
||||
var b: f32 = 2.5;
|
||||
var c: f32 = 6.25;
|
||||
expect(@mulAdd(f32, a, b, c) == 20);
|
||||
}
|
||||
{
|
||||
var a: f64 = 5.5;
|
||||
var b: f64 = 2.5;
|
||||
var c: f64 = 6.25;
|
||||
expect(@mulAdd(f64, a, b, c) == 20);
|
||||
}
|
||||
// Awaits implementation in libm.zig
|
||||
//{
|
||||
// var a: f16 = 5.5;
|
||||
// var b: f128 = 2.5;
|
||||
// var c: f128 = 6.25;
|
||||
// expect(@mulAdd(f128, a, b, c) == 20);
|
||||
//}
|
||||
}
|
Loading…
Reference in New Issue
Block a user