diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -571,11 +571,9 @@ &typeConverter.getContext(), typeConverter, benefit) {} - /// Wrappers around the RewritePattern methods that pass the derived op type. - void rewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { - rewrite(cast(op), operands, rewriter); - } +private: + /// Wrappers around the ConversionPattern methods that pass the derived op + /// type. LogicalResult match(Operation *op) const final { return match(cast(op)); } @@ -584,6 +582,10 @@ ConversionPatternRewriter &rewriter) const final { return matchAndRewrite(cast(op), operands, rewriter); } + void rewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + rewrite(cast(op), operands, rewriter); + } /// Rewrite and Match methods that operate on the SourceOp type. These must be /// overridden by the derived pattern class. @@ -603,10 +605,6 @@ } return failure(); } - -private: - using ConvertToLLVMPattern::match; - using ConvertToLLVMPattern::matchAndRewrite; }; namespace LLVM { @@ -636,6 +634,7 @@ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; using Super = OneToOneConvertToLLVMPattern; +private: /// Converts the type of the result to an LLVM type, pass operands as is, /// preserve attributes. LogicalResult @@ -655,6 +654,7 @@ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; using Super = VectorConvertToLLVMPattern; +private: LogicalResult matchAndRewrite(SourceOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -156,17 +156,6 @@ public: virtual ~RewritePattern() {} - /// Rewrite the IR rooted at the specified operation with the result of - /// this pattern, generating any new operations with the specified - /// builder. If an unexpected error is encountered (an internal - /// compiler error), it is emitted through the normal MLIR diagnostic - /// hooks and the IR is left in a valid state. - virtual void rewrite(Operation *op, PatternRewriter &rewriter) const; - - /// Attempt to match against code rooted at the specified operation, - /// which is the same operation code as getRootKind(). - virtual LogicalResult match(Operation *op) const; - /// Attempt to match against code rooted at the specified operation, /// which is the same operation code as getRootKind(). If successful, this /// function will automatically perform the rewrite. @@ -183,6 +172,18 @@ /// Inherit the base constructors from `Pattern`. using Pattern::Pattern; + /// Attempt to match against code rooted at the specified operation, + /// which is the same operation code as getRootKind(). + virtual LogicalResult match(Operation *op) const; + +private: + /// Rewrite the IR rooted at the specified operation with the result of + /// this pattern, generating any new operations with the specified + /// builder. If an unexpected error is encountered (an internal + /// compiler error), it is emitted through the normal MLIR diagnostic + /// hooks and the IR is left in a valid state. + virtual void rewrite(Operation *op, PatternRewriter &rewriter) const; + /// An anchor for the virtual table. virtual void anchor(); }; @@ -190,12 +191,15 @@ /// OpRewritePattern is a wrapper around RewritePattern that allows for /// matching and rewriting against an instance of a derived operation class as /// opposed to a raw Operation. -template struct OpRewritePattern : public RewritePattern { +template +class OpRewritePattern : public RewritePattern { +public: /// Patterns must specify the root operation name they match against, and can /// also specify the benefit of the pattern matching. OpRewritePattern(MLIRContext *context, PatternBenefit benefit = 1) : RewritePattern(SourceOp::getOperationName(), benefit, context) {} +private: /// Wrappers around the RewritePattern methods that pass the derived op type. void rewrite(Operation *op, PatternRewriter &rewriter) const final { rewrite(cast(op), rewriter); diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -313,6 +313,30 @@ /// patterns of this type can only be used with the 'apply*' methods below. class ConversionPattern : public RewritePattern { public: + /// Return the type converter held by this pattern, or nullptr if the pattern + /// does not require type conversion. + TypeConverter *getTypeConverter() const { return typeConverter; } + +protected: + /// See `RewritePattern::RewritePattern` for information on the other + /// available constructors. + using RewritePattern::RewritePattern; + /// Construct a conversion pattern that matches an operation with the given + /// root name. This constructor allows for providing a type converter to use + /// within the pattern. + ConversionPattern(StringRef rootName, PatternBenefit benefit, + TypeConverter &typeConverter, MLIRContext *ctx) + : RewritePattern(rootName, benefit, ctx), typeConverter(&typeConverter) {} + /// Construct a conversion pattern that matches any operation type. This + /// constructor allows for providing a type converter to use within the + /// pattern. `MatchAnyOpTypeTag` is just a tag to ensure that the "match any" + /// behavior is what the user actually desired, `MatchAnyOpTypeTag()` should + /// always be supplied here. + ConversionPattern(PatternBenefit benefit, TypeConverter &typeConverter, + MatchAnyOpTypeTag tag) + : RewritePattern(benefit, tag), typeConverter(&typeConverter) {} + +private: /// Hook for derived classes to implement rewriting. `op` is the (first) /// operation matched by the pattern, `operands` is a list of the rewritten /// operand values that are passed to `op`, `rewriter` can be used to emit the @@ -323,6 +347,10 @@ llvm_unreachable("unimplemented rewrite"); } + void rewrite(Operation *op, PatternRewriter &rewriter) const final { + llvm_unreachable("never called"); + } + /// Hook for derived classes to implement combined matching and rewriting. virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, @@ -337,42 +365,17 @@ LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final; - /// Return the type converter held by this pattern, or nullptr if the pattern - /// does not require type conversion. - TypeConverter *getTypeConverter() const { return typeConverter; } - -protected: - /// See `RewritePattern::RewritePattern` for information on the other - /// available constructors. - using RewritePattern::RewritePattern; - /// Construct a conversion pattern that matches an operation with the given - /// root name. This constructor allows for providing a type converter to use - /// within the pattern. - ConversionPattern(StringRef rootName, PatternBenefit benefit, - TypeConverter &typeConverter, MLIRContext *ctx) - : RewritePattern(rootName, benefit, ctx), typeConverter(&typeConverter) {} - /// Construct a conversion pattern that matches any operation type. This - /// constructor allows for providing a type converter to use within the - /// pattern. `MatchAnyOpTypeTag` is just a tag to ensure that the "match any" - /// behavior is what the user actually desired, `MatchAnyOpTypeTag()` should - /// always be supplied here. - ConversionPattern(PatternBenefit benefit, TypeConverter &typeConverter, - MatchAnyOpTypeTag tag) - : RewritePattern(benefit, tag), typeConverter(&typeConverter) {} - protected: /// An optional type converter for use by this pattern. TypeConverter *typeConverter = nullptr; - -private: - using RewritePattern::rewrite; }; /// OpConversionPattern is a wrapper around ConversionPattern that allows for /// matching and rewriting against an instance of a derived operation class as /// opposed to a raw Operation. template -struct OpConversionPattern : public ConversionPattern { +class OpConversionPattern : public ConversionPattern { +public: OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1) : ConversionPattern(SourceOp::getOperationName(), benefit, context) {} OpConversionPattern(TypeConverter &typeConverter, MLIRContext *context, @@ -380,6 +383,7 @@ : ConversionPattern(SourceOp::getOperationName(), benefit, typeConverter, context) {} +private: /// Wrappers around the ConversionPattern methods that pass the derived op /// type. void rewrite(Operation *op, ArrayRef operands, @@ -409,9 +413,6 @@ rewrite(op, operands, rewriter); return success(); } - -private: - using ConversionPattern::matchAndRewrite; }; /// Add a pattern to the given pattern list to convert the signature of a FuncOp