diff --git a/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h b/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h --- a/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h +++ b/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h @@ -30,12 +30,12 @@ // or in an externally linked library. // This is a generic entry point for all LinalgOp, except for CopyOp and // IndexedGenericOp, for which omre specialized patterns are provided. -class LinalgOpToLibraryCallRewrite : public RewritePattern { +class LinalgOpToLibraryCallRewrite + : public OpInterfaceRewritePattern { public: - LinalgOpToLibraryCallRewrite() - : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()) {} + using OpInterfaceRewritePattern::OpInterfaceRewritePattern; - LogicalResult matchAndRewrite(Operation *op, + LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override; }; diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h @@ -61,7 +61,7 @@ if (!opName.empty()) patternList.insert(opName, context, options, m); else - patternList.insert(m.addOpFilter(), options); + patternList.insert(context, m.addOpFilter(), options); } /// Promotion transformation enqueues a particular stage-1 pattern for diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -455,7 +455,7 @@ struct LinalgBaseTilingPattern : public RewritePattern { // Entry point to match any LinalgOp OpInterface. LinalgBaseTilingPattern( - LinalgTilingOptions options, + MLIRContext *context, LinalgTilingOptions options, LinalgTransformationFilter filter = LinalgTransformationFilter(), PatternBenefit benefit = 1); // Entry point to match a specific Linalg op. @@ -647,7 +647,8 @@ struct LinalgBaseVectorizationPattern : public RewritePattern { /// MatchAnyOpTag-based constructor with a mandatory `filter`. - LinalgBaseVectorizationPattern(LinalgTransformationFilter filter, + LinalgBaseVectorizationPattern(MLIRContext *context, + LinalgTransformationFilter filter, PatternBenefit benefit = 1); /// Name-based constructor with an optional `filter`. LinalgBaseVectorizationPattern( @@ -666,10 +667,10 @@ /// These constructors are available to anyone. /// MatchAnyOpTag-based constructor with a mandatory `filter`. LinalgVectorizationPattern( - LinalgTransformationFilter filter, + MLIRContext *context, LinalgTransformationFilter filter, LinalgVectorizationOptions options = LinalgVectorizationOptions(), PatternBenefit benefit = 1) - : LinalgBaseVectorizationPattern(filter, benefit) {} + : LinalgBaseVectorizationPattern(context, filter, benefit) {} /// Name-based constructor with an optional `filter`. LinalgVectorizationPattern( StringRef opName, MLIRContext *context, @@ -708,7 +709,7 @@ linalg::LinalgVectorizationOptions options, linalg::LinalgTransformationFilter f) { patternList.insert( - f.addOpFilter(), options); + context, f.addOpFilter(), options); } /// Variadic helper function to insert vectorization patterns for C++ ops. @@ -743,7 +744,7 @@ MLIRContext *context, LinalgLoweringType loweringType, LinalgTransformationFilter filter = LinalgTransformationFilter(), ArrayRef interchangeVector = {}, PatternBenefit benefit = 1) - : RewritePattern(OpTy::getOperationName(), {}, benefit, context), + : RewritePattern(OpTy::getOperationName(), benefit, context), filter(filter), loweringType(loweringType), interchangeVector(interchangeVector.begin(), interchangeVector.end()) {} diff --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h --- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h +++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h @@ -122,7 +122,8 @@ struct UnrollVectorPattern : public RewritePattern { using FilterConstraintType = std::function; UnrollVectorPattern(MLIRContext *context, UnrollVectorOptions options) - : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()), options(options) {} + : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context), + options(options) {} LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { if (options.filterConstraint && failed(options.filterConstraint(op))) @@ -215,7 +216,7 @@ FilterConstraintType filter = [](VectorTransferOpInterface op) { return success(); }, PatternBenefit benefit = 1) - : RewritePattern(benefit, MatchAnyOpTypeTag()), options(options), + : RewritePattern(MatchAnyOpTypeTag(), benefit, context), options(options), filter(filter) {} /// Performs the rewrite. diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -1537,6 +1537,13 @@ #endif return false; } + /// Provide `classof` support for other OpBase derived classes, such as + /// Interfaces. + template + static std::enable_if_t::value, bool> + classof(const T *op) { + return classof(const_cast(op)->getOperation()); + } /// Expose the type we are instantiated on to template machinery that may want /// to introspect traits on this operation. diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -136,11 +136,19 @@ return interfaceMap.lookup(); } + /// Returns true if this operation has the given interface registered to it. + bool hasInterface(TypeID interfaceID) const { + return interfaceMap.contains(interfaceID); + } + /// Returns true if the operation has a particular trait. template