Index: mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp =================================================================== --- mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp +++ mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp @@ -80,6 +80,29 @@ } } +/// Return true if we can prove that the transfer operations access dijoint +/// memory. +template +static bool isDisjoint(TransferTypeA transferA, TransferTypeB transferB) { + if (transferA.memref() != transferB.memref()) + return false; + for (unsigned i = 0, e = std::min(transferA.indices().size(), + transferB.indices().size()); + i < e; i++) { + auto indexA = transferA.indices()[i].template getDefiningOp(); + auto indexB = transferB.indices()[i].template getDefiningOp(); + // If any of the indices are dynamic we cannot prove anything. + if (!indexA || !indexB) + return false; + // If we find a different static index then different slices are being + // accessed. + if (indexA.getValue().template cast().getInt() != + indexB.getValue().template cast().getInt()) + return true; + } + return false; +} + void mlir::linalg::hoistRedundantVectorTransfers(FuncOp func) { bool changed = true; while (changed) { @@ -129,9 +152,9 @@ // Approximate aliasing by checking that: // 1. indices are the same, - // 2. no other use either dominates the transfer_read or is dominated - // by the transfer_write (i.e. aliasing between the write and the read - // across the loop). + // 2. no other operations in the loop access the same memref except + // for transferRead/TramsferWrite accessing statically disjoint + // slices. if (transferRead.indices() != transferWrite.indices()) return WalkResult::advance(); @@ -140,11 +163,26 @@ DominanceInfo dom(loop); if (!dom.properlyDominates(transferRead.getOperation(), transferWrite)) return WalkResult::advance(); - for (auto &use : transferRead.memref().getUses()) - if (dom.properlyDominates(use.getOwner(), - transferRead.getOperation()) || - dom.properlyDominates(transferWrite, use.getOwner())) + for (auto &use : transferRead.memref().getUses()) { + if (!dom.properlyDominates(loop, use.getOwner())) + continue; + if (use.getOwner() == transferRead.getOperation() || + use.getOwner() == transferWrite.getOperation()) + continue; + if (auto transferWriteUse = + dyn_cast(use.getOwner())) { + if (!isDisjoint(transferWrite, transferWriteUse)) + return WalkResult::advance(); + } else if (auto transferReadUse = + dyn_cast(use.getOwner())) { + if (!isDisjoint(transferWrite, transferReadUse)) + return WalkResult::advance(); + } else { + // Unknown use, we cannot prove that it doesn't alias with the + // transferRead/transferWrite operations. return WalkResult::advance(); + } + } // Hoist read before. if (failed(loop.moveOutOfLoop({transferRead}))) Index: mlir/test/Dialect/Linalg/hoisting.mlir =================================================================== --- mlir/test/Dialect/Linalg/hoisting.mlir +++ mlir/test/Dialect/Linalg/hoisting.mlir @@ -132,18 +132,76 @@ %r3 = vector.transfer_read %memref3[%c0, %c0], %cst: memref, vector<4xf32> "some_crippling_use"(%memref4) : (memref) -> () %r4 = vector.transfer_read %memref4[%c0, %c0], %cst: memref, vector<5xf32> + %r5 = vector.transfer_read %memref5[%c0, %c0], %cst: memref, vector<6xf32> + "some_crippling_use"(%memref5) : (memref) -> () %u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32> %u1 = "some_use"(%r1) : (vector<2xf32>) -> vector<2xf32> %u2 = "some_use"(%memref2) : (memref) -> vector<3xf32> %u3 = "some_use"(%r3) : (vector<4xf32>) -> vector<4xf32> %u4 = "some_use"(%r4) : (vector<5xf32>) -> vector<5xf32> + %u5 = "some_use"(%r5) : (vector<6xf32>) -> vector<6xf32> vector.transfer_write %u0, %memref1[%c0, %c0] : vector<1xf32>, memref vector.transfer_write %u1, %memref0[%i, %i] : vector<2xf32>, memref vector.transfer_write %u2, %memref2[%c0, %c0] : vector<3xf32>, memref vector.transfer_write %u3, %memref3[%c0, %c0] : vector<4xf32>, memref vector.transfer_write %u4, %memref4[%c0, %c0] : vector<5xf32>, memref + vector.transfer_write %u5, %memref5[%c0, %c0] : vector<6xf32>, memref "some_crippling_use"(%memref3) : (memref) -> () } } return } + +// VECTOR_TRANSFERS-LABEL: func @hoist_vector_transfer_pairs_disjoint( +// VECTOR_TRANSFERS-SAME: %[[MEMREF0:[a-zA-Z0-9]*]]: memref, +// VECTOR_TRANSFERS-SAME: %[[MEMREF1:[a-zA-Z0-9]*]]: memref, +// VECTOR_TRANSFERS-SAME: %[[VAL:[a-zA-Z0-9]*]]: index, +// VECTOR_TRANSFERS-SAME: %[[LB:[a-zA-Z0-9]*]]: index, +// VECTOR_TRANSFERS-SAME: %[[UB:[a-zA-Z0-9]*]]: index, +// VECTOR_TRANSFERS-SAME: %[[STEP:[a-zA-Z0-9]*]]: index, +// VECTOR_TRANSFERS-SAME: %[[RANDOM:[a-zA-Z0-9]*]]: index, +// VECTOR_TRANSFERS-SAME: %[[CMP:[a-zA-Z0-9]*]]: i1 +func @hoist_vector_transfer_pairs_disjoint( + %memref0: memref, %memref1: memref, %val: index, + %lb : index, %ub : index, %step: index, %random_index : index, + %cmp: i1) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %cst = constant 0.0 : f32 + +// VECTOR_TRANSFERS: vector.transfer_read %{{.*}} : memref, vector<1xf32> +// VECTOR_TRANSFERS: vector.transfer_read %{{.*}} : memref, vector<1xf32> +// VECTOR_TRANSFERS: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) -> (vector<1xf32>, vector<1xf32>) { +// VECTOR_TRANSFERS: scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) -> (vector<1xf32>, vector<1xf32>) { +// VECTOR_TRANSFERS: vector.transfer_read %{{.*}} : memref, vector<2xf32> +// VECTOR_TRANSFERS: vector.transfer_read %{{.*}} : memref, vector<2xf32> +// VECTOR_TRANSFERS: "some_use"(%{{.*}}) : (vector<1xf32>) -> vector<1xf32> +// VECTOR_TRANSFERS: "some_use"(%{{.*}}) : (vector<1xf32>) -> vector<1xf32> +// VECTOR_TRANSFERS: "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32> +// VECTOR_TRANSFERS: "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32> +// VECTOR_TRANSFERS: vector.transfer_write %{{.*}} : vector<2xf32>, memref +// VECTOR_TRANSFERS: vector.transfer_write %{{.*}} : vector<2xf32>, memref +// VECTOR_TRANSFERS: scf.yield {{.*}} : vector<1xf32>, vector<1xf32> +// VECTOR_TRANSFERS: } +// VECTOR_TRANSFERS: scf.yield {{.*}} : vector<1xf32>, vector<1xf32> +// VECTOR_TRANSFERS: } +// VECTOR_TRANSFERS: vector.transfer_write %{{.*}} : vector<1xf32>, memref +// VECTOR_TRANSFERS: vector.transfer_write %{{.*}} : vector<1xf32>, memref + scf.for %i = %lb to %ub step %step { + scf.for %j = %lb to %ub step %step { + %r00 = vector.transfer_read %memref1[%c0, %c0], %cst: memref, vector<1xf32> + %r01 = vector.transfer_read %memref1[%c1, %c1], %cst: memref, vector<1xf32> + %r10 = vector.transfer_read %memref0[%i, %i], %cst: memref, vector<2xf32> + %r11 = vector.transfer_read %memref0[%random_index, %random_index], %cst: memref, vector<2xf32> + %u00 = "some_use"(%r00) : (vector<1xf32>) -> vector<1xf32> + %u01 = "some_use"(%r01) : (vector<1xf32>) -> vector<1xf32> + %u10 = "some_use"(%r10) : (vector<2xf32>) -> vector<2xf32> + %u11 = "some_use"(%r11) : (vector<2xf32>) -> vector<2xf32> + vector.transfer_write %u00, %memref1[%c0, %c0] : vector<1xf32>, memref + vector.transfer_write %u01, %memref1[%c1, %c1] : vector<1xf32>, memref + vector.transfer_write %u10, %memref0[%i, %i] : vector<2xf32>, memref + vector.transfer_write %u11, %memref0[%random_index, %random_index] : vector<2xf32>, memref + } + } + return +}