Index: mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp =================================================================== --- mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp +++ mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp @@ -80,6 +80,44 @@ } } +/// Return true if we can prove that the transfer operations access dijoint +/// memory. +template +static bool isDisjoint(TransferTypeA transferA, TransferTypeB transferB) { + if (transferA.memref() != transferB.memref()) + return false; + // For simplicity only look at transfer of same type. + if (transferA.getVectorType() != transferB.getVectorType()) + return false; + unsigned memrefRank = transferA.getMemRefType().getRank(); + unsigned resultVecRank = transferA.getVectorType().getRank(); + unsigned rankOffset = memrefRank - resultVecRank; + for (unsigned i = 0, e = transferA.indices().size(); i < e; i++) { + auto indexA = transferA.indices()[i].template getDefiningOp(); + auto indexB = transferB.indices()[i].template getDefiningOp(); + // If any of the indices are dynamic we cannot prove anything. + if (!indexA || !indexB) + continue; + + if (i < rankOffset) { + // For dimension used as index if we can prove that index are different we + // know we are accessing disjoint slices. + if (indexA.getValue().template cast().getInt() != + indexB.getValue().template cast().getInt()) + return true; + } else { + // For this dimension, we slice a part of the memref we need to make sure + // the intervals accessed don't overlap. + int64_t distance = + std::abs(indexA.getValue().template cast().getInt() - + indexB.getValue().template cast().getInt()); + if (distance >= transferA.getVectorType().getDimSize(i - rankOffset)) + return true; + } + } + return false; +} + void mlir::linalg::hoistRedundantVectorTransfers(FuncOp func) { bool changed = true; while (changed) { @@ -129,9 +167,9 @@ // Approximate aliasing by checking that: // 1. indices are the same, - // 2. no other use either dominates the transfer_read or is dominated - // by the transfer_write (i.e. aliasing between the write and the read - // across the loop). + // 2. no other operations in the loop access the same memref except + // for transferRead/TramsferWrite accessing statically disjoint + // slices. if (transferRead.indices() != transferWrite.indices()) return WalkResult::advance(); @@ -140,11 +178,26 @@ DominanceInfo dom(loop); if (!dom.properlyDominates(transferRead.getOperation(), transferWrite)) return WalkResult::advance(); - for (auto &use : transferRead.memref().getUses()) - if (dom.properlyDominates(use.getOwner(), - transferRead.getOperation()) || - dom.properlyDominates(transferWrite, use.getOwner())) + for (auto &use : transferRead.memref().getUses()) { + if (!dom.properlyDominates(loop, use.getOwner())) + continue; + if (use.getOwner() == transferRead.getOperation() || + use.getOwner() == transferWrite.getOperation()) + continue; + if (auto transferWriteUse = + dyn_cast(use.getOwner())) { + if (!isDisjoint(transferWrite, transferWriteUse)) + return WalkResult::advance(); + } else if (auto transferReadUse = + dyn_cast(use.getOwner())) { + if (!isDisjoint(transferWrite, transferReadUse)) + return WalkResult::advance(); + } else { + // Unknown use, we cannot prove that it doesn't alias with the + // transferRead/transferWrite operations. return WalkResult::advance(); + } + } // Hoist read before. if (failed(loop.moveOutOfLoop({transferRead}))) Index: mlir/test/Dialect/Linalg/hoisting.mlir =================================================================== --- mlir/test/Dialect/Linalg/hoisting.mlir +++ mlir/test/Dialect/Linalg/hoisting.mlir @@ -132,18 +132,99 @@ %r3 = vector.transfer_read %memref3[%c0, %c0], %cst: memref, vector<4xf32> "some_crippling_use"(%memref4) : (memref) -> () %r4 = vector.transfer_read %memref4[%c0, %c0], %cst: memref, vector<5xf32> + %r5 = vector.transfer_read %memref5[%c0, %c0], %cst: memref, vector<6xf32> + "some_crippling_use"(%memref5) : (memref) -> () %u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32> %u1 = "some_use"(%r1) : (vector<2xf32>) -> vector<2xf32> %u2 = "some_use"(%memref2) : (memref) -> vector<3xf32> %u3 = "some_use"(%r3) : (vector<4xf32>) -> vector<4xf32> %u4 = "some_use"(%r4) : (vector<5xf32>) -> vector<5xf32> + %u5 = "some_use"(%r5) : (vector<6xf32>) -> vector<6xf32> vector.transfer_write %u0, %memref1[%c0, %c0] : vector<1xf32>, memref vector.transfer_write %u1, %memref0[%i, %i] : vector<2xf32>, memref vector.transfer_write %u2, %memref2[%c0, %c0] : vector<3xf32>, memref vector.transfer_write %u3, %memref3[%c0, %c0] : vector<4xf32>, memref vector.transfer_write %u4, %memref4[%c0, %c0] : vector<5xf32>, memref + vector.transfer_write %u5, %memref5[%c0, %c0] : vector<6xf32>, memref "some_crippling_use"(%memref3) : (memref) -> () } } return } + +// VECTOR_TRANSFERS-LABEL: func @hoist_vector_transfer_pairs_disjoint( +// VECTOR_TRANSFERS-SAME: %[[MEMREF0:[a-zA-Z0-9]*]]: memref, +// VECTOR_TRANSFERS-SAME: %[[MEMREF1:[a-zA-Z0-9]*]]: memref, +// VECTOR_TRANSFERS-SAME: %[[MEMREF2:[a-zA-Z0-9]*]]: memref, +// VECTOR_TRANSFERS-SAME: %[[MEMREF3:[a-zA-Z0-9]*]]: memref, +// VECTOR_TRANSFERS-SAME: %[[VAL:[a-zA-Z0-9]*]]: index, +// VECTOR_TRANSFERS-SAME: %[[LB:[a-zA-Z0-9]*]]: index, +// VECTOR_TRANSFERS-SAME: %[[UB:[a-zA-Z0-9]*]]: index, +// VECTOR_TRANSFERS-SAME: %[[STEP:[a-zA-Z0-9]*]]: index, +// VECTOR_TRANSFERS-SAME: %[[RANDOM:[a-zA-Z0-9]*]]: index, +// VECTOR_TRANSFERS-SAME: %[[CMP:[a-zA-Z0-9]*]]: i1 +func @hoist_vector_transfer_pairs_disjoint( + %memref0: memref, %memref1: memref, + %memref2: memref, %memref3: memref, %val: index, %lb : index, %ub : index, + %step: index, %random_index : index, %cmp: i1) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c3 = constant 3 : index + %cst = constant 0.0 : f32 + +// VECTOR_TRANSFERS: vector.transfer_read %[[MEMREF2]]{{.*}} : memref, vector<3xf32> +// VECTOR_TRANSFERS: vector.transfer_read %[[MEMREF2]]{{.*}} : memref, vector<3xf32> +// VECTOR_TRANSFERS: vector.transfer_read %[[MEMREF3]]{{.*}} : memref, vector<4xf32> +// VECTOR_TRANSFERS: vector.transfer_read %[[MEMREF3]]{{.*}} : memref, vector<4xf32> +// VECTOR_TRANSFERS: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) -> (vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>) { +// VECTOR_TRANSFERS: scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) -> (vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>) { +// VECTOR_TRANSFERS: vector.transfer_read %[[MEMREF1]]{{.*}} : memref, vector<2xf32> +// VECTOR_TRANSFERS: vector.transfer_read %[[MEMREF1]]{{.*}} : memref, vector<2xf32> +// VECTOR_TRANSFERS: "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32> +// VECTOR_TRANSFERS: "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32> +// VECTOR_TRANSFERS: "some_use"(%{{.*}}) : (vector<3xf32>) -> vector<3xf32> +// VECTOR_TRANSFERS: "some_use"(%{{.*}}) : (vector<3xf32>) -> vector<3xf32> +// VECTOR_TRANSFERS: "some_use"(%{{.*}}) : (vector<4xf32>) -> vector<4xf32> +// VECTOR_TRANSFERS: "some_use"(%{{.*}}) : (vector<4xf32>) -> vector<4xf32> +// VECTOR_TRANSFERS: "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32> +// VECTOR_TRANSFERS: "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32> +// VECTOR_TRANSFERS: vector.transfer_write %{{.*}}, %[[MEMREF1]]{{.*}} : vector<2xf32>, memref +// VECTOR_TRANSFERS: vector.transfer_write %{{.*}}, %[[MEMREF1]]{{.*}} : vector<2xf32>, memref +// VECTOR_TRANSFERS: scf.yield {{.*}} : vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32> +// VECTOR_TRANSFERS: } +// VECTOR_TRANSFERS: scf.yield {{.*}} : vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32> +// VECTOR_TRANSFERS: } +// VECTOR_TRANSFERS: vector.transfer_write %{{.*}}, %[[MEMREF3]]{{.*}} : vector<4xf32>, memref +// VECTOR_TRANSFERS: vector.transfer_write %{{.*}}, %[[MEMREF3]]{{.*}} : vector<4xf32>, memref +// VECTOR_TRANSFERS: vector.transfer_write %{{.*}}, %[[MEMREF2]]{{.*}} : vector<3xf32>, memref +// VECTOR_TRANSFERS: vector.transfer_write %{{.*}}, %[[MEMREF2]]{{.*}} : vector<3xf32>, memref + scf.for %i = %lb to %ub step %step { + scf.for %j = %lb to %ub step %step { + %r00 = vector.transfer_read %memref1[%c0, %c0], %cst: memref, vector<2xf32> + %r01 = vector.transfer_read %memref1[%c0, %c1], %cst: memref, vector<2xf32> + %r20 = vector.transfer_read %memref2[%c0, %c0], %cst: memref, vector<3xf32> + %r21 = vector.transfer_read %memref2[%c0, %c3], %cst: memref, vector<3xf32> + %r30 = vector.transfer_read %memref3[%c0, %random_index], %cst: memref, vector<4xf32> + %r31 = vector.transfer_read %memref3[%c1, %random_index], %cst: memref, vector<4xf32> + %r10 = vector.transfer_read %memref0[%i, %i], %cst: memref, vector<2xf32> + %r11 = vector.transfer_read %memref0[%random_index, %random_index], %cst: memref, vector<2xf32> + %u00 = "some_use"(%r00) : (vector<2xf32>) -> vector<2xf32> + %u01 = "some_use"(%r01) : (vector<2xf32>) -> vector<2xf32> + %u20 = "some_use"(%r20) : (vector<3xf32>) -> vector<3xf32> + %u21 = "some_use"(%r21) : (vector<3xf32>) -> vector<3xf32> + %u30 = "some_use"(%r30) : (vector<4xf32>) -> vector<4xf32> + %u31 = "some_use"(%r31) : (vector<4xf32>) -> vector<4xf32> + %u10 = "some_use"(%r10) : (vector<2xf32>) -> vector<2xf32> + %u11 = "some_use"(%r11) : (vector<2xf32>) -> vector<2xf32> + vector.transfer_write %u00, %memref1[%c0, %c0] : vector<2xf32>, memref + vector.transfer_write %u01, %memref1[%c0, %c1] : vector<2xf32>, memref + vector.transfer_write %u20, %memref2[%c0, %c0] : vector<3xf32>, memref + vector.transfer_write %u21, %memref2[%c0, %c3] : vector<3xf32>, memref + vector.transfer_write %u30, %memref3[%c0, %random_index] : vector<4xf32>, memref + vector.transfer_write %u31, %memref3[%c1, %random_index] : vector<4xf32>, memref + vector.transfer_write %u10, %memref0[%i, %i] : vector<2xf32>, memref + vector.transfer_write %u11, %memref0[%random_index, %random_index] : vector<2xf32>, memref + } + } + return +}