diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -5136,12 +5136,107 @@ } }; +static VectorType trimTrailingOneDims(VectorType oldType) { + ArrayRef oldShape = oldType.getShape(); + ArrayRef newShape(oldShape.begin(), oldShape.end()); + + ArrayRef oldScalableDims = oldType.getScalableDims(); + ArrayRef newScalableDims(oldScalableDims); + + while (!newShape.empty() && newShape.back() == 1 && !newScalableDims.back()) { + newShape = newShape.drop_back(1); + newScalableDims = newScalableDims.drop_back(1); + } + + // Make sure we have at least 1 dimension. + // TODO: Add support for 0-D vectors. + if (newShape.empty()) { + newShape = oldShape.take_back(); + newScalableDims = oldType.getScalableDims().take_back(); + } + + return VectorType::get(newShape, oldType.getElementType(), newScalableDims); +} + +/// Folds shape_cast(create_mask) into a new create_mask +class FoldShapeCastCreateMask final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ShapeCastOp shapeOp, + PatternRewriter &rewriter) const override { + Value shapeOpSrc = shapeOp->getOperand(0); + auto createMaskOp = shapeOpSrc.getDefiningOp(); + auto constantMaskOp = shapeOpSrc.getDefiningOp(); + if (!createMaskOp && !constantMaskOp) + return failure(); + + VectorType shapeOpResTy = shapeOp.getResultVectorType(); + VectorType shapeOpSrcTy = shapeOp.getSourceVectorType(); + + VectorType newVecType = trimTrailingOneDims(shapeOpSrcTy); + if (newVecType != shapeOpResTy) + return failure(); + + auto numDimsToDrop = + shapeOpSrcTy.getShape().size() - shapeOpResTy.getShape().size(); + + if (createMaskOp) { + auto maskOperands = createMaskOp.getOperands(); + auto numMaskOperands = maskOperands.size(); + + // Check every mask dim size whether it can be dropped + for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop; + i--) { + auto constant = maskOperands[i].getDefiningOp(); + // TODO: Support mask values other than 1 (0 might be the only other + // viable option) + if (!constant || (constant.value() != 1)) + return failure(); + } + SmallVector newMaskOperands = + maskOperands.drop_back(numDimsToDrop); + + rewriter.replaceOpWithNewOp( + shapeOp, shapeOp.getResultVectorType(), newMaskOperands); + return success(); + } + + if (constantMaskOp) { + auto maskDimSizes = constantMaskOp.getMaskDimSizes().getValue(); + auto numMaskOperands = maskDimSizes.size(); + + // Check every mask dim size whether it can be dropped + for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop; + i--) { + // TODO: Support mask values other than 1 (0 might be the only other + // viable option) + if (cast(maskDimSizes[i]).getValue() != 1) + return failure(); + } + + auto newMaskOperands = + constantMaskOp.getMaskDimSizesAttr().getValue().drop_back( + numDimsToDrop); + ArrayAttr newMaskOperandsAttr = rewriter.getArrayAttr(newMaskOperands); + + rewriter.replaceOpWithNewOp( + shapeOp, shapeOp.getResultVectorType(), newMaskOperandsAttr); + return success(); + } + + return failure(); + } +}; + } // namespace void vector::TransposeOp::getCanonicalizationPatterns( RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results + .add( + context); } void vector::TransposeOp::getTransp(SmallVectorImpl &results) { diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -2196,3 +2196,69 @@ %0 = vector.mask %all_true { arith.addf %a, %a : vector<3x4xf32> } : vector<3x4xi1> -> vector<3x4xf32> return %0 : vector<3x4xf32> } + +// ----- + +// CHECK-LABEL: func.func @fold_shape_cast_with_mask( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x?xf32>) -> vector<1x4xi1> { +func.func @fold_shape_cast_with_mask(%arg0: tensor<1x?xf32>) -> vector<1x4xi1> { +// CHECK-NOT: vector.shape_cast +// CHECK: %[[VAL_1:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_2:.*]] = tensor.dim %[[VAL_0]], %[[VAL_1]] : tensor<1x?xf32> +// CHECK: %[[VAL_3:.*]] = vector.create_mask %[[VAL_1]], %[[VAL_2]] : vector<1x4xi1> +// CHECK: return %[[VAL_3]] : vector<1x4xi1> + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %dim_0 = tensor.dim %arg0, %c1 : tensor<1x?xf32> + %1 = vector.create_mask %c1, %dim_0, %c1, %c1 : vector<1x4x1x1xi1> + %2 = vector.shape_cast %1 : vector<1x4x1x1xi1> to vector<1x4xi1> + return %2 : vector<1x4xi1> +} + +// ----- + +// CHECK-LABEL: func.func @fold_shape_cast_with_mask_scalable( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x?xf32>) -> vector<1x[4]xi1> { +func.func @fold_shape_cast_with_mask_scalable(%arg0: tensor<1x?xf32>) -> vector<1x[4]xi1> { +// CHECK-NOT: vector.shape_cast +// CHECK: %[[VAL_1:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_2:.*]] = tensor.dim %[[VAL_0]], %[[VAL_1]] : tensor<1x?xf32> +// CHECK: %[[VAL_3:.*]] = vector.create_mask %[[VAL_1]], %[[VAL_2]] : vector<1x[4]xi1> +// CHECK: return %[[VAL_3]] : vector<1x[4]xi1> + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %dim_0 = tensor.dim %arg0, %c1 : tensor<1x?xf32> + %1 = vector.create_mask %c1, %dim_0, %c1, %c1 : vector<1x[4]x1x1xi1> + %2 = vector.shape_cast %1 : vector<1x[4]x1x1xi1> to vector<1x[4]xi1> + return %2 : vector<1x[4]xi1> +} + +// ----- + +// Check that scalable "1" (i.e. [1]) is not folded +// CHECK-LABEL: func.func @fold_shape_cast_with_mask_scalable_one( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x?xf32>) -> vector<1x[1]xi1> { +func.func @fold_shape_cast_with_mask_scalable_one(%arg0: tensor<1x?xf32>) -> vector<1x[1]xi1>{ +// CHECK: %[[VAL_1:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_2:.*]] = tensor.dim %[[VAL_0]], %[[VAL_1]] : tensor<1x?xf32> +// CHECK: %[[VAL_3:.*]] = vector.create_mask %[[VAL_1]], %[[VAL_2]] : vector<1x[1]xi1> +// CHECK: return %[[VAL_3]] : vector<1x[1]xi1> + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %dim_0 = tensor.dim %arg0, %c1 : tensor<1x?xf32> + %1 = vector.create_mask %c1, %dim_0, %c1 : vector<1x[1]x1xi1> + %2 = vector.shape_cast %1 : vector<1x[1]x1xi1> to vector<1x[1]xi1> + return %2 : vector<1x[1]xi1> +} + +// ----- + +// CHECK-LABEL: func.func @fold_shape_cast_with_constant_mask() -> vector<4xi1> { +func.func @fold_shape_cast_with_constant_mask() -> vector<4xi1>{ +// CHECK-NOT: vector.shape_cast +// CHECK: %[[VAL_0:.*]] = vector.constant_mask [1] : vector<4xi1> +// CHECK: return %[[VAL_0]] : vector<4xi1> + %1 = vector.constant_mask [1, 1, 1] : vector<4x1x1xi1> + %2 = vector.shape_cast %1 : vector<4x1x1xi1> to vector<4xi1> + return %2 : vector<4xi1> +}