diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp @@ -135,12 +135,14 @@ << "\n"); // Approximate aliasing by checking that: - // 1. indices are the same, + // 1. indices, vector type and permutation map are the same (i.e., the + // transfer_read/transfer_write ops are matching), // 2. no other operations in the loop access the same memref except // for transfer_read/transfer_write accessing statically disjoint // slices. - if (transferRead.getIndices() != transferWrite.getIndices() && - transferRead.getVectorType() == transferWrite.getVectorType()) + if (transferRead.getIndices() != transferWrite.getIndices() || + transferRead.getVectorType() != transferWrite.getVectorType() || + transferRead.getPermutationMap() != transferWrite.getPermutationMap()) return WalkResult::advance(); // TODO: may want to memoize this information for performance but it diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir --- a/mlir/test/Dialect/Linalg/hoisting.mlir +++ b/mlir/test/Dialect/Linalg/hoisting.mlir @@ -722,3 +722,37 @@ transform.structured.hoist_redundant_vector_transfers %0 : (!transform.any_op) -> !transform.any_op } + +// ----- + +// The transfers in this test case cannot be hoisted and replaced by a vector +// iter_arg because they do not match. + +// CHECK-LABEL: func.func @non_matching_transfers( +// CHECK: scf.for {{.*}} { +// CHECK: vector.transfer_read +// CHECK: vector.transfer_write +// CHECK: } +func.func @non_matching_transfers(%m: memref<6x1x7x32xf32>) { + %c0 = arith.constant 0 : index + %c1024 = arith.constant 1024 : index + %c128 = arith.constant 128 : index + %cst = arith.constant dense<5.5> : vector<6x7x32xf32> + %cst_0 = arith.constant 0.0 : f32 + scf.for %iv = %c0 to %c1024 step %c128 { + %read = vector.transfer_read %m[%c0, %c0, %c0, %c0], %cst_0 {in_bounds = [true, true, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>} : memref<6x1x7x32xf32>, vector<6x7x32xf32> + %added = arith.addf %read, %cst : vector<6x7x32xf32> + %bc = vector.broadcast %added : vector<6x7x32xf32> to vector<1x6x7x32xf32> + %tr = vector.transpose %bc, [1, 0, 2, 3] : vector<1x6x7x32xf32> to vector<6x1x7x32xf32> + vector.transfer_write %tr, %m[%c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true]} : vector<6x1x7x32xf32>, memref<6x1x7x32xf32> + } + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["func.func"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + transform.structured.hoist_redundant_vector_transfers %0 + : (!transform.any_op) -> !transform.any_op +}