diff --git a/mlir/include/mlir/Dialect/Arith/IR/Arith.h b/mlir/include/mlir/Dialect/Arith/IR/Arith.h --- a/mlir/include/mlir/Dialect/Arith/IR/Arith.h +++ b/mlir/include/mlir/Dialect/Arith/IR/Arith.h @@ -122,16 +122,28 @@ const APFloat &rhs); /// Returns the identity value attribute associated with an AtomicRMWKind op. +/// `useOnlyFiniteValue` defines whether the identity value should steer away +/// from infinity representations or anything that is not a proper finite +/// number. +/// E.g., The identity value for maxf is in theory `-Inf`, but if we want to +/// stay in the finite range, it would be `BiggestRepresentableNegativeFloat`. +/// The purpose of this boolean is to offer constants that will play nice +/// with fast math related optimizations. TypedAttr getIdentityValueAttr(AtomicRMWKind kind, Type resultType, - OpBuilder &builder, Location loc); + OpBuilder &builder, Location loc, + bool useOnlyFiniteValue = false); /// Return the identity numeric value associated to the give op. Return /// std::nullopt if there is no known neutral element. +/// If `op` has `FastMathFlags::ninf`, only finite values will be used +/// as neutral element. std::optional getNeutralElement(Operation *op); /// Returns the identity value associated with an AtomicRMWKind op. +/// \see getIdentityValueAttr for a description of what `useOnlyFiniteValue` +/// does. Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, - Location loc); + Location loc, bool useOnlyFiniteValue = false); /// Returns the value obtained by applying the reduction operation kind /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`. diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -2387,13 +2387,17 @@ /// Returns the identity value attribute associated with an AtomicRMWKind op. TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType, - OpBuilder &builder, Location loc) { + OpBuilder &builder, Location loc, + bool useOnlyFiniteValue) { switch (kind) { - case AtomicRMWKind::maxf: - return builder.getFloatAttr( - resultType, - APFloat::getInf(llvm::cast(resultType).getFloatSemantics(), - /*Negative=*/true)); + case AtomicRMWKind::maxf: { + const llvm::fltSemantics &semantic = + llvm::cast(resultType).getFloatSemantics(); + APFloat identity = useOnlyFiniteValue + ? APFloat::getSmallest(semantic, /*Negative=*/true) + : APFloat::getInf(semantic, /*Negative=*/true); + return builder.getFloatAttr(resultType, identity); + } case AtomicRMWKind::addf: case AtomicRMWKind::addi: case AtomicRMWKind::maxu: @@ -2407,11 +2411,15 @@ return builder.getIntegerAttr( resultType, APInt::getSignedMinValue( llvm::cast(resultType).getWidth())); - case AtomicRMWKind::minf: - return builder.getFloatAttr( - resultType, - APFloat::getInf(llvm::cast(resultType).getFloatSemantics(), - /*Negative=*/false)); + case AtomicRMWKind::minf: { + const llvm::fltSemantics &semantic = + llvm::cast(resultType).getFloatSemantics(); + APFloat identity = useOnlyFiniteValue + ? APFloat::getLargest(semantic, /*Negative=*/false) + : APFloat::getInf(semantic, /*Negative=*/false); + + return builder.getFloatAttr(resultType, identity); + } case AtomicRMWKind::mins: return builder.getIntegerAttr( resultType, APInt::getSignedMaxValue( @@ -2457,17 +2465,28 @@ return std::nullopt; } + bool useOnlyFiniteValue = false; + auto fmfOpInterface = dyn_cast(op); + if (fmfOpInterface) { + arith::FastMathFlagsAttr fmfAttr = fmfOpInterface.getFastMathFlagsAttr(); + useOnlyFiniteValue = + bitEnumContainsAny(fmfAttr.getValue(), arith::FastMathFlags::ninf); + } + // Builder only used as helper for attribute creation. OpBuilder b(op->getContext()); Type resultType = op->getResult(0).getType(); - return getIdentityValueAttr(*maybeKind, resultType, b, op->getLoc()); + return getIdentityValueAttr(*maybeKind, resultType, b, op->getLoc(), + useOnlyFiniteValue); } /// Returns the identity value associated with an AtomicRMWKind op. Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType, - OpBuilder &builder, Location loc) { - auto attr = getIdentityValueAttr(op, resultType, builder, loc); + OpBuilder &builder, Location loc, + bool useOnlyFiniteValue) { + auto attr = + getIdentityValueAttr(op, resultType, builder, loc, useOnlyFiniteValue); return builder.create(loc, attr); } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -2492,7 +2492,8 @@ // Step 1: Compute max along dim. Value outputReduce = b.create(loc, dims, elementType); Value neutralForMaxF = - arith::getIdentityValue(arith::AtomicRMWKind::maxf, elementType, b, loc); + arith::getIdentityValue(arith::AtomicRMWKind::maxf, elementType, b, loc, + /*useOnlyFiniteValue=*/true); Value neutralForMaxFInit = b.create(loc, Value{neutralForMaxF}, outputReduce) .result(); @@ -2503,8 +2504,8 @@ Value numerator = buildSubAndExpOp(b, loc, input, max, output, reductionDim); // Step 3: Compute sum along dim. - Value zero = - arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType, b, loc); + Value zero = arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType, + b, loc, /*useOnlyFiniteValue=*/true); Value zeroInit = b.create(loc, Value{zero}, outputReduce).result(); Value denominator = diff --git a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir --- a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir @@ -210,7 +210,7 @@ // CHECK-LABEL: func.func @softmax( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>, %[[DST:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> { // CHECK-DAG: %[[D1:.+]] = tensor.empty() : tensor<2x16xf32> -// CHECK-DAG: %[[CST:.+]] = arith.constant 0xFF800000 : f32 +// CHECK-DAG: %[[CST:.+]] = arith.constant -1.401300e-45 : f32 // CHECK: %[[D2:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D1]] : tensor<2x16xf32>) -> tensor<2x16xf32> // CHECK: %[[D3:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]]], iterator_types = ["parallel", // CHECK-SAME: "parallel", "reduction"]} ins(%[[ARG0]] : tensor<2x16x32xf32>) outs(%[[D2]] : tensor<2x16xf32>) { diff --git a/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir b/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir --- a/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-split-reduction.mlir @@ -141,6 +141,62 @@ // ----- +// Check that we don't use -inf as the neutral element for maxf when maxf has +// ninf. Instead check that we use the smallest finite floating point value. +// Also check that the fastmath flags are set on the created maxf +// instructions. +func.func @generic_split_3d_ninf(%input: tensor<32x2xf32>, %input_2: tensor<5x32xf32>, %output: tensor<5x2xf32>) + -> tensor<5x2xf32> +{ + %0 = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1, d2) -> (d1, d0)>, + affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1, d2) -> (d2, d0)> + ], + iterator_types = ["parallel", "reduction", "parallel"] + } ins(%input, %input_2 : tensor<32x2xf32>, tensor<5x32xf32>) outs(%output : tensor<5x2xf32>) { + ^bb0(%arg0: f32, %arg1: f32, %arg2: f32): + %3 = arith.addf %arg0, %arg1 : f32 + %4 = arith.maxf %3, %arg2 fastmath : f32 + linalg.yield %4 : f32 + } -> tensor<5x2xf32> + return %0 : tensor<5x2xf32> +} + +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d1, d0)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d2, d1)> +// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d2)> +// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> +// CHECK-LABEL: func @generic_split_3d_ninf +// CHECK-DAG: %[[ID:.*]] = arith.constant -1.401300e-45 : f32 +// CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] : tensor<32x2xf32> into tensor<4x8x2xf32> +// CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] : tensor<5x32xf32> into tensor<5x4x8xf32> +// CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<5x2x4xf32> +// CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<5x2x4xf32>) -> tensor<5x2x4xf32> +// CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel"]} +// CHECK-SAME: ins(%[[I1]], %[[I2]] : tensor<4x8x2xf32>, tensor<5x4x8xf32>) outs(%[[F]] : tensor<5x2x4xf32>) { +// CHECK: arith.addf +// CHECK: arith.maxf {{.*}} fastmath +// CHECK: linalg.yield +// CHECK: } -> tensor<5x2x4xf32> +// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP3]], #[[$MAP4]]], iterator_types = ["parallel", "parallel", "reduction"]} +// CHECK-SAME: ins(%[[G]] : tensor<5x2x4xf32>) outs(%{{.*}} : tensor<5x2xf32>) { +// CHECK: arith.maxf {{.*}} fastmath +// CHECK: linalg.yield +// CHECK: } -> tensor<5x2xf32> +// CHECK: return %[[R]] : tensor<5x2xf32> + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1:4 = transform.structured.split_reduction %0 { split_factor = 4, insert_split_dimension = 2} + : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) +} + +// ----- + func.func @matmul_split(%A : tensor<16x256xf32>, %B: tensor<256x32xf32>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> { %0 = linalg.matmul ins(%A, %B: tensor<16x256xf32>, tensor<256x32xf32>) outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32> @@ -279,3 +335,59 @@ %1:4 = transform.structured.split_reduction %0 { split_factor = 4, insert_split_dimension = 2, inner_parallel} : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) } + +// ----- + +// Check that we don't use +inf as the neutral element for minf when minf has +// ninf. Instead check that we use the largest finite floating point value. +// Also check that the fastmath flags are set on the created minf +// instructions. +func.func @generic_split_3d(%input: tensor<32x2xf32>, %input_2: tensor<5x32xf32>, %output: tensor<5x2xf32>) + -> tensor<5x2xf32> +{ + %0 = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1, d2) -> (d1, d0)>, + affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1, d2) -> (d2, d0)> + ], + iterator_types = ["parallel", "reduction", "parallel"] + } ins(%input, %input_2 : tensor<32x2xf32>, tensor<5x32xf32>) outs(%output : tensor<5x2xf32>) { + ^bb0(%arg0: f32, %arg1: f32, %arg2: f32): + %3 = arith.addf %arg0, %arg1 : f32 + %4 = arith.minf %3, %arg2 fastmath : f32 + linalg.yield %4 : f32 + } -> tensor<5x2xf32> + return %0 : tensor<5x2xf32> +} + +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d0)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d1, d2)> +// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d0, d2)> +// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[$MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> +// CHECK-LABEL: func @generic_split_3d +// CHECK-DAG: %[[ID:.*]] = arith.constant 3.40282347E+38 : f32 +// CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] : tensor<32x2xf32> into tensor<8x4x2xf32> +// CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] : tensor<5x32xf32> into tensor<5x8x4xf32> +// CHECK-DAG: %[[INI:.*]] = tensor.empty() : tensor<5x2x4xf32> +// CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<5x2x4xf32>) -> tensor<5x2x4xf32> +// CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel"]} +// CHECK-SAME: ins(%[[I1]], %[[I2]] : tensor<8x4x2xf32>, tensor<5x8x4xf32>) outs(%[[F]] : tensor<5x2x4xf32>) { +// CHECK: arith.addf +// CHECK: arith.minf {{.*}} fastmath +// CHECK: linalg.yield +// CHECK: } -> tensor<5x2x4xf32> +// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP3]], #[[$MAP4]]], iterator_types = ["parallel", "parallel", "reduction"]} +// CHECK-SAME: ins(%[[G]] : tensor<5x2x4xf32>) outs(%{{.*}} : tensor<5x2xf32>) { +// CHECK: arith.minf {{.*}} fastmath +// CHECK: linalg.yield +// CHECK: } -> tensor<5x2xf32> +// CHECK: return %[[R]] : tensor<5x2xf32> + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1:4 = transform.structured.split_reduction %0 { split_factor = 4, insert_split_dimension = 2, inner_parallel} + : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) +}