diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -2904,13 +2904,11 @@ AffineMap map = op.permutation_map(); if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) return failure(); - AffineMap permutationMap = map.getPermutationMap(permutation, op.getContext()); if (permutationMap.isIdentity()) return failure(); - if (op.mask()) - return failure(); + // Caluclate the map of the new read by applying the inverse permutation. permutationMap = inversePermutation(permutationMap); AffineMap newMap = permutationMap.compose(map); @@ -2920,11 +2918,43 @@ for (auto pos : llvm::enumerate(permutation)) { newVectorShape[pos.value()] = originalShape[pos.index()]; } + + Value newMask; + if (op.mask()) { + // Build helper array of size "number of dimensions of the permutation + // map". For each dim, assign an increasing counter if the dim is used in + // the result. E.g.: + // permutation map: (d0, d1, d2, d3, d4, d5) -> (d5, 0, d3, 0, d2) + // dim in result? [ 0, 0, 1, 1, 0, 1] + // dimUseIndexer: [ 0, 0, 0, 1, 1, 2] + SmallVector dimUseIndexer(map.getNumDims()); + for (unsigned i = 0, pos = 0; i < map.getNumDims(); ++i) { + auto dimInResult = llvm::any_of(map.getResults(), [&](AffineExpr e) { + return e.isa() && + e.dyn_cast().getPosition() == i; + }); + dimUseIndexer[i] = dimInResult ? pos++ : pos; + } + + // Compute mask transpose indices. For each result dim, take corresponding + // mask dim from `dimUseIndexer`. Note: Mask vectors have a dimension for + // each result dim that is not a broadcast. + SmallVector maskTransposeIndices; + for (unsigned i = 0; i < map.getNumResults(); ++i) { + if (auto expr = map.getResult(i).dyn_cast()) + maskTransposeIndices.push_back(dimUseIndexer[expr.getPosition()]); + } + + newMask = rewriter.create(op.getLoc(), op.mask(), + maskTransposeIndices); + } + VectorType newReadType = VectorType::get(newVectorShape, op.getVectorType().getElementType()); Value newRead = rewriter.create( op.getLoc(), newReadType, op.source(), op.indices(), newMap, - op.padding(), op.in_bounds() ? *op.in_bounds() : ArrayAttr()); + op.padding(), newMask, op.in_bounds() ? *op.in_bounds() : ArrayAttr()); + SmallVector transposePerm(permutation.begin(), permutation.end()); rewriter.replaceOpWithNewOp(op, newRead, transposePerm); @@ -2945,8 +2975,6 @@ LogicalResult matchAndRewrite(vector::TransferReadOp op, PatternRewriter &rewriter) const override { - if (op.mask()) - return failure(); AffineMap map = op.permutation_map(); unsigned numLeadingBroadcast = 0; for (auto expr : map.getResults()) { @@ -2982,7 +3010,7 @@ : ArrayAttr(); Value newRead = rewriter.create( op.getLoc(), newReadType, op.source(), op.indices(), newMap, - op.padding(), newInBounds); + op.padding(), op.mask(), newInBounds); rewriter.replaceOpWithNewOp(op, originalVecType, newRead); return success(); diff --git a/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir --- a/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir @@ -228,17 +228,24 @@ // CHECK-DAG: %[[C0:.*]] = constant 0 : index %cst = constant 0.000000e+00 : f32 %c0 = constant 0 : index + %m = constant 1 : i1 - %0 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst {permutation_map = #map0} : memref, vector<7x14x8x16xf32> -// CHECK: vector.transfer_read {{.*}} {permutation_map = #[[$MAP0]]} : memref, vector<14x7x8x16xf32> + %mask0 = splat %m : vector<7x14xi1> + %0 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst, %mask0 {permutation_map = #map0} : memref, vector<7x14x8x16xf32> +// CHECK: %[[MASK0:.*]] = vector.transpose {{.*}} : vector<7x14xi1> to vector<14x7xi1> +// CHECK: vector.transfer_read {{.*}} %[[MASK0]] {permutation_map = #[[$MAP0]]} : memref, vector<14x7x8x16xf32> // CHECK: vector.transpose %{{.*}}, [1, 0, 2, 3] : vector<14x7x8x16xf32> to vector<7x14x8x16xf32> - %1 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst {permutation_map = #map1} : memref, vector<7x14x8x16xf32> -// CHECK: vector.transfer_read {{.*}} {permutation_map = #[[$MAP0]]} : memref, vector<16x14x7x8xf32> + %mask1 = splat %m : vector<14x16xi1> + %1 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst, %mask1 {permutation_map = #map1} : memref, vector<7x14x8x16xf32> +// CHECK: %[[MASK1:.*]] = vector.transpose {{.*}} : vector<14x16xi1> to vector<16x14xi1> +// CHECK: vector.transfer_read {{.*}} %[[MASK1]] {permutation_map = #[[$MAP0]]} : memref, vector<16x14x7x8xf32> // CHECK: vector.transpose %{{.*}}, [2, 1, 3, 0] : vector<16x14x7x8xf32> to vector<7x14x8x16xf32> - %2 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, false, true], permutation_map = #map2} : memref, vector<7x14x8x16xf32> -// CHECK: vector.transfer_read {{.*}} {in_bounds = [true, false, true], permutation_map = #[[$MAP1]]} : memref, vector<14x16x7xf32> + %mask2 = splat %m : vector<7x14xi1> + %2 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst, %mask2 {in_bounds = [true, true, false, true], permutation_map = #map2} : memref, vector<7x14x8x16xf32> +// CHECK: %[[MASK2:.*]] = vector.transpose {{.*}} : vector<7x14xi1> to vector<14x7xi1> +// CHECK: vector.transfer_read {{.*}} %[[MASK2]] {in_bounds = [true, false, true], permutation_map = #[[$MAP1]]} : memref, vector<14x16x7xf32> // CHECK: vector.broadcast %{{.*}} : vector<14x16x7xf32> to vector<8x14x16x7xf32> // CHECK: vector.transpose %{{.*}}, [3, 1, 0, 2] : vector<8x14x16x7xf32> to vector<7x14x8x16xf32>