diff --git a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h @@ -0,0 +1,81 @@ +//===- AttrToLLVMConverter.h - Arith attributes conversion ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_ARITHCOMMON_ATTRTOLLVMCONVERTER_H +#define MLIR_CONVERSION_ARITHCOMMON_ATTRTOLLVMCONVERTER_H + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" + +//===----------------------------------------------------------------------===// +// Support for converting Arith FastMathFlags to LLVM FastmathFlags +//===----------------------------------------------------------------------===// + +namespace mlir { +namespace arith { +// Map arithmetic fastmath enum values to LLVMIR enum values. +LLVM::FastmathFlags +convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF); + +// Create an LLVM fastmath attribute from a given arithmetic fastmath attribute. +LLVM::FastmathFlagsAttr +convertArithFastMathAttrToLLVM(arith::FastMathFlagsAttr fmfAttr); + +// Attribute converter that populates a NamedAttrList by removing the fastmath +// attribute from the source operation attributes, and replacing it with an +// equivalent LLVM fastmath attribute. +template +class AttrConvertFastMathToLLVM { +public: + AttrConvertFastMathToLLVM(SourceOp srcOp) { + // Copy the source attributes. + convertedAttr = NamedAttrList{srcOp->getAttrs()}; + // Get the name of the arith fastmath attribute. + llvm::StringRef arithFMFAttrName = SourceOp::getFastMathAttrName(); + // Remove the source fastmath attribute. + auto arithFMFAttr = + convertedAttr.erase(arithFMFAttrName) + .template dyn_cast_or_null(); + if (arithFMFAttr) { + llvm::StringRef targetAttrName = TargetOp::getFastmathAttrName(); + convertedAttr.set(targetAttrName, + convertArithFastMathAttrToLLVM(arithFMFAttr)); + } + } + + ArrayRef getAttrs() const { return convertedAttr.getAttrs(); } + +private: + NamedAttrList convertedAttr; +}; + +// Attribute converter that populates a NamedAttrList by removing the fastmath +// attribute from the source operation attributes. This may be useful for +// target operations that do not require the fastmath attribute, or for targets +// that do not yet support the LLVM fastmath attribute. +template +class AttrDropFastMath { +public: + AttrDropFastMath(SourceOp srcOp) { + // Copy the source attributes. + convertedAttr = NamedAttrList{srcOp->getAttrs()}; + // Get the name of the arith fastmath attribute. + llvm::StringRef arithFMFAttrName = SourceOp::getFastMathAttrName(); + // Remove the source fastmath attribute. + convertedAttr.erase(arithFMFAttrName); + } + + ArrayRef getAttrs() const { return convertedAttr.getAttrs(); } + +private: + NamedAttrList convertedAttr; +}; +} // namespace arith +} // namespace mlir + +#endif // MLIR_CONVERSION_ARITHCOMMON_ATTRTOLLVMCONVERTER_H diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td --- a/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td +++ b/mlir/include/mlir/Dialect/Arith/IR/ArithBase.td @@ -121,4 +121,9 @@ let printBitEnumPrimaryGroups = 1; } +def Arith_FastMathAttr : + EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + #endif // ARITH_BASE diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td --- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td +++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td @@ -20,11 +20,6 @@ include "mlir/IR/OpAsmInterface.td" include "mlir/IR/EnumAttr.td" -def Arith_FastMathAttr : - EnumAttr { - let assemblyFormat = "`<` $value `>`"; -} - // Base class for Arith dialect ops. Ops in this dialect have no side // effects and can be applied element-wise to vectors and tensors. class Arith_Op traits = []> : diff --git a/mlir/include/mlir/Dialect/Math/IR/Math.h b/mlir/include/mlir/Dialect/Math/IR/Math.h --- a/mlir/include/mlir/Dialect/Math/IR/Math.h +++ b/mlir/include/mlir/Dialect/Math/IR/Math.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_MATH_IR_MATH_H_ #define MLIR_DIALECT_MATH_IR_MATH_H_ +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td --- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td +++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td @@ -9,6 +9,8 @@ #ifndef MATH_OPS #define MATH_OPS +include "mlir/Dialect/Arith/IR/ArithBase.td" +include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.td" include "mlir/Dialect/Math/IR/MathBase.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/VectorInterfaces.td" @@ -36,11 +38,16 @@ // operand and result of the same type. This type can be a floating point type, // vector or tensor thereof. class Math_FloatUnaryOp traits = []> : - Math_Op { - let arguments = (ins FloatLike:$operand); + Math_Op]> { + let arguments = (ins FloatLike:$operand, + DefaultValuedAttr:$fastmath); let results = (outs FloatLike:$result); - let assemblyFormat = "$operand attr-dict `:` type($result)"; + let assemblyFormat = [{ $operand (`fastmath` `` $fastmath^)? + attr-dict `:` type($result) }]; } // Base class for binary math operations on integer types. Require two @@ -58,22 +65,32 @@ // operands and one result of the same type. This type can be a floating point // type, vector or tensor thereof. class Math_FloatBinaryOp traits = []> : - Math_Op { - let arguments = (ins FloatLike:$lhs, FloatLike:$rhs); + Math_Op]> { + let arguments = (ins FloatLike:$lhs, FloatLike:$rhs, + DefaultValuedAttr:$fastmath); let results = (outs FloatLike:$result); - let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($result)"; + let assemblyFormat = [{ $lhs `,` $rhs (`fastmath` `` $fastmath^)? + attr-dict `:` type($result) }]; } // Base class for floating point ternary operations. Require three operands and // one result of the same type. This type can be a floating point type, vector // or tensor thereof. class Math_FloatTernaryOp traits = []> : - Math_Op { - let arguments = (ins FloatLike:$a, FloatLike:$b, FloatLike:$c); + Math_Op]> { + let arguments = (ins FloatLike:$a, FloatLike:$b, FloatLike:$c, + DefaultValuedAttr:$fastmath); let results = (outs FloatLike:$result); - let assemblyFormat = "$a `,` $b `,` $c attr-dict `:` type($result)"; + let assemblyFormat = [{ $a `,` $b `,` $c (`fastmath` `` $fastmath^)? + attr-dict `:` type($result) }]; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp b/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp @@ -0,0 +1,38 @@ +//===- AttrToLLVMConverter.cpp - Arith attributes conversion to LLVM ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h" + +using namespace mlir; + +// Map arithmetic fastmath enum values to LLVMIR enum values. +LLVM::FastmathFlags +mlir::arith::convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF) { + LLVM::FastmathFlags llvmFMF{}; + const std::pair flags[] = { + {arith::FastMathFlags::nnan, LLVM::FastmathFlags::nnan}, + {arith::FastMathFlags::ninf, LLVM::FastmathFlags::ninf}, + {arith::FastMathFlags::nsz, LLVM::FastmathFlags::nsz}, + {arith::FastMathFlags::arcp, LLVM::FastmathFlags::arcp}, + {arith::FastMathFlags::contract, LLVM::FastmathFlags::contract}, + {arith::FastMathFlags::afn, LLVM::FastmathFlags::afn}, + {arith::FastMathFlags::reassoc, LLVM::FastmathFlags::reassoc}}; + for (auto fmfMap : flags) { + if (bitEnumContainsAny(arithFMF, fmfMap.first)) + llvmFMF = llvmFMF | fmfMap.second; + } + return llvmFMF; +} + +// Create an LLVM fastmath attribute from a given arithmetic fastmath attribute. +LLVM::FastmathFlagsAttr +mlir::arith::convertArithFastMathAttrToLLVM(arith::FastMathFlagsAttr fmfAttr) { + arith::FastMathFlags arithFMF = fmfAttr.getValue(); + return LLVM::FastmathFlagsAttr::get( + fmfAttr.getContext(), convertArithFastMathFlagsToLLVM(arithFMF)); +} diff --git a/mlir/lib/Conversion/ArithCommon/CMakeLists.txt b/mlir/lib/Conversion/ArithCommon/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/ArithCommon/CMakeLists.txt @@ -0,0 +1,10 @@ +add_mlir_conversion_library(MLIRArithAttrToLLVMConversion + AttrToLLVMConverter.cpp + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRArithDialect + MLIRLLVMDialect + ) diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp --- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp +++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp @@ -8,6 +8,7 @@ #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" +#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/VectorPattern.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -24,93 +25,20 @@ namespace { -// Map arithmetic fastmath enum values to LLVMIR enum values. -static LLVM::FastmathFlags -convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF) { - LLVM::FastmathFlags llvmFMF{}; - const std::pair flags[] = { - {arith::FastMathFlags::nnan, LLVM::FastmathFlags::nnan}, - {arith::FastMathFlags::ninf, LLVM::FastmathFlags::ninf}, - {arith::FastMathFlags::nsz, LLVM::FastmathFlags::nsz}, - {arith::FastMathFlags::arcp, LLVM::FastmathFlags::arcp}, - {arith::FastMathFlags::contract, LLVM::FastmathFlags::contract}, - {arith::FastMathFlags::afn, LLVM::FastmathFlags::afn}, - {arith::FastMathFlags::reassoc, LLVM::FastmathFlags::reassoc}}; - for (auto fmfMap : flags) { - if (bitEnumContainsAny(arithFMF, fmfMap.first)) - llvmFMF = llvmFMF | fmfMap.second; - } - return llvmFMF; -} - -// Create an LLVM fastmath attribute from a given arithmetic fastmath attribute. -static LLVM::FastmathFlagsAttr -convertArithFastMathAttr(arith::FastMathFlagsAttr fmfAttr) { - arith::FastMathFlags arithFMF = fmfAttr.getValue(); - return LLVM::FastmathFlagsAttr::get( - fmfAttr.getContext(), convertArithFastMathFlagsToLLVM(arithFMF)); -} - -// Attribute converter that populates a NamedAttrList by removing the fastmath -// attribute from the source operation attributes, and replacing it with an -// equivalent LLVM fastmath attribute. -template -class AttrConvertFastMath { -public: - AttrConvertFastMath(SourceOp srcOp) { - // Copy the source attributes. - convertedAttr = NamedAttrList{srcOp->getAttrs()}; - // Get the name of the arith fastmath attribute. - llvm::StringRef arithFMFAttrName = SourceOp::getFastMathAttrName(); - // Remove the source fastmath attribute. - auto arithFMFAttr = convertedAttr.erase(arithFMFAttrName) - .template dyn_cast_or_null(); - if (arithFMFAttr) { - llvm::StringRef targetAttrName = TargetOp::getFastmathAttrName(); - convertedAttr.set(targetAttrName, convertArithFastMathAttr(arithFMFAttr)); - } - } - - ArrayRef getAttrs() const { return convertedAttr.getAttrs(); } - -private: - NamedAttrList convertedAttr; -}; - -// Attribute converter that populates a NamedAttrList by removing the fastmath -// attribute from the source operation attributes. This may be useful for -// target operations that do not require the fastmath attribute, or for targets -// that do not yet support the LLVM fastmath attribute. -template -class AttrDropFastMath { -public: - AttrDropFastMath(SourceOp srcOp) { - // Copy the source attributes. - convertedAttr = NamedAttrList{srcOp->getAttrs()}; - // Get the name of the arith fastmath attribute. - llvm::StringRef arithFMFAttrName = SourceOp::getFastMathAttrName(); - // Remove the source fastmath attribute. - convertedAttr.erase(arithFMFAttrName); - } - - ArrayRef getAttrs() const { return convertedAttr.getAttrs(); } - -private: - NamedAttrList convertedAttr; -}; - //===----------------------------------------------------------------------===// // Straightforward Op Lowerings //===----------------------------------------------------------------------===// -using AddFOpLowering = VectorConvertToLLVMPattern; +using AddFOpLowering = + VectorConvertToLLVMPattern; using AddIOpLowering = VectorConvertToLLVMPattern; using AndIOpLowering = VectorConvertToLLVMPattern; using BitcastOpLowering = VectorConvertToLLVMPattern; -using DivFOpLowering = VectorConvertToLLVMPattern; +using DivFOpLowering = + VectorConvertToLLVMPattern; using DivSIOpLowering = VectorConvertToLLVMPattern; using DivUIOpLowering = @@ -125,28 +53,30 @@ using FPToUIOpLowering = VectorConvertToLLVMPattern; // TODO: Add LLVM intrinsic support for fastmath -using MaxFOpLowering = - VectorConvertToLLVMPattern; +using MaxFOpLowering = VectorConvertToLLVMPattern; using MaxSIOpLowering = VectorConvertToLLVMPattern; using MaxUIOpLowering = VectorConvertToLLVMPattern; // TODO: Add LLVM intrinsic support for fastmath -using MinFOpLowering = - VectorConvertToLLVMPattern; +using MinFOpLowering = VectorConvertToLLVMPattern; using MinSIOpLowering = VectorConvertToLLVMPattern; using MinUIOpLowering = VectorConvertToLLVMPattern; -using MulFOpLowering = VectorConvertToLLVMPattern; +using MulFOpLowering = + VectorConvertToLLVMPattern; using MulIOpLowering = VectorConvertToLLVMPattern; -using NegFOpLowering = VectorConvertToLLVMPattern; +using NegFOpLowering = + VectorConvertToLLVMPattern; using OrIOpLowering = VectorConvertToLLVMPattern; // TODO: Add LLVM intrinsic support for fastmath -using RemFOpLowering = - VectorConvertToLLVMPattern; +using RemFOpLowering = VectorConvertToLLVMPattern; using RemSIOpLowering = VectorConvertToLLVMPattern; using RemUIOpLowering = @@ -160,8 +90,9 @@ VectorConvertToLLVMPattern; using SIToFPOpLowering = VectorConvertToLLVMPattern; -using SubFOpLowering = VectorConvertToLLVMPattern; +using SubFOpLowering = + VectorConvertToLLVMPattern; using SubIOpLowering = VectorConvertToLLVMPattern; using TruncFOpLowering = VectorConvertToLLVMPattern; diff --git a/mlir/lib/Conversion/ArithToLLVM/CMakeLists.txt b/mlir/lib/Conversion/ArithToLLVM/CMakeLists.txt --- a/mlir/lib/Conversion/ArithToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/ArithToLLVM/CMakeLists.txt @@ -11,6 +11,7 @@ Core LINK_LIBS PUBLIC + MLIRArithAttrToLLVMConversion MLIRArithDialect MLIRLLVMCommonConversion MLIRLLVMDialect diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -1,5 +1,6 @@ add_subdirectory(AffineToStandard) add_subdirectory(AMDGPUToROCDL) +add_subdirectory(ArithCommon) add_subdirectory(ArithToLLVM) add_subdirectory(ArithToSPIRV) add_subdirectory(ArmNeon2dToIntr) diff --git a/mlir/lib/Conversion/MathToLLVM/CMakeLists.txt b/mlir/lib/Conversion/MathToLLVM/CMakeLists.txt --- a/mlir/lib/Conversion/MathToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/MathToLLVM/CMakeLists.txt @@ -11,6 +11,7 @@ Core LINK_LIBS PUBLIC + MLIRArithAttrToLLVMConversion MLIRLLVMCommonConversion MLIRLLVMDialect MLIRMathDialect diff --git a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp --- a/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp +++ b/mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp @@ -8,6 +8,7 @@ #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" +#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/LLVMCommon/VectorPattern.h" @@ -24,31 +25,39 @@ using namespace mlir; namespace { -using AbsFOpLowering = VectorConvertToLLVMPattern; -using CeilOpLowering = VectorConvertToLLVMPattern; + +template +using ConvertFastMath = arith::AttrConvertFastMathToLLVM; + +template +using ConvertFMFMathToLLVMPattern = + VectorConvertToLLVMPattern; + +using AbsFOpLowering = ConvertFMFMathToLLVMPattern; +using CeilOpLowering = ConvertFMFMathToLLVMPattern; using CopySignOpLowering = - VectorConvertToLLVMPattern; -using CosOpLowering = VectorConvertToLLVMPattern; + ConvertFMFMathToLLVMPattern; +using CosOpLowering = ConvertFMFMathToLLVMPattern; using CtPopFOpLowering = VectorConvertToLLVMPattern; -using Exp2OpLowering = VectorConvertToLLVMPattern; -using ExpOpLowering = VectorConvertToLLVMPattern; +using Exp2OpLowering = ConvertFMFMathToLLVMPattern; +using ExpOpLowering = ConvertFMFMathToLLVMPattern; using FloorOpLowering = - VectorConvertToLLVMPattern; -using FmaOpLowering = VectorConvertToLLVMPattern; + ConvertFMFMathToLLVMPattern; +using FmaOpLowering = ConvertFMFMathToLLVMPattern; using Log10OpLowering = - VectorConvertToLLVMPattern; -using Log2OpLowering = VectorConvertToLLVMPattern; -using LogOpLowering = VectorConvertToLLVMPattern; -using PowFOpLowering = VectorConvertToLLVMPattern; + ConvertFMFMathToLLVMPattern; +using Log2OpLowering = ConvertFMFMathToLLVMPattern; +using LogOpLowering = ConvertFMFMathToLLVMPattern; +using PowFOpLowering = ConvertFMFMathToLLVMPattern; using RoundEvenOpLowering = - VectorConvertToLLVMPattern; + ConvertFMFMathToLLVMPattern; using RoundOpLowering = - VectorConvertToLLVMPattern; -using SinOpLowering = VectorConvertToLLVMPattern; -using SqrtOpLowering = VectorConvertToLLVMPattern; + ConvertFMFMathToLLVMPattern; +using SinOpLowering = ConvertFMFMathToLLVMPattern; +using SqrtOpLowering = ConvertFMFMathToLLVMPattern; using FTruncOpLowering = - VectorConvertToLLVMPattern; + ConvertFMFMathToLLVMPattern; // A `CtLz/CtTz/absi(a)` is converted into `CtLz/CtTz/absi(a, false)`. template @@ -113,6 +122,8 @@ auto resultType = op.getResult().getType(); auto floatType = getElementTypeOrSelf(resultType).cast(); auto floatOne = rewriter.getFloatAttr(floatType, 1.0); + ConvertFastMath expAttrs(op); + ConvertFastMath subAttrs(op); if (!operandType.isa()) { LLVM::ConstantOp one; @@ -123,8 +134,10 @@ } else { one = rewriter.create(loc, operandType, floatOne); } - auto exp = rewriter.create(loc, adaptor.getOperand()); - rewriter.replaceOpWithNewOp(op, operandType, exp, one); + auto exp = rewriter.create(loc, adaptor.getOperand(), + expAttrs.getAttrs()); + rewriter.replaceOpWithNewOp( + op, operandType, ValueRange{exp, one}, subAttrs.getAttrs()); return success(); } @@ -142,9 +155,10 @@ floatOne); auto one = rewriter.create(loc, llvm1DVectorTy, splatAttr); - auto exp = - rewriter.create(loc, llvm1DVectorTy, operands[0]); - return rewriter.create(loc, llvm1DVectorTy, exp, one); + auto exp = rewriter.create( + loc, llvm1DVectorTy, operands[0], expAttrs.getAttrs()); + return rewriter.create( + loc, llvm1DVectorTy, ValueRange{exp, one}, subAttrs.getAttrs()); }, rewriter); } @@ -166,6 +180,8 @@ auto resultType = op.getResult().getType(); auto floatType = getElementTypeOrSelf(resultType).cast(); auto floatOne = rewriter.getFloatAttr(floatType, 1.0); + ConvertFastMath addAttrs(op); + ConvertFastMath logAttrs(op); if (!operandType.isa()) { LLVM::ConstantOp one = @@ -176,9 +192,11 @@ floatOne)) : rewriter.create(loc, operandType, floatOne); - auto add = rewriter.create(loc, operandType, one, - adaptor.getOperand()); - rewriter.replaceOpWithNewOp(op, operandType, add); + auto add = rewriter.create( + loc, operandType, ValueRange{one, adaptor.getOperand()}, + addAttrs.getAttrs()); + rewriter.replaceOpWithNewOp(op, operandType, ValueRange{add}, + logAttrs.getAttrs()); return success(); } @@ -196,9 +214,11 @@ floatOne); auto one = rewriter.create(loc, llvm1DVectorTy, splatAttr); - auto add = rewriter.create(loc, llvm1DVectorTy, one, - operands[0]); - return rewriter.create(loc, llvm1DVectorTy, add); + auto add = rewriter.create(loc, llvm1DVectorTy, + ValueRange{one, operands[0]}, + addAttrs.getAttrs()); + return rewriter.create( + loc, llvm1DVectorTy, ValueRange{add}, logAttrs.getAttrs()); }, rewriter); } @@ -220,6 +240,8 @@ auto resultType = op.getResult().getType(); auto floatType = getElementTypeOrSelf(resultType).cast(); auto floatOne = rewriter.getFloatAttr(floatType, 1.0); + ConvertFastMath sqrtAttrs(op); + ConvertFastMath divAttrs(op); if (!operandType.isa()) { LLVM::ConstantOp one; @@ -230,8 +252,10 @@ } else { one = rewriter.create(loc, operandType, floatOne); } - auto sqrt = rewriter.create(loc, adaptor.getOperand()); - rewriter.replaceOpWithNewOp(op, operandType, one, sqrt); + auto sqrt = rewriter.create(loc, adaptor.getOperand(), + sqrtAttrs.getAttrs()); + rewriter.replaceOpWithNewOp( + op, operandType, ValueRange{one, sqrt}, divAttrs.getAttrs()); return success(); } @@ -249,9 +273,10 @@ floatOne); auto one = rewriter.create(loc, llvm1DVectorTy, splatAttr); - auto sqrt = - rewriter.create(loc, llvm1DVectorTy, operands[0]); - return rewriter.create(loc, llvm1DVectorTy, one, sqrt); + auto sqrt = rewriter.create( + loc, llvm1DVectorTy, operands[0], sqrtAttrs.getAttrs()); + return rewriter.create( + loc, llvm1DVectorTy, ValueRange{one, sqrt}, divAttrs.getAttrs()); }, rewriter); } diff --git a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir --- a/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir +++ b/mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir @@ -36,6 +36,18 @@ // ----- +// CHECK-LABEL: func @log1p_fmf( +// CHECK-SAME: f32 +func.func @log1p_fmf(%arg0 : f32) { + // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32 + // CHECK: %[[ADD:.*]] = llvm.fadd %[[ONE]], %arg0 {fastmathFlags = #llvm.fastmath} : f32 + // CHECK: %[[LOG:.*]] = llvm.intr.log(%[[ADD]]) {fastmathFlags = #llvm.fastmath} : (f32) -> f32 + %0 = math.log1p %arg0 fastmath : f32 + func.return +} + +// ----- + // CHECK-LABEL: func @log1p_2dvector( func.func @log1p_2dvector(%arg0 : vector<4x3xf32>) { // CHECK: %[[EXTRACT:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<4 x vector<3xf32>> @@ -49,6 +61,19 @@ // ----- +// CHECK-LABEL: func @log1p_2dvector_fmf( +func.func @log1p_2dvector_fmf(%arg0 : vector<4x3xf32>) { + // CHECK: %[[EXTRACT:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<4 x vector<3xf32>> + // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<3xf32>) : vector<3xf32> + // CHECK: %[[ADD:.*]] = llvm.fadd %[[ONE]], %[[EXTRACT]] {fastmathFlags = #llvm.fastmath} : vector<3xf32> + // CHECK: %[[LOG:.*]] = llvm.intr.log(%[[ADD]]) {fastmathFlags = #llvm.fastmath} : (vector<3xf32>) -> vector<3xf32> + // CHECK: %[[INSERT:.*]] = llvm.insertvalue %[[LOG]], %{{.*}}[0] : !llvm.array<4 x vector<3xf32>> + %0 = math.log1p %arg0 fastmath : vector<4x3xf32> + func.return +} + +// ----- + // CHECK-LABEL: func @expm1( // CHECK-SAME: f32 func.func @expm1(%arg0 : f32) { @@ -61,6 +86,42 @@ // ----- +// CHECK-LABEL: func @expm1_fmf( +// CHECK-SAME: f32 +func.func @expm1_fmf(%arg0 : f32) { + // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32 + // CHECK: %[[EXP:.*]] = llvm.intr.exp(%arg0) {fastmathFlags = #llvm.fastmath} : (f32) -> f32 + // CHECK: %[[SUB:.*]] = llvm.fsub %[[EXP]], %[[ONE]] {fastmathFlags = #llvm.fastmath} : f32 + %0 = math.expm1 %arg0 fastmath : f32 + func.return +} + +// ----- + +// CHECK-LABEL: func @expm1_vector( +// CHECK-SAME: vector<4xf32> +func.func @expm1_vector(%arg0 : vector<4xf32>) { + // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<4xf32>) : vector<4xf32> + // CHECK: %[[EXP:.*]] = llvm.intr.exp(%arg0) : (vector<4xf32>) -> vector<4xf32> + // CHECK: %[[SUB:.*]] = llvm.fsub %[[EXP]], %[[ONE]] : vector<4xf32> + %0 = math.expm1 %arg0 : vector<4xf32> + func.return +} + +// ----- + +// CHECK-LABEL: func @expm1_vector_fmf( +// CHECK-SAME: vector<4xf32> +func.func @expm1_vector_fmf(%arg0 : vector<4xf32>) { + // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<4xf32>) : vector<4xf32> + // CHECK: %[[EXP:.*]] = llvm.intr.exp(%arg0) {fastmathFlags = #llvm.fastmath} : (vector<4xf32>) -> vector<4xf32> + // CHECK: %[[SUB:.*]] = llvm.fsub %[[EXP]], %[[ONE]] {fastmathFlags = #llvm.fastmath} : vector<4xf32> + %0 = math.expm1 %arg0 fastmath : vector<4xf32> + func.return +} + +// ----- + // CHECK-LABEL: func @rsqrt( // CHECK-SAME: f32 func.func @rsqrt(%arg0 : f32) { @@ -148,6 +209,18 @@ // ----- +// CHECK-LABEL: func @rsqrt_double_fmf( +// CHECK-SAME: f64 +func.func @rsqrt_double_fmf(%arg0 : f64) { + // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f64) : f64 + // CHECK: %[[SQRT:.*]] = llvm.intr.sqrt(%arg0) {fastmathFlags = #llvm.fastmath} : (f64) -> f64 + // CHECK: %[[DIV:.*]] = llvm.fdiv %[[ONE]], %[[SQRT]] {fastmathFlags = #llvm.fastmath} : f64 + %0 = math.rsqrt %arg0 fastmath : f64 + func.return +} + +// ----- + // CHECK-LABEL: func @rsqrt_vector( // CHECK-SAME: vector<4xf32> func.func @rsqrt_vector(%arg0 : vector<4xf32>) { @@ -160,6 +233,18 @@ // ----- +// CHECK-LABEL: func @rsqrt_vector_fmf( +// CHECK-SAME: vector<4xf32> +func.func @rsqrt_vector_fmf(%arg0 : vector<4xf32>) { + // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<4xf32>) : vector<4xf32> + // CHECK: %[[SQRT:.*]] = llvm.intr.sqrt(%arg0) {fastmathFlags = #llvm.fastmath} : (vector<4xf32>) -> vector<4xf32> + // CHECK: %[[DIV:.*]] = llvm.fdiv %[[ONE]], %[[SQRT]] {fastmathFlags = #llvm.fastmath} : vector<4xf32> + %0 = math.rsqrt %arg0 fastmath : vector<4xf32> + func.return +} + +// ----- + // CHECK-LABEL: func @rsqrt_multidim_vector( func.func @rsqrt_multidim_vector(%arg0 : vector<4x3xf32>) { // CHECK: %[[EXTRACT:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<4 x vector<3xf32>> @@ -210,3 +295,19 @@ %0 = math.trunc %arg0 : f32 func.return } + +// ----- + +// CHECK-LABEL: func @fastmath( +// CHECK-SAME: f32 +func.func @fastmath(%arg0 : f32, %arg1 : vector<4xf32>) { + // CHECK: llvm.intr.trunc(%arg0) {fastmathFlags = #llvm.fastmath} : (f32) -> f32 + %0 = math.trunc %arg0 fastmath : f32 + // CHECK: llvm.intr.pow(%arg0, %arg0) {fastmathFlags = #llvm.fastmath} : (f32, f32) -> f32 + %1 = math.powf %arg0, %arg0 fastmath : f32 + // CHECK: llvm.intr.sqrt(%arg0) : (f32) -> f32 + %2 = math.sqrt %arg0 fastmath : f32 + // CHECK: llvm.intr.fma(%arg0, %arg0, %arg0) {fastmathFlags = #llvm.fastmath} : (f32, f32, f32) -> f32 + %3 = math.fma %arg0, %arg0, %arg0 fastmath : f32 + func.return +} diff --git a/mlir/test/Dialect/Math/ops.mlir b/mlir/test/Dialect/Math/ops.mlir --- a/mlir/test/Dialect/Math/ops.mlir +++ b/mlir/test/Dialect/Math/ops.mlir @@ -269,3 +269,17 @@ %2 = math.trunc %t : tensor<4x4x?xf32> return } + +// CHECK-LABEL: func @fastmath( +// CHECK-SAME: %[[F:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[T:.*]]: tensor<4x4x?xf32>) +func.func @fastmath(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) { + // CHECK: %{{.*}} = math.trunc %[[F]] fastmath : f32 + %0 = math.trunc %f fastmath : f32 + // CHECK: %{{.*}} = math.powf %[[V]], %[[V]] fastmath : vector<4xf32> + %1 = math.powf %v, %v fastmath : vector<4xf32> + // CHECK: %{{.*}} = math.fma %[[T]], %[[T]], %[[T]] : tensor<4x4x?xf32> + %2 = math.fma %t, %t, %t fastmath : tensor<4x4x?xf32> + // CHECK: %{{.*}} = math.absf %[[F]] fastmath : f32 + %3 = math.absf %f fastmath : f32 + return +}