diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h --- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h @@ -22,8 +22,10 @@ /// and given operands. LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, + ArrayRef targetAttrs, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter); + } // namespace detail } // namespace LLVM @@ -197,7 +199,7 @@ matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { return LLVM::detail::oneToOneRewrite(op, TargetOp::getOperationName(), - adaptor.getOperands(), + adaptor.getOperands(), op->getAttrs(), *this->getTypeConverter(), rewriter); } }; diff --git a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h --- a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h @@ -56,14 +56,34 @@ LogicalResult vectorOneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, + ArrayRef targetAttrs, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter); } // namespace detail } // namespace LLVM +// Default attribute conversion class, which passes all source attributes +// through to the target op, unmodified. +template +class AttrConvertPassThrough { +public: + AttrConvertPassThrough(SourceOp srcOp) : srcAttrs(srcOp->getAttrs()) {} + + ArrayRef getAttrs() const { return srcAttrs; } + +private: + ArrayRef srcAttrs; +}; + /// Basic lowering implementation to rewrite Ops with just one result to the /// LLVM Dialect. This supports higher-dimensional vector types. -template +/// The AttrConvert template template parameter should be a template class +/// with SourceOp and TargetOp type parameters, a constructor that takes +/// a SourceOp instance, and a getAttrs() method that returns +/// ArrayRef. +template typename AttrConvert = + AttrConvertPassThrough> class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern { public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -75,9 +95,12 @@ static_assert( std::is_base_of, SourceOp>::value, "expected single result op"); + // Determine attributes for the target op + AttrConvert attrConvert(op); + return LLVM::detail::vectorOneToOneRewrite( op, TargetOp::getOperationName(), adaptor.getOperands(), - *this->getTypeConverter(), rewriter); + attrConvert.getAttrs(), *this->getTypeConverter(), rewriter); } }; } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Arith/IR/Arith.h b/mlir/include/mlir/Dialect/Arith/IR/Arith.h --- a/mlir/include/mlir/Dialect/Arith/IR/Arith.h +++ b/mlir/include/mlir/Dialect/Arith/IR/Arith.h @@ -17,6 +17,7 @@ #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/VectorInterfaces.h" +#include "llvm/ADT/StringExtras.h" //===----------------------------------------------------------------------===// // ArithDialect @@ -29,6 +30,13 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arith/IR/ArithOpsEnums.h.inc" +#define GET_ATTRDEF_CLASSES +#include "mlir/Dialect/Arith/IR/ArithOpsAttributes.h.inc" + +//===----------------------------------------------------------------------===// +// Arith Interfaces +//===----------------------------------------------------------------------===// +#include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.h.inc" //===----------------------------------------------------------------------===// // Arith Dialect Operations diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td --- a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td +++ b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td @@ -23,6 +23,7 @@ }]; let hasConstantMaterializer = 1; + let useDefaultAttributePrinterParser = 1; } // The predicate indicates the type of the comparison to perform: @@ -92,4 +93,32 @@ let cppNamespace = "::mlir::arith"; } +def FASTMATH_NONE : I32BitEnumAttrCaseNone<"none" >; +def FASTMATH_REASSOC : I32BitEnumAttrCaseBit<"reassoc", 0>; +def FASTMATH_NO_NANS : I32BitEnumAttrCaseBit<"nnan", 1>; +def FASTMATH_NO_INFS : I32BitEnumAttrCaseBit<"ninf", 2>; +def FASTMATH_NO_SIGNED_ZEROS : I32BitEnumAttrCaseBit<"nsz", 3>; +def FASTMATH_ALLOW_RECIP : I32BitEnumAttrCaseBit<"arcp", 4>; +def FASTMATH_ALLOW_CONTRACT : I32BitEnumAttrCaseBit<"contract", 5>; +def FASTMATH_APPROX_FUNC : I32BitEnumAttrCaseBit<"afn", 6>; +def FASTMATH_FAST : I32BitEnumAttrCaseGroup< + "fast", + [ + FASTMATH_REASSOC, FASTMATH_NO_NANS, FASTMATH_NO_INFS, + FASTMATH_NO_SIGNED_ZEROS, FASTMATH_ALLOW_RECIP, FASTMATH_ALLOW_CONTRACT, + FASTMATH_APPROX_FUNC]>; + +def FastMathFlags : I32BitEnumAttr< + "FastMathFlags", + "Floating point fast math flags", + [ + FASTMATH_NONE, FASTMATH_REASSOC, FASTMATH_NO_NANS, + FASTMATH_NO_INFS, FASTMATH_NO_SIGNED_ZEROS, FASTMATH_ALLOW_RECIP, + FASTMATH_ALLOW_CONTRACT, FASTMATH_APPROX_FUNC, FASTMATH_FAST]> { + let separator = ","; + let cppNamespace = "::mlir::arith"; + let genSpecializedAttr = 0; + let printBitEnumPrimaryGroups = 1; +} + #endif // ARITH_BASE diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td --- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td +++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td @@ -10,6 +10,7 @@ #define ARITH_OPS include "mlir/Dialect/Arith/IR/ArithBase.td" +include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.td" include "mlir/Interfaces/CastInterfaces.td" include "mlir/Interfaces/InferIntRangeInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td" @@ -17,6 +18,12 @@ include "mlir/Interfaces/VectorInterfaces.td" include "mlir/IR/BuiltinAttributeInterfaces.td" include "mlir/IR/OpAsmInterface.td" +include "mlir/IR/EnumAttr.td" + +def Arith_FastMathAttr : + EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} // Base class for Arith dialect ops. Ops in this dialect have no side // effects and can be applied element-wise to vectors and tensors. @@ -58,15 +65,27 @@ // Base class for floating point unary operations. class Arith_FloatUnaryOp traits = []> : - Arith_UnaryOp, - Arguments<(ins FloatLike:$operand)>, - Results<(outs FloatLike:$result)>; + Arith_UnaryOp], + traits)>, + Arguments<(ins FloatLike:$operand, + DefaultValuedAttr:$fastmath)>, + Results<(outs FloatLike:$result)> { + let assemblyFormat = [{ $operand custom($fastmath) + attr-dict `:` type($result) }]; +} // Base class for floating point binary operations. class Arith_FloatBinaryOp traits = []> : - Arith_BinaryOp, - Arguments<(ins FloatLike:$lhs, FloatLike:$rhs)>, - Results<(outs FloatLike:$result)>; + Arith_BinaryOp], + traits)>, + Arguments<(ins FloatLike:$lhs, FloatLike:$rhs, + DefaultValuedAttr:$fastmath)>, + Results<(outs FloatLike:$result)> { + let assemblyFormat = [{ $lhs `,` $rhs `` custom($fastmath) + attr-dict `:` type($result) }]; +} // Base class for arithmetic cast operations. Requires a single operand and // result. If either is a shaped type, then the other must be of the same shape. diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td @@ -0,0 +1,52 @@ +//===-- ArithOpsInterfaces.td - arith op interfaces ---*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the Arith interfaces definition file. +// +//===----------------------------------------------------------------------===// + +#ifndef ARITH_OPS_INTERFACES +#define ARITH_OPS_INTERFACES + +include "mlir/IR/OpBase.td" + +def ArithFastMathInterface : OpInterface<"ArithFastMathInterface"> { + let description = [{ + Access to operation fastmath flags. + }]; + + let cppNamespace = "::mlir::arith"; + + let methods = [ + InterfaceMethod< + /*desc=*/ "Returns a FastMathFlagsAttr attribute for the operation", + /*returnType=*/ "FastMathFlagsAttr", + /*methodName=*/ "getFastMathFlagsAttr", + /*args=*/ (ins), + /*methodBody=*/ [{}], + /*defaultImpl=*/ [{ + ConcreteOp op = cast(this->getOperation()); + return op.getFastmathAttr(); + }] + >, + StaticInterfaceMethod< + /*desc=*/ [{Returns the name of the FastMathFlagsAttr attribute + for the operation}], + /*returnType=*/ "StringRef", + /*methodName=*/ "getFastMathAttrName", + /*args=*/ (ins), + /*methodBody=*/ [{}], + /*defaultImpl=*/ [{ + return "fastmath"; + }] + > + + ]; +} + +#endif // ARITH_OPS_INTERFACES diff --git a/mlir/include/mlir/Dialect/Arith/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Arith/IR/CMakeLists.txt --- a/mlir/include/mlir/Dialect/Arith/IR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Arith/IR/CMakeLists.txt @@ -1,5 +1,14 @@ set(LLVM_TARGET_DEFINITIONS ArithOps.td) mlir_tablegen(ArithOpsEnums.h.inc -gen-enum-decls) mlir_tablegen(ArithOpsEnums.cpp.inc -gen-enum-defs) +mlir_tablegen(ArithOpsAttributes.h.inc -gen-attrdef-decls + -attrdefs-dialect=arith) +mlir_tablegen(ArithOpsAttributes.cpp.inc -gen-attrdef-defs + -attrdefs-dialect=arith) add_mlir_dialect(ArithOps arith) add_mlir_doc(ArithOps ArithOps Dialects/ -gen-dialect-doc) + +set(LLVM_TARGET_DEFINITIONS ArithOpsInterfaces.td) +mlir_tablegen(ArithOpsInterfaces.h.inc -gen-op-interface-decls) +mlir_tablegen(ArithOpsInterfaces.cpp.inc -gen-op-interface-defs) +add_public_tablegen_target(MLIRArithOpsInterfacesIncGen) diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h @@ -48,7 +48,6 @@ namespace mlir { namespace LLVM { class LLVMDialect; -class LoopOptionsAttrBuilder; namespace detail { struct LLVMTypeStorage; diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpsInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpsInterfaces.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpsInterfaces.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpsInterfaces.td @@ -23,8 +23,28 @@ let cppNamespace = "::mlir::LLVM"; let methods = [ - InterfaceMethod<"Get fastmath flags", "::mlir::LLVM::FastmathFlags", - "getFastmathFlags">, + InterfaceMethod< + /*desc=*/ "Returns a FastmathFlagsAttr attribute for the operation", + /*returnType=*/ "FastmathFlagsAttr", + /*methodName=*/ "getFastmathAttr", + /*args=*/ (ins), + /*methodBody=*/ [{}], + /*defaultImpl=*/ [{ + ConcreteOp op = cast(this->getOperation()); + return op.getFastmathFlagsAttr(); + }] + >, + StaticInterfaceMethod< + /*desc=*/ [{Returns the name of the FastmathFlagsAttr attribute + for the operation}], + /*returnType=*/ "StringRef", + /*methodName=*/ "getFastmathAttrName", + /*args=*/ (ins), + /*methodBody=*/ [{}], + /*defaultImpl=*/ [{ + return "fastmathFlags"; + }] + > ]; } diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp --- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp +++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp @@ -24,16 +24,93 @@ namespace { +// Map arithmetic fastmath enum values to LLVMIR enum values. +static LLVM::FastmathFlags +convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF) { + LLVM::FastmathFlags llvmFMF{}; + const std::pair flags[] = { + {arith::FastMathFlags::nnan, LLVM::FastmathFlags::nnan}, + {arith::FastMathFlags::ninf, LLVM::FastmathFlags::ninf}, + {arith::FastMathFlags::nsz, LLVM::FastmathFlags::nsz}, + {arith::FastMathFlags::arcp, LLVM::FastmathFlags::arcp}, + {arith::FastMathFlags::contract, LLVM::FastmathFlags::contract}, + {arith::FastMathFlags::afn, LLVM::FastmathFlags::afn}, + {arith::FastMathFlags::reassoc, LLVM::FastmathFlags::reassoc}}; + for (auto fmfMap : flags) { + if (bitEnumContainsAny(arithFMF, fmfMap.first)) + llvmFMF = llvmFMF | fmfMap.second; + } + return llvmFMF; +} + +// Create an LLVM fastmath attribute from a given arithmetic fastmath attribute. +static LLVM::FastmathFlagsAttr +convertArithFastMathAttr(arith::FastMathFlagsAttr fmfAttr) { + arith::FastMathFlags arithFMF = fmfAttr.getValue(); + return LLVM::FastmathFlagsAttr::get( + fmfAttr.getContext(), convertArithFastMathFlagsToLLVM(arithFMF)); +} + +// Attribute converter that populates a NamedAttrList by removing the fastmath +// attribute from the source operation attributes, and replacing it with an +// equivalent LLVM fastmath attribute. +template +class AttrConvertFastMath { +public: + AttrConvertFastMath(SourceOp srcOp) { + // Copy the source attributes. + convertedAttr = NamedAttrList{srcOp->getAttrs()}; + // Get the name of the arith fastmath attribute. + llvm::StringRef arithFMFAttrName = SourceOp::getFastMathAttrName(); + // Remove the source fastmath attribute. + auto arithFMFAttr = convertedAttr.erase(arithFMFAttrName) + .dyn_cast_or_null(); + if (arithFMFAttr) { + llvm::StringRef targetAttrName = TargetOp::getFastmathAttrName(); + convertedAttr.set(targetAttrName, convertArithFastMathAttr(arithFMFAttr)); + } + } + + ArrayRef getAttrs() const { return convertedAttr.getAttrs(); } + +private: + NamedAttrList convertedAttr; +}; + +// Attribute converter that populates a NamedAttrList by removing the fastmath +// attribute from the source operation attributes. This may be useful for +// target operations that do not require the fastmath attribute, or for targets +// that do not yet support the LLVM fastmath attribute. +template +class AttrDropFastMath { +public: + AttrDropFastMath(SourceOp srcOp) { + // Copy the source attributes. + convertedAttr = NamedAttrList{srcOp->getAttrs()}; + // Get the name of the arith fastmath attribute. + llvm::StringRef arithFMFAttrName = SourceOp::getFastMathAttrName(); + // Remove the source fastmath attribute. + convertedAttr.erase(arithFMFAttrName); + } + + ArrayRef getAttrs() const { return convertedAttr.getAttrs(); } + +private: + NamedAttrList convertedAttr; +}; + //===----------------------------------------------------------------------===// // Straightforward Op Lowerings //===----------------------------------------------------------------------===// -using AddFOpLowering = VectorConvertToLLVMPattern; +using AddFOpLowering = VectorConvertToLLVMPattern; using AddIOpLowering = VectorConvertToLLVMPattern; using AndIOpLowering = VectorConvertToLLVMPattern; using BitcastOpLowering = VectorConvertToLLVMPattern; -using DivFOpLowering = VectorConvertToLLVMPattern; +using DivFOpLowering = VectorConvertToLLVMPattern; using DivSIOpLowering = VectorConvertToLLVMPattern; using DivUIOpLowering = @@ -47,23 +124,29 @@ VectorConvertToLLVMPattern; using FPToUIOpLowering = VectorConvertToLLVMPattern; +// TODO: Add LLVM intrinsic support for fastmath using MaxFOpLowering = - VectorConvertToLLVMPattern; + VectorConvertToLLVMPattern; using MaxSIOpLowering = VectorConvertToLLVMPattern; using MaxUIOpLowering = VectorConvertToLLVMPattern; +// TODO: Add LLVM intrinsic support for fastmath using MinFOpLowering = - VectorConvertToLLVMPattern; + VectorConvertToLLVMPattern; using MinSIOpLowering = VectorConvertToLLVMPattern; using MinUIOpLowering = VectorConvertToLLVMPattern; -using MulFOpLowering = VectorConvertToLLVMPattern; +using MulFOpLowering = VectorConvertToLLVMPattern; using MulIOpLowering = VectorConvertToLLVMPattern; -using NegFOpLowering = VectorConvertToLLVMPattern; +using NegFOpLowering = VectorConvertToLLVMPattern; using OrIOpLowering = VectorConvertToLLVMPattern; -using RemFOpLowering = VectorConvertToLLVMPattern; +// TODO: Add LLVM intrinsic support for fastmath +using RemFOpLowering = + VectorConvertToLLVMPattern; using RemSIOpLowering = VectorConvertToLLVMPattern; using RemUIOpLowering = @@ -77,7 +160,8 @@ VectorConvertToLLVMPattern; using SIToFPOpLowering = VectorConvertToLLVMPattern; -using SubFOpLowering = VectorConvertToLLVMPattern; +using SubFOpLowering = VectorConvertToLLVMPattern; using SubIOpLowering = VectorConvertToLLVMPattern; using TruncFOpLowering = VectorConvertToLLVMPattern; @@ -153,7 +237,7 @@ ConstantOpLowering::matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { return LLVM::detail::oneToOneRewrite(op, LLVM::ConstantOp::getOperationName(), - adaptor.getOperands(), + adaptor.getOperands(), op->getAttrs(), *getTypeConverter(), rewriter); } diff --git a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp --- a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp +++ b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp @@ -90,7 +90,7 @@ ConversionPatternRewriter &rewriter) const override { return LLVM::detail::oneToOneRewrite( op, LLVM::ConstantOp::getOperationName(), adaptor.getOperands(), - *getTypeConverter(), rewriter); + op->getAttrs(), *getTypeConverter(), rewriter); } }; diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp --- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp @@ -308,7 +308,8 @@ /// and given operands. LogicalResult LLVM::detail::oneToOneRewrite( Operation *op, StringRef targetOp, ValueRange operands, - LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) { + ArrayRef targetAttrs, LLVMTypeConverter &typeConverter, + ConversionPatternRewriter &rewriter) { unsigned numResults = op->getNumResults(); SmallVector resultTypes; @@ -322,7 +323,7 @@ // Create the operation through state since we don't know its C++ type. Operation *newOp = rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp), operands, - resultTypes, op->getAttrs()); + resultTypes, targetAttrs); // If the operation produced 0 or 1 result, return them immediately. if (numResults == 0) diff --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp --- a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp @@ -105,7 +105,8 @@ LogicalResult LLVM::detail::vectorOneToOneRewrite( Operation *op, StringRef targetOp, ValueRange operands, - LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) { + ArrayRef targetAttrs, LLVMTypeConverter &typeConverter, + ConversionPatternRewriter &rewriter) { assert(!operands.empty()); // Cannot convert ops if their operands are not of LLVM type. @@ -114,13 +115,14 @@ auto llvmNDVectorTy = operands[0].getType(); if (!llvmNDVectorTy.isa()) - return oneToOneRewrite(op, targetOp, operands, typeConverter, rewriter); + return oneToOneRewrite(op, targetOp, operands, targetAttrs, typeConverter, + rewriter); - auto callback = [op, targetOp, &rewriter](Type llvm1DVectorTy, - ValueRange operands) { + auto callback = [op, targetOp, targetAttrs, &rewriter](Type llvm1DVectorTy, + ValueRange operands) { return rewriter .create(op->getLoc(), rewriter.getStringAttr(targetOp), operands, - llvm1DVectorTy, op->getAttrs()) + llvm1DVectorTy, targetAttrs) ->getResult(0); }; diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td --- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td +++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td @@ -215,17 +215,21 @@ //===----------------------------------------------------------------------===// // mulf(negf(x), negf(y)) -> mulf(x,y) +// (retain fastmath flags of original mulf) def MulFOfNegF : - Pat<(Arith_MulFOp (Arith_NegFOp $x), (Arith_NegFOp $y)), (Arith_MulFOp $x, $y), - [(Constraint> $x, $y)]>; + Pat<(Arith_MulFOp (Arith_NegFOp $x, $_), (Arith_NegFOp $y, $_), $fmf), + (Arith_MulFOp $x, $y, $fmf), + [(Constraint> $x, $y)]>; //===----------------------------------------------------------------------===// // DivFOp //===----------------------------------------------------------------------===// // divf(negf(x), negf(y)) -> divf(x,y) +// (retain fastmath flags of original divf) def DivFOfNegF : - Pat<(Arith_DivFOp (Arith_NegFOp $x), (Arith_NegFOp $y)), (Arith_DivFOp $x, $y), - [(Constraint> $x, $y)]>; + Pat<(Arith_DivFOp (Arith_NegFOp $x, $_), (Arith_NegFOp $y, $_), $fmf), + (Arith_DivFOp $x, $y, $fmf), + [(Constraint> $x, $y)]>; #endif // ARITH_PATTERNS diff --git a/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp b/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp --- a/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp @@ -8,12 +8,17 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/DialectImplementation.h" #include "mlir/Transforms/InliningUtils.h" +#include "llvm/ADT/TypeSwitch.h" using namespace mlir; using namespace mlir::arith; #include "mlir/Dialect/Arith/IR/ArithOpsDialect.cpp.inc" +#include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.cpp.inc" +#define GET_ATTRDEF_CLASSES +#include "mlir/Dialect/Arith/IR/ArithOpsAttributes.cpp.inc" namespace { /// This class defines the interface for handling inlining for arithmetic @@ -34,6 +39,10 @@ #define GET_OP_LIST #include "mlir/Dialect/Arith/IR/ArithOps.cpp.inc" >(); + addAttributes< +#define GET_ATTRDEF_LIST +#include "mlir/Dialect/Arith/IR/ArithOpsAttributes.cpp.inc" + >(); addInterfaces(); } diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -23,6 +23,31 @@ using namespace mlir; using namespace mlir::arith; +//===----------------------------------------------------------------------===// +// Floating point op parse/print helpers +//===----------------------------------------------------------------------===// +static ParseResult parseArithFastMathAttr(OpAsmParser &parser, + Attribute &attr) { + if (succeeded( + parser.parseOptionalKeyword(FastMathFlagsAttr::getMnemonic()))) { + attr = FastMathFlagsAttr::parse(parser, Type{}); + return success(static_cast(attr)); + } else { + // No fastmath attribute mnemonic present - defer attribute creation and use + // the default value. + return success(); + } +} + +static void printArithFastMathAttr(OpAsmPrinter &printer, Operation *op, + FastMathFlagsAttr fmAttr) { + // Elide printing the fastmath attribute when fastmath=none + if (fmAttr && (fmAttr.getValue() != FastMathFlags::none)) { + printer << " " << FastMathFlagsAttr::getMnemonic(); + fmAttr.print(printer); + } +} + //===----------------------------------------------------------------------===// // Pattern helpers //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Arith/IR/CMakeLists.txt b/mlir/lib/Dialect/Arith/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Arith/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Arith/IR/CMakeLists.txt @@ -12,6 +12,7 @@ DEPENDS MLIRArithOpsIncGen + MLIRArithOpsInterfacesIncGen LINK_LIBS PUBLIC MLIRDialect diff --git a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp --- a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp @@ -51,13 +51,13 @@ Type elementType = getSrcVectorElementType(op); unsigned bitwidth = elementType.getIntOrFloatBitWidth(); if (bitwidth == 32) - return LLVM::detail::oneToOneRewrite(op, Intr32OpTy::getOperationName(), - adaptor.getOperands(), - getTypeConverter(), rewriter); + return LLVM::detail::oneToOneRewrite( + op, Intr32OpTy::getOperationName(), adaptor.getOperands(), + op->getAttrs(), getTypeConverter(), rewriter); if (bitwidth == 64) - return LLVM::detail::oneToOneRewrite(op, Intr64OpTy::getOperationName(), - adaptor.getOperands(), - getTypeConverter(), rewriter); + return LLVM::detail::oneToOneRewrite( + op, Intr64OpTy::getOperationName(), adaptor.getOperands(), + op->getAttrs(), getTypeConverter(), rewriter); return rewriter.notifyMatchFailure( op, "expected 'src' to be either f32 or f64"); } diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp @@ -160,9 +160,9 @@ // clang-format on }; llvm::FastMathFlags ret; - auto fmf = op.getFastmathFlags(); + ::mlir::LLVM::FastmathFlags fmfMlir = op.getFastmathAttr().getValue(); for (auto it : handlers) - if (bitEnumContainsAll(fmf, it.first)) + if (bitEnumContainsAll(fmfMlir, it.first)) (ret.*(it.second))(true); return ret; } diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -309,7 +309,7 @@ // clang-format off // CHECK-LABEL: @stats // CHECK: Number of operations: 12 - // CHECK: Number of attributes: 4 + // CHECK: Number of attributes: 5 // CHECK: Number of blocks: 3 // CHECK: Number of regions: 3 // CHECK: Number of values: 9 diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir --- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir +++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir @@ -448,3 +448,20 @@ %1 = arith.maxf %arg0, %arg1 : f32 return %0 : f32 } + +// ----- + +// CHECK-LABEL: @fastmath +func.func @fastmath(%arg0: f32, %arg1: f32, %arg2: i32) { +// CHECK: {{.*}} = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath} : f32 +// CHECK: {{.*}} = llvm.fmul %arg0, %arg1 {fastmathFlags = #llvm.fastmath} : f32 +// CHECK: {{.*}} = llvm.fneg %arg0 {fastmathFlags = #llvm.fastmath} : f32 +// CHECK: {{.*}} = llvm.fadd %arg0, %arg1 : f32 +// CHECK: {{.*}} = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath} : f32 + %0 = arith.addf %arg0, %arg1 fastmath : f32 + %1 = arith.mulf %arg0, %arg1 fastmath : f32 + %2 = arith.negf %arg0 fastmath : f32 + %3 = arith.addf %arg0, %arg1 fastmath : f32 + %4 = arith.addf %arg0, %arg1 fastmath : f32 + return +} diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir --- a/mlir/test/Dialect/Arith/ops.mlir +++ b/mlir/test/Dialect/Arith/ops.mlir @@ -1031,3 +1031,27 @@ %min_unsigned = arith.minui %i1, %i2 : i32 return } + +// CHECK-LABEL: @fastmath +func.func @fastmath(%arg0: f32, %arg1: f32, %arg2: i32) { +// CHECK: {{.*}} = arith.addf %arg0, %arg1 fastmath : f32 +// CHECK: {{.*}} = arith.subf %arg0, %arg1 fastmath : f32 +// CHECK: {{.*}} = arith.mulf %arg0, %arg1 fastmath : f32 +// CHECK: {{.*}} = arith.divf %arg0, %arg1 fastmath : f32 +// CHECK: {{.*}} = arith.remf %arg0, %arg1 fastmath : f32 +// CHECK: {{.*}} = arith.negf %arg0 fastmath : f32 + %0 = arith.addf %arg0, %arg1 fastmath : f32 + %1 = arith.subf %arg0, %arg1 fastmath : f32 + %2 = arith.mulf %arg0, %arg1 fastmath : f32 + %3 = arith.divf %arg0, %arg1 fastmath : f32 + %4 = arith.remf %arg0, %arg1 fastmath : f32 + %5 = arith.negf %arg0 fastmath : f32 +// CHECK: {{.*}} = arith.addf %arg0, %arg1 : f32 + %6 = arith.addf %arg0, %arg1 fastmath : f32 +// CHECK: {{.*}} = arith.addf %arg0, %arg1 fastmath : f32 + %7 = arith.addf %arg0, %arg1 fastmath : f32 +// CHECK: {{.*}} = arith.mulf %arg0, %arg1 fastmath : f32 + %8 = arith.mulf %arg0, %arg1 fastmath : f32 + + return +} diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -270,7 +270,7 @@ // ----- func.func @generic(%arg0: memref) { - // expected-error @+6 {{block with no terminator, has %0 = "arith.addf"(%arg1, %arg1) : (f32, f32) -> f32}} + // expected-error @+6 {{block with no terminator, has %0 = "arith.addf"(%arg1, %arg1) {fastmath = #arith.fastmath} : (f32, f32) -> f32}} linalg.generic { indexing_maps = [ affine_map<(i, j) -> (i, j)> ], iterator_types = ["parallel", "parallel"]} diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -1001,7 +1001,7 @@ for (FormatElement *param : dir->getArguments()) { if (auto *attr = dyn_cast(param)) { const NamedAttribute *var = attr->getVar(); - if (var->attr.isOptional()) + if (var->attr.isOptional() || var->attr.hasDefaultValue()) body << llvm::formatv(" if ({0}Attr)\n ", var->name); body << llvm::formatv(" result.addAttribute(\"{0}\", {0}Attr);\n",