diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -820,8 +820,12 @@ Value shapedOp = valuesToTile[opOperand->getOperandNumber()]; LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for operand " << shapedOp); AffineMap map = linalgOp.getTiedIndexingMap(opOperand); - // If the shape is not tiled, we can use it as is. - if (!isTiled(map, tileSizes)) { + // Use `opOperand` as is if it is not tiled and not an output tensor. Having + // an extract/insert slice pair for all output tensors simplifies follow up + // transformations such as padding and bufferization since the + // extract/insert slice pairs make the accessed iteration argument + // subdomains explicit. + if (!isTiled(map, tileSizes) && !linalgOp.isOutputTensor(opOperand)) { tiledShapes.push_back(shapedOp); LLVM_DEBUG(llvm::dbgs() << ": not tiled: use shape: " << opOperand->get().getType() << "\n"); diff --git a/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir b/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir --- a/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir +++ b/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir @@ -263,7 +263,9 @@ // CHECK: %[[ST_FILL:.*]] = linalg.fill(%[[C0]], %[[ST]]) {__internal_linalg_transform__ = "after_out_fusion_producer"} : f32, tensor -> tensor // CHECK: %[[ST_MM_RES:.*]] = scf.for %[[K:.*]]{{.*}}iter_args(%[[BB:.*]] = %[[ST_FILL]]) -> (tensor) { // CHECK-NOT: fill -// CHECK: %[[ST_MM:.*]] = linalg.matmul {__internal_linalg_transform__ = "after_out_fusion"} ins(%{{.*}}, %{{.*}} : tensor, tensor) outs(%[[BB]] : tensor) -> tensor +// CHECK: %[[ST_FILL_SUB:.*]] = tensor.extract_slice %[[BB]][0, 0] +// CHECK: %[[ST_MM_SUB:.*]] = linalg.matmul {__internal_linalg_transform__ = "after_out_fusion"} ins(%{{.*}}, %{{.*}} : tensor, tensor) outs(%[[ST_FILL_SUB]] : tensor) -> tensor +// CHECK: %[[ST_MM:.*]] = tensor.insert_slice %[[ST_MM_SUB]] into %[[BB]] // CHECK: scf.yield %[[ST_MM]] : tensor // CHECK: %[[MM:.*]] = tensor.insert_slice %[[ST_MM_RES]] into {{.*}} // CHECK: scf.yield %[[MM]] : tensor @@ -307,11 +309,13 @@ // TLOOP: %[[A_SUB_SUB:.*]] = tensor.extract_slice %[[A_SUB_]][0, %[[K]]] // TLOOP: %[[B_SUB_SUB:.*]] = tensor.extract_slice %[[B_SUB_]][%[[K]], 0] +// TLOOP: %[[INIT_SUB_SUB:.*]] = tensor.extract_slice %[[INIT_SUB_]][0, 0] // TLOOP: %[[AB_SUB_SUB:.*]] = linalg.matmul // TLOOP-SAME: ins(%[[A_SUB_SUB]], %[[B_SUB_SUB]] : [[TY]], [[TY]]) -// TLOOP-SAME: outs(%[[INIT_SUB_]] : [[TY]]) -> [[TY]] -// TLOOP: linalg.yield %[[AB_SUB_SUB]] : [[TY]] +// TLOOP-SAME: outs(%[[INIT_SUB_SUB]] : [[TY]]) -> [[TY]] +// TLOOP: %[[AB_SUB_:.*]] = tensor.insert_slice %[[AB_SUB_SUB]] into %[[INIT_SUB_]] +// TLOOP: linalg.yield %[[AB_SUB_]] : [[TY]] // TLOOP: } // TLOOP: %[[SUB_RESULT:.*]] = tensor.insert_slice %[[AB_SUB]] // TLOOP-SAME: into %[[OUT_]][%[[I]], %[[J]]] @@ -380,11 +384,13 @@ // TLOOP: %[[A_SUB_SUB:.*]] = tensor.extract_slice %[[A_SUB_]][0, %[[K]]] // TLOOP: %[[B_SUB_SUB:.*]] = tensor.extract_slice %[[B_SUB_]][%[[K]], 0] +// TLOOP: %[[INIT_SUB_SUB:.*]] = tensor.extract_slice %[[INIT_SUB_]][0, 0] // TLOOP: %[[AB_SUB_SUB:.*]] = linalg.matmul // TLOOP-SAME: ins(%[[A_SUB_SUB]], %[[B_SUB_SUB]] : [[TY]], [[TY]]) -// TLOOP-SAME: outs(%[[INIT_SUB_]] : [[TY]]) -> [[TY]] -// TLOOP: linalg.yield %[[AB_SUB_SUB]] : [[TY]] +// TLOOP-SAME: outs(%[[INIT_SUB_SUB]] : [[TY]]) -> [[TY]] +// TLOOP: %[[AB_SUB_:.*]] = tensor.insert_slice %[[AB_SUB_SUB]] into %[[INIT_SUB_]] +// TLOOP: linalg.yield %[[AB_SUB_]] : [[TY]] // TLOOP: } // TLOOP: %[[SUB_RESULT:.*]] = tensor.insert_slice %[[AB_SUB]] // TLOOP-SAME: into %[[OUT_]][%[[I]], %[[J]]] diff --git a/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir --- a/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir +++ b/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir @@ -47,6 +47,8 @@ builtin.func @fuse_output(%arg0: tensor<24x12xf32>, %arg1: tensor<12x25xf32>, %arg2: tensor<24x25xf32>) -> tensor<24x25xf32> { + // MATMUL-DAG: %[[C0:.*]] = arith.constant 0 : index + // MATMUL-DAG: %[[C1:.*]] = arith.constant 1 : index %c0 = arith.constant 0 : index %c12 = arith.constant 12 : index %c25 = arith.constant 25 : index @@ -67,7 +69,17 @@ // MATMUL-SAME: %[[TS1]], %[[TS0]] // MATMUL: %[[T1:.*]] = linalg.fill(%{{.*}}, %[[T0]]) // MATMUL: scf.for %[[IV2:.*]] = {{.*}} iter_args(%[[ARG5:.*]] = %[[T1]] - // MATMUL: %{{.*}} = linalg.matmul {{.*}} outs(%[[ARG5]] + + // Check there is an extract/insert slice pair for the output operand. + // MATMUL-DAG: %[[D0:.*]] = tensor.dim %[[ARG5]], %[[C0]] + // MATMUL-DAG: %[[D1:.*]] = tensor.dim %[[ARG5]], %[[C1]] + // MATMUL: %[[T2:.*]] = tensor.extract_slice %[[ARG5]] + // MATMUL-SAME: 0, 0 + // MATMUL-SAME: %[[D0]], %[[D1]] + // MATMUL: %[[T3:.*]] = linalg.matmul {{.*}} outs(%[[T2]] + // MATMUL: %{{.*}} = tensor.insert_slice %[[T3]] into %[[ARG5]] + // MATMUL-SAME: 0, 0 + // MATMUL-SAME: %[[D0]], %[[D1]] %1 = linalg.matmul ins(%arg0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%0 : tensor<24x25xf32>) -> tensor<24x25xf32> return %1 : tensor<24x25xf32> } @@ -185,7 +197,8 @@ // MATMUL: %[[T2:.*]] = tensor.extract_slice %[[ARG0]] // MATMUL-SAME: %[[IV1]], %[[IV2]] // MATMUL: %[[T3:.*]] = linalg.fill(%{{.*}}, %[[T2]]) - // MATMUL: %{{.*}} = linalg.matmul ins(%[[T3]], {{.*}} outs(%[[ARG5]] + // MATMUL: %[[T4:.*]] = tensor.extract_slice %[[ARG5]] + // MATMUL: %{{.*}} = linalg.matmul ins(%[[T3]], {{.*}} outs(%[[T4]] %2 = linalg.matmul ins(%0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%1 : tensor<24x25xf32>) -> tensor<24x25xf32> return %2 : tensor<24x25xf32> }