Index: mlir/include/mlir/Dialect/Math/IR/CMakeLists.txt =================================================================== --- mlir/include/mlir/Dialect/Math/IR/CMakeLists.txt +++ mlir/include/mlir/Dialect/Math/IR/CMakeLists.txt @@ -1,2 +1,16 @@ +set(LLVM_TARGET_DEFINITIONS MathOps.td) +mlir_tablegen(MathOpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(MathOpsEnums.cpp.inc -gen-enum-defs) +mlir_tablegen(MathOpsAttributes.h.inc -gen-attrdef-decls + -attrdefs-dialect=math) +mlir_tablegen(MathOpsAttributes.cpp.inc -gen-attrdef-defs + -attrdefs-dialect=math) +add_public_tablegen_target(MLIRMathOpsAttributesIncGen) + +set(LLVM_TARGET_DEFINITIONS MathOpsInterfaces.td) +mlir_tablegen(MathOpsInterfaces.h.inc -gen-op-interface-decls) +mlir_tablegen(MathOpsInterfaces.cpp.inc -gen-op-interface-defs) +add_public_tablegen_target(MLIRMathOpsInterfacesIncGen) + add_mlir_dialect(MathOps math) add_mlir_doc(MathOps MathOps Dialects/ -gen-dialect-doc) Index: mlir/include/mlir/Dialect/Math/IR/Math.h =================================================================== --- mlir/include/mlir/Dialect/Math/IR/Math.h +++ mlir/include/mlir/Dialect/Math/IR/Math.h @@ -23,6 +23,20 @@ #include "mlir/Dialect/Math/IR/MathOpsDialect.h.inc" +//===----------------------------------------------------------------------===// +// Math Dialect Enum Attributes +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Math/IR/MathOpsEnums.h.inc" +#define GET_ATTRDEF_CLASSES +#include "mlir/Dialect/Math/IR/MathOpsAttributes.h.inc" + +//===----------------------------------------------------------------------===// +// Math Interfaces +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Math/IR/MathOpsInterfaces.h.inc" + //===----------------------------------------------------------------------===// // Math Dialect Operations //===----------------------------------------------------------------------===// Index: mlir/include/mlir/Dialect/Math/IR/MathBase.td =================================================================== --- mlir/include/mlir/Dialect/Math/IR/MathBase.td +++ mlir/include/mlir/Dialect/Math/IR/MathBase.td @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #ifndef MATH_BASE #define MATH_BASE +include "mlir/IR/EnumAttr.td" include "mlir/IR/OpBase.td" def Math_Dialect : Dialect { let name = "math"; @@ -30,5 +31,34 @@ ``` }]; let hasConstantMaterializer = 1; + let useDefaultAttributePrinterParser = 1; +} + +def FASTMATH_NONE : I32BitEnumAttrCaseNone<"none" >; +def FASTMATH_REASSOC : I32BitEnumAttrCaseBit<"reassoc", 0>; +def FASTMATH_NO_NANS : I32BitEnumAttrCaseBit<"nnan", 1>; +def FASTMATH_NO_INFS : I32BitEnumAttrCaseBit<"ninf", 2>; +def FASTMATH_NO_SIGNED_ZEROS : I32BitEnumAttrCaseBit<"nsz", 3>; +def FASTMATH_ALLOW_RECIP : I32BitEnumAttrCaseBit<"arcp", 4>; +def FASTMATH_ALLOW_CONTRACT : I32BitEnumAttrCaseBit<"contract", 5>; +def FASTMATH_APPROX_FUNC : I32BitEnumAttrCaseBit<"afn", 6>; +def FASTMATH_FAST : I32BitEnumAttrCaseGroup< + "fast", + [ + FASTMATH_REASSOC, FASTMATH_NO_NANS, FASTMATH_NO_INFS, + FASTMATH_NO_SIGNED_ZEROS, FASTMATH_ALLOW_RECIP, FASTMATH_ALLOW_CONTRACT, + FASTMATH_APPROX_FUNC]>; + +def FastMathFlags : I32BitEnumAttr< + "FastMathFlags", + "Floating point fast math flags", + [ + FASTMATH_NONE, FASTMATH_REASSOC, FASTMATH_NO_NANS, + FASTMATH_NO_INFS, FASTMATH_NO_SIGNED_ZEROS, FASTMATH_ALLOW_RECIP, + FASTMATH_ALLOW_CONTRACT, FASTMATH_APPROX_FUNC, FASTMATH_FAST]> { + let separator = ","; + let cppNamespace = "::mlir::math"; + let genSpecializedAttr = 0; + let printBitEnumPrimaryGroups = 1; } #endif // MATH_BASE Index: mlir/include/mlir/Dialect/Math/IR/MathOps.td =================================================================== --- mlir/include/mlir/Dialect/Math/IR/MathOps.td +++ mlir/include/mlir/Dialect/Math/IR/MathOps.td @@ -10,10 +10,16 @@ #define MATH_OPS include "mlir/Dialect/Math/IR/MathBase.td" +include "mlir/Dialect/Math/IR/MathOpsInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/VectorInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" +def Math_FastMathAttr : + EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + // Base class for math dialect ops. Ops in this dialect have no side effects and // can be applied element-wise to vectors and tensors. class Math_Op traits = []> : @@ -36,11 +42,14 @@ // operand and result of the same type. This type can be a floating point type, // vector or tensor thereof. class Math_FloatUnaryOp traits = []> : - Math_Op { - let arguments = (ins FloatLike:$operand); + Math_Op]> { + let arguments = (ins FloatLike:$operand, + DefaultValuedAttr:$fastmath); let results = (outs FloatLike:$result); - let assemblyFormat = "$operand attr-dict `:` type($result)"; + let assemblyFormat = "$operand custom($fastmath) attr-dict `:` type($result)"; } // Base class for binary math operations on integer types. Require two @@ -58,22 +67,28 @@ // operands and one result of the same type. This type can be a floating point // type, vector or tensor thereof. class Math_FloatBinaryOp traits = []> : - Math_Op { - let arguments = (ins FloatLike:$lhs, FloatLike:$rhs); + Math_Op]> { + let arguments = (ins FloatLike:$lhs, FloatLike:$rhs, + DefaultValuedAttr:$fastmath); let results = (outs FloatLike:$result); - let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($result)"; + let assemblyFormat = "$lhs `,` $rhs custom($fastmath) attr-dict `:` type($result)"; } // Base class for floating point ternary operations. Require three operands and // one result of the same type. This type can be a floating point type, vector // or tensor thereof. class Math_FloatTernaryOp traits = []> : - Math_Op { - let arguments = (ins FloatLike:$a, FloatLike:$b, FloatLike:$c); + Math_Op]> { + let arguments = (ins FloatLike:$a, FloatLike:$b, FloatLike:$c, + DefaultValuedAttr:$fastmath); let results = (outs FloatLike:$result); - let assemblyFormat = "$a `,` $b `,` $c attr-dict `:` type($result)"; + let assemblyFormat = "$a `,` $b `,` $c custom($fastmath) attr-dict `:` type($result)"; } //===----------------------------------------------------------------------===// Index: mlir/include/mlir/Dialect/Math/IR/MathOpsInterfaces.td =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Math/IR/MathOpsInterfaces.td @@ -0,0 +1,51 @@ +//===-- MathOpsInterfaces.td - math op interfaces ----------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the Math interfaces definition file. +// +//===----------------------------------------------------------------------===// + +#ifndef MATH_OPS_INTERFACES +#define MATH_OPS_INTERFACES + +include "mlir/IR/OpBase.td" + +def MathFastMathInterface : OpInterface<"MathFastMathInterface"> { + let description = [{ + Access to operation fastmath flags. + }]; + + let cppNamespace = "::mlir::math"; + + let methods = [ + InterfaceMethod< + /*desc=*/ "Returns a FastMathFlagsAttr attribute for the operation", + /*returnType=*/ "FastMathFlagsAttr", + /*methodName=*/ "getFastMathFlagsAttr", + /*args=*/ (ins), + /*methodBody=*/ [{}], + /*defaultImpl=*/ [{ + ConcreteOp op = cast(this->getOperation()); + return op.getFastmathAttr(); + }] + >, + StaticInterfaceMethod< + /*desc=*/ [{Returns the name of the FastMathFlagsAttr attribute + for the operation}], + /*returnType=*/ "StringRef", + /*methodName=*/ "getFastMathAttrName", + /*args=*/ (ins), + /*methodBody=*/ [{}], + /*defaultImpl=*/ [{ + return "fastmath"; + }] + > + ]; +} + +#endif // MATH_OPS_INTERFACES Index: mlir/lib/Conversion/LLVMCommon/CMakeLists.txt =================================================================== --- mlir/lib/Conversion/LLVMCommon/CMakeLists.txt +++ mlir/lib/Conversion/LLVMCommon/CMakeLists.txt @@ -13,6 +13,7 @@ LINK_LIBS PUBLIC MLIRIR MLIRLLVMDialect + MLIRMathDialect MLIRSupport MLIRTransforms ) Index: mlir/lib/Conversion/LLVMCommon/Pattern.cpp =================================================================== --- mlir/lib/Conversion/LLVMCommon/Pattern.cpp +++ mlir/lib/Conversion/LLVMCommon/Pattern.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinAttributes.h" Index: mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp =================================================================== --- mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp +++ mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp @@ -24,31 +24,93 @@ using namespace mlir; namespace { -using AbsFOpLowering = VectorConvertToLLVMPattern; -using CeilOpLowering = VectorConvertToLLVMPattern; +// Map Math dialect fastmath enum values to LLVMIR enum values. +static LLVM::FastmathFlags +convertMathFastMathFlagsToLLVM(math::FastMathFlags mathFMF) { + LLVM::FastmathFlags llvmFMF{}; + const std::pair flags[] = { + {math::FastMathFlags::nnan, LLVM::FastmathFlags::nnan}, + {math::FastMathFlags::ninf, LLVM::FastmathFlags::ninf}, + {math::FastMathFlags::nsz, LLVM::FastmathFlags::nsz}, + {math::FastMathFlags::arcp, LLVM::FastmathFlags::arcp}, + {math::FastMathFlags::contract, LLVM::FastmathFlags::contract}, + {math::FastMathFlags::afn, LLVM::FastmathFlags::afn}, + {math::FastMathFlags::reassoc, LLVM::FastmathFlags::reassoc}}; + for (auto fmfMap : flags) { + if (bitEnumContainsAny(mathFMF, fmfMap.first)) + llvmFMF = llvmFMF | fmfMap.second; + } + return llvmFMF; +} + +// Create an LLVM fastmath attribute from a given Math dialect fastmath +// attribute. +static LLVM::FastmathFlagsAttr +convertMathFastMathAttr(math::FastMathFlagsAttr fmfAttr) { + auto mathFMF = fmfAttr.getValue(); + return LLVM::FastmathFlagsAttr::get(fmfAttr.getContext(), + convertMathFastMathFlagsToLLVM(mathFMF)); +} + +// Attribute converter that populates a NamedAttrList by removing the fastmath +// attribute from the source operation attributes, and replacing it with an +// equivalent LLVM fastmath attribute. +template +class ConvertFastMathHelper { +public: + template + ConvertFastMathHelper(SourceOp srcOp) { + // Copy the source attributes. + convertedAttr = NamedAttrList{srcOp->getAttrs()}; + // Get the name of the Math dialect's fastmath attribute. + llvm::StringRef mathFMFAttrName = SourceOp::getFastMathAttrName(); + // Remove the source fastmath attribute. + auto mathFMFAttr = convertedAttr.erase(mathFMFAttrName) + .dyn_cast_or_null(); + if (mathFMFAttr) { + llvm::StringRef targetAttrName = TargetOp::getFastmathAttrName(); + convertedAttr.set(targetAttrName, convertMathFastMathAttr(mathFMFAttr)); + } + } + + ArrayRef getAttrs() const { return convertedAttr.getAttrs(); } + +private: + NamedAttrList convertedAttr; +}; + +template +using ConvertFastMath = ConvertFastMathHelper; + +template +using ConvertFMFMathToLLVMPattern = + VectorConvertToLLVMPattern; + +using AbsFOpLowering = ConvertFMFMathToLLVMPattern; +using CeilOpLowering = ConvertFMFMathToLLVMPattern; using CopySignOpLowering = - VectorConvertToLLVMPattern; -using CosOpLowering = VectorConvertToLLVMPattern; + ConvertFMFMathToLLVMPattern; +using CosOpLowering = ConvertFMFMathToLLVMPattern; using CtPopFOpLowering = VectorConvertToLLVMPattern; -using Exp2OpLowering = VectorConvertToLLVMPattern; -using ExpOpLowering = VectorConvertToLLVMPattern; +using Exp2OpLowering = ConvertFMFMathToLLVMPattern; +using ExpOpLowering = ConvertFMFMathToLLVMPattern; using FloorOpLowering = - VectorConvertToLLVMPattern; -using FmaOpLowering = VectorConvertToLLVMPattern; + ConvertFMFMathToLLVMPattern; +using FmaOpLowering = ConvertFMFMathToLLVMPattern; using Log10OpLowering = - VectorConvertToLLVMPattern; -using Log2OpLowering = VectorConvertToLLVMPattern; -using LogOpLowering = VectorConvertToLLVMPattern; -using PowFOpLowering = VectorConvertToLLVMPattern; + ConvertFMFMathToLLVMPattern; +using Log2OpLowering = ConvertFMFMathToLLVMPattern; +using LogOpLowering = ConvertFMFMathToLLVMPattern; +using PowFOpLowering = ConvertFMFMathToLLVMPattern; using RoundEvenOpLowering = - VectorConvertToLLVMPattern; + ConvertFMFMathToLLVMPattern; using RoundOpLowering = - VectorConvertToLLVMPattern; -using SinOpLowering = VectorConvertToLLVMPattern; -using SqrtOpLowering = VectorConvertToLLVMPattern; + ConvertFMFMathToLLVMPattern; +using SinOpLowering = ConvertFMFMathToLLVMPattern; +using SqrtOpLowering = ConvertFMFMathToLLVMPattern; using FTruncOpLowering = - VectorConvertToLLVMPattern; + ConvertFMFMathToLLVMPattern; // A `CtLz/CtTz/absi(a)` is converted into `CtLz/CtTz/absi(a, false)`. template @@ -113,6 +175,8 @@ auto resultType = op.getResult().getType(); auto floatType = getElementTypeOrSelf(resultType).cast(); auto floatOne = rewriter.getFloatAttr(floatType, 1.0); + ConvertFastMathHelper expAttrs(op); + ConvertFastMathHelper subAttrs(op); if (!operandType.isa()) { LLVM::ConstantOp one; @@ -123,8 +187,10 @@ } else { one = rewriter.create(loc, operandType, floatOne); } - auto exp = rewriter.create(loc, adaptor.getOperand()); - rewriter.replaceOpWithNewOp(op, operandType, exp, one); + auto exp = rewriter.create(loc, adaptor.getOperand(), + expAttrs.getAttrs()); + rewriter.replaceOpWithNewOp( + op, operandType, ValueRange{exp, one}, subAttrs.getAttrs()); return success(); } @@ -142,9 +208,10 @@ floatOne); auto one = rewriter.create(loc, llvm1DVectorTy, splatAttr); - auto exp = - rewriter.create(loc, llvm1DVectorTy, operands[0]); - return rewriter.create(loc, llvm1DVectorTy, exp, one); + auto exp = rewriter.create( + loc, llvm1DVectorTy, operands[0], expAttrs.getAttrs()); + return rewriter.create( + loc, llvm1DVectorTy, ValueRange{exp, one}, subAttrs.getAttrs()); }, rewriter); } @@ -166,6 +233,8 @@ auto resultType = op.getResult().getType(); auto floatType = getElementTypeOrSelf(resultType).cast(); auto floatOne = rewriter.getFloatAttr(floatType, 1.0); + ConvertFastMathHelper addAttrs(op); + ConvertFastMathHelper logAttrs(op); if (!operandType.isa()) { LLVM::ConstantOp one = @@ -176,9 +245,11 @@ floatOne)) : rewriter.create(loc, operandType, floatOne); - auto add = rewriter.create(loc, operandType, one, - adaptor.getOperand()); - rewriter.replaceOpWithNewOp(op, operandType, add); + auto add = rewriter.create( + loc, operandType, ValueRange{one, adaptor.getOperand()}, + addAttrs.getAttrs()); + rewriter.replaceOpWithNewOp(op, operandType, ValueRange{add}, + logAttrs.getAttrs()); return success(); } @@ -196,9 +267,11 @@ floatOne); auto one = rewriter.create(loc, llvm1DVectorTy, splatAttr); - auto add = rewriter.create(loc, llvm1DVectorTy, one, - operands[0]); - return rewriter.create(loc, llvm1DVectorTy, add); + auto add = rewriter.create(loc, llvm1DVectorTy, + ValueRange{one, operands[0]}, + addAttrs.getAttrs()); + return rewriter.create( + loc, llvm1DVectorTy, ValueRange{add}, logAttrs.getAttrs()); }, rewriter); } @@ -220,6 +293,8 @@ auto resultType = op.getResult().getType(); auto floatType = getElementTypeOrSelf(resultType).cast(); auto floatOne = rewriter.getFloatAttr(floatType, 1.0); + ConvertFastMathHelper sqrtAttrs(op); + ConvertFastMathHelper divAttrs(op); if (!operandType.isa()) { LLVM::ConstantOp one; @@ -230,8 +305,10 @@ } else { one = rewriter.create(loc, operandType, floatOne); } - auto sqrt = rewriter.create(loc, adaptor.getOperand()); - rewriter.replaceOpWithNewOp(op, operandType, one, sqrt); + auto sqrt = rewriter.create(loc, adaptor.getOperand(), + sqrtAttrs.getAttrs()); + rewriter.replaceOpWithNewOp( + op, operandType, ValueRange{one, sqrt}, divAttrs.getAttrs()); return success(); } @@ -249,9 +326,10 @@ floatOne); auto one = rewriter.create(loc, llvm1DVectorTy, splatAttr); - auto sqrt = - rewriter.create(loc, llvm1DVectorTy, operands[0]); - return rewriter.create(loc, llvm1DVectorTy, one, sqrt); + auto sqrt = rewriter.create( + loc, llvm1DVectorTy, operands[0], sqrtAttrs.getAttrs()); + return rewriter.create( + loc, llvm1DVectorTy, ValueRange{one, sqrt}, divAttrs.getAttrs()); }, rewriter); } Index: mlir/lib/Dialect/Math/IR/CMakeLists.txt =================================================================== --- mlir/lib/Dialect/Math/IR/CMakeLists.txt +++ mlir/lib/Dialect/Math/IR/CMakeLists.txt @@ -6,7 +6,9 @@ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Math DEPENDS + MLIRMathOpsAttributesIncGen MLIRMathOpsIncGen + MLIRMathOpsInterfacesIncGen LINK_LIBS PUBLIC MLIRArithDialect Index: mlir/lib/Dialect/Math/IR/MathDialect.cpp =================================================================== --- mlir/lib/Dialect/Math/IR/MathDialect.cpp +++ mlir/lib/Dialect/Math/IR/MathDialect.cpp @@ -9,10 +9,16 @@ #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Transforms/InliningUtils.h" +#include "mlir/IR/DialectImplementation.h" +#include "llvm/ADT/TypeSwitch.h" + using namespace mlir; using namespace mlir::math; #include "mlir/Dialect/Math/IR/MathOpsDialect.cpp.inc" +#include "mlir/Dialect/Math/IR/MathOpsInterfaces.cpp.inc" +#define GET_ATTRDEF_CLASSES +#include "mlir/Dialect/Math/IR/MathOpsAttributes.cpp.inc" namespace { /// This class defines the interface for handling inlining with math @@ -33,5 +39,9 @@ #define GET_OP_LIST #include "mlir/Dialect/Math/IR/MathOps.cpp.inc" >(); + addAttributes< +#define GET_ATTRDEF_LIST +#include "mlir/Dialect/Math/IR/MathOpsAttributes.cpp.inc" + >(); addInterfaces(); } Index: mlir/lib/Dialect/Math/IR/MathOps.cpp =================================================================== --- mlir/lib/Dialect/Math/IR/MathOps.cpp +++ mlir/lib/Dialect/Math/IR/MathOps.cpp @@ -14,6 +14,30 @@ using namespace mlir; using namespace mlir::math; +//===----------------------------------------------------------------------===// +// Floating point op parse/print helpers +//===----------------------------------------------------------------------===// +static ParseResult parseMathFastMathAttr(OpAsmParser &parser, Attribute &attr) { + if (succeeded( + parser.parseOptionalKeyword(FastMathFlagsAttr::getMnemonic()))) { + attr = FastMathFlagsAttr::parse(parser, Type{}); + return success(static_cast(attr)); + } else { + // No fastmath attribute mnemonic present - defer attribute creation and use + // the default value. + return success(); + } +} + +static void printMathFastMathAttr(OpAsmPrinter &printer, Operation *op, + FastMathFlagsAttr fmAttr) { + // Elide printing the fastmath attribute when fastmath=none + if (fmAttr && (fmAttr.getValue() != FastMathFlags::none)) { + printer << " " << FastMathFlagsAttr::getMnemonic(); + fmAttr.print(printer); + } +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// @@ -21,6 +45,12 @@ #define GET_OP_CLASSES #include "mlir/Dialect/Math/IR/MathOps.cpp.inc" +//===----------------------------------------------------------------------===// +// TableGen'd enum attribute definitions +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Math/IR/MathOpsEnums.cpp.inc" + //===----------------------------------------------------------------------===// // AbsFOp folder //===----------------------------------------------------------------------===// Index: mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir =================================================================== --- mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir +++ mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir @@ -36,6 +36,18 @@ // ----- +// CHECK-LABEL: func @log1p_fmf( +// CHECK-SAME: f32 +func.func @log1p_fmf(%arg0 : f32) { + // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32 + // CHECK: %[[ADD:.*]] = llvm.fadd %[[ONE]], %arg0 {fastmathFlags = #llvm.fastmath} : f32 + // CHECK: %[[LOG:.*]] = llvm.intr.log(%[[ADD]]) {fastmathFlags = #llvm.fastmath} : (f32) -> f32 + %0 = math.log1p %arg0 fastmath : f32 + func.return +} + +// ----- + // CHECK-LABEL: func @log1p_2dvector( func.func @log1p_2dvector(%arg0 : vector<4x3xf32>) { // CHECK: %[[EXTRACT:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<4 x vector<3xf32>> @@ -49,6 +61,19 @@ // ----- +// CHECK-LABEL: func @log1p_2dvector_fmf( +func.func @log1p_2dvector_fmf(%arg0 : vector<4x3xf32>) { + // CHECK: %[[EXTRACT:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<4 x vector<3xf32>> + // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<3xf32>) : vector<3xf32> + // CHECK: %[[ADD:.*]] = llvm.fadd %[[ONE]], %[[EXTRACT]] {fastmathFlags = #llvm.fastmath} : vector<3xf32> + // CHECK: %[[LOG:.*]] = llvm.intr.log(%[[ADD]]) {fastmathFlags = #llvm.fastmath} : (vector<3xf32>) -> vector<3xf32> + // CHECK: %[[INSERT:.*]] = llvm.insertvalue %[[LOG]], %{{.*}}[0] : !llvm.array<4 x vector<3xf32>> + %0 = math.log1p %arg0 fastmath : vector<4x3xf32> + func.return +} + +// ----- + // CHECK-LABEL: func @expm1( // CHECK-SAME: f32 func.func @expm1(%arg0 : f32) { @@ -61,6 +86,42 @@ // ----- +// CHECK-LABEL: func @expm1_fmf( +// CHECK-SAME: f32 +func.func @expm1_fmf(%arg0 : f32) { + // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32 + // CHECK: %[[EXP:.*]] = llvm.intr.exp(%arg0) {fastmathFlags = #llvm.fastmath} : (f32) -> f32 + // CHECK: %[[SUB:.*]] = llvm.fsub %[[EXP]], %[[ONE]] {fastmathFlags = #llvm.fastmath} : f32 + %0 = math.expm1 %arg0 fastmath : f32 + func.return +} + +// ----- + +// CHECK-LABEL: func @expm1_vector( +// CHECK-SAME: vector<4xf32> +func.func @expm1_vector(%arg0 : vector<4xf32>) { + // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<4xf32>) : vector<4xf32> + // CHECK: %[[EXP:.*]] = llvm.intr.exp(%arg0) : (vector<4xf32>) -> vector<4xf32> + // CHECK: %[[SUB:.*]] = llvm.fsub %[[EXP]], %[[ONE]] : vector<4xf32> + %0 = math.expm1 %arg0 : vector<4xf32> + func.return +} + +// ----- + +// CHECK-LABEL: func @expm1_vector_fmf( +// CHECK-SAME: vector<4xf32> +func.func @expm1_vector_fmf(%arg0 : vector<4xf32>) { + // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<4xf32>) : vector<4xf32> + // CHECK: %[[EXP:.*]] = llvm.intr.exp(%arg0) {fastmathFlags = #llvm.fastmath} : (vector<4xf32>) -> vector<4xf32> + // CHECK: %[[SUB:.*]] = llvm.fsub %[[EXP]], %[[ONE]] {fastmathFlags = #llvm.fastmath} : vector<4xf32> + %0 = math.expm1 %arg0 fastmath : vector<4xf32> + func.return +} + +// ----- + // CHECK-LABEL: func @rsqrt( // CHECK-SAME: f32 func.func @rsqrt(%arg0 : f32) { @@ -148,6 +209,18 @@ // ----- +// CHECK-LABEL: func @rsqrt_double_fmf( +// CHECK-SAME: f64 +func.func @rsqrt_double_fmf(%arg0 : f64) { + // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1.000000e+00 : f64) : f64 + // CHECK: %[[SQRT:.*]] = llvm.intr.sqrt(%arg0) {fastmathFlags = #llvm.fastmath} : (f64) -> f64 + // CHECK: %[[DIV:.*]] = llvm.fdiv %[[ONE]], %[[SQRT]] {fastmathFlags = #llvm.fastmath} : f64 + %0 = math.rsqrt %arg0 fastmath : f64 + func.return +} + +// ----- + // CHECK-LABEL: func @rsqrt_vector( // CHECK-SAME: vector<4xf32> func.func @rsqrt_vector(%arg0 : vector<4xf32>) { @@ -160,6 +233,18 @@ // ----- +// CHECK-LABEL: func @rsqrt_vector_fmf( +// CHECK-SAME: vector<4xf32> +func.func @rsqrt_vector_fmf(%arg0 : vector<4xf32>) { + // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<4xf32>) : vector<4xf32> + // CHECK: %[[SQRT:.*]] = llvm.intr.sqrt(%arg0) {fastmathFlags = #llvm.fastmath} : (vector<4xf32>) -> vector<4xf32> + // CHECK: %[[DIV:.*]] = llvm.fdiv %[[ONE]], %[[SQRT]] {fastmathFlags = #llvm.fastmath} : vector<4xf32> + %0 = math.rsqrt %arg0 fastmath : vector<4xf32> + func.return +} + +// ----- + // CHECK-LABEL: func @rsqrt_multidim_vector( func.func @rsqrt_multidim_vector(%arg0 : vector<4x3xf32>) { // CHECK: %[[EXTRACT:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.array<4 x vector<3xf32>> @@ -210,3 +295,19 @@ %0 = math.trunc %arg0 : f32 func.return } + +// ----- + +// CHECK-LABEL: func @fastmath( +// CHECK-SAME: f32 +func.func @fastmath(%arg0 : f32, %arg1 : vector<4xf32>) { + // CHECK: llvm.intr.trunc(%arg0) {fastmathFlags = #llvm.fastmath} : (f32) -> f32 + %0 = math.trunc %arg0 fastmath : f32 + // CHECK: llvm.intr.pow(%arg0, %arg0) {fastmathFlags = #llvm.fastmath} : (f32, f32) -> f32 + %1 = math.powf %arg0, %arg0 fastmath : f32 + // CHECK: llvm.intr.sqrt(%arg0) : (f32) -> f32 + %2 = math.sqrt %arg0 fastmath : f32 + // CHECK: llvm.intr.fma(%arg0, %arg0, %arg0) {fastmathFlags = #llvm.fastmath} : (f32, f32, f32) -> f32 + %3 = math.fma %arg0, %arg0, %arg0 fastmath : f32 + func.return +} Index: mlir/test/Dialect/Math/ops.mlir =================================================================== --- mlir/test/Dialect/Math/ops.mlir +++ mlir/test/Dialect/Math/ops.mlir @@ -269,3 +269,15 @@ %2 = math.trunc %t : tensor<4x4x?xf32> return } + +// CHECK-LABEL: func @fastmath( +// CHECK-SAME: %[[F:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[T:.*]]: tensor<4x4x?xf32>) +func.func @fastmath(%f: f32, %v: vector<4xf32>, %t: tensor<4x4x?xf32>) { + // CHECK: %{{.*}} = math.trunc %[[F]] fastmath : f32 + %0 = math.trunc %f fastmath : f32 + // CHECK: %{{.*}} = math.powf %[[V]], %[[V]] fastmath : vector<4xf32> + %1 = math.powf %v, %v fastmath : vector<4xf32> + // CHECK: %{{.*}} = math.fma %[[T]], %[[T]], %[[T]] : tensor<4x4x?xf32> + %2 = math.fma %t, %t, %t fastmath : tensor<4x4x?xf32> + return +}