diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td @@ -13,52 +13,55 @@ // "intr." to avoid potential name clashes. class LLVM_UnaryIntrOpBase traits = [], - dag addAttrs = (ins)> : + list traits = [], bit requiresFastmath = 0> : LLVM_OneResultIntrOp { - dag args = (ins LLVM_ScalarOrVectorOf:$in); - let arguments = !con(args, addAttrs); + !listconcat([Pure, SameOperandsAndResultType], traits), + requiresFastmath> { + dag commonArgs = (ins LLVM_ScalarOrVectorOf:$in); let assemblyFormat = "`(` operands `)` custom(attr-dict) `:` " "functional-type(operands, results)"; } class LLVM_UnaryIntrOpI traits = []> : - LLVM_UnaryIntrOpBase; + LLVM_UnaryIntrOpBase { + let arguments = commonArgs; +} class LLVM_UnaryIntrOpF traits = []> : - LLVM_UnaryIntrOpBase], - traits), - (ins DefaultValuedAttr:$fastmathFlags)>; + LLVM_UnaryIntrOpBase { + dag fmfArg = ( + ins DefaultValuedAttr:$fastmathFlags); + let arguments = !con(commonArgs, fmfArg); +} class LLVM_BinarySameArgsIntrOpBase traits = [], - dag addAttrs = (ins)> : + list traits = [], bit requiresFastmath = 0> : LLVM_OneResultIntrOp { - dag args = (ins LLVM_ScalarOrVectorOf:$a, - LLVM_ScalarOrVectorOf:$b); - let arguments = !con(args, addAttrs); + !listconcat([Pure, SameOperandsAndResultType], traits), + requiresFastmath> { + dag commonArgs = (ins LLVM_ScalarOrVectorOf:$a, + LLVM_ScalarOrVectorOf:$b); let assemblyFormat = "`(` operands `)` custom(attr-dict) `:` " "functional-type(operands, results)"; } class LLVM_BinarySameArgsIntrOpI traits = []> : - LLVM_BinarySameArgsIntrOpBase; + LLVM_BinarySameArgsIntrOpBase { + let arguments = commonArgs; +} class LLVM_BinarySameArgsIntrOpF traits = []> : - LLVM_BinarySameArgsIntrOpBase], - traits), - (ins DefaultValuedAttr:$fastmathFlags)>; + LLVM_BinarySameArgsIntrOpBase { + dag fmfArg = ( + ins DefaultValuedAttr:$fastmathFlags); + let arguments = !con(commonArgs, fmfArg); +} class LLVM_TernarySameArgsIntrOpF traits = []> : LLVM_OneResultIntrOp, - Pure, SameOperandsAndResultType], traits)> { + !listconcat([Pure, SameOperandsAndResultType], traits), + /*requiresFastmath=*/1> { let arguments = (ins LLVM_ScalarOrVectorOf:$a, LLVM_ScalarOrVectorOf:$b, LLVM_ScalarOrVectorOf:$c, @@ -106,7 +109,8 @@ def LLVM_SqrtOp : LLVM_UnaryIntrOpF<"sqrt">; def LLVM_PowOp : LLVM_BinarySameArgsIntrOpF<"pow">; def LLVM_PowIOp : LLVM_OneResultIntrOp<"powi", [], [0,1], - [DeclareOpInterfaceMethods, Pure]> { + [DeclareOpInterfaceMethods, Pure], + /*requiresFastmath=*/1> { let arguments = (ins LLVM_ScalarOrVectorOf:$val, AnySignlessInteger:$power, diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -345,8 +345,13 @@ class LLVM_IntrOpBase overloadedResults, list overloadedOperands, list traits, int numResults, - bit requiresAccessGroup = 0, bit requiresAliasScope = 0> - : LLVM_OpBase, + bit requiresAccessGroup = 0, bit requiresAliasScope = 0, + bit requiresFastmath = 0> + : LLVM_OpBase], + []), + traits)>, Results { string resultPattern = !if(!gt(numResults, 1), LLVM_IntrPatterns.structResult, @@ -378,9 +383,11 @@ return failure(); SmallVector resultTypes = }] # !if(!gt(numResults, 0), "{$_resultType};", "{};") # [{ - Operation *op = $_builder.create<$_qualCppClassName>( + auto op = $_builder.create<$_qualCppClassName>( $_location, resultTypes, *mlirOperands); - }] # !if(!gt(numResults, 0), "$res = op->getResult(0);", "(void)op;"); + }] # !if(!gt(requiresFastmath, 0), + "setFastmathFlagsAttr(inst, op);", "") + # !if(!gt(numResults, 0), "$res = op;", "(void)op;"); } // Base class for LLVM intrinsic operations, should not be used directly. Places @@ -388,10 +395,11 @@ class LLVM_IntrOp overloadedResults, list overloadedOperands, list traits, int numResults, bit requiresAccessGroup = 0, - bit requiresAliasScope = 0> + bit requiresAliasScope = 0, bit requiresFastmath = 0> : LLVM_IntrOpBase; + numResults, requiresAccessGroup, requiresAliasScope, + requiresFastmath>; // Base class for LLVM intrinsic operations returning no results. Places the // intrinsic into the LLVM dialect and prefixes its name with "intr.". @@ -419,8 +427,11 @@ // empty otherwise. class LLVM_OneResultIntrOp overloadedResults = [], list overloadedOperands = [], - list traits = []> - : LLVM_IntrOp; + list traits = [], + bit requiresFastmath = 0> + : LLVM_IntrOp; def LLVM_OneResultOpBuilder : OpBuilder<(ins "Type":$resultType, "ValueRange":$operands, diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -44,14 +44,14 @@ let builders = [LLVM_OneResultOpBuilder]; let assemblyFormat = "$lhs `,` $rhs custom(attr-dict) `:` type($res)"; string llvmInstName = instName; - string mlirBuilder = [{ - $res = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs); - }]; } class LLVM_IntArithmeticOp traits = []> : LLVM_ArithmeticOpBase { let arguments = commonArgs; + string mlirBuilder = [{ + $res = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs); + }]; } class LLVM_FloatArithmeticOp traits = []> : @@ -60,6 +60,11 @@ dag fmfArg = ( ins DefaultValuedAttr:$fastmathFlags); let arguments = !con(commonArgs, fmfArg); + string mlirBuilder = [{ + auto op = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs); + setFastmathFlagsAttr(inst, op); + $res = op; + }]; } // Class for arithmetic unary operations. @@ -76,8 +81,10 @@ let assemblyFormat = "$operand custom(attr-dict) `:` type($res)"; string llvmInstName = instName; string mlirBuilder = [{ - $res = $_builder.create<$_qualCppClassName>($_location, $operand); - }]; + auto op = $_builder.create<$_qualCppClassName>($_location, $operand); + setFastmathFlagsAttr(inst, op); + $res = op; + }]; } // Integer binary operations. @@ -146,11 +153,12 @@ string llvmBuilder = [{ $res = builder.CreateFCmp(getLLVMCmpPredicate($predicate), $lhs, $rhs); }]; - // FIXME: Import fastmath flags. string mlirBuilder = [{ auto *fCmpInst = cast(inst); - $res = $_builder.create<$_qualCppClassName>( + auto op = $_builder.create<$_qualCppClassName>( $_location, getFCmpPredicate(fCmpInst->getPredicate()), $lhs, $rhs); + setFastmathFlagsAttr(inst, op); + $res = op; }]; // Set the $predicate index to -1 to indicate there is no matching operand // and decrement the following indices. diff --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp --- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp @@ -38,6 +38,7 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" +#include "llvm/IR/Operator.h" #include "llvm/IR/Type.h" #include "llvm/IRReader/IRReader.h" #include "llvm/Support/Error.h" @@ -327,8 +328,13 @@ return blocks; } -// Handles importing globals and functions from an LLVM module. namespace { +/// Module import implementation class that provides methods to import globals +/// and functions from an LLVM module into an MLIR module. It holds mappings +/// between the original and translated globals, basic blocks, and values used +/// during the translation. Additionally, it keeps track of the current constant +/// insertion point since LLVM immediate values translate to MLIR operations +/// that are introduced at the beginning of the region. class Importer { public: Importer(MLIRContext *context, ModuleOp module) @@ -422,6 +428,10 @@ constantInsertionOp = nullptr; } + /// Sets the fastmath flags attribute for the imported operation `op` given + /// the original instruction `inst`. Asserts if the operation does not + /// implement the fastmath interface. + void setFastmathFlagsAttr(llvm::Instruction *inst, Operation *op) const; /// Returns personality of `func` as a FlatSymbolRefAttr. FlatSymbolRefAttr getPersonalityAsAttr(llvm::Function *func); /// Imports `bb` into `block`, which must be initially empty. @@ -488,6 +498,31 @@ }; } // namespace +void Importer::setFastmathFlagsAttr(llvm::Instruction *inst, + Operation *op) const { + auto iface = cast(op); + + // Even if the imported operation implements the fastmath interface, the + // original instruction may not have fastmath flags set. Exit if an + // instruction, such as a non floating-point function call, does not have + // fastmath flags. + if (!isa(inst)) + return; + llvm::FastMathFlags flags = inst->getFastMathFlags(); + + // Set the fastmath bits flag-by-flag. + FastmathFlags value = {}; + value = bitEnumSet(value, FastmathFlags::nnan, flags.noNaNs()); + value = bitEnumSet(value, FastmathFlags::ninf, flags.noInfs()); + value = bitEnumSet(value, FastmathFlags::nsz, flags.noSignedZeros()); + value = bitEnumSet(value, FastmathFlags::arcp, flags.allowReciprocal()); + value = bitEnumSet(value, FastmathFlags::contract, flags.allowContract()); + value = bitEnumSet(value, FastmathFlags::afn, flags.approxFunc()); + value = bitEnumSet(value, FastmathFlags::reassoc, flags.allowReassoc()); + FastmathFlagsAttr attr = FastmathFlagsAttr::get(builder.getContext(), value); + iface->setAttr(iface.getFastmathAttrName(), attr); +} + // We only need integers, floats, doubles, and vectors and tensors thereof for // attributes. Scalar and vector types are converted to the standard // equivalents. Array types are converted to ranked tensors; nested array types @@ -1033,6 +1068,7 @@ } else { callOp = builder.create(loc, types, operands); } + setFastmathFlagsAttr(inst, callOp); if (!callInst->getType()->isVoidTy()) mapValue(inst, callOp.getResult()); return success(); @@ -1118,7 +1154,7 @@ LogicalResult Importer::processInstruction(llvm::Instruction *inst) { // FIXME: Support uses of SubtargetData. // FIXME: Add support for inbounds GEPs. - // FIXME: Add support for fast-math flags and call / operand attributes. + // FIXME: Add support for call / operand attributes. // FIXME: Add support for the indirectbr, cleanupret, catchret, catchswitch, // callbr, vaarg, landingpad, catchpad, cleanuppad instructions. diff --git a/mlir/test/Target/LLVMIR/Import/fastmath.ll b/mlir/test/Target/LLVMIR/Import/fastmath.ll new file mode 100644 --- /dev/null +++ b/mlir/test/Target/LLVMIR/Import/fastmath.ll @@ -0,0 +1,56 @@ +; RUN: mlir-translate -import-llvm -split-input-file %s | FileCheck %s + +; CHECK-LABEL: @fastmath_inst +define void @fastmath_inst(float %arg1, float %arg2) { + ; CHECK: llvm.fadd %{{.*}}, %{{.*}} {fastmathFlags = #llvm.fastmath} : f32 + %1 = fadd nnan ninf float %arg1, %arg2 + ; CHECK: llvm.fsub %{{.*}}, %{{.*}} {fastmathFlags = #llvm.fastmath} : f32 + %2 = fsub nsz float %arg1, %arg2 + ; CHECK: llvm.fmul %{{.*}}, %{{.*}} {fastmathFlags = #llvm.fastmath} : f32 + %3 = fmul arcp contract float %arg1, %arg2 + ; CHECK: llvm.fdiv %{{.*}}, %{{.*}} {fastmathFlags = #llvm.fastmath} : f32 + %4 = fdiv afn reassoc float %arg1, %arg2 + ; CHECK: llvm.fneg %{{.*}} {fastmathFlags = #llvm.fastmath} : f32 + %5 = fneg fast float %arg1 + ret void +} + +; // ----- + +; CHECK-LABEL: @fastmath_fcmp +define void @fastmath_fcmp(float %arg1, float %arg2) { + ; CHECK: llvm.fcmp "oge" %{{.*}}, %{{.*}} {fastmathFlags = #llvm.fastmath} : f32 + %1 = fcmp nsz oge float %arg1, %arg2 + ret void +} + +; // ----- + +declare float @fn(float) + +; CHECK-LABEL: @fastmath_call +define void @fastmath_call(float %arg1) { + ; CHECK: llvm.call @fn(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (f32) -> f32 + %1 = call ninf float @fn(float %arg1) + ret void +} + +; // ----- + +declare float @llvm.exp.f32(float) +declare float @llvm.powi.f32.i32(float, i32) +declare float @llvm.pow.f32(float, float) +declare float @llvm.fmuladd.f32(float, float, float) + +; CHECK-LABEL: @fastmath_intr +define void @fastmath_intr(float %arg1, i32 %arg2) { + ; CHECK: llvm.intr.exp(%{{.*}}) {fastmathFlags = #llvm.fastmath} : (f32) -> f32 + %1 = call nnan ninf float @llvm.exp.f32(float %arg1) + ; CHECK: llvm.intr.powi(%{{.*}}, %{{.*}}) {fastmathFlags = #llvm.fastmath} : (f32, i32) -> f32 + %2 = call fast float @llvm.powi.f32.i32(float %arg1, i32 %arg2) + ; CHECK: llvm.intr.pow(%{{.*}}, %{{.*}}) {fastmathFlags = #llvm.fastmath} : (f32, f32) -> f32 + %3 = call fast float @llvm.pow.f32(float %arg1, float %arg1) + ; CHECK: llvm.intr.fmuladd(%{{.*}}, %{{.*}}, %{{.*}}) {fastmathFlags = #llvm.fastmath} : (f32, f32, f32) -> f32 + %4 = call fast float @llvm.fmuladd.f32(float %arg1, float %arg1, float %arg1) + ret void +}