diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -4779,6 +4779,114 @@ } }; +/// Helper function that computes a new vector type based on the input vector +/// type by removing the trailing one dims: +/// +/// vector<4x1x1xi1> --> vector<4x1> +/// +static VectorType trimTrailingOneDims(VectorType oldType) { + ArrayRef oldShape = oldType.getShape(); + ArrayRef newShape = oldShape; + + ArrayRef oldScalableDims = oldType.getScalableDims(); + ArrayRef newScalableDims = oldScalableDims; + + while (!newShape.empty() && newShape.back() == 1 && !newScalableDims.back()) { + newShape = newShape.drop_back(1); + newScalableDims = newScalableDims.drop_back(1); + } + + // Make sure we have at least 1 dimension. + // TODO: Add support for 0-D vectors. + if (newShape.empty()) { + newShape = oldShape.take_back(); + newScalableDims = oldScalableDims.take_back(); + } + + return VectorType::get(newShape, oldType.getElementType(), newScalableDims); +} + +/// Folds qualifying shape_cast(create_mask) into a new create_mask +/// +/// Looks at `vector.shape_cast` Ops that simply "drop" the trailing unit +/// dimension. If the input vector comes from `vector.create_mask` for which +/// the corresponding mask input value is 1 (e.g. `%c1` below), then it is safe +/// to fold shape_cast into create_mask. +/// +/// BEFORE: +/// %1 = vector.create_mask %c1, %dim, %c1, %c1 : vector<1x[4]x1x1xi1> +/// %2 = vector.shape_cast %1 : vector<1x[4]x1x1xi1> to vector<1x[4]xi1> +/// AFTER: +/// %0 = vector.create_mask %c1, %dim : vector<1x[4]xi1> +class ShapeCastCreateMaskFolderTrailingOneDim final + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ShapeCastOp shapeOp, + PatternRewriter &rewriter) const override { + Value shapeOpSrc = shapeOp->getOperand(0); + auto createMaskOp = shapeOpSrc.getDefiningOp(); + auto constantMaskOp = shapeOpSrc.getDefiningOp(); + if (!createMaskOp && !constantMaskOp) + return failure(); + + VectorType shapeOpResTy = shapeOp.getResultVectorType(); + VectorType shapeOpSrcTy = shapeOp.getSourceVectorType(); + + VectorType newVecType = trimTrailingOneDims(shapeOpSrcTy); + if (newVecType != shapeOpResTy) + return failure(); + + auto numDimsToDrop = + shapeOpSrcTy.getShape().size() - shapeOpResTy.getShape().size(); + + // No unit dims to drop + if (!numDimsToDrop) + return failure(); + + if (createMaskOp) { + auto maskOperands = createMaskOp.getOperands(); + auto numMaskOperands = maskOperands.size(); + + // Check every mask dim size to see whether it can be dropped + for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop; + --i) { + auto constant = maskOperands[i].getDefiningOp(); + if (!constant || (constant.value() != 1)) + return failure(); + } + SmallVector newMaskOperands = + maskOperands.drop_back(numDimsToDrop); + + rewriter.replaceOpWithNewOp(shapeOp, shapeOpResTy, + newMaskOperands); + return success(); + } + + if (constantMaskOp) { + auto maskDimSizes = constantMaskOp.getMaskDimSizes().getValue(); + auto numMaskOperands = maskDimSizes.size(); + + // Check every mask dim size to see whether it can be dropped + for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop; + --i) { + if (cast(maskDimSizes[i]).getValue() != 1) + return failure(); + } + + auto newMaskOperands = maskDimSizes.drop_back(numDimsToDrop); + ArrayAttr newMaskOperandsAttr = rewriter.getArrayAttr(newMaskOperands); + + rewriter.replaceOpWithNewOp(shapeOp, shapeOpResTy, + newMaskOperandsAttr); + return success(); + } + + return failure(); + } +}; + /// Pattern to rewrite a ShapeCast(Broadcast) -> Broadcast. /// This only applies when the shape of the broadcast source /// 1. is a suffix of the shape of the result (i.e. when broadcast without @@ -4831,7 +4939,8 @@ void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -2219,3 +2219,66 @@ %0 = vector.mask %all_true { arith.addf %a, %a : vector<3x4xf32> } : vector<3x4xi1> -> vector<3x4xf32> return %0 : vector<3x4xf32> } + +// ----- + +// CHECK-LABEL: func.func @fold_shape_cast_with_mask( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x?xf32>) -> vector<1x4xi1> { +func.func @fold_shape_cast_with_mask(%arg0: tensor<1x?xf32>) -> vector<1x4xi1> { +// CHECK-NOT: vector.shape_cast +// CHECK: %[[VAL_1:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_2:.*]] = tensor.dim %[[VAL_0]], %[[VAL_1]] : tensor<1x?xf32> +// CHECK: %[[VAL_3:.*]] = vector.create_mask %[[VAL_1]], %[[VAL_2]] : vector<1x4xi1> +// CHECK: return %[[VAL_3]] : vector<1x4xi1> + %c1 = arith.constant 1 : index + %dim = tensor.dim %arg0, %c1 : tensor<1x?xf32> + %1 = vector.create_mask %c1, %dim, %c1, %c1 : vector<1x4x1x1xi1> + %2 = vector.shape_cast %1 : vector<1x4x1x1xi1> to vector<1x4xi1> + return %2 : vector<1x4xi1> +} + +// ----- + +// CHECK-LABEL: func.func @fold_shape_cast_with_mask_scalable( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x?xf32>) -> vector<1x[4]xi1> { +func.func @fold_shape_cast_with_mask_scalable(%arg0: tensor<1x?xf32>) -> vector<1x[4]xi1> { +// CHECK-NOT: vector.shape_cast +// CHECK: %[[VAL_1:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_2:.*]] = tensor.dim %[[VAL_0]], %[[VAL_1]] : tensor<1x?xf32> +// CHECK: %[[VAL_3:.*]] = vector.create_mask %[[VAL_1]], %[[VAL_2]] : vector<1x[4]xi1> +// CHECK: return %[[VAL_3]] : vector<1x[4]xi1> + %c1 = arith.constant 1 : index + %dim = tensor.dim %arg0, %c1 : tensor<1x?xf32> + %1 = vector.create_mask %c1, %dim, %c1, %c1 : vector<1x[4]x1x1xi1> + %2 = vector.shape_cast %1 : vector<1x[4]x1x1xi1> to vector<1x[4]xi1> + return %2 : vector<1x[4]xi1> +} + +// ----- + +// Check that scalable "1" (i.e. [1]) is not folded +// CHECK-LABEL: func.func @fold_shape_cast_with_mask_scalable_one( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x?xf32>) -> vector<1x[1]xi1> { +func.func @fold_shape_cast_with_mask_scalable_one(%arg0: tensor<1x?xf32>) -> vector<1x[1]xi1>{ +// CHECK: %[[VAL_1:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_2:.*]] = tensor.dim %[[VAL_0]], %[[VAL_1]] : tensor<1x?xf32> +// CHECK: %[[VAL_3:.*]] = vector.create_mask %[[VAL_1]], %[[VAL_2]] : vector<1x[1]xi1> +// CHECK: return %[[VAL_3]] : vector<1x[1]xi1> + %c1 = arith.constant 1 : index + %dim = tensor.dim %arg0, %c1 : tensor<1x?xf32> + %1 = vector.create_mask %c1, %dim, %c1 : vector<1x[1]x1xi1> + %2 = vector.shape_cast %1 : vector<1x[1]x1xi1> to vector<1x[1]xi1> + return %2 : vector<1x[1]xi1> +} + +// ----- + +// CHECK-LABEL: func.func @fold_shape_cast_with_constant_mask() -> vector<4xi1> { +func.func @fold_shape_cast_with_constant_mask() -> vector<4xi1>{ +// CHECK-NOT: vector.shape_cast +// CHECK: %[[VAL_0:.*]] = vector.constant_mask [1] : vector<4xi1> +// CHECK: return %[[VAL_0]] : vector<4xi1> + %1 = vector.constant_mask [1, 1, 1] : vector<4x1x1xi1> + %2 = vector.shape_cast %1 : vector<4x1x1xi1> to vector<4xi1> + return %2 : vector<4xi1> +}