diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -354,8 +354,6 @@ return getValueOrCreateConstantIndexOp(b, loc, ofr); })); - Operation *tiledOp = nullptr; - // 1. Create the ForallOp. We don't use the lambda body-builder // version because we require the use of RewriterBase in the body, so we // manually move the insertion point to the body below. @@ -371,6 +369,8 @@ // 3. Clone the tileable op and update its destination operands to use the // output bbArgs of the ForallOp. ArrayRef destBbArgs = forallOp.getOutputBlockArguments(); + Operation *tiledOp = nullptr; + SmallVector tiledValues; { // 3.a. RAII guard, inserting within forallOp, before terminator. OpBuilder::InsertionGuard g(b); @@ -395,13 +395,12 @@ assert(tilingResult->tiledOps.size() == 1 && "expected a single produced tiled op"); tiledOp = tilingResult->tiledOps.front(); + tiledValues = tilingResult->tiledValues; } // 5. Parallel insert back into the result tensor. - auto tilingInterfaceOp = dyn_cast(tiledOp); - assert(tilingInterfaceOp && "Tiled op does not implement TilingInterface"); for (auto it : llvm::zip(llvm::seq(unsigned(0), unsigned(dest.size())), - tilingInterfaceOp->getResults(), destBbArgs)) { + tiledValues, destBbArgs)) { // 5.a. Partial subset information is inserted just before the terminator. OpBuilder::InsertionGuard g(b); b.setInsertionPoint(forallOp.getTerminator()); diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp @@ -617,7 +617,8 @@ // Create pad(extract_slice(x)). Value newSliceOp = b.create( loc, padOp.getSource(), newOffsets, newLengths, newStrides); - auto newPadOp = b.create(loc, Type(), newSliceOp, newLows, newHighs); + auto newPadOp = b.create(loc, Type(), newSliceOp, newLows, newHighs, + /*nofold=*/padOp.getNofold()); // Copy region to new PadOp. IRMapping bvm; diff --git a/mlir/test/Dialect/Linalg/transform-op-tile.mlir b/mlir/test/Dialect/Linalg/transform-op-tile.mlir --- a/mlir/test/Dialect/Linalg/transform-op-tile.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-tile.mlir @@ -123,3 +123,28 @@ -> tensor<128x128xf32> return %0, %1 : tensor<128x128xf32>, tensor<128x128xf32> } + +// ----- + +// CHECK-LABEL: tile_tensor_pad +func.func @tile_tensor_pad( + %arg0 : tensor, %cst : f32, %low: index, %high: index) + -> tensor<20x40xf32> +{ + // CHECK: scf.forall + // CHECK: scf.if + // CHECK: tensor.generate + // CHECK: else + // CHECK: tensor.pad {{.*}} nofold + %0 = tensor.pad %arg0 nofold low[%low, %low] high[%high, %high] { + ^bb0(%arg9: index, %arg10: index): + tensor.yield %cst : f32 + } : tensor to tensor<20x40xf32> + return %0 : tensor<20x40xf32> +} + +transform.sequence failures(propagate) { +^bb0(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 : (!pdl.operation) -> !pdl.operation + transform.structured.tile_to_forall_op %0 tile_sizes[1, 1] +}