diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h --- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h @@ -24,6 +24,26 @@ ValueRange operands, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter); + +/// Class to manage named attributes for lowering ops to LLVM. For ops that +/// require attribute modification, an internal vector of named attributes +/// will be populated with attributes for the destination LLVMIR operation. +/// For those ops that do not require modification, a reference to the original +/// source attributes will be returned by the getAttrs() function. (This class +/// is currently only suitable for "universal" rules, as conversion decisions +/// are made in the class constructor.) +class ConvertToLLVMAttributes { +public: + ConvertToLLVMAttributes(Operation *op); + ArrayRef getAttrs() const { + return convertedAttrs ? convertedAttrs.value() : srcAttrs; + } + +private: + using AttrVector = SmallVector; + ArrayRef srcAttrs; + Optional convertedAttrs; +}; } // namespace detail } // namespace LLVM diff --git a/mlir/include/mlir/Dialect/Arith/IR/Arith.h b/mlir/include/mlir/Dialect/Arith/IR/Arith.h --- a/mlir/include/mlir/Dialect/Arith/IR/Arith.h +++ b/mlir/include/mlir/Dialect/Arith/IR/Arith.h @@ -17,6 +17,7 @@ #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/VectorInterfaces.h" +#include "llvm/ADT/StringExtras.h" //===----------------------------------------------------------------------===// // ArithDialect @@ -29,6 +30,13 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arith/IR/ArithOpsEnums.h.inc" +#define GET_ATTRDEF_CLASSES +#include "mlir/Dialect/Arith/IR/ArithOpsAttributes.h.inc" + +//===----------------------------------------------------------------------===// +// Arith Interfaces +//===----------------------------------------------------------------------===// +#include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.h.inc" //===----------------------------------------------------------------------===// // Arith Dialect Operations diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td --- a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td +++ b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td @@ -23,6 +23,7 @@ }]; let hasConstantMaterializer = 1; + let useDefaultAttributePrinterParser = 1; } // The predicate indicates the type of the comparison to perform: @@ -92,4 +93,32 @@ let cppNamespace = "::mlir::arith"; } +def FASTMATH_NONE : I32BitEnumAttrCaseNone<"none" >; +def FASTMATH_REASSOC : I32BitEnumAttrCaseBit<"reassoc", 0>; +def FASTMATH_NO_NANS : I32BitEnumAttrCaseBit<"nnan", 1>; +def FASTMATH_NO_INFS : I32BitEnumAttrCaseBit<"ninf", 2>; +def FASTMATH_NO_SIGNED_ZEROS : I32BitEnumAttrCaseBit<"nsz", 3>; +def FASTMATH_ALLOW_RECIP : I32BitEnumAttrCaseBit<"arcp", 4>; +def FASTMATH_ALLOW_CONTRACT : I32BitEnumAttrCaseBit<"contract", 5>; +def FASTMATH_APPROX_FUNC : I32BitEnumAttrCaseBit<"afn", 6>; +def FASTMATH_FAST : I32BitEnumAttrCaseGroup< + "fast", + [ + FASTMATH_REASSOC, FASTMATH_NO_NANS, FASTMATH_NO_INFS, + FASTMATH_NO_SIGNED_ZEROS, FASTMATH_ALLOW_RECIP, FASTMATH_ALLOW_CONTRACT, + FASTMATH_APPROX_FUNC]>; + +def FastMathFlags : I32BitEnumAttr< + "FastMathFlags", + "Floating point fast math flags", + [ + FASTMATH_NONE, FASTMATH_REASSOC, FASTMATH_NO_NANS, + FASTMATH_NO_INFS, FASTMATH_NO_SIGNED_ZEROS, FASTMATH_ALLOW_RECIP, + FASTMATH_ALLOW_CONTRACT, FASTMATH_APPROX_FUNC, FASTMATH_FAST]> { + let separator = ","; + let cppNamespace = "::mlir::arith"; + let genSpecializedAttr = 0; + let printBitEnumPrimaryGroups = 1; +} + #endif // ARITH_BASE diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td --- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td +++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td @@ -10,6 +10,7 @@ #define ARITH_OPS include "mlir/Dialect/Arith/IR/ArithBase.td" +include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.td" include "mlir/Interfaces/CastInterfaces.td" include "mlir/Interfaces/InferIntRangeInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td" @@ -17,6 +18,12 @@ include "mlir/Interfaces/VectorInterfaces.td" include "mlir/IR/BuiltinAttributeInterfaces.td" include "mlir/IR/OpAsmInterface.td" +include "mlir/IR/EnumAttr.td" + +def Arith_FastMathAttr : + EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} // Base class for Arith dialect ops. Ops in this dialect have no side // effects and can be applied element-wise to vectors and tensors. @@ -58,15 +65,27 @@ // Base class for floating point unary operations. class Arith_FloatUnaryOp traits = []> : - Arith_UnaryOp, - Arguments<(ins FloatLike:$operand)>, - Results<(outs FloatLike:$result)>; + Arith_UnaryOp], + traits)>, + Arguments<(ins FloatLike:$operand, + DefaultValuedAttr:$fastmath)>, + Results<(outs FloatLike:$result)> { + let assemblyFormat = [{ $operand custom($fastmath) + attr-dict `:` type($result) }]; +} // Base class for floating point binary operations. class Arith_FloatBinaryOp traits = []> : - Arith_BinaryOp, - Arguments<(ins FloatLike:$lhs, FloatLike:$rhs)>, - Results<(outs FloatLike:$result)>; + Arith_BinaryOp], + traits)>, + Arguments<(ins FloatLike:$lhs, FloatLike:$rhs, + DefaultValuedAttr:$fastmath)>, + Results<(outs FloatLike:$result)> { + let assemblyFormat = [{ $lhs `,` $rhs `` custom($fastmath) + attr-dict `:` type($result) }]; +} // Base class for arithmetic cast operations. Requires a single operand and // result. If either is a shaped type, then the other must be of the same shape. diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td @@ -0,0 +1,40 @@ +//===-- ArithOpsInterfaces.td - arith op interfaces ---*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the Arith interfaces definition file. +// +//===----------------------------------------------------------------------===// + +#ifndef ARITH_OPS_INTERFACES +#define ARITH_OPS_INTERFACES + +include "mlir/IR/OpBase.td" + +def ArithFastMathInterface : OpInterface<"ArithFastMathInterface"> { + let description = [{ + Access to operation fastmath flags. + }]; + + let cppNamespace = "::mlir::arith"; + + let methods = [ + InterfaceMethod< + /*desc=*/ "Returns a FastMathFlagsAttr attribute for the operation", + /*returnType=*/ "FastMathFlagsAttr", + /*methodName=*/ "getFastMathFlagsAttr", + /*args=*/ (ins), + /*methodBody=*/ [{}], + /*defaultImpl=*/ [{ + ConcreteOp op = cast(this->getOperation()); + return op.getFastmathAttr(); + }] + > + ]; +} + +#endif // ARITH_OPS_INTERFACES diff --git a/mlir/include/mlir/Dialect/Arith/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Arith/IR/CMakeLists.txt --- a/mlir/include/mlir/Dialect/Arith/IR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Arith/IR/CMakeLists.txt @@ -1,5 +1,14 @@ set(LLVM_TARGET_DEFINITIONS ArithOps.td) mlir_tablegen(ArithOpsEnums.h.inc -gen-enum-decls) mlir_tablegen(ArithOpsEnums.cpp.inc -gen-enum-defs) +mlir_tablegen(ArithOpsAttributes.h.inc -gen-attrdef-decls + -attrdefs-dialect=arith) +mlir_tablegen(ArithOpsAttributes.cpp.inc -gen-attrdef-defs + -attrdefs-dialect=arith) add_mlir_dialect(ArithOps arith) add_mlir_doc(ArithOps ArithOps Dialects/ -gen-dialect-doc) + +set(LLVM_TARGET_DEFINITIONS ArithOpsInterfaces.td) +mlir_tablegen(ArithOpsInterfaces.h.inc -gen-op-interface-decls) +mlir_tablegen(ArithOpsInterfaces.cpp.inc -gen-op-interface-defs) +add_public_tablegen_target(MLIRArithOpsInterfacesIncGen) diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h @@ -47,6 +47,15 @@ } // namespace LLVM } // namespace mlir +namespace mlir { +namespace LLVM { +class LoopOptionsAttrBuilder; +} // namespace LLVM +} // namespace mlir + +#define GET_ATTRDEF_CLASSES +#include "mlir/Dialect/LLVMIR/LLVMOpsAttrDefs.h.inc" + #include "mlir/Dialect/LLVMIR/LLVMOpsInterfaces.h.inc" namespace llvm { @@ -61,7 +70,6 @@ namespace mlir { namespace LLVM { class LLVMDialect; -class LoopOptionsAttrBuilder; namespace detail { struct LLVMTypeStorage; @@ -70,9 +78,6 @@ } // namespace LLVM } // namespace mlir -#define GET_ATTRDEF_CLASSES -#include "mlir/Dialect/LLVMIR/LLVMOpsAttrDefs.h.inc" - namespace mlir { namespace LLVM { template diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpsInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpsInterfaces.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpsInterfaces.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpsInterfaces.td @@ -26,6 +26,19 @@ InterfaceMethod<"Get fastmath flags", "::mlir::LLVM::FastmathFlags", "getFastmathFlags">, ]; + let methods = [ + InterfaceMethod< + /*desc=*/ "Returns a FastmathFlagsAttr attribute for the operation", + /*returnType=*/ "FastmathFlagsAttr", + /*methodName=*/ "getFastmathAttr", + /*args=*/ (ins), + /*methodBody=*/ [{}], + /*defaultImpl=*/ [{ + ConcreteOp op = cast(this->getOperation()); + return op.getFastmathFlagsAttr(); + }] + > + ]; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt b/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt --- a/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt +++ b/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt @@ -15,4 +15,5 @@ MLIRLLVMDialect MLIRSupport MLIRTransforms + MLIRArithDialect ) diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp --- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" @@ -15,6 +16,33 @@ using namespace mlir; +// Map arithmetic fastmath enum values to LLVMIR enum values. +static LLVM::FastmathFlags +convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF) { + LLVM::FastmathFlags llvmFMF{}; + const std::pair flags[] = { + {arith::FastMathFlags::nnan, LLVM::FastmathFlags::nnan}, + {arith::FastMathFlags::ninf, LLVM::FastmathFlags::ninf}, + {arith::FastMathFlags::nsz, LLVM::FastmathFlags::nsz}, + {arith::FastMathFlags::arcp, LLVM::FastmathFlags::arcp}, + {arith::FastMathFlags::contract, LLVM::FastmathFlags::contract}, + {arith::FastMathFlags::afn, LLVM::FastmathFlags::afn}, + {arith::FastMathFlags::reassoc, LLVM::FastmathFlags::reassoc}}; + for (auto fmfMap : flags) { + if (bitEnumContainsAny(arithFMF, fmfMap.first)) + llvmFMF = llvmFMF | fmfMap.second; + } + return llvmFMF; +} + +// Create an LLVM fastmath attribute from a given arithmetic fastmath attribute. +static LLVM::FastmathFlagsAttr +convertArithFastMathAttr(arith::FastMathFlagsAttr fmfAttr) { + auto arithFMF = fmfAttr.getValue(); + return LLVM::FastmathFlagsAttr::get( + fmfAttr.getContext(), convertArithFastMathFlagsToLLVM(arithFMF)); +} + //===----------------------------------------------------------------------===// // ConvertToLLVMPattern //===----------------------------------------------------------------------===// @@ -319,10 +347,13 @@ return failure(); } + // Lower source op attributes to LLVM equivalents. + LLVM::detail::ConvertToLLVMAttributes attrConvert(op); + // Create the operation through state since we don't know its C++ type. Operation *newOp = rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp), operands, - resultTypes, op->getAttrs()); + resultTypes, attrConvert.getAttrs()); // If the operation produced 0 or 1 result, return them immediately. if (numResults == 0) @@ -341,3 +372,22 @@ rewriter.replaceOp(op, results); return success(); } + +/// Constructor for class to manage attributes when lowering to LLVMIR. +/// Currently converts arithmetic fastmath flags to LLVM fastmath flags, and +/// passes all other attributes through unmodified. +LLVM::detail::ConvertToLLVMAttributes::ConvertToLLVMAttributes(Operation *op) + : srcAttrs(op->getAttrs()) { + if (auto fmi = dyn_cast(*op)) { + auto arithFMFAttr = fmi.getFastMathFlagsAttr(); + convertedAttrs = AttrVector{}; + for (auto &srcAttr : srcAttrs) { + convertedAttrs.value().push_back( + (arithFMFAttr == srcAttr.getValue()) + ? NamedAttribute( + StringAttr::get(arithFMFAttr.getContext(), "fastmathFlags"), + convertArithFastMathAttr(arithFMFAttr)) + : srcAttr); + } + } +} diff --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp --- a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp @@ -116,11 +116,14 @@ if (!llvmNDVectorTy.isa()) return oneToOneRewrite(op, targetOp, operands, typeConverter, rewriter); - auto callback = [op, targetOp, &rewriter](Type llvm1DVectorTy, - ValueRange operands) { + // Lower source op attributes to LLVM equivalents. + LLVM::detail::ConvertToLLVMAttributes attrConvert(op); + + auto callback = [op, targetOp, &attrConvert, &rewriter](Type llvm1DVectorTy, + ValueRange operands) { return rewriter .create(op->getLoc(), rewriter.getStringAttr(targetOp), operands, - llvm1DVectorTy, op->getAttrs()) + llvm1DVectorTy, attrConvert.getAttrs()) ->getResult(0); }; diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td --- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td +++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td @@ -215,17 +215,21 @@ //===----------------------------------------------------------------------===// // mulf(negf(x), negf(y)) -> mulf(x,y) +// (retain fastmath flags of original mulf) def MulFOfNegF : - Pat<(Arith_MulFOp (Arith_NegFOp $x), (Arith_NegFOp $y)), (Arith_MulFOp $x, $y), - [(Constraint> $x, $y)]>; + Pat<(Arith_MulFOp (Arith_NegFOp $x, $_), (Arith_NegFOp $y, $_), $fmf), + (Arith_MulFOp $x, $y, $fmf), + [(Constraint> $x, $y)]>; //===----------------------------------------------------------------------===// // DivFOp //===----------------------------------------------------------------------===// // divf(negf(x), negf(y)) -> divf(x,y) +// (retain fastmath flags of original divf) def DivFOfNegF : - Pat<(Arith_DivFOp (Arith_NegFOp $x), (Arith_NegFOp $y)), (Arith_DivFOp $x, $y), - [(Constraint> $x, $y)]>; + Pat<(Arith_DivFOp (Arith_NegFOp $x, $_), (Arith_NegFOp $y, $_), $fmf), + (Arith_DivFOp $x, $y, $fmf), + [(Constraint> $x, $y)]>; #endif // ARITH_PATTERNS diff --git a/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp b/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp --- a/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp @@ -8,12 +8,17 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/DialectImplementation.h" #include "mlir/Transforms/InliningUtils.h" +#include "llvm/ADT/TypeSwitch.h" using namespace mlir; using namespace mlir::arith; #include "mlir/Dialect/Arith/IR/ArithOpsDialect.cpp.inc" +#include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.cpp.inc" +#define GET_ATTRDEF_CLASSES +#include "mlir/Dialect/Arith/IR/ArithOpsAttributes.cpp.inc" namespace { /// This class defines the interface for handling inlining for arithmetic @@ -34,6 +39,10 @@ #define GET_OP_LIST #include "mlir/Dialect/Arith/IR/ArithOps.cpp.inc" >(); + addAttributes< +#define GET_ATTRDEF_LIST +#include "mlir/Dialect/Arith/IR/ArithOpsAttributes.cpp.inc" + >(); addInterfaces(); } diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -23,6 +23,31 @@ using namespace mlir; using namespace mlir::arith; +//===----------------------------------------------------------------------===// +// Floating point op parse/print helpers +//===----------------------------------------------------------------------===// +static ParseResult parseArithFastMathAttr(OpAsmParser &parser, + Attribute &attr) { + if (succeeded( + parser.parseOptionalKeyword(FastMathFlagsAttr::getMnemonic()))) { + attr = FastMathFlagsAttr::parse(parser, Type{}); + return success(static_cast(attr)); + } else { + // No fastmath attribute mnemonic present - defer attribute creation and use + // the default value. + return success(); + } +} + +static void printArithFastMathAttr(OpAsmPrinter &printer, Operation *op, + FastMathFlagsAttr fmAttr) { + // Elide printing the fastmath attribute when fastmath=none + if (fmAttr && (fmAttr.getValue() != FastMathFlags::none)) { + printer << " " << FastMathFlagsAttr::getMnemonic(); + fmAttr.print(printer); + } +} + //===----------------------------------------------------------------------===// // Pattern helpers //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Arith/IR/CMakeLists.txt b/mlir/lib/Dialect/Arith/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Arith/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Arith/IR/CMakeLists.txt @@ -12,6 +12,7 @@ DEPENDS MLIRArithOpsIncGen + MLIRArithOpsInterfacesIncGen LINK_LIBS PUBLIC MLIRDialect diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp @@ -160,9 +160,9 @@ // clang-format on }; llvm::FastMathFlags ret; - auto fmf = op.getFastmathFlags(); + ::mlir::LLVM::FastmathFlags fmfMlir = op.getFastmathAttr().getValue(); for (auto it : handlers) - if (bitEnumContainsAll(fmf, it.first)) + if (bitEnumContainsAll(fmfMlir, it.first)) (ret.*(it.second))(true); return ret; } diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir --- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir +++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir @@ -448,3 +448,20 @@ %1 = arith.maxf %arg0, %arg1 : f32 return %0 : f32 } + +// ----- + +// CHECK-LABEL: @fastmath +func.func @fastmath(%arg0: f32, %arg1: f32, %arg2: i32) { +// CHECK: {{.*}} = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath} : f32 +// CHECK: {{.*}} = llvm.fmul %arg0, %arg1 {fastmathFlags = #llvm.fastmath} : f32 +// CHECK: {{.*}} = llvm.fneg %arg0 {fastmathFlags = #llvm.fastmath} : f32 +// CHECK: {{.*}} = llvm.fadd %arg0, %arg1 : f32 +// CHECK: {{.*}} = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath} : f32 + %0 = arith.addf %arg0, %arg1 fastmath : f32 + %1 = arith.mulf %arg0, %arg1 fastmath : f32 + %2 = arith.negf %arg0 fastmath : f32 + %3 = arith.addf %arg0, %arg1 fastmath : f32 + %4 = arith.addf %arg0, %arg1 fastmath : f32 + return +} diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir --- a/mlir/test/Dialect/Arith/ops.mlir +++ b/mlir/test/Dialect/Arith/ops.mlir @@ -1031,3 +1031,27 @@ %min_unsigned = arith.minui %i1, %i2 : i32 return } + +// CHECK-LABEL: @fastmath +func.func @fastmath(%arg0: f32, %arg1: f32, %arg2: i32) { +// CHECK: {{.*}} = arith.addf %arg0, %arg1 fastmath : f32 +// CHECK: {{.*}} = arith.subf %arg0, %arg1 fastmath : f32 +// CHECK: {{.*}} = arith.mulf %arg0, %arg1 fastmath : f32 +// CHECK: {{.*}} = arith.divf %arg0, %arg1 fastmath : f32 +// CHECK: {{.*}} = arith.remf %arg0, %arg1 fastmath : f32 +// CHECK: {{.*}} = arith.negf %arg0 fastmath : f32 + %0 = arith.addf %arg0, %arg1 fastmath : f32 + %1 = arith.subf %arg0, %arg1 fastmath : f32 + %2 = arith.mulf %arg0, %arg1 fastmath : f32 + %3 = arith.divf %arg0, %arg1 fastmath : f32 + %4 = arith.remf %arg0, %arg1 fastmath : f32 + %5 = arith.negf %arg0 fastmath : f32 +// CHECK: {{.*}} = arith.addf %arg0, %arg1 : f32 + %6 = arith.addf %arg0, %arg1 fastmath : f32 +// CHECK: {{.*}} = arith.addf %arg0, %arg1 fastmath : f32 + %7 = arith.addf %arg0, %arg1 fastmath : f32 +// CHECK: {{.*}} = arith.mulf %arg0, %arg1 fastmath : f32 + %8 = arith.mulf %arg0, %arg1 fastmath : f32 + + return +} diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -270,7 +270,7 @@ // ----- func.func @generic(%arg0: memref) { - // expected-error @+6 {{block with no terminator, has %0 = "arith.addf"(%arg1, %arg1) : (f32, f32) -> f32}} + // expected-error @+6 {{block with no terminator, has %0 = "arith.addf"(%arg1, %arg1) {fastmath = #arith.fastmath} : (f32, f32) -> f32}} linalg.generic { indexing_maps = [ affine_map<(i, j) -> (i, j)> ], iterator_types = ["parallel", "parallel"]} diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -1001,7 +1001,7 @@ for (FormatElement *param : dir->getArguments()) { if (auto *attr = dyn_cast(param)) { const NamedAttribute *var = attr->getVar(); - if (var->attr.isOptional()) + if (var->attr.isOptional() || var->attr.hasDefaultValue()) body << llvm::formatv(" if ({0}Attr)\n ", var->name); body << llvm::formatv(" result.addAttribute(\"{0}\", {0}Attr);\n",