diff --git a/mlir/include/mlir/Dialect/Arith/IR/Arith.h b/mlir/include/mlir/Dialect/Arith/IR/Arith.h --- a/mlir/include/mlir/Dialect/Arith/IR/Arith.h +++ b/mlir/include/mlir/Dialect/Arith/IR/Arith.h @@ -122,16 +122,29 @@ const APFloat &rhs); /// Returns the identity value attribute associated with an AtomicRMWKind op. +/// \p useOnlyFiniteValue defines whether the identity value should steer away +/// from infinity representations or anything that is not a proper finite +/// number. +/// E.g., The identity value for maxf is in theory `-Inf`, but if we want to +/// stay in the finite range, it would be `BiggestRepresentableNegativeFloat`. +/// The purpose of this boolean is to offer constants that will play nice +/// with fast math related optimizations. TypedAttr getIdentityValueAttr(AtomicRMWKind kind, Type resultType, - OpBuilder &builder, Location loc); + OpBuilder &builder, Location loc, + bool useOnlyFiniteValue); /// Return the identity numeric value associated to the give op. Return /// std::nullopt if there is no known neutral element. -std::optional getNeutralElement(Operation *op); +/// \see getIdentityValueAttr for a description of what \p useOnlyFiniteValue +/// does. +std::optional getNeutralElement(Operation *op, + bool useOnlyFiniteValue); /// Returns the identity value associated with an AtomicRMWKind op. +/// \see getIdentityValueAttr for a description of what \p useOnlyFiniteValue +/// does. Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, - Location loc); + Location loc, bool useOnlyFiniteValue); /// Returns the value obtained by applying the reduction operation kind /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`. diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp --- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp +++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp @@ -230,7 +230,8 @@ assert(reductionOp && "Reduction operation cannot be of None Type"); arith::AtomicRMWKind reductionOpValue = *reductionOp; identityVals.push_back( - arith::getIdentityValue(reductionOpValue, resultType, rewriter, loc)); + arith::getIdentityValue(reductionOpValue, resultType, rewriter, loc, + /*useOnlyFiniteValue=*/false)); } parOp = rewriter.create( loc, lowerBoundTuple, upperBoundTuple, steps, identityVals, diff --git a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp --- a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp @@ -989,8 +989,9 @@ if (!VectorType::isValidElementType(scalarTy)) return nullptr; - Attribute valueAttr = getIdentityValueAttr( - reductionKind, scalarTy, state.builder, oldOperand.getLoc()); + Attribute valueAttr = + getIdentityValueAttr(reductionKind, scalarTy, state.builder, + oldOperand.getLoc(), /*useOnlyFiniteValue=*/false); auto vecTy = getVectorType(scalarTy, state.strategy); auto vecAttr = DenseElementsAttr::get(vecTy, valueAttr); auto newConstOp = @@ -1261,8 +1262,9 @@ Type scalarTy = value.getType(); if (!VectorType::isValidElementType(scalarTy)) return false; - Attribute valueAttr = getIdentityValueAttr(reductionKind, scalarTy, - state.builder, value.getLoc()); + Attribute valueAttr = + getIdentityValueAttr(reductionKind, scalarTy, state.builder, + value.getLoc(), /*useOnlyFiniteValue=*/false); if (auto constOp = dyn_cast_or_null(value.getDefiningOp())) return constOp.getValue() == valueAttr; return false; diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -2350,13 +2350,17 @@ /// Returns the identity value attribute associated with an AtomicRMWKind op. TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType, - OpBuilder &builder, Location loc) { + OpBuilder &builder, Location loc, + bool useOnlyFiniteValue) { switch (kind) { - case AtomicRMWKind::maxf: - return builder.getFloatAttr( - resultType, - APFloat::getInf(llvm::cast(resultType).getFloatSemantics(), - /*Negative=*/true)); + case AtomicRMWKind::maxf: { + const llvm::fltSemantics &semantic = + llvm::cast(resultType).getFloatSemantics(); + APFloat identity = useOnlyFiniteValue + ? APFloat::getSmallest(semantic, /*Negative=*/true) + : APFloat::getInf(semantic, /*Negative=*/true); + return builder.getFloatAttr(resultType, identity); + } case AtomicRMWKind::addf: case AtomicRMWKind::addi: case AtomicRMWKind::maxu: @@ -2370,11 +2374,15 @@ return builder.getIntegerAttr( resultType, APInt::getSignedMinValue( llvm::cast(resultType).getWidth())); - case AtomicRMWKind::minf: - return builder.getFloatAttr( - resultType, - APFloat::getInf(llvm::cast(resultType).getFloatSemantics(), - /*Negative=*/false)); + case AtomicRMWKind::minf: { + const llvm::fltSemantics &semantic = + llvm::cast(resultType).getFloatSemantics(); + APFloat identity = useOnlyFiniteValue + ? APFloat::getLargest(semantic, /*Negative=*/false) + : APFloat::getInf(semantic, /*Negative=*/false); + + return builder.getFloatAttr(resultType, identity); + } case AtomicRMWKind::mins: return builder.getIntegerAttr( resultType, APInt::getSignedMaxValue( @@ -2396,7 +2404,8 @@ } /// Return the identity numeric value associated to the give op. -std::optional mlir::arith::getNeutralElement(Operation *op) { +std::optional +mlir::arith::getNeutralElement(Operation *op, bool useOnlyFiniteValue) { std::optional maybeKind = llvm::TypeSwitch>(op) // Floating-point operations. @@ -2424,13 +2433,16 @@ OpBuilder b(op->getContext()); Type resultType = op->getResult(0).getType(); - return getIdentityValueAttr(*maybeKind, resultType, b, op->getLoc()); + return getIdentityValueAttr(*maybeKind, resultType, b, op->getLoc(), + useOnlyFiniteValue); } /// Returns the identity value associated with an AtomicRMWKind op. Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType, - OpBuilder &builder, Location loc) { - auto attr = getIdentityValueAttr(op, resultType, builder, loc); + OpBuilder &builder, Location loc, + bool useOnlyFiniteValue) { + auto attr = + getIdentityValueAttr(op, resultType, builder, loc, useOnlyFiniteValue); return builder.create(loc, attr); } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -2470,7 +2470,8 @@ // Step 1: Compute max along dim. Value outputReduce = b.create(loc, dims, elementType); Value neutralForMaxF = - arith::getIdentityValue(arith::AtomicRMWKind::maxf, elementType, b, loc); + arith::getIdentityValue(arith::AtomicRMWKind::maxf, elementType, b, loc, + /*useOnlyFiniteValue=*/true); Value neutralForMaxFInit = b.create(loc, Value{neutralForMaxF}, outputReduce) .result(); @@ -2481,8 +2482,8 @@ Value numerator = buildSubAndExpOp(b, loc, input, max, output, reductionDim); // Step 3: Compute sum along dim. - Value zero = - arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType, b, loc); + Value zero = arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType, + b, loc, /*useOnlyFiniteValue=*/true); Value zeroInit = b.create(loc, Value{zero}, outputReduce).result(); Value denominator = diff --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp @@ -66,7 +66,8 @@ return b.notifyMatchFailure(op, "Cannot match the reduction pattern"); Operation *reductionOp = combinerOps[0]; - std::optional identity = arith::getNeutralElement(reductionOp); + std::optional identity = + arith::getNeutralElement(reductionOp, /*useOnlyFiniteValue=*/false); if (!identity.has_value()) return b.notifyMatchFailure(op, "Unknown identity value for the reduction"); @@ -275,7 +276,7 @@ SmallVector neutralElements; for (Operation *reductionOp : combinerOps) { std::optional neutralElement = - arith::getNeutralElement(reductionOp); + arith::getNeutralElement(reductionOp, /*useOnlyFiniteValue=*/false); if (!neutralElement.has_value()) return b.notifyMatchFailure(op, "cannot find neutral element."); neutralElements.push_back(*neutralElement); diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -271,7 +271,8 @@ return op->emitOpError("Failed to anaysis the reduction operation."); Operation *reductionOp = combinerOps[0]; - std::optional identity = arith::getNeutralElement(reductionOp); + std::optional identity = + arith::getNeutralElement(reductionOp, /*useOnlyFiniteValue=*/false); if (!identity.has_value()) return op->emitOpError( "Failed to get an identity value for the reduction operation."); diff --git a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir --- a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir @@ -210,7 +210,7 @@ // CHECK-LABEL: func.func @softmax( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>, %[[DST:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> { // CHECK-DAG: %[[D1:.+]] = tensor.empty() : tensor<2x16xf32> -// CHECK-DAG: %[[CST:.+]] = arith.constant 0xFF800000 : f32 +// CHECK-DAG: %[[CST:.+]] = arith.constant -1.401300e-45 : f32 // CHECK: %[[D2:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D1]] : tensor<2x16xf32>) -> tensor<2x16xf32> // CHECK: %[[D3:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]]], iterator_types = ["parallel", // CHECK-SAME: "parallel", "reduction"]} ins(%[[ARG0]] : tensor<2x16x32xf32>) outs(%[[D2]] : tensor<2x16xf32>) {