Index: mlir/include/mlir/Dialect/Vector/VectorTransforms.h =================================================================== --- mlir/include/mlir/Dialect/Vector/VectorTransforms.h +++ mlir/include/mlir/Dialect/Vector/VectorTransforms.h @@ -117,21 +117,20 @@ }; /// Pattern to apply `unrollSingleResultVectorOp` to a `targetShape` /// declaratively. -template -struct UnrollVectorPattern : public OpRewritePattern { - using FilterConstraintType = std::function; +struct UnrollVectorPattern : public RewritePattern { + using FilterConstraintType = std::function; UnrollVectorPattern(MLIRContext *context, UnrollVectorOptions options) - : OpRewritePattern(context), options(options) {} - LogicalResult matchAndRewrite(OpTy op, + : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()), options(options) {} + LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { if (options.filterConstraint && failed(options.filterConstraint(op))) return failure(); if (!options.nativeShape) { - return op.emitError("vector unrolling expects the native shape or native" + return op->emitError("vector unrolling expects the native shape or native" "shape call back function to be set"); } auto unrollableVectorOp = - dyn_cast(op.getOperation()); + dyn_cast(op); if (!unrollableVectorOp) return failure(); auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll(); @@ -139,12 +138,12 @@ return failure(); Optional> targetShape = options.nativeShape(op); if (!targetShape) - return op.emitError("failed to get target shape for vector unroll"); + return op->emitError("failed to get target shape for vector unroll"); auto maybeShapeRatio = shapeRatio(*maybeUnrollShape, *targetShape); if (!maybeShapeRatio || llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) return failure(); - if (std::is_same::value) { + if (isa(op)) { if (failed(unrollTransferWriteOp(rewriter, op, *targetShape))) return failure(); rewriter.eraseOp(op); Index: mlir/test/lib/Transforms/TestVectorTransforms.cpp =================================================================== --- mlir/test/lib/Transforms/TestVectorTransforms.cpp +++ mlir/test/lib/Transforms/TestVectorTransforms.cpp @@ -27,14 +27,21 @@ void runOnFunction() override { OwningRewritePatternList patterns; auto *ctx = &getContext(); - patterns.insert>( - ctx, UnrollVectorOptions().setNativeShape(ArrayRef{2, 2})); - patterns.insert>( - ctx, UnrollVectorOptions().setNativeShape(ArrayRef{2, 2, 2})); + patterns.insert( + ctx, UnrollVectorOptions().setNativeShapeFn(getShape)); populateVectorToVectorCanonicalizationPatterns(patterns, ctx); populateVectorToVectorTransformationPatterns(patterns, ctx); applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); } + +private: + static Optional> getShape(Operation *op) { + if (isa(op)) + return SmallVector(2, 2); + if (isa(op)) + return SmallVector(3, 2); + return llvm::None; + } }; struct TestVectorSlicesConversion @@ -120,8 +127,11 @@ void runOnFunction() override { MLIRContext *ctx = &getContext(); OwningRewritePatternList patterns; - patterns.insert>( - ctx, UnrollVectorOptions().setNativeShape(ArrayRef{2, 2})); + patterns.insert( + ctx, UnrollVectorOptions() + .setNativeShape(ArrayRef{2, 2}) + .setFilterContraint( + [](Operation *op) { return success(isa(op)); })); if (unrollBasedOnType) { UnrollVectorOptions::NativeShapeFnType nativeShapeFn = @@ -137,12 +147,19 @@ } return nativeShape; }; - patterns.insert>( - ctx, UnrollVectorOptions().setNativeShapeFn(nativeShapeFn)); + patterns.insert( + ctx, UnrollVectorOptions() + .setNativeShapeFn(nativeShapeFn) + .setFilterContraint([](Operation *op) { + return success(isa(op)); + })); } else { - patterns.insert>( - ctx, - UnrollVectorOptions().setNativeShape(ArrayRef{2, 2, 2})); + patterns.insert( + ctx, UnrollVectorOptions() + .setNativeShape(ArrayRef{2, 2, 2}) + .setFilterContraint([](Operation *op) { + return success(isa(op)); + })); } populateVectorToVectorCanonicalizationPatterns(patterns, ctx); populateVectorToVectorTransformationPatterns(patterns, ctx); @@ -273,10 +290,14 @@ void runOnFunction() override { MLIRContext *ctx = &getContext(); OwningRewritePatternList patterns; - patterns.insert>( - ctx, UnrollVectorOptions().setNativeShape(ArrayRef{2, 2})); - patterns.insert>( - ctx, UnrollVectorOptions().setNativeShape(ArrayRef{2, 2})); + patterns.insert( + ctx, + UnrollVectorOptions() + .setNativeShape(ArrayRef{2, 2}) + .setFilterContraint([](Operation *op) { + return success( + isa(op)); + })); populateVectorToVectorCanonicalizationPatterns(patterns, ctx); populateVectorToVectorTransformationPatterns(patterns, ctx); applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));