Index: mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h =================================================================== --- mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h +++ mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h @@ -48,7 +48,25 @@ struct LLVMDialectImpl; } // namespace detail +void setFastmathFlag(Operation *op, FastmathFlags flag, bool value); +bool hasFastmathFlag(Operation *op, FastmathFlags flag); } // namespace LLVM +namespace OpTrait { +namespace LLVM { +template +class FastmathFlagsInterface + : public OpTrait::TraitBase { +public: + void setFastmathFlag(::mlir::LLVM::FastmathFlags flag, bool value) { + ::mlir::LLVM::setFastmathFlag(this->getOperation(), flag, value); + } + bool hasFastmathFlag(::mlir::LLVM::FastmathFlags flag) { + return ::mlir::LLVM::hasFastmathFlag(this->getOperation(), flag); + } +}; + +} // namespace LLVM +} // namespace OpTrait } // namespace mlir ///// Ops ///// Index: mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td =================================================================== --- mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -18,6 +18,25 @@ include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" +def FMFnnan : StrEnumAttrCase<"nnan">; +def FMFninf : StrEnumAttrCase<"ninf">; +def FMFnsz : StrEnumAttrCase<"nsz">; +def FMFarcp : StrEnumAttrCase<"arcp">; +def FMFcontract : StrEnumAttrCase<"contract">; +def FMFafn : StrEnumAttrCase<"afn">; +def FMFreassoc : StrEnumAttrCase<"reassoc">; +def FMFfast : StrEnumAttrCase<"fast">; + +def FastmathFlags : StrEnumAttr< + "FastmathFlags", + "fastmath flags", + [FMFnnan, FMFninf, FMFnsz, FMFarcp, FMFcontract, FMFafn, FMFreassoc, FMFfast + ]> { + let cppNamespace = "::mlir::LLVM"; +} + +def FastmathFlagsInterface : NativeOpTrait<"LLVM::FastmathFlagsInterface">; + class LLVM_Builder { string llvmBuilder = builder; } @@ -200,7 +219,8 @@ } // Other integer operations. -def LLVM_FCmpOp : LLVM_OneResultOp<"fcmp", [NoSideEffect]> { +def LLVM_FCmpOp : LLVM_OneResultOp<"fcmp", + [NoSideEffect, FastmathFlagsInterface]> { let arguments = (ins FCmpPredicate:$predicate, LLVM_ScalarOrVectorOf:$lhs, LLVM_ScalarOrVectorOf:$rhs); @@ -218,13 +238,19 @@ } // Floating point binary operations. -def LLVM_FAddOp : LLVM_FloatArithmeticOp<"fadd", "CreateFAdd">; -def LLVM_FSubOp : LLVM_FloatArithmeticOp<"fsub", "CreateFSub">; -def LLVM_FMulOp : LLVM_FloatArithmeticOp<"fmul", "CreateFMul">; -def LLVM_FDivOp : LLVM_FloatArithmeticOp<"fdiv", "CreateFDiv">; -def LLVM_FRemOp : LLVM_FloatArithmeticOp<"frem", "CreateFRem">; +def LLVM_FAddOp : LLVM_FloatArithmeticOp<"fadd", "CreateFAdd", + [FastmathFlagsInterface]>; +def LLVM_FSubOp : LLVM_FloatArithmeticOp<"fsub", "CreateFSub", + [FastmathFlagsInterface]>; +def LLVM_FMulOp : LLVM_FloatArithmeticOp<"fmul", "CreateFMul", + [FastmathFlagsInterface]>; +def LLVM_FDivOp : LLVM_FloatArithmeticOp<"fdiv", "CreateFDiv", + [FastmathFlagsInterface]>; +def LLVM_FRemOp : LLVM_FloatArithmeticOp<"frem", "CreateFRem", + [FastmathFlagsInterface]>; def LLVM_FNegOp : LLVM_UnaryArithmeticOp, - "fneg", "CreateFNeg">; + "fneg", "CreateFNeg", + [FastmathFlagsInterface]>; // Common code definition that is used to verify and set the alignment attribute // of LLVM ops that accept such an attribute. @@ -405,7 +431,7 @@ let printer = [{ printLandingpadOp(p, *this); }]; } -def LLVM_CallOp : LLVM_Op<"call">, +def LLVM_CallOp : LLVM_Op<"call", [FastmathFlagsInterface]>, Results<(outs Variadic)> { let arguments = (ins OptionalAttr:$callee, Variadic); Index: mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp =================================================================== --- mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -34,6 +34,7 @@ static constexpr const char kVolatileAttrName[] = "volatile_"; static constexpr const char kNonTemporalAttrName[] = "nontemporal"; +static constexpr const char kFastmathFlagsAttrName[] = "fastmathflags"; #include "mlir/Dialect/LLVMIR/LLVMOpsEnums.cpp.inc" @@ -1881,3 +1882,41 @@ return op->hasTrait() && op->hasTrait(); } + +static ArrayAttr getFastmathFlagsVal(Operation *op) { + if (auto attr = op->getAttrOfType(kFastmathFlagsAttrName)) { + return attr; + } + return ArrayAttr::get({}, op->getContext()); +} + +static void setFastmathFlagsVal(Operation *op, ArrayRef value) { + auto type = IntegerType::get(32, op->getContext()); + op->setAttr(kFastmathFlagsAttrName, ArrayAttr::get(value, op->getContext())); +} + +void mlir::LLVM::setFastmathFlag(Operation *op, FastmathFlags flag, + bool value) { + auto flagName = + StringAttr::get(stringifyFastmathFlags(flag), op->getContext()); + auto values = getFastmathFlagsVal(op); + if (value) { + if (llvm::find(values, flagName) == values.end()) { + SmallVector newValues(values.begin(), values.end()); + newValues.emplace_back(flagName); + setFastmathFlagsVal(op, newValues); + } + } else { + SmallVector newValues(values.begin(), values.end()); + newValues.erase(std::remove(newValues.begin(), newValues.end(), flagName), + newValues.end()); + setFastmathFlagsVal(op, newValues); + } +} + +bool mlir::LLVM::hasFastmathFlag(Operation *op, FastmathFlags flag) { + auto flagName = + StringAttr::get(stringifyFastmathFlags(flag), op->getContext()); + auto values = getFastmathFlagsVal(op); + return llvm::find(values, flagName) != values.end(); +} Index: mlir/lib/Target/LLVMIR/ModuleTranslation.cpp =================================================================== --- mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -511,6 +511,31 @@ }); } +static llvm::FastMathFlags getFastmathFlags(Operation &op) { + using FMF = ::llvm::FastMathFlags; + using mlirFMF = ::mlir::LLVM::FastmathFlags; + using FuncT = void (FMF::*)(bool); + const std::pair<::mlir::LLVM::FastmathFlags, FuncT> handlers[] = { + // clang-format off + {mlirFMF::nnan, &FMF::setNoNaNs}, + {mlirFMF::ninf, &FMF::setNoInfs}, + {mlirFMF::nsz, &FMF::setNoSignedZeros}, + {mlirFMF::arcp, &FMF::setAllowReciprocal}, + {mlirFMF::contract, &FMF::setAllowContract}, + {mlirFMF::afn, &FMF::setApproxFunc}, + {mlirFMF::reassoc, &FMF::setAllowReassoc}, + {mlirFMF::fast, &FMF::setFast}, + // clang-format on + }; + llvm::FastMathFlags ret; + for (auto it : handlers) { + if (hasFastmathFlag(&op, it.first)) { + (ret.*(it.second))(true); + } + } + return ret; +} + /// Given a single MLIR operation, create the corresponding LLVM IR operation /// using the `builder`. LLVM IR Builder does not have a generic interface so /// this has to be a long chain of `if`s calling different functions with a @@ -525,6 +550,11 @@ return position; }; + llvm::IRBuilder<>::FastMathFlagGuard FMFGuard(builder); + if (opInst.hasTrait<::mlir::OpTrait::LLVM::FastmathFlagsInterface>()) { + builder.setFastMathFlags(getFastmathFlags(opInst)); + } + #include "mlir/Dialect/LLVMIR/LLVMConversions.inc" // Emit function calls. If the "callee" attribute is present, this is a