diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -1094,17 +1094,18 @@ return Value(); } -/// Fold extractOp with scalar result coming from BroadcastOp. +/// Fold extractOp with scalar result coming from BroadcastOp or SplatOp. static Value foldExtractFromBroadcast(ExtractOp extractOp) { - auto broadcastOp = extractOp.vector().getDefiningOp(); - if (!broadcastOp) + Operation *defOp = extractOp.vector().getDefiningOp(); + if (!defOp || !isa(defOp)) return Value(); - if (extractOp.getType() == broadcastOp.getSourceType()) - return broadcastOp.source(); + Value source = defOp->getOperand(0); + if (extractOp.getType() == source.getType()) + return source; auto getRank = [](Type type) { return type.isa() ? type.cast().getRank() : 0; }; - unsigned broadcasrSrcRank = getRank(broadcastOp.getSourceType()); + unsigned broadcasrSrcRank = getRank(source.getType()); unsigned extractResultRank = getRank(extractOp.getType()); if (extractResultRank < broadcasrSrcRank) { auto extractPos = extractVector(extractOp.position()); @@ -1112,7 +1113,7 @@ extractPos.erase( extractPos.begin(), std::next(extractPos.begin(), extractPos.size() - rankDiff)); - extractOp.setOperand(broadcastOp.source()); + extractOp.setOperand(source); // OpBuilder is only used as a helper to build an I64ArrayAttr. OpBuilder b(extractOp.getContext()); extractOp->setAttr(ExtractOp::getPositionAttrName(), @@ -2259,6 +2260,21 @@ } }; +/// Pattern to rewrite an ExtractStridedSliceOp(SplatOp) to SplatOp. +class StridedSliceSplat final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ExtractStridedSliceOp op, + PatternRewriter &rewriter) const override { + auto splat = op.vector().getDefiningOp(); + if (!splat) + return failure(); + rewriter.replaceOpWithNewOp(op, op.getType(), splat.input()); + return success(); + } +}; + } // end anonymous namespace void ExtractStridedSliceOp::getCanonicalizationPatterns( @@ -2266,7 +2282,7 @@ // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) -> // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp. results.add(context); + StridedSliceBroadcast, StridedSliceSplat>(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -462,6 +462,17 @@ // ----- +// CHECK-LABEL: fold_extract_splat +// CHECK-SAME: %[[A:.*]]: f32 +// CHECK: return %[[A]] : f32 +func @fold_extract_splat(%a : f32) -> f32 { + %b = splat %a : vector<1x2x4xf32> + %r = vector.extract %b[0, 1, 2] : vector<1x2x4xf32> + return %r : f32 +} + +// ----- + // CHECK-LABEL: fold_extract_broadcast_vector // CHECK-SAME: %[[A:.*]]: vector<4xf32> // CHECK: return %[[A]] : vector<4xf32> @@ -1047,3 +1058,16 @@ // CHECK: return %[[SOURCE]] return %0: vector<16x16xf16> } + +// ----- + +// CHECK-LABEL: extract_strided_splat +// CHECK: %[[B:.*]] = splat %{{.*}} : vector<2x4xf16> +// CHECK-NEXT: return %[[B]] : vector<2x4xf16> +func @extract_strided_splat(%arg0: f16) -> vector<2x4xf16> { + %0 = splat %arg0 : vector<16x4xf16> + %1 = vector.extract_strided_slice %0 + {offsets = [1, 0], sizes = [2, 4], strides = [1, 1]} : + vector<16x4xf16> to vector<2x4xf16> + return %1 : vector<2x4xf16> +}