diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -3069,13 +3069,11 @@ AffineMap map = op.permutation_map(); if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) return failure(); - AffineMap permutationMap = map.getPermutationMap(permutation, op.getContext()); if (permutationMap.isIdentity()) return failure(); - if (op.mask()) - return failure(); + // Caluclate the map of the new read by applying the inverse permutation. permutationMap = inversePermutation(permutationMap); AffineMap newMap = permutationMap.compose(map); @@ -3085,11 +3083,32 @@ for (auto pos : llvm::enumerate(permutation)) { newVectorShape[pos.value()] = originalShape[pos.index()]; } + + Value newMask; + if (op.mask()) { + // Remove unused dims from the permutation map. E.g.: + // E.g.: (d0, d1, d2, d3, d4, d5) -> (d5, 0, d3, 0, d2) + // comp = (d0, d1, d2) -> (d2, 0, d1, 0 d0) + auto comp = compressUnusedDims(map); + // Get positions of remaining result dims. + // E.g.: (d0, d1, d2) -> (d2, 0, d1, 0 d0) + // maskTransposeIndices = [ 2, 1, 0] + SmallVector maskTransposeIndices; + for (unsigned i = 0; i < comp.getNumResults(); ++i) { + if (auto expr = comp.getResult(i).dyn_cast()) + maskTransposeIndices.push_back(expr.getPosition()); + } + + newMask = rewriter.create(op.getLoc(), op.mask(), + maskTransposeIndices); + } + VectorType newReadType = VectorType::get(newVectorShape, op.getVectorType().getElementType()); Value newRead = rewriter.create( op.getLoc(), newReadType, op.source(), op.indices(), newMap, - op.padding(), op.in_bounds() ? *op.in_bounds() : ArrayAttr()); + op.padding(), newMask, op.in_bounds() ? *op.in_bounds() : ArrayAttr()); + SmallVector transposePerm(permutation.begin(), permutation.end()); rewriter.replaceOpWithNewOp(op, newRead, transposePerm); @@ -3110,8 +3129,6 @@ LogicalResult matchAndRewrite(vector::TransferReadOp op, PatternRewriter &rewriter) const override { - if (op.mask()) - return failure(); AffineMap map = op.permutation_map(); unsigned numLeadingBroadcast = 0; for (auto expr : map.getResults()) { @@ -3147,7 +3164,7 @@ : ArrayAttr(); Value newRead = rewriter.create( op.getLoc(), newReadType, op.source(), op.indices(), newMap, - op.padding(), newInBounds); + op.padding(), op.mask(), newInBounds); rewriter.replaceOpWithNewOp(op, originalVecType, newRead); return success(); diff --git a/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir --- a/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir @@ -228,17 +228,24 @@ // CHECK-DAG: %[[C0:.*]] = constant 0 : index %cst = constant 0.000000e+00 : f32 %c0 = constant 0 : index + %m = constant 1 : i1 - %0 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst {permutation_map = #map0} : memref, vector<7x14x8x16xf32> -// CHECK: vector.transfer_read {{.*}} {permutation_map = #[[$MAP0]]} : memref, vector<14x7x8x16xf32> + %mask0 = splat %m : vector<7x14xi1> + %0 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst, %mask0 {permutation_map = #map0} : memref, vector<7x14x8x16xf32> +// CHECK: %[[MASK0:.*]] = vector.transpose {{.*}} : vector<7x14xi1> to vector<14x7xi1> +// CHECK: vector.transfer_read {{.*}} %[[MASK0]] {permutation_map = #[[$MAP0]]} : memref, vector<14x7x8x16xf32> // CHECK: vector.transpose %{{.*}}, [1, 0, 2, 3] : vector<14x7x8x16xf32> to vector<7x14x8x16xf32> - %1 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst {permutation_map = #map1} : memref, vector<7x14x8x16xf32> -// CHECK: vector.transfer_read {{.*}} {permutation_map = #[[$MAP0]]} : memref, vector<16x14x7x8xf32> + %mask1 = splat %m : vector<14x16xi1> + %1 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst, %mask1 {permutation_map = #map1} : memref, vector<7x14x8x16xf32> +// CHECK: %[[MASK1:.*]] = vector.transpose {{.*}} : vector<14x16xi1> to vector<16x14xi1> +// CHECK: vector.transfer_read {{.*}} %[[MASK1]] {permutation_map = #[[$MAP0]]} : memref, vector<16x14x7x8xf32> // CHECK: vector.transpose %{{.*}}, [2, 1, 3, 0] : vector<16x14x7x8xf32> to vector<7x14x8x16xf32> - %2 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, false, true], permutation_map = #map2} : memref, vector<7x14x8x16xf32> -// CHECK: vector.transfer_read {{.*}} {in_bounds = [true, false, true], permutation_map = #[[$MAP1]]} : memref, vector<14x16x7xf32> + %mask2 = splat %m : vector<7x14xi1> + %2 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst, %mask2 {in_bounds = [true, true, false, true], permutation_map = #map2} : memref, vector<7x14x8x16xf32> +// CHECK: %[[MASK2:.*]] = vector.transpose {{.*}} : vector<7x14xi1> to vector<14x7xi1> +// CHECK: vector.transfer_read {{.*}} %[[MASK2]] {in_bounds = [true, false, true], permutation_map = #[[$MAP1]]} : memref, vector<14x16x7xf32> // CHECK: vector.broadcast %{{.*}} : vector<14x16x7xf32> to vector<8x14x16x7xf32> // CHECK: vector.transpose %{{.*}}, [3, 1, 0, 2] : vector<8x14x16x7xf32> to vector<7x14x8x16xf32>