diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -54,22 +54,9 @@ unsigned short representation; }; -/// Pattern state is used by patterns that want to maintain state between their -/// match and rewrite phases. Patterns can define a pattern-specific subclass -/// of this. -class PatternState { -public: - virtual ~PatternState() {} - -protected: - // Must be subclassed. - PatternState() {} -}; - -/// This is the type returned by a pattern match. A match failure returns a -/// None value. A match success returns a Some value with any state the pattern -/// may need to maintain (but may also be null). -using PatternMatchResult = Optional>; +/// This is the type returned by a pattern match. +/// TODO: Replace usages with LogicalResult directly. +using PatternMatchResult = LogicalResult; //===----------------------------------------------------------------------===// // Pattern class @@ -97,9 +84,7 @@ //===--------------------------------------------------------------------===// /// Attempt to match against code rooted at the specified operation, - /// which is the same operation code as getRootKind(). On failure, this - /// returns a None value. On success it returns a (possibly null) - /// pattern-specific state wrapped in an Optional. + /// which is the same operation code as getRootKind(). virtual PatternMatchResult match(Operation *op) const = 0; virtual ~Pattern() {} @@ -108,14 +93,11 @@ // Helper methods to simplify pattern implementations //===--------------------------------------------------------------------===// - /// This method indicates that no match was found. - static PatternMatchResult matchFailure() { return None; } + /// Return a result, indicating that no match was found. + PatternMatchResult matchFailure() const { return failure(); } - /// This method indicates that a match was found and has the specified cost. - PatternMatchResult - matchSuccess(std::unique_ptr state = {}) const { - return PatternMatchResult(std::move(state)); - } + /// This method indicates that a match was found. + PatternMatchResult matchSuccess() const { return success(); } protected: /// Patterns must specify the root operation name they match against, and can @@ -136,21 +118,12 @@ /// separate the concerns of matching and rewriting. /// * Single-step RewritePattern with "matchAndRewrite" /// - By overloading the "matchAndRewrite" function, the user can perform -/// the rewrite in the same call as the match. This removes the need for -/// any PatternState. +/// the rewrite in the same call as the match. /// class RewritePattern : public Pattern { public: /// Rewrite the IR rooted at the specified operation with the result of /// this pattern, generating any new operations with the specified - /// rewriter. If an unexpected error is encountered (an internal - /// compiler error), it is emitted through the normal MLIR diagnostic - /// hooks and the IR is left in a valid state. - virtual void rewrite(Operation *op, std::unique_ptr state, - PatternRewriter &rewriter) const; - - /// Rewrite the IR rooted at the specified operation with the result of - /// this pattern, generating any new operations with the specified /// builder. If an unexpected error is encountered (an internal /// compiler error), it is emitted through the normal MLIR diagnostic /// hooks and the IR is left in a valid state. @@ -168,8 +141,8 @@ /// function will automatically perform the rewrite. virtual PatternMatchResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const { - if (auto matchResult = match(op)) { - rewrite(op, std::move(*matchResult), rewriter); + if (succeeded(match(op))) { + rewrite(op, rewriter); return matchSuccess(); } return matchFailure(); @@ -206,10 +179,6 @@ : RewritePattern(SourceOp::getOperationName(), benefit, context) {} /// Wrappers around the RewritePattern methods that pass the derived op type. - void rewrite(Operation *op, std::unique_ptr state, - PatternRewriter &rewriter) const final { - rewrite(cast(op), std::move(state), rewriter); - } void rewrite(Operation *op, PatternRewriter &rewriter) const final { rewrite(cast(op), rewriter); } @@ -223,20 +192,16 @@ /// Rewrite and Match methods that operate on the SourceOp type. These must be /// overridden by the derived pattern class. - virtual void rewrite(SourceOp op, std::unique_ptr state, - PatternRewriter &rewriter) const { - rewrite(op, rewriter); - } virtual void rewrite(SourceOp op, PatternRewriter &rewriter) const { - llvm_unreachable("must override matchAndRewrite or a rewrite method"); + llvm_unreachable("must override rewrite or matchAndRewrite"); } virtual PatternMatchResult match(SourceOp op) const { llvm_unreachable("must override match or matchAndRewrite"); } virtual PatternMatchResult matchAndRewrite(SourceOp op, PatternRewriter &rewriter) const { - if (auto matchResult = match(op)) { - rewrite(op, std::move(*matchResult), rewriter); + if (succeeded(match(op))) { + rewrite(op, rewriter); return matchSuccess(); } return matchFailure(); diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -238,7 +238,7 @@ virtual PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { - if (!match(op)) + if (failed(match(op))) return matchFailure(); rewrite(op, operands, rewriter); return matchSuccess(); @@ -285,7 +285,7 @@ virtual PatternMatchResult matchAndRewrite(SourceOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { - if (!match(op)) + if (failed(match(op))) return matchFailure(); rewrite(op, operands, rewriter); return matchSuccess(); diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -805,7 +805,7 @@ // multiple times. auto success = matchAndRewrite(insertStridedSliceOp, rewriter); (void)success; - assert(success && "Unexpected failure"); + assert(succeeded(success) && "Unexpected failure"); extractedSource = insertStridedSliceOp; } // 4. Insert the extractedSource into the res vector. @@ -1083,7 +1083,7 @@ // multiple times. auto success = matchAndRewrite(stridedSliceOp, rewriter); (void)success; - assert(success && "Unexpected failure"); + assert(succeeded(success) && "Unexpected failure"); extracted = stridedSliceOp; } res = insertOne(rewriter, loc, extracted, res, idx); diff --git a/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp @@ -318,9 +318,8 @@ auto *falseBlock = brConditionalOp.getSuccessor(1); auto *mergeBlock = selectionOp.getMergeBlock(); - if (!canCanonicalizeSelection(trueBlock, falseBlock, mergeBlock)) { + if (failed(canCanonicalizeSelection(trueBlock, falseBlock, mergeBlock))) return matchFailure(); - } auto trueValue = getSrcValue(trueBlock); auto falseValue = getSrcValue(falseBlock); diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -39,11 +39,6 @@ // RewritePattern and PatternRewriter implementation //===----------------------------------------------------------------------===// -void RewritePattern::rewrite(Operation *op, std::unique_ptr state, - PatternRewriter &rewriter) const { - rewrite(op, rewriter); -} - void RewritePattern::rewrite(Operation *op, PatternRewriter &rewriter) const { llvm_unreachable("need to implement either matchAndRewrite or one of the " "rewrite functions!"); @@ -191,7 +186,7 @@ // Try to match and rewrite this pattern. The patterns are sorted by // benefit, so if we match we can immediately rewrite and return. - if (pattern->matchAndRewrite(op, rewriter)) + if (succeeded(pattern->matchAndRewrite(op, rewriter))) return true; } return false; diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -1237,12 +1237,12 @@ // Try to rewrite with the given pattern. rewriter.setInsertionPoint(op); - auto matchedPattern = pattern->matchAndRewrite(op, rewriter); + LogicalResult matchedPattern = pattern->matchAndRewrite(op, rewriter); #ifndef NDEBUG assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates"); #endif - if (!matchedPattern) { + if (failed(matchedPattern)) { LLVM_DEBUG(logFailure(rewriterImpl.logger, "pattern failed to match")); return cleanupFailure(); }