diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -5136,12 +5136,81 @@ } }; +static VectorType trimTrailingOneDims(VectorType oldType) { + ArrayRef oldShape = oldType.getShape(); + ArrayRef newShape(oldShape.begin(), oldShape.end()); + + while (newShape.back() == 1) + newShape = newShape.drop_back(1); + + auto newScalableDims = + oldType.getScalableDims().drop_front(oldShape.size() - newShape.size()); + + // Make sure we have at least 1 dimension per vector type requirements. + if (newShape.empty()) { + newShape = oldShape.take_back(); + newScalableDims = oldType.getScalableDims().take_back(); + } + return VectorType::get(newShape, oldType.getElementType(), newScalableDims); +} + +/// Folds shape_cast(create_mask) into a new create_mask +class FoldShapeCastCreateMask final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ShapeCastOp shapeOp, + PatternRewriter &rewriter) const override { + Value shapeOpSrc = shapeOp->getOperand(0); + auto createMaskOp = shapeOpSrc.getDefiningOp(); + auto constantMaskOp = shapeOpSrc.getDefiningOp(); + if (!createMaskOp && !constantMaskOp) + return failure(); + + VectorType shapeOpResTy = shapeOp.getResultVectorType(); + VectorType shapeOpSrcTy = shapeOp.getSourceVectorType(); + + VectorType newVecType = trimTrailingOneDims(shapeOpSrcTy); + if (newVecType != shapeOpResTy) + return failure(); + + if (createMaskOp) { + auto maskOperands = createMaskOp.getOperands(); + auto numDimsToDrop = + shapeOpSrcTy.getShape().size() - shapeOpResTy.getShape().size(); + auto numMaskOperands = maskOperands.size(); + + // Check every mask operand whether it can be dropped + for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop; + i--) { + auto constant = maskOperands[i].getDefiningOp(); + // TODO: Support mask values other than 1 (0 might be only viable + // option) + if (!constant || (constant.value() != 1)) + return failure(); + } + SmallVector newMaskOperands = + maskOperands.drop_back(numDimsToDrop); + + rewriter.replaceOpWithNewOp( + shapeOp, shapeOp.getResultVectorType(), newMaskOperands); + return success(); + } + + // TODO: Constant mask case + + return failure(); + } +}; + } // namespace void vector::TransposeOp::getCanonicalizationPatterns( RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results + .add( + context); } void vector::TransposeOp::getTransp(SmallVectorImpl &results) { diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -2196,3 +2196,21 @@ %0 = vector.mask %all_true { arith.addf %a, %a : vector<3x4xf32> } : vector<3x4xi1> -> vector<3x4xf32> return %0 : vector<3x4xf32> } + +// ----- + +// CHECK-LABEL: func.func @fold_shape_cast_with_mask( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x?xf32>) -> vector<1x4xi1> { +func.func @fold_shape_cast_with_mask(%arg0: tensor<1x?xf32>) -> vector<1x4xi1> { +// CHECK-NOT: vector.shape_cast +// CHECK: %[[VAL_1:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_2:.*]] = tensor.dim %[[VAL_0]], %[[VAL_1]] : tensor<1x?xf32> +// CHECK: %[[VAL_3:.*]] = vector.create_mask %[[VAL_1]], %[[VAL_2]] : vector<1x4xi1> +// CHECK: return %[[VAL_3]] : vector<1x4xi1> + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %dim_0 = tensor.dim %arg0, %c1 : tensor<1x?xf32> + %1 = vector.create_mask %c1, %dim_0, %c1, %c1 : vector<1x4x1x1xi1> + %2 = vector.shape_cast %1 : vector<1x4x1x1xi1> to vector<1x4xi1> + return %2 : vector<1x4xi1> +}