diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h --- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h @@ -23,6 +23,26 @@ ValueRange operands, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter); + +/// Class to manage named attributes for lowering ops to LLVM. For ops that +/// require attribute modification, an internal vector of named attributes +/// will be populated with attributes for the destination LLVMIR operation. +/// For those ops that do not require modification, a reference to the original +/// source attributes will be returned by the getAttrs() function. (This class +/// is currently only suitable for "universal" rules, as conversion decisions +/// are made in the class constructor.) +class ConvertToLLVMAttributes { +public: + ConvertToLLVMAttributes(Operation *op); + ArrayRef getAttrs() const { + return convertAttrs ? convertAttrs.getValue() : srcAttrs; + } + +private: + using AttrVector = SmallVector; + ArrayRef srcAttrs; + Optional convertAttrs; +}; } // namespace detail } // namespace LLVM diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h b/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h --- a/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h @@ -15,6 +15,7 @@ #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/VectorInterfaces.h" +#include "llvm/ADT/StringExtras.h" //===----------------------------------------------------------------------===// // ArithmeticDialect @@ -27,6 +28,13 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arithmetic/IR/ArithmeticOpsEnums.h.inc" +#define GET_ATTRDEF_CLASSES +#include "mlir/Dialect/Arithmetic/IR/ArithmeticOpsAttributes.h.inc" + +//===----------------------------------------------------------------------===// +// Arithmetic Interfaces +//===----------------------------------------------------------------------===// +#include "mlir/Dialect/Arithmetic/IR/ArithmeticOpsInterfaces.h.inc" //===----------------------------------------------------------------------===// // Arithmetic Dialect Operations diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticBase.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticBase.td --- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticBase.td +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticBase.td @@ -24,6 +24,7 @@ let hasConstantMaterializer = 1; let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed; + let useDefaultAttributePrinterParser = 1; } // The predicate indicates the type of the comparison to perform: @@ -93,4 +94,32 @@ let cppNamespace = "::mlir::arith"; } +def FASTMATH_NONE : I32BitEnumAttrCaseNone<"none" >; +def FASTMATH_REASSOC : I32BitEnumAttrCaseBit<"reassoc", 0>; +def FASTMATH_NO_NANS : I32BitEnumAttrCaseBit<"nnan", 1>; +def FASTMATH_NO_INFS : I32BitEnumAttrCaseBit<"ninf", 2>; +def FASTMATH_NO_SIGNED_ZEROS : I32BitEnumAttrCaseBit<"nsz", 3>; +def FASTMATH_ALLOW_RECIP : I32BitEnumAttrCaseBit<"arcp", 4>; +def FASTMATH_ALLOW_CONTRACT : I32BitEnumAttrCaseBit<"contract", 5>; +def FASTMATH_APPROX_FUNC : I32BitEnumAttrCaseBit<"afn", 6>; +def FASTMATH_FAST : I32BitEnumAttrCaseGroup< + "fast", + [ + FASTMATH_REASSOC, FASTMATH_NO_NANS, FASTMATH_NO_INFS, + FASTMATH_NO_SIGNED_ZEROS, FASTMATH_ALLOW_RECIP, FASTMATH_ALLOW_CONTRACT, + FASTMATH_APPROX_FUNC]>; + +def FastMathFlags : I32BitEnumAttr< + "FastMathFlags", + "Floating point fast math flags", + [ + FASTMATH_NONE, FASTMATH_REASSOC, FASTMATH_NO_NANS, + FASTMATH_NO_INFS, FASTMATH_NO_SIGNED_ZEROS, FASTMATH_ALLOW_RECIP, + FASTMATH_ALLOW_CONTRACT, FASTMATH_APPROX_FUNC, FASTMATH_FAST]> { + let separator = ","; + let cppNamespace = "::mlir::arith"; + let genSpecializedAttr = 0; + let printBitEnumPrimaryGroups = 1; +} + #endif // ARITHMETIC_BASE diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td --- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td @@ -10,11 +10,18 @@ #define ARITHMETIC_OPS include "mlir/Dialect/Arithmetic/IR/ArithmeticBase.td" +include "mlir/Dialect/Arithmetic/IR/ArithmeticOpsInterfaces.td" include "mlir/Interfaces/CastInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/VectorInterfaces.td" include "mlir/IR/OpAsmInterface.td" +include "mlir/IR/EnumAttr.td" + +def Arith_FastMathAttr : + EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} // Base class for Arithmetic dialect ops. Ops in this dialect have no side // effects and can be applied element-wise to vectors and tensors. @@ -55,15 +62,27 @@ // Base class for floating point unary operations. class Arith_FloatUnaryOp traits = []> : - Arith_UnaryOp, - Arguments<(ins FloatLike:$operand)>, - Results<(outs FloatLike:$result)>; + Arith_UnaryOp], + traits)>, + Arguments<(ins FloatLike:$operand, + DefaultValuedAttr:$fastmath)>, + Results<(outs FloatLike:$result)> { + let assemblyFormat = [{ $operand custom($fastmath) + attr-dict `:` type($result) }]; +} // Base class for floating point binary operations. class Arith_FloatBinaryOp traits = []> : - Arith_BinaryOp, - Arguments<(ins FloatLike:$lhs, FloatLike:$rhs)>, - Results<(outs FloatLike:$result)>; + Arith_BinaryOp], + traits)>, + Arguments<(ins FloatLike:$lhs, FloatLike:$rhs, + DefaultValuedAttr:$fastmath)>, + Results<(outs FloatLike:$result)> { + let assemblyFormat = [{ $lhs `,` $rhs `` custom($fastmath) + attr-dict `:` type($result) }]; +} // Base class for arithmetic cast operations. Requires a single operand and // result. If either is a shaped type, then the other must be of the same shape. diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOpsInterfaces.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOpsInterfaces.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOpsInterfaces.td @@ -0,0 +1,40 @@ +//===-- ArithmeticOpsInterfaces.td - arith op interfaces ---*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the Arithmetic interfaces definition file. +// +//===----------------------------------------------------------------------===// + +#ifndef ARITHMETIC_OPS_INTERFACES +#define ARITHMETIC_OPS_INTERFACES + +include "mlir/IR/OpBase.td" + +def ArithFastMathInterface : OpInterface<"ArithFastMathInterface"> { + let description = [{ + Access to operation fastmath flags. + }]; + + let cppNamespace = "::mlir::arith"; + + let methods = [ + InterfaceMethod< + /*desc=*/ "Returns a FastMathFlagsAttr attribute for the operation", + /*returnType=*/ "FastMathFlagsAttr", + /*methodName=*/ "getFastMathFlagsAttr", + /*args=*/ (ins), + /*methodBody=*/ [{}], + /*defaultImpl=*/ [{ + ConcreteOp op = cast(this->getOperation()); + return op.getFastmathAttr(); + }] + > + ]; +} + +#endif // ARITHMETIC_OPS_INTERFACES \ No newline at end of file diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Arithmetic/IR/CMakeLists.txt --- a/mlir/include/mlir/Dialect/Arithmetic/IR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/CMakeLists.txt @@ -1,5 +1,14 @@ set(LLVM_TARGET_DEFINITIONS ArithmeticOps.td) mlir_tablegen(ArithmeticOpsEnums.h.inc -gen-enum-decls) mlir_tablegen(ArithmeticOpsEnums.cpp.inc -gen-enum-defs) +mlir_tablegen(ArithmeticOpsAttributes.h.inc -gen-attrdef-decls + -attrdefs-dialect=arith) +mlir_tablegen(ArithmeticOpsAttributes.cpp.inc -gen-attrdef-defs + -attrdefs-dialect=arith) add_mlir_dialect(ArithmeticOps arith) add_mlir_doc(ArithmeticOps ArithmeticOps Dialects/ -gen-dialect-doc) + +set(LLVM_TARGET_DEFINITIONS ArithmeticOpsInterfaces.td) +mlir_tablegen(ArithmeticOpsInterfaces.h.inc -gen-op-interface-decls) +mlir_tablegen(ArithmeticOpsInterfaces.cpp.inc -gen-op-interface-defs) +add_public_tablegen_target(MLIRArithmeticOpsInterfacesIncGen) diff --git a/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt b/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt --- a/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt +++ b/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt @@ -15,4 +15,5 @@ MLIRLLVMIR MLIRSupport MLIRTransforms + MLIRArithmetic ) diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp --- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" @@ -15,6 +16,31 @@ using namespace mlir; +// Map arithmetic fastmath enum values to LLVMIR enum values. +static auto convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF) { + LLVM::FastmathFlags llvmFMF{}; + const std::pair flags[] = { + {arith::FastMathFlags::nnan, LLVM::FastmathFlags::nnan}, + {arith::FastMathFlags::ninf, LLVM::FastmathFlags::ninf}, + {arith::FastMathFlags::nsz, LLVM::FastmathFlags::nsz}, + {arith::FastMathFlags::arcp, LLVM::FastmathFlags::arcp}, + {arith::FastMathFlags::contract, LLVM::FastmathFlags::contract}, + {arith::FastMathFlags::afn, LLVM::FastmathFlags::afn}, + {arith::FastMathFlags::reassoc, LLVM::FastmathFlags::reassoc}}; + for (auto fmfMap : flags) { + if (bitEnumContains(arithFMF, fmfMap.first)) + llvmFMF = llvmFMF | fmfMap.second; + } + return llvmFMF; +} + +// Create an LLVM fastmath attribute from a given arithmetic fastmath attribute. +static auto convertArithFastMathAttr(arith::FastMathFlagsAttr fmfAttr) { + auto arithFMF = fmfAttr.getValue(); + return LLVM::FMFAttr::get(fmfAttr.getContext(), + convertArithFastMathFlagsToLLVM(arithFMF)); +} + //===----------------------------------------------------------------------===// // ConvertToLLVMPattern //===----------------------------------------------------------------------===// @@ -319,10 +345,13 @@ return failure(); } + // Lower source op attributes to LLVM equivalents. + LLVM::detail::ConvertToLLVMAttributes attrConvert(op); + // Create the operation through state since we don't know its C++ type. Operation *newOp = rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp), operands, - packedType, op->getAttrs()); + packedType, attrConvert.getAttrs()); // If the operation produced 0 or 1 result, return them immediately. if (numResults == 0) @@ -342,3 +371,22 @@ rewriter.replaceOp(op, results); return success(); } + +/// Constructor for class to manage attributes when lowering to LLVMIR. +/// Currently converts arithmetic fastmath flags to LLVM fastmath flags, and +/// passes all other attributes through unmodified. +LLVM::detail::ConvertToLLVMAttributes::ConvertToLLVMAttributes(Operation *op) + : srcAttrs(op->getAttrs()) { + if (auto fmi = dyn_cast(*op)) { + auto arithFMFAttr = fmi.getFastMathFlagsAttr(); + convertAttrs = AttrVector{}; + for (auto &srcAttr : srcAttrs) { + convertAttrs.getValue().push_back( + (arithFMFAttr == srcAttr.getValue()) + ? NamedAttribute( + StringAttr::get(arithFMFAttr.getContext(), "fastmathFlags"), + convertArithFastMathAttr(arithFMFAttr)) + : srcAttr); + } + } +} diff --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp --- a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp @@ -128,11 +128,14 @@ if (!llvmNDVectorTy.isa()) return oneToOneRewrite(op, targetOp, operands, typeConverter, rewriter); - auto callback = [op, targetOp, &rewriter](Type llvm1DVectorTy, - ValueRange operands) { + // Lower source op attributes to LLVM equivalents. + LLVM::detail::ConvertToLLVMAttributes attrConvert(op); + + auto callback = [op, targetOp, &attrConvert, &rewriter](Type llvm1DVectorTy, + ValueRange operands) { return rewriter .create(op->getLoc(), rewriter.getStringAttr(targetOp), operands, - llvm1DVectorTy, op->getAttrs()) + llvm1DVectorTy, attrConvert.getAttrs()) ->getResult(0); }; diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticDialect.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticDialect.cpp --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticDialect.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticDialect.cpp @@ -8,12 +8,17 @@ #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/DialectImplementation.h" #include "mlir/Transforms/InliningUtils.h" +#include "llvm/ADT/TypeSwitch.h" using namespace mlir; using namespace mlir::arith; #include "mlir/Dialect/Arithmetic/IR/ArithmeticOpsDialect.cpp.inc" +#include "mlir/Dialect/Arithmetic/IR/ArithmeticOpsInterfaces.cpp.inc" +#define GET_ATTRDEF_CLASSES +#include "mlir/Dialect/Arithmetic/IR/ArithmeticOpsAttributes.cpp.inc" namespace { /// This class defines the interface for handling inlining for arithmetic @@ -34,6 +39,10 @@ #define GET_OP_LIST #include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.cpp.inc" >(); + addAttributes< +#define GET_ATTRDEF_LIST +#include "mlir/Dialect/Arithmetic/IR/ArithmeticOpsAttributes.cpp.inc" + >(); addInterfaces(); } diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -22,6 +22,31 @@ using namespace mlir; using namespace mlir::arith; +//===----------------------------------------------------------------------===// +// Floating point op parse/print helpers +//===----------------------------------------------------------------------===// +static ParseResult parseArithFastMathAttr(OpAsmParser &parser, + Attribute &attr) { + if (succeeded( + parser.parseOptionalKeyword(FastMathFlagsAttr::getMnemonic()))) { + attr = FastMathFlagsAttr::parse(parser, Type{}); + return success(static_cast(attr)); + } else { + // No fastmath attribute mnemonic present - defer attribute creation and use + // the default value. + return success(); + } +} + +static void printArithFastMathAttr(OpAsmPrinter &printer, Operation *op, + FastMathFlagsAttr fmAttr) { + // Elide printing the fastmath attribute when fastmath=none + if (fmAttr && (fmAttr.getValue() != FastMathFlags::none)) { + printer << " " << FastMathFlagsAttr::getMnemonic(); + fmAttr.print(printer); + } +} + //===----------------------------------------------------------------------===// // Pattern helpers //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Arithmetic/IR/CMakeLists.txt b/mlir/lib/Dialect/Arithmetic/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Arithmetic/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Arithmetic/IR/CMakeLists.txt @@ -11,6 +11,7 @@ DEPENDS MLIRArithmeticOpsIncGen + MLIRArithmeticOpsInterfacesIncGen LINK_LIBS PUBLIC MLIRDialect diff --git a/mlir/test/Conversion/ArithmeticToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithmeticToLLVM/arith-to-llvm.mlir --- a/mlir/test/Conversion/ArithmeticToLLVM/arith-to-llvm.mlir +++ b/mlir/test/Conversion/ArithmeticToLLVM/arith-to-llvm.mlir @@ -383,3 +383,18 @@ %0 = arith.select %arg0, %arg1, %arg2 : i32 return %0 : i32 } + +// CHECK-LABEL: @fastmath +func.func @fastmath(%arg0: f32, %arg1: f32, %arg2: i32) { +// CHECK: {{.*}} = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath} : f32 +// CHECK: {{.*}} = llvm.fmul %arg0, %arg1 {fastmathFlags = #llvm.fastmath} : f32 +// CHECK: {{.*}} = llvm.fneg %arg0 {fastmathFlags = #llvm.fastmath} : f32 +// CHECK: {{.*}} = llvm.fadd %arg0, %arg1 : f32 +// CHECK: {{.*}} = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath} : f32 + %0 = arith.addf %arg0, %arg1 fastmath : f32 + %1 = arith.mulf %arg0, %arg1 fastmath : f32 + %2 = arith.negf %arg0 fastmath : f32 + %3 = arith.addf %arg0, %arg1 fastmath : f32 + %4 = arith.addf %arg0, %arg1 fastmath : f32 + return +} diff --git a/mlir/test/Dialect/Arithmetic/ops.mlir b/mlir/test/Dialect/Arithmetic/ops.mlir --- a/mlir/test/Dialect/Arithmetic/ops.mlir +++ b/mlir/test/Dialect/Arithmetic/ops.mlir @@ -952,3 +952,27 @@ %min_unsigned = arith.minui %i1, %i2 : i32 return } + +// CHECK-LABEL: @fastmath +func.func @fastmath(%arg0: f32, %arg1: f32, %arg2: i32) { +// CHECK: {{.*}} = arith.addf %arg0, %arg1 fastmath : f32 +// CHECK: {{.*}} = arith.subf %arg0, %arg1 fastmath : f32 +// CHECK: {{.*}} = arith.mulf %arg0, %arg1 fastmath : f32 +// CHECK: {{.*}} = arith.divf %arg0, %arg1 fastmath : f32 +// CHECK: {{.*}} = arith.remf %arg0, %arg1 fastmath : f32 +// CHECK: {{.*}} = arith.negf %arg0 fastmath : f32 + %0 = arith.addf %arg0, %arg1 fastmath : f32 + %1 = arith.subf %arg0, %arg1 fastmath : f32 + %2 = arith.mulf %arg0, %arg1 fastmath : f32 + %3 = arith.divf %arg0, %arg1 fastmath : f32 + %4 = arith.remf %arg0, %arg1 fastmath : f32 + %5 = arith.negf %arg0 fastmath : f32 +// CHECK: {{.*}} = arith.addf %arg0, %arg1 : f32 + %6 = arith.addf %arg0, %arg1 fastmath : f32 +// CHECK: {{.*}} = arith.addf %arg0, %arg1 fastmath : f32 + %7 = arith.addf %arg0, %arg1 fastmath : f32 +// CHECK: {{.*}} = arith.mulf %arg0, %arg1 fastmath : f32 + %8 = arith.mulf %arg0, %arg1 fastmath : f32 + + return +} diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -993,7 +993,7 @@ for (FormatElement *param : dir->getArguments()) { if (auto *attr = dyn_cast(param)) { const NamedAttribute *var = attr->getVar(); - if (var->attr.isOptional()) + if (var->attr.isOptional() || var->attr.hasDefaultValue()) body << llvm::formatv(" if ({0}Attr)\n ", var->name); body << llvm::formatv(" result.addAttribute(\"{0}\", {0}Attr);\n", diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp --- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp +++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp @@ -669,10 +669,11 @@ continue; } - builderLines.push_back(llvm::formatv(attribute->attr.isOptional() - ? initOptionalAttributeTemplate - : initAttributeTemplate, - attribute->name, argNames[i])); + builderLines.push_back(llvm::formatv( + attribute->attr.isOptional() || attribute->attr.hasDefaultValue() + ? initOptionalAttributeTemplate + : initAttributeTemplate, + attribute->name, argNames[i])); } }