diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -5265,13 +5265,38 @@ } }; +/// Folds transpose(create_mask) into a new transposed create_mask. +class FoldTransposeCreateMask final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TransposeOp transposeOp, + PatternRewriter &rewriter) const override { + auto createMaskOp = + transposeOp.getVector().getDefiningOp(); + if (!createMaskOp) + return failure(); + + // Get the transpose permutation and apply it to the vector.create_mask + // operands. + auto maskOperands = createMaskOp.getOperands(); + SmallVector permutation; + transposeOp.getTransp(permutation); + SmallVector newOperands(maskOperands.begin(), maskOperands.end()); + applyPermutationToVector(newOperands, permutation); + + rewriter.replaceOpWithNewOp( + transposeOp, transposeOp.getResultVectorType(), newOperands); + return success(); + } +}; + } // namespace void vector::TransposeOp::getCanonicalizationPatterns( RewritePatternSet &results, MLIRContext *context) { - results - .add( - context); + results.add(context); } void vector::TransposeOp::getTransp(SmallVectorImpl &results) { diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -54,6 +54,19 @@ // ----- +// CHECK-LABEL: create_mask_transpose_to_transposed_create_mask +// CHECK-SAME: %[[DIM0:.*]]: index, %[[DIM1:.*]]: index, %[[DIM2:.*]]: index +func.func @create_mask_transpose_to_transposed_create_mask( + %dim0: index, %dim1: index, %dim2: index) -> (vector<2x3x4xi1>, vector<4x2x3xi1>) { + // CHECK: vector.create_mask %[[DIM0]], %[[DIM1]], %[[DIM2]] : vector<2x3x4xi1> + // CHECK: vector.create_mask %[[DIM2]], %[[DIM0]], %[[DIM1]] : vector<4x2x3xi1> + %0 = vector.create_mask %dim0, %dim1, %dim2 : vector<2x3x4xi1> + %1 = vector.transpose %0, [2, 0, 1] : vector<2x3x4xi1> to vector<4x2x3xi1> + return %0, %1 : vector<2x3x4xi1>, vector<4x2x3xi1> +} + +// ----- + func.func @extract_strided_slice_of_constant_mask() -> (vector<2x2xi1>) { %0 = vector.constant_mask [2, 2] : vector<4x3xi1> %1 = vector.extract_strided_slice %0