Index: mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp =================================================================== --- mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -276,6 +276,9 @@ SmallVector newOutputShape; ArrayRef oldShape = linalgOp.getShape(linalgOp.getDpsInitOperand(0)); + assert(sizes.size() == oldShape.size() + 1 && + "result tensor should have rank exactly one dimension smaller than " + "the number of loops."); SmallVector dynamicDims; for (int64_t idx : llvm::seq(0, oldShape.size() + 1)) { if (idx == insertSplitDimension) { Index: mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp =================================================================== --- mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -453,6 +453,19 @@ break; } } + { + auto origResultTensor = cast(op.getOperation()) + .getDpsInitOperand(0); + size_t origResultSize = 0; + if (auto shapedType = + origResultTensor->get().getType().dyn_cast()) + origResultSize = shapedType.getShape().size(); + if (iterationDomain.size() != origResultSize + 1) { + return b.notifyMatchFailure( + op, "only support result tensor whose rank is exactly one dimension " + "smaller than the number of loops."); + } + } // 1. create the inital tensor value. FailureOr identityTensor = op.generateInitialTensorForPartialReduction(b, loc, tileSize, Index: mlir/test/Dialect/Linalg/transform-tile-reduction.mlir =================================================================== --- mlir/test/Dialect/Linalg/transform-tile-reduction.mlir +++ mlir/test/Dialect/Linalg/transform-tile-reduction.mlir @@ -197,3 +197,36 @@ // CHECK: linalg.yield // CHECK: } -> tensor // CHECK: return %[[R]] : tensor + +// ----- + +func.func @reduction_bug(%arg0: tensor<32x32xi32>, %arg1: tensor<32x32xi32>, %out: tensor<32xi32>) -> tensor<32xi32> { + %red = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1)>, + affine_map<(d0, d1, d2) -> (d0)>], + iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<32x32xi32>, tensor<32x32xi32>) outs(%out : tensor<32xi32>) { + ^bb0(%a: i32, %b: i32, %c: i32): + %r1 = arith.muli %a, %b: i32 + %r2 = arith.addi %c, %r1 : i32 + linalg.yield %r2 : i32 + } -> tensor<32xi32> + return %red : tensor<32xi32> +} + +transform.sequence failures(suppress) { +^bb0(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1, %2, %3 = transform.structured.tile_reduction_using_scf %0 { tile_sizes = [0, 0, 8] } +} + +// // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> +// // CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> +// // CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0)> +// // CHECK-LABEL: func @reduction_bug +// // CHECK: %[[RED:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "reduction"]} +// // CHECK-SAME: ins(%[[I1]], %[[I2]] : tensor<32x32xi32>, tensor<32x32xi32>) outs(%[[F]] : tensor<32xi32>) { +// // CHECK: arith.muli +// // CHECK: arith.addi +// // CHECK: linalg.yield +// // CHECK: } -> tensor<32xi32> +// // CHECK: return %[[RED]] : tensor<32xi32>