diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp @@ -52,6 +52,30 @@ }); } +static bool noOtherUsesInLoop(vector::TransferReadOp transferRead, + LoopLikeOpInterface loop) { + Value source = transferRead.getSource(); + while (auto subView = source.getDefiningOp()) + source = subView.getSource(); + llvm::SmallVector users(source.getUsers().begin(), + source.getUsers().end()); + llvm::SmallDenseSet processed; + while (!users.empty()) { + Operation *user = users.pop_back_val(); + // If the user has already been processed skip. + if (!processed.insert(user).second) + continue; + if (auto subView = dyn_cast(user)) { + users.append(subView->getUsers().begin(), subView->getUsers().end()); + continue; + } + if (isMemoryEffectFree(user) || isa(user)) + continue; + return false; + } + return true; +} + void mlir::linalg::hoistRedundantVectorTransfers(func::FuncOp func) { bool changed = true; while (changed) { @@ -95,9 +119,15 @@ if (!loop.isDefinedOutsideOfLoop(operand)) return WalkResult::advance(); - // Only hoist transfer_read / transfer_write pairs for now. - if (!transferWrite) + // Only hoist transfer_read / transfer_write pairs and singleton + // transfer_reads for now. + if (!transferWrite) { + // Make sure there are no other accesses to the memref before + // hoisting transfer_read. + if (noOtherUsesInLoop(transferRead, loop)) + loop.moveOutOfLoop(transferRead); return WalkResult::advance(); + } LLVM_DEBUG(DBGS() << "Candidate: " << *transferWrite.getOperation() << "\n"); diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir --- a/mlir/test/Dialect/Linalg/hoisting.mlir +++ b/mlir/test/Dialect/Linalg/hoisting.mlir @@ -676,3 +676,47 @@ transform.structured.hoist_redundant_tensor_subsets %0 : (!pdl.operation) -> !pdl.operation } + +// ----- + +// CHECK-LABEL: func.func @hoist_vector_transfer_read( +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index +// CHECK-DAG: %[[C1024:.+]] = arith.constant 1024 : index +// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<32x64xf32> +// CHECK: %[[ALLOC_0:.+]] = memref.alloc() : memref<32x128xf32> +// CHECK: %[[CAST:.+]] = memref.cast %[[ALLOC_0]] : memref<32x128xf32> to memref<32x128xf32, strided<[128, 1], +// CHECK-SAME: offset: ?>> +// CHECK: %[[D0:.+]] = vector.transfer_read %[[ALLOC]][%[[C0]], %[[C0]]], %[[CST]] {in_bounds = [true, true]} : +// CHECK-SAME: memref<32x64xf32>, vector<32x64xf32> +// CHECK: scf.for %[[ARG0:.+]] = %[[C0]] to %[[C1024]] step %[[C128]] { +// CHECK: %[[D1:.+]] = vector.transfer_read %[[ALLOC_0]][%[[C0]], %[[C0]]], %[[CST]] {in_bounds = [true, true]} +// CHECK-SAME: : memref<32x128xf32>, vector<32x128xf32> +// CHECK: "some_use"(%[[D0]], %[[D1]], %[[CAST]]) : (vector<32x64xf32>, vector<32x128xf32>, memref<32x128xf32, +// CHECK-SAME: strided<[128, 1], offset: ?>>) -> () +// CHECK: } +// CHECK: return +func.func @hoist_vector_transfer_read() { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %cst_2 = arith.constant 0.000000e+00 : f32 + %memref0 = memref.alloc() : memref<32x64xf32> + %memref2 = memref.alloc() : memref<32x128xf32> + %subview2 = memref.subview %memref2[%c0, %c0] [32, 128] [1, 1]: memref<32x128xf32> to memref<32x128xf32, strided<[128, 1], offset: ?>> + scf.for %arg0 = %c0 to %c1024 step %c128 { + %2 = vector.transfer_read %memref2[%c0, %c0], %cst_2 {in_bounds = [true, true]} : memref<32x128xf32>, vector<32x128xf32> + %3 = vector.transfer_read %memref0[%c0, %c0], %cst_2 {in_bounds = [true, true]} : memref<32x64xf32>, vector<32x64xf32> + "some_use"(%3, %2, %subview2) : (vector<32x64xf32>, vector<32x128xf32>, memref<32x128xf32, strided<[128, 1], offset: ?>>) -> () + } + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["func.func"]} in %arg1 + : (!pdl.operation) -> !pdl.operation + transform.structured.hoist_redundant_vector_transfers %0 + : (!pdl.operation) -> !pdl.operation +}