diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1969,7 +1969,7 @@ (static_cast(srcVectorType.getRank()) + positionAttr.size() != static_cast(destVectorType.getRank()))) return emitOpError("expected position attribute rank + source rank to " - "match dest vector rank"); + "match dest vector rank"); if (!srcVectorType && (positionAttr.size() != static_cast(destVectorType.getRank()))) return emitOpError( @@ -2302,8 +2302,7 @@ int64_t numFixedVectorSizes = fixedVectorSizes.size(); if (inputVectorType.getRank() != inputShapeRank + numFixedVectorSizes) - return emitError("invalid input shape for vector type ") - << inputVectorType; + return emitError("invalid input shape for vector type ") << inputVectorType; if (outputVectorType.getRank() != outputShapeRank + numFixedVectorSizes) return emitError("invalid output shape for vector type ") @@ -2396,24 +2395,29 @@ auto sizes = getSizesAttr(); auto strides = getStridesAttr(); if (offsets.size() != sizes.size() || offsets.size() != strides.size()) - return emitOpError("expected offsets, sizes and strides attributes of same size"); + return emitOpError( + "expected offsets, sizes and strides attributes of same size"); auto shape = type.getShape(); auto offName = getOffsetsAttrName(); auto sizesName = getSizesAttrName(); auto stridesName = getStridesAttrName(); - if (failed(isIntegerArrayAttrSmallerThanShape(*this, offsets, shape, offName)) || - failed(isIntegerArrayAttrSmallerThanShape(*this, sizes, shape, sizesName)) || + if (failed( + isIntegerArrayAttrSmallerThanShape(*this, offsets, shape, offName)) || + failed( + isIntegerArrayAttrSmallerThanShape(*this, sizes, shape, sizesName)) || failed(isIntegerArrayAttrSmallerThanShape(*this, strides, shape, stridesName)) || - failed(isIntegerArrayAttrConfinedToShape(*this, offsets, shape, offName)) || + failed( + isIntegerArrayAttrConfinedToShape(*this, offsets, shape, offName)) || failed(isIntegerArrayAttrConfinedToShape(*this, sizes, shape, sizesName, /*halfOpen=*/false, /*min=*/1)) || - failed(isIntegerArrayAttrConfinedToRange(*this, strides, 1, 1, stridesName, + failed(isIntegerArrayAttrConfinedToRange(*this, strides, 1, 1, + stridesName, /*halfOpen=*/false)) || - failed(isSumOfIntegerArrayAttrConfinedToShape(*this, offsets, sizes, shape, - offName, sizesName, + failed(isSumOfIntegerArrayAttrConfinedToShape(*this, offsets, sizes, + shape, offName, sizesName, /*halfOpen=*/false))) return failure(); @@ -4193,12 +4197,49 @@ } }; +/// Pattern to rewrite a ShapeCast(Broadcast) -> Broadcast. +/// This only applies when the shape of the broadcast source is a prefix or a +/// suffix of the shape of the result (i.e. when broadcast without reshape is +/// expressive enough to capture the result in a single op). +class ShapeCastBroadcastFolder final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp, + PatternRewriter &rewriter) const override { + auto broadcastOp = + shapeCastOp.getSource().getDefiningOp(); + if (!broadcastOp) + return failure(); + + auto broadcastSourceVectorType = + broadcastOp.getSourceType().dyn_cast(); + auto broadcastSourceShape = broadcastSourceVectorType + ? broadcastSourceVectorType.getShape() + : ArrayRef{}; + auto shapeCastTargetShape = shapeCastOp.getResultVectorType().getShape(); + + bool isPrefix = (broadcastSourceShape == shapeCastTargetShape.take_front( + broadcastSourceShape.size())); + bool isSuffix = (broadcastSourceShape == shapeCastTargetShape.take_back( + broadcastSourceShape.size())); + // Bail if `broadcastSourceShape` is neither a prefix or suffix of the + // result. + if (!isPrefix && !isSuffix) + return failure(); + + rewriter.replaceOpWithNewOp( + shapeCastOp, shapeCastOp.getResultVectorType(), + broadcastOp.getSource()); + return success(); + } +}; + } // namespace void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - // Pattern to rewrite a ShapeCastOp(ConstantOp) -> ConstantOp. - results.add(context); + results.add(context); } //===----------------------------------------------------------------------===// @@ -4223,7 +4264,7 @@ if (sourceVectorType.getRank() == 0) { if (sourceElementBits != resultElementBits) return emitOpError("source/result bitwidth of the 0-D vector element " - "types must be equal"); + "types must be equal"); } else if (sourceElementBits * sourceVectorType.getShape().back() != resultElementBits * resultVectorType.getShape().back()) { return emitOpError( diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -646,10 +646,10 @@ // ----- -// CHECK-LABEL: func @dont_fold_broadcast_shapecast_scalar +// CHECK-LABEL: func @canonicalize_broadcast_shapecast_scalar // CHECK: vector.broadcast -// CHECK: vector.shape_cast -func.func @dont_fold_broadcast_shapecast_scalar(%arg0: f32) -> vector<1xf32> { +// CHECK-NOT: vector.shape_cast +func.func @canonicalize_broadcast_shapecast_scalar(%arg0: f32) -> vector<1xf32> { %0 = vector.broadcast %arg0 : f32 to vector<1x1x1xf32> %1 = vector.shape_cast %0 : vector<1x1x1xf32> to vector<1xf32> return %1 : vector<1xf32> @@ -668,6 +668,17 @@ // ----- +// CHECK-LABEL: func @canonicalize_broadcast_shapecast +// CHECK: vector.broadcast +// CHECK-NOT: vector.shape_cast +func.func @canonicalize_broadcast_shapecast(%arg0: vector<3xf32>) -> vector<8x3xf32> { + %0 = vector.broadcast %arg0 : vector<3xf32> to vector<2x4x3xf32> + %1 = vector.shape_cast %0 : vector<2x4x3xf32> to vector<8x3xf32> + return %1 : vector<8x3xf32> +} + +// ----- + // CHECK-LABEL: fold_vector_transfers func.func @fold_vector_transfers(%A: memref) -> (vector<4x8xf32>, vector<4x9xf32>) { %c0 = arith.constant 0 : index