diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -3137,6 +3137,70 @@ } }; +/// Lower transfer_write op with permutation into a transfer_write with a +/// minor identity permutation map. (transfer_write ops cannot have broadcasts.) +/// Ex: +/// vector.transfer_write %v ... +/// permutation_map: (d0, d1, d2) -> (d2, d0, d1) +/// into: +/// %tmp = vector.transpose %v, [2, 0, 1] +/// vector.transfer_write %tmp ... +/// permutation_map: (d0, d1, d2) -> (d0, d1, d2) +/// +/// vector.transfer_write %v ... +/// permutation_map: (d0, d1, d2, d3) -> (d3, d2) +/// into: +/// %tmp = vector.transpose %v, [1, 0] +/// %v = vector.transfer_write %tmp ... +/// permutation_map: (d0, d1, d2, d3) -> (d2, d3) +struct TransferWritePermutationLowering + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferWriteOp op, + PatternRewriter &rewriter) const override { + SmallVector permutation; + AffineMap map = op.permutation_map(); + if (map.isMinorIdentity()) + return failure(); + if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) + return failure(); + + // Remove unused dims from the permutation map. E.g.: + // E.g.: (d0, d1, d2, d3, d4, d5) -> (d5, d3, d4) + // comp = (d0, d1, d2) -> (d2, d0, d1) + auto comp = compressUnusedDims(map); + // Get positions of remaining result dims. + SmallVector indices; + llvm::transform(comp.getResults(), std::back_inserter(indices), + [](AffineExpr expr) { + return expr.dyn_cast().getPosition(); + }); + + // Transpose mask operand. + Value newMask = op.mask() + ? rewriter.create(op.getLoc(), op.mask(), indices) + : Value(); + + // Transpose in_bounds attribute. + ArrayAttr newInBounds = op.in_bounds() + ? transposeInBoundsAttr(rewriter, op.in_bounds().getValue(), + permutation) + : ArrayAttr(); + + // Generate new transfer_write operation. + Value newVec = rewriter.create( + op.getLoc(), op.vector(), indices); + auto newMap = AffineMap::getMinorIdentityMap( + map.getNumDims(), map.getNumResults(), rewriter.getContext()); + rewriter.replaceOpWithNewOp( + op, Type(), newVec, op.source(), op.indices(), newMap, newMask, + newInBounds); + + return success(); + } +}; + /// Lower transfer_read op with broadcast in the leading dimensions into /// transfer_read of lower rank + vector.broadcast. /// Ex: vector.transfer_read ... @@ -4089,7 +4153,8 @@ RewritePatternSet &patterns) { patterns .add( + TransferReadPermutationLowering, TransferWritePermutationLowering, + TransferOpReduceRank>( patterns.getContext()); } diff --git a/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir --- a/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir @@ -267,3 +267,25 @@ vector<7x14x8x16xf32>, vector<7x14x8x16xf32>, vector<7x14x8x16xf32>, vector<7x14x8x16xf32> } + +// ----- + +// CHECK-LABEL: func @transfer_write_permutations +func @transfer_write_permutations(%arg0 : memref, + %v1 : vector<7x14x8x16xf32>, %v2 : vector<8x16xf32>) -> () { + // CHECK-DAG: %[[C0:.*]] = constant 0 : index + %c0 = constant 0 : index + %m = constant 1 : i1 + + %mask0 = splat %m : vector<7x14x8x16xi1> + vector.transfer_write %v1, %arg0[%c0, %c0, %c0, %c0], %mask0 {in_bounds = [true, false, false, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3, d0)>} : vector<7x14x8x16xf32>, memref + // CHECK: %[[NEW_MASK0:.*]] = vector.transpose %{{.*}} [2, 1, 3, 0] : vector<7x14x8x16xi1> to vector<8x14x16x7xi1> + // CHECK: %[[NEW_VEC0:.*]] = vector.transpose %{{.*}} [2, 1, 3, 0] : vector<7x14x8x16xf32> to vector<8x14x16x7xf32> + // CHECK: vector.transfer_write %[[NEW_VEC0]], %arg0[%c0, %c0, %c0, %c0], %[[NEW_MASK0]] {in_bounds = [false, false, true, true]} : vector<8x14x16x7xf32>, memref + + vector.transfer_write %v2, %arg0[%c0, %c0, %c0, %c0] {permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>} : vector<8x16xf32>, memref + // CHECK: %[[NEW_VEC1:.*]] = vector.transpose %{{.*}} [1, 0] : vector<8x16xf32> to vector<16x8xf32> + // CHECK: vector.transfer_write %[[NEW_VEC1]], %arg0[%c0, %c0, %c0, %c0] : vector<16x8xf32>, memref + + return +}