diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/SCF/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/TilingInterface.h" @@ -286,6 +287,24 @@ return v.cast(); } +// replace iter args of the outer most loop with region args of the inner most +// one. +static void replaceIterArgs(scf::ForOp outerFor, scf::ForOp innerFor, + PatternRewriter &rewriter) { + assert(outerFor.getNumIterOperands() == innerFor.getNumIterOperands() && + "expect same number of iter args"); + Region &body = innerFor.getRegion(); + BlockAndValueMapping mapping; + mapping.map(outerFor.getIterOperands(), innerFor.getRegionIterArgs()); + body.walk([&](tensor::ExtractSliceOp sliceOp) { + Value source = sliceOp.getSource(); + if (mapping.contains(source)) + rewriter.updateRootInPlace(sliceOp, [&]() { + sliceOp.getSourceMutable().assign(mapping.lookup(source)); + }); + }); +} + FailureOr scf::TileConsumerAndFuseProducersUsingSCFForOp::returningMatchAndRewrite( TilingInterface op, PatternRewriter &rewriter) const { @@ -401,5 +420,7 @@ } } } + replaceIterArgs(tileAndFuseResult.loops.front(), + tileAndFuseResult.loops.back(), rewriter); return tileAndFuseResult; } diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir --- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir +++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir @@ -23,7 +23,7 @@ // CHECK-SAME: iter_args(%[[ITERARG1:.+]] = %[[ITERARG0]]) // CHECK-DAG: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0] // CHECK-DAG: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]] -// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT]][%[[IV0]], %[[IV1]]] +// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV0]], %[[IV1]]] // CHECK: %[[FILL_TILE:.+]] = linalg.fill // CHECK-SAME: outs(%[[INIT_TILE]] : // CHECK: %[[GEMM_TILE:.+]] = linalg.matmul @@ -68,7 +68,7 @@ // CHECK-SAME: iter_args(%[[ITERARG1:.+]] = %[[ITERARG0]]) // CHECK-DAG: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0] // CHECK-DAG: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]] -// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT]][%[[IV0]], %[[IV1]]] +// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV0]], %[[IV1]]] // CHECK: %[[FILL_TILE:.+]] = linalg.fill // CHECK-SAME: outs(%[[INIT_TILE]] : // CHECK: %[[GEMM_TILE:.+]] = linalg.matmul @@ -123,7 +123,7 @@ // CHECK-SAME: ins(%[[LHS0_TILE]], %[[RHS0_TILE]] : // CHECK-SAME: outs(%[[FILL0_TILE]] : // CHECK-DAG: %[[RHS1_TILE:.+]] = tensor.extract_slice %[[RHS1]][0, 0] -// CHECK-DAG: %[[INIT1_TILE:.+]] = tensor.extract_slice %[[INIT1]][%[[IV]], 0] +// CHECK-DAG: %[[INIT1_TILE:.+]] = tensor.extract_slice %[[ITERARG]][%[[IV]], 0] // CHECK: %[[FILL1_TILE:.+]] = linalg.fill // CHECK-SAME: outs(%[[INIT1_TILE]] : // CHECK: %[[GEMM1_TILE:.+]] = linalg.matmul