diff --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h --- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h +++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h @@ -91,7 +91,7 @@ /// Callback function that indicates whether vector unrolling should be /// attempted on the operation. FilterConstraintFnType filterConstraint = nullptr; - UnrollVectorOptions &setFilterContraint(FilterConstraintFnType constraint) { + UnrollVectorOptions &setFilterConstraint(FilterConstraintFnType constraint) { filterConstraint = constraint; return *this; } @@ -117,21 +117,19 @@ }; /// Pattern to apply `unrollSingleResultVectorOp` to a `targetShape` /// declaratively. -template -struct UnrollVectorPattern : public OpRewritePattern { - using FilterConstraintType = std::function; +struct UnrollVectorPattern : public RewritePattern { + using FilterConstraintType = std::function; UnrollVectorPattern(MLIRContext *context, UnrollVectorOptions options) - : OpRewritePattern(context), options(options) {} - LogicalResult matchAndRewrite(OpTy op, + : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()), options(options) {} + LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { if (options.filterConstraint && failed(options.filterConstraint(op))) return failure(); if (!options.nativeShape) { - return op.emitError("vector unrolling expects the native shape or native" - "shape call back function to be set"); + return op->emitError("vector unrolling expects the native shape or native" + "shape call back function to be set"); } - auto unrollableVectorOp = - dyn_cast(op.getOperation()); + auto unrollableVectorOp = dyn_cast(op); if (!unrollableVectorOp) return failure(); auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll(); @@ -139,12 +137,12 @@ return failure(); Optional> targetShape = options.nativeShape(op); if (!targetShape) - return op.emitError("failed to get target shape for vector unroll"); + return op->emitError("failed to get target shape for vector unroll"); auto maybeShapeRatio = shapeRatio(*maybeUnrollShape, *targetShape); if (!maybeShapeRatio || llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) return failure(); - if (std::is_same::value) { + if (isa(op)) { if (failed(unrollTransferWriteOp(rewriter, op, *targetShape))) return failure(); rewriter.eraseOp(op); diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp --- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp +++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp @@ -27,14 +27,22 @@ void runOnFunction() override { OwningRewritePatternList patterns; auto *ctx = &getContext(); - patterns.insert>( - ctx, UnrollVectorOptions().setNativeShape(ArrayRef{2, 2})); - patterns.insert>( - ctx, UnrollVectorOptions().setNativeShape(ArrayRef{2, 2, 2})); + patterns.insert( + ctx, UnrollVectorOptions().setNativeShapeFn(getShape)); populateVectorToVectorCanonicalizationPatterns(patterns, ctx); populateVectorToVectorTransformationPatterns(patterns, ctx); applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); } + +private: + // Return the target shape based on op type. + static Optional> getShape(Operation *op) { + if (isa(op)) + return SmallVector(2, 2); + if (isa(op)) + return SmallVector(3, 2); + return llvm::None; + } }; struct TestVectorSlicesConversion @@ -120,8 +128,11 @@ void runOnFunction() override { MLIRContext *ctx = &getContext(); OwningRewritePatternList patterns; - patterns.insert>( - ctx, UnrollVectorOptions().setNativeShape(ArrayRef{2, 2})); + patterns.insert( + ctx, UnrollVectorOptions() + .setNativeShape(ArrayRef{2, 2}) + .setFilterConstraint( + [](Operation *op) { return success(isa(op)); })); if (unrollBasedOnType) { UnrollVectorOptions::NativeShapeFnType nativeShapeFn = @@ -137,12 +148,19 @@ } return nativeShape; }; - patterns.insert>( - ctx, UnrollVectorOptions().setNativeShapeFn(nativeShapeFn)); + patterns.insert( + ctx, UnrollVectorOptions() + .setNativeShapeFn(nativeShapeFn) + .setFilterConstraint([](Operation *op) { + return success(isa(op)); + })); } else { - patterns.insert>( - ctx, - UnrollVectorOptions().setNativeShape(ArrayRef{2, 2, 2})); + patterns.insert( + ctx, UnrollVectorOptions() + .setNativeShape(ArrayRef{2, 2, 2}) + .setFilterConstraint([](Operation *op) { + return success(isa(op)); + })); } populateVectorToVectorCanonicalizationPatterns(patterns, ctx); populateVectorToVectorTransformationPatterns(patterns, ctx); @@ -273,10 +291,14 @@ void runOnFunction() override { MLIRContext *ctx = &getContext(); OwningRewritePatternList patterns; - patterns.insert>( - ctx, UnrollVectorOptions().setNativeShape(ArrayRef{2, 2})); - patterns.insert>( - ctx, UnrollVectorOptions().setNativeShape(ArrayRef{2, 2})); + patterns.insert( + ctx, + UnrollVectorOptions() + .setNativeShape(ArrayRef{2, 2}) + .setFilterConstraint([](Operation *op) { + return success( + isa(op)); + })); populateVectorToVectorCanonicalizationPatterns(patterns, ctx); populateVectorToVectorTransformationPatterns(patterns, ctx); applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));