diff --git a/mlir/include/mlir/Dialect/Arith/IR/Arith.h b/mlir/include/mlir/Dialect/Arith/IR/Arith.h --- a/mlir/include/mlir/Dialect/Arith/IR/Arith.h +++ b/mlir/include/mlir/Dialect/Arith/IR/Arith.h @@ -122,16 +122,29 @@ const APFloat &rhs); /// Returns the identity value attribute associated with an AtomicRMWKind op. +/// `useOnlyFiniteValue` defines whether the identity value should steer away +/// from infinity representations or anything that is not a proper finite +/// number. +/// E.g., The identity value for maxf is in theory `-Inf`, but if we want to +/// stay in the finite range, it would be `BiggestRepresentableNegativeFloat`. +/// The purpose of this boolean is to offer constants that will play nice +/// with fast math related optimizations. TypedAttr getIdentityValueAttr(AtomicRMWKind kind, Type resultType, - OpBuilder &builder, Location loc); + OpBuilder &builder, Location loc, + bool useOnlyFiniteValue = false); /// Return the identity numeric value associated to the give op. Return /// std::nullopt if there is no known neutral element. -std::optional getNeutralElement(Operation *op); +/// \see getIdentityValueAttr for a description of what `useOnlyFiniteValue` +/// does. +std::optional getNeutralElement(Operation *op, + bool useOnlyFiniteValue = false); /// Returns the identity value associated with an AtomicRMWKind op. +/// \see getIdentityValueAttr for a description of what `useOnlyFiniteValue` +/// does. Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, - Location loc); + Location loc, bool useOnlyFiniteValue = false); /// Returns the value obtained by applying the reduction operation kind /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`. diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -2387,13 +2387,17 @@ /// Returns the identity value attribute associated with an AtomicRMWKind op. TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType, - OpBuilder &builder, Location loc) { + OpBuilder &builder, Location loc, + bool useOnlyFiniteValue) { switch (kind) { - case AtomicRMWKind::maxf: - return builder.getFloatAttr( - resultType, - APFloat::getInf(llvm::cast(resultType).getFloatSemantics(), - /*Negative=*/true)); + case AtomicRMWKind::maxf: { + const llvm::fltSemantics &semantic = + llvm::cast(resultType).getFloatSemantics(); + APFloat identity = useOnlyFiniteValue + ? APFloat::getSmallest(semantic, /*Negative=*/true) + : APFloat::getInf(semantic, /*Negative=*/true); + return builder.getFloatAttr(resultType, identity); + } case AtomicRMWKind::addf: case AtomicRMWKind::addi: case AtomicRMWKind::maxu: @@ -2407,11 +2411,15 @@ return builder.getIntegerAttr( resultType, APInt::getSignedMinValue( llvm::cast(resultType).getWidth())); - case AtomicRMWKind::minf: - return builder.getFloatAttr( - resultType, - APFloat::getInf(llvm::cast(resultType).getFloatSemantics(), - /*Negative=*/false)); + case AtomicRMWKind::minf: { + const llvm::fltSemantics &semantic = + llvm::cast(resultType).getFloatSemantics(); + APFloat identity = useOnlyFiniteValue + ? APFloat::getLargest(semantic, /*Negative=*/false) + : APFloat::getInf(semantic, /*Negative=*/false); + + return builder.getFloatAttr(resultType, identity); + } case AtomicRMWKind::mins: return builder.getIntegerAttr( resultType, APInt::getSignedMaxValue( @@ -2433,7 +2441,8 @@ } /// Return the identity numeric value associated to the give op. -std::optional mlir::arith::getNeutralElement(Operation *op) { +std::optional +mlir::arith::getNeutralElement(Operation *op, bool useOnlyFiniteValue) { std::optional maybeKind = llvm::TypeSwitch>(op) // Floating-point operations. @@ -2461,13 +2470,16 @@ OpBuilder b(op->getContext()); Type resultType = op->getResult(0).getType(); - return getIdentityValueAttr(*maybeKind, resultType, b, op->getLoc()); + return getIdentityValueAttr(*maybeKind, resultType, b, op->getLoc(), + useOnlyFiniteValue); } /// Returns the identity value associated with an AtomicRMWKind op. Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType, - OpBuilder &builder, Location loc) { - auto attr = getIdentityValueAttr(op, resultType, builder, loc); + OpBuilder &builder, Location loc, + bool useOnlyFiniteValue) { + auto attr = + getIdentityValueAttr(op, resultType, builder, loc, useOnlyFiniteValue); return builder.create(loc, attr); } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -2492,7 +2492,8 @@ // Step 1: Compute max along dim. Value outputReduce = b.create(loc, dims, elementType); Value neutralForMaxF = - arith::getIdentityValue(arith::AtomicRMWKind::maxf, elementType, b, loc); + arith::getIdentityValue(arith::AtomicRMWKind::maxf, elementType, b, loc, + /*useOnlyFiniteValue=*/true); Value neutralForMaxFInit = b.create(loc, Value{neutralForMaxF}, outputReduce) .result(); @@ -2503,8 +2504,8 @@ Value numerator = buildSubAndExpOp(b, loc, input, max, output, reductionDim); // Step 3: Compute sum along dim. - Value zero = - arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType, b, loc); + Value zero = arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType, + b, loc, /*useOnlyFiniteValue=*/true); Value zeroInit = b.create(loc, Value{zero}, outputReduce).result(); Value denominator = diff --git a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir --- a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir @@ -210,7 +210,7 @@ // CHECK-LABEL: func.func @softmax( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>, %[[DST:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> { // CHECK-DAG: %[[D1:.+]] = tensor.empty() : tensor<2x16xf32> -// CHECK-DAG: %[[CST:.+]] = arith.constant 0xFF800000 : f32 +// CHECK-DAG: %[[CST:.+]] = arith.constant -1.401300e-45 : f32 // CHECK: %[[D2:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D1]] : tensor<2x16xf32>) -> tensor<2x16xf32> // CHECK: %[[D3:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]]], iterator_types = ["parallel", // CHECK-SAME: "parallel", "reduction"]} ins(%[[ARG0]] : tensor<2x16x32xf32>) outs(%[[D2]] : tensor<2x16xf32>) {