diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -380,11 +380,14 @@ auto destinationStyleOp = dyn_cast(clonedOp); if (destinationStyleOp) { for (OpOperand *outOperand : destinationStyleOp.getDpsInitOperands()) { - auto *it = llvm::find(dest, outOperand->get()); - if (it == dest.end()) - return op->emitOpError("must have \"tensor semantic\" for tiling"); - unsigned destNum = std::distance(dest.begin(), it); - outOperand->set(destBbArgs[destNum]); + // Swap tensor inits with the corresponding block argument of the + // scf.forall op. Memref inits remain as is. + if (outOperand->get().getType().isa()) { + auto *it = llvm::find(dest, outOperand->get()); + assert(it != dest.end() && "could not find destination tensor"); + unsigned destNum = std::distance(dest.begin(), it); + outOperand->set(destBbArgs[destNum]); + } } } diff --git a/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir b/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir --- a/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir +++ b/mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir @@ -40,6 +40,53 @@ // ----- +module { + // CHECK-LABEL: func @matmul_memref( + // CHECK: scf.forall (%{{.*}}, %{{.*}}) in (10, 20) { + // CHECK: memref.subview + // CHECK: memref.subview + // CHECK: memref.subview + // CHECK: linalg.matmul + // CHECK: } {mapping = [#gpu.thread, #gpu.thread]} + func.func @matmul_memref(%A: memref, %B: memref, %C: memref) { + linalg.matmul ins(%A, %B : memref, memref) + outs(%C : memref) + return + } + + transform.sequence failures(propagate) { + ^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1:2 = transform.structured.tile_to_forall_op %0 num_threads [10, 20] (mapping = [ #gpu.thread, #gpu.thread ] ) + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + } +} + +// ----- + +module { + // CHECK-LABEL: func @copy_memref( + // CHECK: scf.forall (%{{.*}}, %{{.*}}) in (10, 20) { + // CHECK: memref.subview + // CHECK: memref.subview + // CHECK: linalg.copy + // CHECK: } {mapping = [#gpu.thread, #gpu.thread]} + func.func @copy_memref(%A: memref, %B: memref) { + linalg.copy ins(%A: memref) + outs(%B : memref) + return + } + + transform.sequence failures(propagate) { + ^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["linalg.copy"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1:2 = transform.structured.tile_to_forall_op %0 num_threads [10, 20] (mapping = [ #gpu.thread, #gpu.thread ] ) + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + } +} + +// ----- + // In this test case, matmul dims and tile size are dynamic. // CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0, s1] -> (s0 ceildiv s1)>