diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -697,7 +697,7 @@ template struct UndoComplexPattern : public mlir::RewritePattern { UndoComplexPattern(mlir::MLIRContext *ctx) - : mlir::RewritePattern("fir.insert_value", {}, 2, ctx) {} + : mlir::RewritePattern("fir.insert_value", 2, ctx) {} mlir::LogicalResult matchAndRewrite(mlir::Operation *op, diff --git a/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h b/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h --- a/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h +++ b/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h @@ -30,12 +30,12 @@ // or in an externally linked library. // This is a generic entry point for all LinalgOp, except for CopyOp and // IndexedGenericOp, for which omre specialized patterns are provided. -class LinalgOpToLibraryCallRewrite : public RewritePattern { +class LinalgOpToLibraryCallRewrite + : public OpInterfaceRewritePattern { public: - LinalgOpToLibraryCallRewrite() - : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()) {} + using OpInterfaceRewritePattern::OpInterfaceRewritePattern; - LogicalResult matchAndRewrite(Operation *op, + LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override; }; diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h @@ -60,7 +60,8 @@ if (!opName.empty()) patternList.add(opName, patternList.getContext(), options, m); else - patternList.add(m.addOpFilter(), options); + patternList.add(patternList.getContext(), + m.addOpFilter(), options); } /// Promotion transformation enqueues a particular stage-1 pattern for diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -452,7 +452,7 @@ struct LinalgBaseTilingPattern : public RewritePattern { // Entry point to match any LinalgOp OpInterface. LinalgBaseTilingPattern( - LinalgTilingOptions options, + MLIRContext *context, LinalgTilingOptions options, LinalgTransformationFilter filter = LinalgTransformationFilter(), PatternBenefit benefit = 1); // Entry point to match a specific Linalg op. @@ -644,7 +644,8 @@ struct LinalgBaseVectorizationPattern : public RewritePattern { /// MatchAnyOpTag-based constructor with a mandatory `filter`. - LinalgBaseVectorizationPattern(LinalgTransformationFilter filter, + LinalgBaseVectorizationPattern(MLIRContext *context, + LinalgTransformationFilter filter, PatternBenefit benefit = 1); /// Name-based constructor with an optional `filter`. LinalgBaseVectorizationPattern( @@ -663,10 +664,10 @@ /// These constructors are available to anyone. /// MatchAnyOpTag-based constructor with a mandatory `filter`. LinalgVectorizationPattern( - LinalgTransformationFilter filter, + MLIRContext *context, LinalgTransformationFilter filter, LinalgVectorizationOptions options = LinalgVectorizationOptions(), PatternBenefit benefit = 1) - : LinalgBaseVectorizationPattern(filter, benefit) {} + : LinalgBaseVectorizationPattern(context, filter, benefit) {} /// Name-based constructor with an optional `filter`. LinalgVectorizationPattern( StringRef opName, MLIRContext *context, @@ -702,8 +703,8 @@ void insertVectorizationPatternImpl(RewritePatternSet &patternList, linalg::LinalgVectorizationOptions options, linalg::LinalgTransformationFilter f) { - patternList.add(f.addOpFilter(), - options); + patternList.add( + patternList.getContext(), f.addOpFilter(), options); } /// Variadic helper function to insert vectorization patterns for C++ ops. @@ -737,7 +738,7 @@ MLIRContext *context, LinalgLoweringType loweringType, LinalgTransformationFilter filter = LinalgTransformationFilter(), ArrayRef interchangeVector = {}, PatternBenefit benefit = 1) - : RewritePattern(OpTy::getOperationName(), {}, benefit, context), + : RewritePattern(OpTy::getOperationName(), benefit, context), filter(filter), loweringType(loweringType), interchangeVector(interchangeVector.begin(), interchangeVector.end()) {} diff --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h --- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h +++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h @@ -123,7 +123,8 @@ struct UnrollVectorPattern : public RewritePattern { using FilterConstraintType = std::function; UnrollVectorPattern(MLIRContext *context, UnrollVectorOptions options) - : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()), options(options) {} + : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context), + options(options) {} LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { if (options.filterConstraint && failed(options.filterConstraint(op))) @@ -216,7 +217,7 @@ FilterConstraintType filter = [](VectorTransferOpInterface op) { return success(); }, PatternBenefit benefit = 1) - : RewritePattern(benefit, MatchAnyOpTypeTag()), options(options), + : RewritePattern(MatchAnyOpTypeTag(), benefit, context), options(options), filter(filter) {} /// Performs the rewrite. diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -1516,6 +1516,13 @@ #endif return false; } + /// Provide `classof` support for other OpBase derived classes, such as + /// Interfaces. + template + static std::enable_if_t::value, bool> + classof(const T *op) { + return classof(const_cast(op)->getOperation()); + } /// Expose the type we are instantiated on to template machinery that may want /// to introspect traits on this operation. diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -142,12 +142,20 @@ return interfaceMap.lookup(); } + /// Returns true if this operation has the given interface registered to it. + bool hasInterface(TypeID interfaceID) const { + return interfaceMap.contains(interfaceID); + } + /// Returns true if the operation has a particular trait. template