diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -5269,23 +5269,37 @@ public: using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(TransposeOp transposeOp, + LogicalResult matchAndRewrite(TransposeOp transpOp, PatternRewriter &rewriter) const override { - auto createMaskOp = - transposeOp.getVector().getDefiningOp(); - if (!createMaskOp) + Value transposeSrc = transpOp.getVector(); + auto createMaskOp = transposeSrc.getDefiningOp(); + auto constantMaskOp = transposeSrc.getDefiningOp(); + if (!createMaskOp && !constantMaskOp) return failure(); - // Get the transpose permutation and apply it to the vector.create_mask - // operands. - auto maskOperands = createMaskOp.getOperands(); + // Get the transpose permutation and apply it to the vector.create_mask or + // vector.constant_mask operands. SmallVector permutation; - transposeOp.getTransp(permutation); - SmallVector newOperands(maskOperands.begin(), maskOperands.end()); - applyPermutationToVector(newOperands, permutation); + transpOp.getTransp(permutation); + + if (createMaskOp) { + auto maskOperands = createMaskOp.getOperands(); + SmallVector newOperands(maskOperands.begin(), maskOperands.end()); + applyPermutationToVector(newOperands, permutation); + + rewriter.replaceOpWithNewOp( + transpOp, transpOp.getResultVectorType(), newOperands); + return success(); + } + + // ConstantMaskOp case. + auto maskDimSizes = constantMaskOp.getMaskDimSizes(); + SmallVector newMaskDimSizes(maskDimSizes.getValue()); + applyPermutationToVector(newMaskDimSizes, permutation); - rewriter.replaceOpWithNewOp( - transposeOp, transposeOp.getResultVectorType(), newOperands); + rewriter.replaceOpWithNewOp( + transpOp, transpOp.getResultVectorType(), + ArrayAttr::get(transpOp.getContext(), newMaskDimSizes)); return success(); } }; diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -58,8 +58,9 @@ // CHECK-SAME: %[[DIM0:.*]]: index, %[[DIM1:.*]]: index, %[[DIM2:.*]]: index func.func @create_mask_transpose_to_transposed_create_mask( %dim0: index, %dim1: index, %dim2: index) -> (vector<2x3x4xi1>, vector<4x2x3xi1>) { - // CHECK: vector.create_mask %[[DIM0]], %[[DIM1]], %[[DIM2]] : vector<2x3x4xi1> - // CHECK: vector.create_mask %[[DIM2]], %[[DIM0]], %[[DIM1]] : vector<4x2x3xi1> + // CHECK: vector.create_mask %[[DIM0]], %[[DIM1]], %[[DIM2]] : vector<2x3x4xi1> + // CHECK: vector.create_mask %[[DIM2]], %[[DIM0]], %[[DIM1]] : vector<4x2x3xi1> + // CHECK-NOT: vector.transpose %0 = vector.create_mask %dim0, %dim1, %dim2 : vector<2x3x4xi1> %1 = vector.transpose %0, [2, 0, 1] : vector<2x3x4xi1> to vector<4x2x3xi1> return %0, %1 : vector<2x3x4xi1>, vector<4x2x3xi1> @@ -67,6 +68,18 @@ // ----- +// CHECK-LABEL: constant_mask_transpose_to_transposed_constant_mask +func.func @constant_mask_transpose_to_transposed_constant_mask() -> (vector<2x3x4xi1>, vector<4x2x3xi1>) { + // CHECK: vector.constant_mask [1, 2, 3] : vector<2x3x4xi1> + // CHECK: vector.constant_mask [3, 1, 2] : vector<4x2x3xi1> + // CHECK-NOT: vector.transpose + %0 = vector.constant_mask [1, 2, 3] : vector<2x3x4xi1> + %1 = vector.transpose %0, [2, 0, 1] : vector<2x3x4xi1> to vector<4x2x3xi1> + return %0, %1 : vector<2x3x4xi1>, vector<4x2x3xi1> +} + +// ----- + func.func @extract_strided_slice_of_constant_mask() -> (vector<2x2xi1>) { %0 = vector.constant_mask [2, 2] : vector<4x3xi1> %1 = vector.extract_strided_slice %0