diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp @@ -52,6 +52,19 @@ }); } +static bool noOtherUsesInLoop(vector::TransferReadOp transferRead, + LoopLikeOpInterface loop) { + for (auto &use : transferRead.getSource().getUses()) { + // Skip current transfer_read. + if (use.getOwner() == transferRead.getOperation()) + continue; + // Make sure use is not inside loop. + if (loop->isAncestor(use.getOwner())) + return false; + } + return true; +} + void mlir::linalg::hoistRedundantVectorTransfers(func::FuncOp func) { bool changed = true; while (changed) { @@ -95,9 +108,15 @@ if (!loop.isDefinedOutsideOfLoop(operand)) return WalkResult::advance(); - // Only hoist transfer_read / transfer_write pairs for now. - if (!transferWrite) + // Only hoist transfer_read / transfer_write pairs and singleton + // transfer_reads for now. + if (!transferWrite) { + // Make sure there are no other accesses to the memref before + // hoisting transfer_read. + if (noOtherUsesInLoop(transferRead, loop)) + loop.moveOutOfLoop(transferRead); return WalkResult::advance(); + } LLVM_DEBUG(DBGS() << "Candidate: " << *transferWrite.getOperation() << "\n"); diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir --- a/mlir/test/Dialect/Linalg/hoisting.mlir +++ b/mlir/test/Dialect/Linalg/hoisting.mlir @@ -676,3 +676,50 @@ transform.structured.hoist_redundant_tensor_subsets %0 : (!pdl.operation) -> !pdl.operation } + +// ----- + +// CHECK-LABEL: func.func @hoist_vector_transfer_read( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<32x64xf32>, +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref<1x1024x64xf32>, +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref<32x128xf32> +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index +// CHECK-DAG: %[[C1024:.+]] = arith.constant 1024 : index +// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[D0:.+]] = vector.transfer_read %[[ARG2]][%[[C0]], %[[C0]]], %[[CST]] {in_bounds = [true, true]} : +// CHECK-SAME: memref<32x128xf32>, vector<32x128xf32> +// CHECK: %[[D1:.+]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], %[[CST]] {in_bounds = [true, true]} : +// CHECK-SAME: memref<32x64xf32>, vector<32x64xf32> +// CHECK: %[[D2:.+]] = scf.for %[[ARG3:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C1024]] step %[[C128]] +// CHECK-SAME: iter_args(%[[ARG4:[a-zA-Z0-9_]+]] = %[[D0]]) -> (vector<32x128xf32>) { +// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG1]][0, %[[ARG3]], 0] [1, 128, 64] [1, 1, 1] : +// CHECK-SAME: memref<1x1024x64xf32> to memref<128x64xf32, strided<[64, 1], offset: ?>> +// CHECK: %[[D3:.+]] = vector.transfer_read %[[SUBVIEW]][%[[C0]], %[[C0]]], %[[CST]] {in_bounds = [true, true]} +// CHECK-SAME: : memref<128x64xf32, strided<[64, 1], offset: ?>>, vector<128x64xf32> +// CHECK: %[[D4:.+]] = "some_use"(%[[D1]], %[[D3]], %[[ARG4]]) : (vector<32x64xf32>, vector<128x64xf32>, +// CHECK-SAME: vector<32x128xf32>) -> vector<32x128xf32> +// CHECK: scf.yield %[[D4]] : vector<32x128xf32> +func.func @hoist_vector_transfer_read(%memref0 : memref<32x64xf32>, %memref1 : memref<1x1024x64xf32>, %memref2 : memref<32x128xf32>) { + %c0 = arith.constant 0 : index + %c128 = arith.constant 128 : index + %c1024 = arith.constant 1024 : index + %cst_2 = arith.constant 0.000000e+00 : f32 + %0 = vector.transfer_read %memref2[%c0, %c0], %cst_2 {in_bounds = [true, true]} : memref<32x128xf32>, vector<32x128xf32> + %1 = scf.for %arg0 = %c0 to %c1024 step %c128 iter_args(%arg1 = %0) -> (vector<32x128xf32>) { + %subview1 = memref.subview %memref1[0, %arg0, 0] [1, 128, 64] [1, 1, 1] : memref<1x1024x64xf32> to memref<128x64xf32, strided<[64, 1], offset: ?>> + %2 = vector.transfer_read %subview1[%c0, %c0], %cst_2 {in_bounds = [true, true]} : memref<128x64xf32, strided<[64, 1], offset: ?>>, vector<128x64xf32> + %3 = vector.transfer_read %memref0[%c0, %c0], %cst_2 {in_bounds = [true, true]} : memref<32x64xf32>, vector<32x64xf32> + %4 = "some_use"(%3, %2, %arg1) : (vector<32x64xf32>, vector<128x64xf32>, vector<32x128xf32>) -> vector<32x128xf32> + scf.yield %4 : vector<32x128xf32> + } + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["func.func"]} in %arg1 + : (!pdl.operation) -> !pdl.operation + transform.structured.hoist_redundant_vector_transfers %0 + : (!pdl.operation) -> !pdl.operation +}