diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -279,6 +279,9 @@ SmallVector newOutputShape; ArrayRef oldShape = linalgOp.getShape(linalgOp.getDpsInitOperand(0)); + assert(sizes.size() == oldShape.size() + 1 && + "result tensor should have rank exactly one dimension smaller than " + "the number of loops."); SmallVector dynamicDims; for (int64_t idx : llvm::seq(0, oldShape.size() + 1)) { if (idx == insertSplitDimension) { diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -453,6 +453,19 @@ break; } } + { + auto origResultTensor = cast(op.getOperation()) + .getDpsInitOperand(0); + size_t origResultSize = 0; + if (auto shapedType = + origResultTensor->get().getType().dyn_cast()) + origResultSize = shapedType.getShape().size(); + if (iterationDomain.size() != origResultSize + 1) { + return b.notifyMatchFailure( + op, "only support result tensor whose rank is exactly one dimension " + "smaller than the number of loops."); + } + } // 1. create the inital tensor value. FailureOr identityTensor = op.generateInitialTensorForPartialReduction(b, loc, tileSize,