diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -2465,31 +2465,32 @@ Type elementType = inputType.getElementType(); int64_t reductionDim = getDimension(); SmallVector dims = tensor::getMixedSizes(b, loc, input); - Value outputNd = b.create(loc, dims, elementType); + Value output = getOutput(); dims.erase(dims.begin() + reductionDim); // Step 1: Compute max along dim. - Value output = b.create(loc, dims, elementType); + Value outputReduce = b.create(loc, dims, elementType); Value neutralForMaxF = arith::getIdentityValue(arith::AtomicRMWKind::maxf, elementType, b, loc); Value neutralForMaxFInit = - b.create(loc, Value{neutralForMaxF}, output).result(); + b.create(loc, Value{neutralForMaxF}, outputReduce) + .result(); Value max = reduce(b, loc, input, neutralForMaxFInit, reductionDim); // Step 2: Subtract max from input and exponentiate. - Value numerator = - buildSubAndExpOp(b, loc, input, max, outputNd, reductionDim); + Value numerator = buildSubAndExpOp(b, loc, input, max, output, reductionDim); // Step 3: Compute sum along dim. Value zero = arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType, b, loc); - Value zeroInit = b.create(loc, Value{zero}, output).result(); + Value zeroInit = + b.create(loc, Value{zero}, outputReduce).result(); Value denominator = reduce(b, loc, numerator, zeroInit, reductionDim); // Step 4: Compute softmax. Value result = - buildDivOp(b, loc, numerator, denominator, outputNd, reductionDim); + buildDivOp(b, loc, numerator, denominator, output, reductionDim); return SmallVector{result}; } diff --git a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir --- a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir @@ -208,8 +208,7 @@ } // CHECK-LABEL: func.func @softmax( -//CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>, %[[DST:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> { -// CHECK-DAG: %[[D0:.+]] = tensor.empty() : tensor<2x16x32xf32> +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>, %[[DST:[a-zA-Z0-9_]+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> { // CHECK-DAG: %[[D1:.+]] = tensor.empty() : tensor<2x16xf32> // CHECK-DAG: %[[CST:.+]] = arith.constant 0xFF800000 : f32 // CHECK: %[[D2:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[D1]] : tensor<2x16xf32>) -> tensor<2x16xf32> @@ -221,7 +220,7 @@ // CHECK: } -> tensor<2x16xf32> // CHECK: %[[D4:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP]]], iterator_types = // CHECK-SAME: ["parallel", "parallel", "parallel"]} ins(%[[ARG0]], %[[D3]] : tensor<2x16x32xf32>, tensor<2x16xf32>) -// CHECK-SAME: outs(%[[D0]] : tensor<2x16x32xf32>) { +// CHECK-SAME: outs(%[[DST]] : tensor<2x16x32xf32>) { // CHECK: ^bb0(%[[IN:.+]]: f32, %[[IN_1:.+]]: f32, %[[OUT:.+]]: f32): // CHECK: %[[D8]] = arith.subf %[[IN]], %[[IN_1]] : f32 // CHECK: %[[D9:.+]] = math.exp %[[D8]] : f32 @@ -237,13 +236,12 @@ // CHECK: } -> tensor<2x16xf32> // CHECK: %[[D7:.+]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP1]], #[[$MAP]]], iterator_types = // CHECK-SAME: ["parallel", "parallel", "parallel"]} ins(%[[D4]], %[[D6]] : tensor<2x16x32xf32>, tensor<2x16xf32>) -// CHECK-SAME: outs(%[[D0]] : tensor<2x16x32xf32>) { +// CHECK-SAME: outs(%[[DST]] : tensor<2x16x32xf32>) { // CHECK: ^bb0(%[[IN:.+]]: f32, %[[IN_1:.+]]: f32, %[[OUT:.+]]: f32): // CHECK: %[[D8]] = arith.divf %[[IN]], %[[IN_1]] : f32 // CHECK: linalg.yield %[[D8]] : f32 // CHECK: } -> tensor<2x16x32xf32> // CHECK: return %[[D7]] : tensor<2x16x32xf32> -// CHECK: } transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op):