diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h --- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h @@ -156,24 +156,6 @@ rewriter); } - /// Rewrite and Match methods that operate on the SourceOp type. These must be - /// overridden by the derived pattern class. - /// NOTICE: These methods are deprecated and will be removed. All new code - /// should use the adaptor methods below instead. - virtual void rewrite(SourceOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const { - llvm_unreachable("must override rewrite or matchAndRewrite"); - } - virtual LogicalResult - matchAndRewrite(SourceOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const { - if (succeeded(match(op))) { - rewrite(op, OpAdaptor(operands, op->getAttrDictionary()), rewriter); - return success(); - } - return failure(); - } - /// Rewrite and Match methods that operate on the SourceOp type. These must be /// overridden by the derived pattern class. virtual LogicalResult match(SourceOp op) const { @@ -181,21 +163,15 @@ } virtual void rewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - ValueRange operands = adaptor.getOperands(); - rewrite(op, - ArrayRef(operands.getBase().get(), - operands.size()), - rewriter); + llvm_unreachable("must override rewrite or matchAndRewrite"); } virtual LogicalResult matchAndRewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - ValueRange operands = adaptor.getOperands(); - return matchAndRewrite( - op, - ArrayRef(operands.getBase().get(), - operands.size()), - rewriter); + if (failed(match(op))) + return failure(); + rewrite(op, adaptor, rewriter); + return success(); } private: diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -353,7 +353,7 @@ /// Construct a conversion pattern with the given converter, and forward the /// remaining arguments to RewritePattern. template - ConversionPattern(TypeConverter &typeConverter, Args &&... args) + ConversionPattern(TypeConverter &typeConverter, Args &&...args) : RewritePattern(std::forward(args)...), typeConverter(&typeConverter) {} @@ -382,6 +382,9 @@ /// Wrappers around the ConversionPattern methods that pass the derived op /// type. + LogicalResult match(Operation *op) const final { + return match(cast(op)); + } void rewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { rewrite(cast(op), OpAdaptor(operands, op->getAttrDictionary()), @@ -395,42 +398,22 @@ rewriter); } - /// Rewrite and Match methods that operate on the SourceOp type and accept the - /// raw operand values. - /// NOTICE: These methods are deprecated and will be removed. All new code - /// should use the adaptor methods below instead. - virtual void rewrite(SourceOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const { - llvm_unreachable("must override matchAndRewrite or a rewrite method"); - } - virtual LogicalResult - matchAndRewrite(SourceOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const { - if (failed(match(op))) - return failure(); - rewrite(op, OpAdaptor(operands, op->getAttrDictionary()), rewriter); - return success(); - } - /// Rewrite and Match methods that operate on the SourceOp type. These must be /// overridden by the derived pattern class. + virtual LogicalResult match(SourceOp op) const { + llvm_unreachable("must override match or matchAndRewrite"); + } virtual void rewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - ValueRange operands = adaptor.getOperands(); - rewrite(op, - ArrayRef(operands.getBase().get(), - operands.size()), - rewriter); + llvm_unreachable("must override matchAndRewrite or a rewrite method"); } virtual LogicalResult matchAndRewrite(SourceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - ValueRange operands = adaptor.getOperands(); - return matchAndRewrite( - op, - ArrayRef(operands.getBase().get(), - operands.size()), - rewriter); + if (failed(match(op))) + return failure(); + rewrite(op, adaptor, rewriter); + return success(); } private: