diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp @@ -46,6 +46,21 @@ return builder.create(loc, newVecType, vec); } +/// Extend the rank of a vector Value by `addedRanks` by adding inner unit +/// dimensions. +static Value extendMaskRank(OpBuilder &builder, Location loc, Value vec, + int64_t addedRank) { + Value broadcasted = extendVectorRank(builder, loc, vec, addedRank); + SmallVector permutation; + for (int64_t i = addedRank, + e = broadcasted.getType().cast().getRank(); + i < e; ++i) + permutation.push_back(i); + for (int64_t i = 0; i < addedRank; ++i) + permutation.push_back(i); + return builder.create(loc, broadcasted, permutation); +} + //===----------------------------------------------------------------------===// // populateVectorTransferPermutationMapLoweringPatterns //===----------------------------------------------------------------------===// @@ -246,9 +261,14 @@ missingInnerDim.push_back(i); exprs.push_back(rewriter.getAffineDimExpr(i)); } - // Add unit dims at the beginning of the shape. + // Vector: add unit dims at the beginning of the shape. Value newVec = extendVectorRank(rewriter, op.getLoc(), op.getVector(), missingInnerDim.size()); + // Mask: add unit dims at the end of the shape. + Value newMask; + if (op.getMask()) + newMask = extendMaskRank(rewriter, op.getLoc(), op.getMask(), + missingInnerDim.size()); exprs.append(map.getResults().begin(), map.getResults().end()); AffineMap newMap = AffineMap::get(map.getNumDims(), 0, exprs, op.getContext()); @@ -263,7 +283,7 @@ } rewriter.replaceOpWithNewOp( op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap), - op.getMask(), newInBoundsAttr); + newMask, newInBoundsAttr); return success(); } }; diff --git a/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-transfer-permutation-lowering.mlir @@ -0,0 +1,27 @@ +// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file | FileCheck %s + +// CHECK-LABEL: func @lower_permutation_with_mask( +// CHECK: %[[vec:.*]] = arith.constant dense<-2.000000e+00> : vector<7x1xf32> +// CHECK: %[[mask:.*]] = arith.constant dense<[true, false, true, false, true, true, true]> : vector<7xi1> +// CHECK: %[[b:.*]] = vector.broadcast %[[mask]] : vector<7xi1> to vector<1x7xi1> +// CHECK: %[[tp:.*]] = vector.transpose %[[b]], [1, 0] : vector<1x7xi1> to vector<7x1xi1> +// CHECK: vector.transfer_write %[[vec]], %{{.*}}[%{{.*}}, %{{.*}}], %[[tp]] {in_bounds = [false, true]} : vector<7x1xf32>, memref +func.func @lower_permutation_with_mask(%A : memref, %base1 : index, + %base2 : index) { + %fn1 = arith.constant -2.0 : f32 + %vf0 = vector.splat %fn1 : vector<7xf32> + %mask = arith.constant dense<[1, 0, 1, 0, 1, 1, 1]> : vector<7xi1> + vector.transfer_write %vf0, %A[%base1, %base2], %mask + {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [false]} + : vector<7xf32>, memref + return +} + +transform.sequence failures(propagate) { +^bb1(%module_op: !transform.any_op): + %f = transform.structured.match ops{["func.func"]} in %module_op + : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %f { + transform.apply_patterns.vector.transfer_permutation_patterns + } : !transform.any_op +}