diff --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp @@ -15,6 +15,7 @@ #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" @@ -171,11 +172,14 @@ outputExpr.push_back( b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1)); } - Value initTensor = b.create( - loc, newOutputShape, op.getRegionOutputArgs()[0].getType()); + Value allocTensor = b.create( + loc, + RankedTensorType::get(newOutputShape, + op.getRegionOutputArgs()[0].getType()), + ValueRange{}); Value constantOp = b.create(loc, identity); Value identityTensor = - b.create(op->getLoc(), constantOp, initTensor) + b.create(op->getLoc(), constantOp, allocTensor) .getResult(0); newMaps.push_back(AffineMap::get(oldOutputMap.getNumDims() + 1, 0, outputExpr, @@ -189,7 +193,7 @@ // Create the new op matching the original op with an extra parallel // dimension. GenericOp genericOp = b.create( - loc, TypeRange({initTensor.getType()}), newInputs, + loc, TypeRange({allocTensor.getType()}), newInputs, ValueRange({identityTensor}), newMaps, newIteratorTypes); b.inlineRegionBefore(op->getRegion(0), genericOp.region(), genericOp.region().begin()); @@ -297,7 +301,7 @@ return b.notifyMatchFailure(op, "unknown reduction neutral"); // TODO: relax this when multi-reduction support is available. - if (op.getNumOutputs() != neutralElements.size()) + if (op.getNumOutputs() != static_cast(neutralElements.size()) return b.notifyMatchFailure(op, "expect one reduction per output"); // Rewrite part. @@ -327,8 +331,7 @@ reductionDimSize / splitFactor, insertSplitDimension); SmallVector dims = tensor::createDynamicDimValues(b, loc, rankedTensor); - Value initTensor = b.create( - loc, dims, newT.getShape(), t.getElementType()); + Value initTensor = b.create(loc, newT, dims); Value constantOp = b.create(loc, std::get<1>(it)); fillOps.push_back( b.create(op->getLoc(), constantOp, initTensor)); diff --git a/mlir/test/Dialect/Linalg/split_reduction.mlir b/mlir/test/Dialect/Linalg/split_reduction.mlir --- a/mlir/test/Dialect/Linalg/split_reduction.mlir +++ b/mlir/test/Dialect/Linalg/split_reduction.mlir @@ -15,7 +15,7 @@ // CHECK-DAG: %[[ID:.*]] = arith.constant 0.000000e+00 : f32 // CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] : tensor<16x256xf32> into tensor<16x4x64xf32> // CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] : tensor<256x32xf32> into tensor<4x64x32xf32> -// CHECK-DAG: %[[INI:.*]] = linalg.init_tensor [16, 32, 4] : tensor<16x32x4xf32> +// CHECK-DAG: %[[INI:.*]] = bufferization.alloc_tensor() : tensor<16x32x4xf32> // CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<16x32x4xf32>) -> tensor<16x32x4xf32> // CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]] // CHECK-SAME: , iterator_types = ["parallel", "parallel", "parallel", "reduction"]} @@ -57,7 +57,7 @@ //CHECK-LABEL: @generic_split_1d // CHECK: %[[ID:.*]] = arith.constant 1.000000e+00 : f32 // CHECK: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1]] : tensor<32xf32> into tensor<4x8xf32> -// CHECK: %[[INI:.*]] = linalg.init_tensor [4] : tensor<4xf32> +// CHECK: %[[INI:.*]] = bufferization.alloc_tensor() : tensor<4xf32> // CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<4xf32>) -> tensor<4xf32> // CHECK: %[[G:.*]] = linalg.generic // CHECK: {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], @@ -103,7 +103,7 @@ // CHECK: %[[ID:.*]] = arith.constant -3.40282347E+38 : f32 // CHECK-DAG: %[[I1:.*]] = tensor.expand_shape %{{.*}}[0, 1], [2]] : tensor<32x2xf32> into tensor<4x8x2xf32> // CHECK-DAG: %[[I2:.*]] = tensor.expand_shape %{{.*}}[0], [1, 2]] : tensor<5x32xf32> into tensor<5x4x8xf32> -// CHECK: %[[INI:.*]] = linalg.init_tensor [5, 2, 4] : tensor<5x2x4xf32> +// CHECK: %[[INI:.*]] = bufferization.alloc_tensor() : tensor<5x2x4xf32> // CHECK: %[[F:.*]] = linalg.fill ins(%[[ID]] : f32) outs(%[[INI]] : tensor<5x2x4xf32>) -> tensor<5x2x4xf32> // CHECK: %[[G:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "reduction", "parallel", "parallel"]} // CHECK-SAME: ins(%[[I1]], %[[I2]] : tensor<4x8x2xf32>, tensor<5x4x8xf32>) outs(%[[F]] : tensor<5x2x4xf32>) { diff --git a/mlir/test/Dialect/Linalg/transform-op-split-reduction-by-scaling.mlir b/mlir/test/Dialect/Linalg/transform-op-split-reduction-by-scaling.mlir --- a/mlir/test/Dialect/Linalg/transform-op-split-reduction-by-scaling.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-split-reduction-by-scaling.mlir @@ -3,6 +3,7 @@ // CHECK-LABEL: func.func @matmul_split func.func @matmul_split(%A : tensor, %B: tensor<256x32xf32>, %C: tensor) -> tensor { + // CHECK: bufferization.alloc_tensor({{.*}}) : tensor // CHECK: linalg.generic // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"] // CHECK-SAME: ins(%{{[a-zA-Z0-9]*}}, %{{[a-zA-Z0-9]*}}, %{{[a-zA-Z0-9]*}} : tensor, tensor<256x32xf32>, tensor<64x4xi1>) diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" @@ -41,6 +42,7 @@ void getDependentDialects(DialectRegistry ®istry) const override { // clang-format off registry.insert