diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -271,6 +271,7 @@ }]; let assemblyFormat = "$source attr-dict `:` type($source) `to` type($vector)"; let hasFolder = 1; + let hasCanonicalizer = 1; } def Vector_ShuffleOp : diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -1110,6 +1110,36 @@ return {}; } +namespace { + +// BroadcastOp can only add dimensions or broadcast a dimension from 1 to N. In +// the degenerated case where the broadcast only adds dimensions of size 1 it +// can be replaced by a ShapeCastOp. This canonicalization checks if the total +// number of elements is the same before and after the broadcast to detect if +// the only change in the vector type are new dimensions of size 1. +class BroadcastToShapeCast final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(BroadcastOp broadcastOp, + PatternRewriter &rewriter) const override { + auto srcVecType = broadcastOp.getSourceType().dyn_cast(); + if (!srcVecType || broadcastOp.getVectorType().getNumElements() != + srcVecType.getNumElements()) + return failure(); + rewriter.replaceOpWithNewOp( + broadcastOp, broadcastOp.getVectorType(), broadcastOp.source()); + return success(); + } +}; + +} // namespace + +void BroadcastOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // ShuffleOp //===----------------------------------------------------------------------===// @@ -1768,7 +1798,8 @@ namespace { -// Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) -> ConstantMaskOp. +// Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to +// ConstantMaskOp. class StridedSliceConstantMaskFolder final : public OpRewritePattern { public: @@ -1847,14 +1878,70 @@ } }; +// Helper that returns a subset of `arrayAttr` as a vector of int64_t. +static SmallVector getI64SubArray(ArrayAttr arrayAttr, + unsigned dropFront = 0, + unsigned dropBack = 0) { + assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds"); + auto range = arrayAttr.getAsRange(); + SmallVector res; + res.reserve(arrayAttr.size() - dropFront - dropBack); + for (auto it = range.begin() + dropFront, eit = range.end() - dropBack; + it != eit; ++it) + res.push_back((*it).getValue().getSExtValue()); + return res; +} + +// Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to +// BroadcastOp(ExtractStrideSliceOp). +class StridedSliceBroadcast final + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ExtractStridedSliceOp op, + PatternRewriter &rewriter) const override { + auto broadcast = op.vector().getDefiningOp(); + if (!broadcast) + return failure(); + auto srcVecType = broadcast.source().getType().dyn_cast(); + unsigned srcRrank = srcVecType ? srcVecType.getRank() : 0; + auto dstVecType = op.getType().cast(); + unsigned dstRank = dstVecType.getRank(); + unsigned rankDiff = dstRank - srcRrank; + // Check if the most inner dimensions of the source of the broacast are the + // same as the destination of the extract. If this is the case we can just + // use a broadcast as the original dimensions are untouched. + bool lowerDimMatch = true; + for (unsigned i = 0; i < srcRrank; i++) { + if (srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) { + lowerDimMatch = false; + break; + } + } + Value source = broadcast.source(); + if (!lowerDimMatch) { + // The inner dimensions don't match, it means we need to extract from the + // source of the orignal broadcast and then broadcast the extracted value. + source = rewriter.create( + op->getLoc(), source, + getI64SubArray(op.offsets(), /* dropFront=*/rankDiff), + getI64SubArray(op.sizes(), /* dropFront=*/rankDiff), + getI64SubArray(op.strides(), /* dropFront=*/rankDiff)); + } + rewriter.replaceOpWithNewOp(op, op.getType(), source); + return success(); + } +}; + } // end anonymous namespace void ExtractStridedSliceOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) -> // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp. - results.insert( - context); + results.insert(context); } //===----------------------------------------------------------------------===// @@ -2652,10 +2739,12 @@ return source(); // Canceling shape casts. - if (auto otherOp = source().getDefiningOp()) + if (auto otherOp = source().getDefiningOp()) { if (result().getType() == otherOp.source().getType()) return otherOp.source(); - + setOperand(otherOp.source()); + return getResult(); + } return {}; } diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -613,4 +613,51 @@ return %0, %1 : vector<12x2xf32>, vector<2x13x3xi32> } +// ----- + +// CHECK-LABEL: extract_strided_broadcast +// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} : vector<4xf16> to vector<2x4xf16> +// CHECK-NEXT: return %[[B]] : vector<2x4xf16> +func @extract_strided_broadcast(%arg0: vector<4xf16>) -> vector<2x4xf16> { + %0 = vector.broadcast %arg0 : vector<4xf16> to vector<16x4xf16> + %1 = vector.extract_strided_slice %0 + {offsets = [0, 0], sizes = [2, 4], strides = [1, 1]} : + vector<16x4xf16> to vector<2x4xf16> + return %1 : vector<2x4xf16> +} + +// ----- + +// CHECK-LABEL: extract_strided_broadcast2 +// CHECK: %[[E:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0], sizes = [2], strides = [1]} : vector<4xf16> to vector<2xf16> +// CHECK-NEXT: %[[B:.*]] = vector.broadcast %[[E]] : vector<2xf16> to vector<2x2xf16> +// CHECK-NEXT: return %[[B]] : vector<2x2xf16> +func @extract_strided_broadcast2(%arg0: vector<4xf16>) -> vector<2x2xf16> { + %0 = vector.broadcast %arg0 : vector<4xf16> to vector<16x4xf16> + %1 = vector.extract_strided_slice %0 + {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : + vector<16x4xf16> to vector<2x2xf16> + return %1 : vector<2x2xf16> +} + +// ----- + +// CHECK-LABEL: consecutive_shape_cast +// CHECK: %[[C:.*]] = vector.shape_cast %{{.*}} : vector<16xf16> to vector<4x4xf16> +// CHECK-NEXT: return %[[C]] : vector<4x4xf16> +func @consecutive_shape_cast(%arg0: vector<16xf16>) -> vector<4x4xf16> { + %0 = vector.shape_cast %arg0 : vector<16xf16> to vector<2x8xf16> + %1 = vector.shape_cast %0 : vector<2x8xf16> to vector<4x4xf16> + return %1 : vector<4x4xf16> +} + +// ----- + +// CHECK-LABEL: broadcast_to_shapecast +// CHECK: %[[C:.*]] = vector.shape_cast %{{.*}} : vector<4x4xf16> to vector<1x4x4xf16> +// CHECK-NEXT: return %[[C]] : vector<1x4x4xf16> +func @broadcast_to_shapecast(%arg0: vector<4x4xf16>) -> vector<1x4x4xf16> { + %0 = vector.broadcast %arg0 : vector<4x4xf16> to vector<1x4x4xf16> + return %0 : vector<1x4x4xf16> +}