diff --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md --- a/mlir/docs/DialectConversion.md +++ b/mlir/docs/DialectConversion.md @@ -247,7 +247,7 @@ /// The `matchAndRewrite` hooks on ConversionPatterns take an additional /// `operands` parameter, containing the remapped operands of the original /// operation. - virtual PatternMatchResult + virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const; }; diff --git a/mlir/docs/QuickstartRewrites.md b/mlir/docs/QuickstartRewrites.md --- a/mlir/docs/QuickstartRewrites.md +++ b/mlir/docs/QuickstartRewrites.md @@ -171,8 +171,8 @@ ConvertTFLeakyRelu(MLIRContext *context) : RewritePattern("tf.LeakyRelu", 1, context) {} - PatternMatchResult match(Operation *op) const override { - return matchSuccess(); + LogicalResult match(Operation *op) const override { + return success(); } void rewrite(Operation *op, PatternRewriter &rewriter) const override { @@ -188,12 +188,12 @@ ConvertTFLeakyRelu(MLIRContext *context) : RewritePattern("tf.LeakyRelu", 1, context) {} - PatternMatchResult matchAndRewrite(Operation *op, + LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp( op, op->getResult(0).getType(), op->getOperand(0), /*alpha=*/op->getAttrOfType("alpha")); - return matchSuccess(); + return success(); } }; ``` diff --git a/mlir/docs/Tutorials/Toy/Ch-3.md b/mlir/docs/Tutorials/Toy/Ch-3.md --- a/mlir/docs/Tutorials/Toy/Ch-3.md +++ b/mlir/docs/Tutorials/Toy/Ch-3.md @@ -86,7 +86,7 @@ /// This method is attempting to match a pattern and rewrite it. The rewriter /// argument is the orchestrator of the sequence of rewrites. It is expected /// to interact with it to perform any changes to the IR from here. - mlir::PatternMatchResult + mlir::LogicalResult matchAndRewrite(TransposeOp op, mlir::PatternRewriter &rewriter) const override { // Look through the input of the current transpose. @@ -96,11 +96,11 @@ // Input defined by another transpose? If not, no match. if (!transposeInputOp) - return matchFailure(); + return failure(); // Otherwise, we have a redundant transpose. Use the rewriter. rewriter.replaceOp(op, {transposeInputOp.getOperand()}, {transposeInputOp}); - return matchSuccess(); + return success(); } }; ``` diff --git a/mlir/docs/Tutorials/Toy/Ch-5.md b/mlir/docs/Tutorials/Toy/Ch-5.md --- a/mlir/docs/Tutorials/Toy/Ch-5.md +++ b/mlir/docs/Tutorials/Toy/Ch-5.md @@ -106,7 +106,7 @@ /// Match and rewrite the given `toy.transpose` operation, with the given /// operands that have been remapped from `tensor<...>` to `memref<...>`. - mlir::PatternMatchResult + mlir::LogicalResult matchAndRewrite(mlir::Operation *op, ArrayRef operands, mlir::ConversionPatternRewriter &rewriter) const final { auto loc = op->getLoc(); @@ -132,7 +132,7 @@ SmallVector reverseIvs(llvm::reverse(loopIvs)); return rewriter.create(loc, input, reverseIvs); }); - return matchSuccess(); + return success(); } }; ``` diff --git a/mlir/examples/toy/Ch3/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch3/mlir/ToyCombine.cpp --- a/mlir/examples/toy/Ch3/mlir/ToyCombine.cpp +++ b/mlir/examples/toy/Ch3/mlir/ToyCombine.cpp @@ -35,7 +35,7 @@ /// This method attempts to match a pattern and rewrite it. The rewriter /// argument is the orchestrator of the sequence of rewrites. The pattern is /// expected to interact with it to perform any changes to the IR from here. - mlir::PatternMatchResult + mlir::LogicalResult matchAndRewrite(TransposeOp op, mlir::PatternRewriter &rewriter) const override { // Look through the input of the current transpose. @@ -45,11 +45,11 @@ // Input defined by another transpose? If not, no match. if (!transposeInputOp) - return matchFailure(); + return failure(); // Otherwise, we have a redundant transpose. Use the rewriter. rewriter.replaceOp(op, {transposeInputOp.getOperand()}); - return matchSuccess(); + return success(); } }; diff --git a/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp --- a/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp +++ b/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp @@ -40,7 +40,7 @@ /// This method attempts to match a pattern and rewrite it. The rewriter /// argument is the orchestrator of the sequence of rewrites. The pattern is /// expected to interact with it to perform any changes to the IR from here. - mlir::PatternMatchResult + mlir::LogicalResult matchAndRewrite(TransposeOp op, mlir::PatternRewriter &rewriter) const override { // Look through the input of the current transpose. @@ -50,11 +50,11 @@ // Input defined by another transpose? If not, no match. if (!transposeInputOp) - return matchFailure(); + return failure(); // Otherwise, we have a redundant transpose. Use the rewriter. rewriter.replaceOp(op, {transposeInputOp.getOperand()}); - return matchSuccess(); + return success(); } }; diff --git a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp --- a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp @@ -103,7 +103,7 @@ BinaryOpLowering(MLIRContext *ctx) : ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {} - PatternMatchResult + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { auto loc = op->getLoc(); @@ -126,7 +126,7 @@ // Create the binary operation performed on the loaded values. return rewriter.create(loc, loadedLhs, loadedRhs); }); - return matchSuccess(); + return success(); } }; using AddOpLowering = BinaryOpLowering; @@ -139,8 +139,8 @@ struct ConstantOpLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(toy::ConstantOp op, - PatternRewriter &rewriter) const final { + LogicalResult matchAndRewrite(toy::ConstantOp op, + PatternRewriter &rewriter) const final { DenseElementsAttr constantValue = op.value(); Location loc = op.getLoc(); @@ -189,7 +189,7 @@ // Replace this operation with the generated alloc. rewriter.replaceOp(op, alloc); - return matchSuccess(); + return success(); } }; @@ -200,16 +200,16 @@ struct ReturnOpLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(toy::ReturnOp op, - PatternRewriter &rewriter) const final { + LogicalResult matchAndRewrite(toy::ReturnOp op, + PatternRewriter &rewriter) const final { // During this lowering, we expect that all function calls have been // inlined. if (op.hasOperand()) - return matchFailure(); + return failure(); // We lower "toy.return" directly to "std.return". rewriter.replaceOpWithNewOp(op); - return matchSuccess(); + return success(); } }; @@ -221,7 +221,7 @@ TransposeOpLowering(MLIRContext *ctx) : ConversionPattern(toy::TransposeOp::getOperationName(), 1, ctx) {} - PatternMatchResult + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { auto loc = op->getLoc(); @@ -240,7 +240,7 @@ SmallVector reverseIvs(llvm::reverse(loopIvs)); return rewriter.create(loc, input, reverseIvs); }); - return matchSuccess(); + return success(); } }; diff --git a/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp --- a/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp +++ b/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp @@ -40,7 +40,7 @@ /// This method attempts to match a pattern and rewrite it. The rewriter /// argument is the orchestrator of the sequence of rewrites. The pattern is /// expected to interact with it to perform any changes to the IR from here. - mlir::PatternMatchResult + mlir::LogicalResult matchAndRewrite(TransposeOp op, mlir::PatternRewriter &rewriter) const override { // Look through the input of the current transpose. @@ -50,11 +50,11 @@ // Input defined by another transpose? If not, no match. if (!transposeInputOp) - return matchFailure(); + return failure(); // Otherwise, we have a redundant transpose. Use the rewriter. rewriter.replaceOp(op, {transposeInputOp.getOperand()}); - return matchSuccess(); + return success(); } }; diff --git a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp --- a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp @@ -103,7 +103,7 @@ BinaryOpLowering(MLIRContext *ctx) : ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {} - PatternMatchResult + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { auto loc = op->getLoc(); @@ -126,7 +126,7 @@ // Create the binary operation performed on the loaded values. return rewriter.create(loc, loadedLhs, loadedRhs); }); - return matchSuccess(); + return success(); } }; using AddOpLowering = BinaryOpLowering; @@ -139,8 +139,8 @@ struct ConstantOpLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(toy::ConstantOp op, - PatternRewriter &rewriter) const final { + LogicalResult matchAndRewrite(toy::ConstantOp op, + PatternRewriter &rewriter) const final { DenseElementsAttr constantValue = op.value(); Location loc = op.getLoc(); @@ -189,7 +189,7 @@ // Replace this operation with the generated alloc. rewriter.replaceOp(op, alloc); - return matchSuccess(); + return success(); } }; @@ -200,16 +200,16 @@ struct ReturnOpLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(toy::ReturnOp op, - PatternRewriter &rewriter) const final { + LogicalResult matchAndRewrite(toy::ReturnOp op, + PatternRewriter &rewriter) const final { // During this lowering, we expect that all function calls have been // inlined. if (op.hasOperand()) - return matchFailure(); + return failure(); // We lower "toy.return" directly to "std.return". rewriter.replaceOpWithNewOp(op); - return matchSuccess(); + return success(); } }; @@ -221,7 +221,7 @@ TransposeOpLowering(MLIRContext *ctx) : ConversionPattern(toy::TransposeOp::getOperationName(), 1, ctx) {} - PatternMatchResult + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { auto loc = op->getLoc(); @@ -240,7 +240,7 @@ SmallVector reverseIvs(llvm::reverse(loopIvs)); return rewriter.create(loc, input, reverseIvs); }); - return matchSuccess(); + return success(); } }; diff --git a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp --- a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp +++ b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp @@ -41,7 +41,7 @@ explicit PrintOpLowering(MLIRContext *context) : ConversionPattern(toy::PrintOp::getOperationName(), 1, context) {} - PatternMatchResult + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto memRefType = (*op->operand_type_begin()).cast(); @@ -91,7 +91,7 @@ // Notify the rewriter that this operation has been removed. rewriter.eraseOp(op); - return matchSuccess(); + return success(); } private: diff --git a/mlir/examples/toy/Ch6/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch6/mlir/ToyCombine.cpp --- a/mlir/examples/toy/Ch6/mlir/ToyCombine.cpp +++ b/mlir/examples/toy/Ch6/mlir/ToyCombine.cpp @@ -40,7 +40,7 @@ /// This method attempts to match a pattern and rewrite it. The rewriter /// argument is the orchestrator of the sequence of rewrites. The pattern is /// expected to interact with it to perform any changes to the IR from here. - mlir::PatternMatchResult + mlir::LogicalResult matchAndRewrite(TransposeOp op, mlir::PatternRewriter &rewriter) const override { // Look through the input of the current transpose. @@ -50,11 +50,11 @@ // Input defined by another transpose? If not, no match. if (!transposeInputOp) - return matchFailure(); + return failure(); // Otherwise, we have a redundant transpose. Use the rewriter. rewriter.replaceOp(op, {transposeInputOp.getOperand()}); - return matchSuccess(); + return success(); } }; diff --git a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp --- a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp @@ -103,7 +103,7 @@ BinaryOpLowering(MLIRContext *ctx) : ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {} - PatternMatchResult + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { auto loc = op->getLoc(); @@ -126,7 +126,7 @@ // Create the binary operation performed on the loaded values. return rewriter.create(loc, loadedLhs, loadedRhs); }); - return matchSuccess(); + return success(); } }; using AddOpLowering = BinaryOpLowering; @@ -139,8 +139,8 @@ struct ConstantOpLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(toy::ConstantOp op, - PatternRewriter &rewriter) const final { + LogicalResult matchAndRewrite(toy::ConstantOp op, + PatternRewriter &rewriter) const final { DenseElementsAttr constantValue = op.value(); Location loc = op.getLoc(); @@ -189,7 +189,7 @@ // Replace this operation with the generated alloc. rewriter.replaceOp(op, alloc); - return matchSuccess(); + return success(); } }; @@ -200,16 +200,16 @@ struct ReturnOpLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(toy::ReturnOp op, - PatternRewriter &rewriter) const final { + LogicalResult matchAndRewrite(toy::ReturnOp op, + PatternRewriter &rewriter) const final { // During this lowering, we expect that all function calls have been // inlined. if (op.hasOperand()) - return matchFailure(); + return failure(); // We lower "toy.return" directly to "std.return". rewriter.replaceOpWithNewOp(op); - return matchSuccess(); + return success(); } }; @@ -221,7 +221,7 @@ TransposeOpLowering(MLIRContext *ctx) : ConversionPattern(toy::TransposeOp::getOperationName(), 1, ctx) {} - PatternMatchResult + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { auto loc = op->getLoc(); @@ -240,7 +240,7 @@ SmallVector reverseIvs(llvm::reverse(loopIvs)); return rewriter.create(loc, input, reverseIvs); }); - return matchSuccess(); + return success(); } }; diff --git a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp --- a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp +++ b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp @@ -41,7 +41,7 @@ explicit PrintOpLowering(MLIRContext *context) : ConversionPattern(toy::PrintOp::getOperationName(), 1, context) {} - PatternMatchResult + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto memRefType = (*op->operand_type_begin()).cast(); @@ -91,7 +91,7 @@ // Notify the rewriter that this operation has been removed. rewriter.eraseOp(op); - return matchSuccess(); + return success(); } private: diff --git a/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp --- a/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp +++ b/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp @@ -58,7 +58,7 @@ /// This method attempts to match a pattern and rewrite it. The rewriter /// argument is the orchestrator of the sequence of rewrites. The pattern is /// expected to interact with it to perform any changes to the IR from here. - mlir::PatternMatchResult + mlir::LogicalResult matchAndRewrite(TransposeOp op, mlir::PatternRewriter &rewriter) const override { // Look through the input of the current transpose. @@ -68,11 +68,11 @@ // Input defined by another transpose? If not, no match. if (!transposeInputOp) - return matchFailure(); + return failure(); // Otherwise, we have a redundant transpose. Use the rewriter. rewriter.replaceOp(op, {transposeInputOp.getOperand()}); - return matchSuccess(); + return success(); } }; diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td --- a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransformPatterns.td @@ -53,7 +53,7 @@ "if (failed(tileAndFuseLinalgOpAndSetMarker($_builder, op, {" # StrJoinInt.result # "}, {" # StrJoinInt.result # "}," # " \"" # value # "\")))" # - " return matchFailure();">; + " return failure();">; //===----------------------------------------------------------------------===// // Linalg tiling patterns. @@ -70,22 +70,22 @@ "if (failed(tileLinalgOpAndSetMarker($_builder, op, {" # StrJoinInt.result # "}, \"" # value # "\", {" # StrJoinInt.result # "})))" # - " return matchFailure();">; + " return failure();">; //===----------------------------------------------------------------------===// // Linalg to loop patterns. //===----------------------------------------------------------------------===// class LinalgOpToLoops : NativeCodeCall< "if (failed(linalgOpToLoops<" # OpType # ">($_builder, op))) " # - " return matchFailure();">; + " return failure();">; class LinalgOpToParallelLoops : NativeCodeCall< "if (failed(linalgOpToParallelLoops<" # OpType # ">($_builder, op))) " # - " return matchFailure();">; + " return failure();">; class LinalgOpToAffineLoops : NativeCodeCall< "if (failed(linalgOpToAffineLoops<" # OpType # ">($_builder, op))) " # - " return matchFailure();">; + " return failure();">; //===----------------------------------------------------------------------===// // Linalg to vector patterns precondition and DRR. diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -54,10 +54,6 @@ unsigned short representation; }; -/// This is the type returned by a pattern match. -/// TODO: Replace usages with LogicalResult directly. -using PatternMatchResult = LogicalResult; - //===----------------------------------------------------------------------===// // Pattern class //===----------------------------------------------------------------------===// @@ -85,20 +81,10 @@ /// Attempt to match against code rooted at the specified operation, /// which is the same operation code as getRootKind(). - virtual PatternMatchResult match(Operation *op) const = 0; + virtual LogicalResult match(Operation *op) const = 0; virtual ~Pattern() {} - //===--------------------------------------------------------------------===// - // Helper methods to simplify pattern implementations - //===--------------------------------------------------------------------===// - - /// Return a result, indicating that no match was found. - PatternMatchResult matchFailure() const { return failure(); } - - /// This method indicates that a match was found. - PatternMatchResult matchSuccess() const { return success(); } - protected: /// Patterns must specify the root operation name they match against, and can /// also specify the benefit of the pattern matching. @@ -130,22 +116,19 @@ virtual void rewrite(Operation *op, PatternRewriter &rewriter) const; /// Attempt to match against code rooted at the specified operation, - /// which is the same operation code as getRootKind(). On failure, this - /// returns a None value. On success, it returns a (possibly null) - /// pattern-specific state wrapped in an Optional. This state is passed back - /// into the rewrite function if this match is selected. - PatternMatchResult match(Operation *op) const override; + /// which is the same operation code as getRootKind(). + LogicalResult match(Operation *op) const override; /// Attempt to match against code rooted at the specified operation, /// which is the same operation code as getRootKind(). If successful, this /// function will automatically perform the rewrite. - virtual PatternMatchResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const { + virtual LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const { if (succeeded(match(op))) { rewrite(op, rewriter); - return matchSuccess(); + return success(); } - return matchFailure(); + return failure(); } /// Return a list of operations that may be generated when rewriting an @@ -182,11 +165,11 @@ void rewrite(Operation *op, PatternRewriter &rewriter) const final { rewrite(cast(op), rewriter); } - PatternMatchResult match(Operation *op) const final { + LogicalResult match(Operation *op) const final { return match(cast(op)); } - PatternMatchResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const final { + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const final { return matchAndRewrite(cast(op), rewriter); } @@ -195,16 +178,16 @@ virtual void rewrite(SourceOp op, PatternRewriter &rewriter) const { llvm_unreachable("must override rewrite or matchAndRewrite"); } - virtual PatternMatchResult match(SourceOp op) const { + virtual LogicalResult match(SourceOp op) const { llvm_unreachable("must override match or matchAndRewrite"); } - virtual PatternMatchResult matchAndRewrite(SourceOp op, - PatternRewriter &rewriter) const { + virtual LogicalResult matchAndRewrite(SourceOp op, + PatternRewriter &rewriter) const { if (succeeded(match(op))) { rewrite(op, rewriter); - return matchSuccess(); + return success(); } - return matchFailure(); + return failure(); } }; diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -235,18 +235,18 @@ } /// Hook for derived classes to implement combined matching and rewriting. - virtual PatternMatchResult + virtual LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { if (failed(match(op))) - return matchFailure(); + return failure(); rewrite(op, operands, rewriter); - return matchSuccess(); + return success(); } /// Attempt to match and rewrite the IR root at the specified operation. - PatternMatchResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const final; + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const final; private: using RewritePattern::rewrite; @@ -266,7 +266,7 @@ ConversionPatternRewriter &rewriter) const final { rewrite(cast(op), operands, rewriter); } - PatternMatchResult + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { return matchAndRewrite(cast(op), operands, rewriter); @@ -282,13 +282,13 @@ llvm_unreachable("must override matchAndRewrite or a rewrite method"); } - virtual PatternMatchResult + virtual LogicalResult matchAndRewrite(SourceOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { if (failed(match(op))) - return matchFailure(); + return failure(); rewrite(op, operands, rewriter); - return matchSuccess(); + return success(); } private: diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp --- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp +++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp @@ -297,15 +297,15 @@ public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(AffineMinOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(AffineMinOp op, + PatternRewriter &rewriter) const override { Value reduced = lowerAffineMapMin(rewriter, op.getLoc(), op.map(), op.operands()); if (!reduced) - return matchFailure(); + return failure(); rewriter.replaceOp(op, reduced); - return matchSuccess(); + return success(); } }; @@ -313,15 +313,15 @@ public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(AffineMaxOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(AffineMaxOp op, + PatternRewriter &rewriter) const override { Value reduced = lowerAffineMapMax(rewriter, op.getLoc(), op.map(), op.operands()); if (!reduced) - return matchFailure(); + return failure(); rewriter.replaceOp(op, reduced); - return matchSuccess(); + return success(); } }; @@ -330,10 +330,10 @@ public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(AffineTerminatorOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(AffineTerminatorOp op, + PatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op); - return matchSuccess(); + return success(); } }; @@ -341,8 +341,8 @@ public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(AffineForOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(AffineForOp op, + PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value lowerBound = lowerAffineLowerBound(op, rewriter); Value upperBound = lowerAffineUpperBound(op, rewriter); @@ -351,7 +351,7 @@ f.region().getBlocks().clear(); rewriter.inlineRegionBefore(op.region(), f.region(), f.region().end()); rewriter.eraseOp(op); - return matchSuccess(); + return success(); } }; @@ -359,8 +359,8 @@ public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(AffineIfOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(AffineIfOp op, + PatternRewriter &rewriter) const override { auto loc = op.getLoc(); // Now we just have to handle the condition logic. @@ -381,7 +381,7 @@ operandsRef.take_front(numDims), operandsRef.drop_front(numDims)); if (!affResult) - return matchFailure(); + return failure(); auto pred = isEquality ? CmpIPredicate::eq : CmpIPredicate::sge; Value cmpVal = rewriter.create(loc, pred, affResult, zeroConstant); @@ -402,7 +402,7 @@ // Ok, we're done! rewriter.eraseOp(op); - return matchSuccess(); + return success(); } }; @@ -412,15 +412,15 @@ public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(AffineApplyOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(AffineApplyOp op, + PatternRewriter &rewriter) const override { auto maybeExpandedMap = expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), llvm::to_vector<8>(op.getOperands())); if (!maybeExpandedMap) - return matchFailure(); + return failure(); rewriter.replaceOp(op, *maybeExpandedMap); - return matchSuccess(); + return success(); } }; @@ -431,18 +431,18 @@ public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(AffineLoadOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(AffineLoadOp op, + PatternRewriter &rewriter) const override { // Expand affine map from 'affineLoadOp'. SmallVector indices(op.getMapOperands()); auto resultOperands = expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices); if (!resultOperands) - return matchFailure(); + return failure(); // Build std.load memref[expandedMap.results]. rewriter.replaceOpWithNewOp(op, op.getMemRef(), *resultOperands); - return matchSuccess(); + return success(); } }; @@ -453,20 +453,20 @@ public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(AffinePrefetchOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(AffinePrefetchOp op, + PatternRewriter &rewriter) const override { // Expand affine map from 'affinePrefetchOp'. SmallVector indices(op.getMapOperands()); auto resultOperands = expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices); if (!resultOperands) - return matchFailure(); + return failure(); // Build std.prefetch memref[expandedMap.results]. rewriter.replaceOpWithNewOp( op, op.memref(), *resultOperands, op.isWrite(), op.localityHint().getZExtValue(), op.isDataCache()); - return matchSuccess(); + return success(); } }; @@ -477,19 +477,19 @@ public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(AffineStoreOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(AffineStoreOp op, + PatternRewriter &rewriter) const override { // Expand affine map from 'affineStoreOp'. SmallVector indices(op.getMapOperands()); auto maybeExpandedMap = expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices); if (!maybeExpandedMap) - return matchFailure(); + return failure(); // Build std.store valueToStore, memref[expandedMap.results]. rewriter.replaceOpWithNewOp(op, op.getValueToStore(), op.getMemRef(), *maybeExpandedMap); - return matchSuccess(); + return success(); } }; @@ -500,8 +500,8 @@ public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(AffineDmaStartOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(AffineDmaStartOp op, + PatternRewriter &rewriter) const override { SmallVector operands(op.getOperands()); auto operandsRef = llvm::makeArrayRef(operands); @@ -510,26 +510,26 @@ rewriter, op.getLoc(), op.getSrcMap(), operandsRef.drop_front(op.getSrcMemRefOperandIndex() + 1)); if (!maybeExpandedSrcMap) - return matchFailure(); + return failure(); // Expand affine map for DMA destination memref. auto maybeExpandedDstMap = expandAffineMap( rewriter, op.getLoc(), op.getDstMap(), operandsRef.drop_front(op.getDstMemRefOperandIndex() + 1)); if (!maybeExpandedDstMap) - return matchFailure(); + return failure(); // Expand affine map for DMA tag memref. auto maybeExpandedTagMap = expandAffineMap( rewriter, op.getLoc(), op.getTagMap(), operandsRef.drop_front(op.getTagMemRefOperandIndex() + 1)); if (!maybeExpandedTagMap) - return matchFailure(); + return failure(); // Build std.dma_start operation with affine map results. rewriter.replaceOpWithNewOp( op, op.getSrcMemRef(), *maybeExpandedSrcMap, op.getDstMemRef(), *maybeExpandedDstMap, op.getNumElements(), op.getTagMemRef(), *maybeExpandedTagMap, op.getStride(), op.getNumElementsPerStride()); - return matchSuccess(); + return success(); } }; @@ -540,19 +540,19 @@ public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(AffineDmaWaitOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(AffineDmaWaitOp op, + PatternRewriter &rewriter) const override { // Expand affine map for DMA tag memref. SmallVector indices(op.getTagIndices()); auto maybeExpandedTagMap = expandAffineMap(rewriter, op.getLoc(), op.getTagMap(), indices); if (!maybeExpandedTagMap) - return matchFailure(); + return failure(); // Build std.dma_wait operation with affine map results. rewriter.replaceOpWithNewOp( op, op.getTagMemRef(), *maybeExpandedTagMap, op.getNumElements()); - return matchSuccess(); + return success(); } }; diff --git a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h --- a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h +++ b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h @@ -46,7 +46,7 @@ indexBitwidth(getIndexBitWidth(lowering_)) {} // Convert the kernel arguments to an LLVM type, preserve the rest. - PatternMatchResult + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); @@ -63,7 +63,7 @@ newOp = rewriter.create(loc, LLVM::LLVMType::getInt32Ty(dialect)); break; default: - return matchFailure(); + return failure(); } if (indexBitwidth > 32) { @@ -75,7 +75,7 @@ } rewriter.replaceOp(op, {newOp}); - return matchSuccess(); + return success(); } }; diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h --- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h +++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h @@ -34,7 +34,7 @@ lowering_.getDialect()->getContext(), lowering_), f32Func(f32Func), f64Func(f64Func) {} - PatternMatchResult + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { using LLVM::LLVMFuncOp; @@ -49,13 +49,13 @@ LLVMType funcType = getFunctionType(resultType, operands); StringRef funcName = getFunctionName(resultType); if (funcName.empty()) - return matchFailure(); + return failure(); LLVMFuncOp funcOp = appendOrGetFuncOp(funcName, funcType, op); auto callOp = rewriter.create( op->getLoc(), resultType, rewriter.getSymbolRefAttr(funcOp), operands); rewriter.replaceOp(op, {callOp.getResult(0)}); - return matchSuccess(); + return success(); } private: diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -51,7 +51,7 @@ /// !llvm<"{ float, i1 }"> /// %shfl_pred = llvm.extractvalue %shfl[1 : index] : /// !llvm<"{ float, i1 }"> - PatternMatchResult + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); @@ -84,7 +84,7 @@ loc, predTy, shfl, rewriter.getIndexArrayAttr(1)); rewriter.replaceOp(op, {shflValue, isActiveSrcLane}); - return matchSuccess(); + return success(); } }; @@ -94,7 +94,7 @@ typeConverter.getDialect()->getContext(), typeConverter) {} - PatternMatchResult + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { assert(operands.empty() && "func op is not expected to have operands"); @@ -219,7 +219,7 @@ signatureConversion); rewriter.eraseOp(gpuFuncOp); - return matchSuccess(); + return success(); } }; @@ -229,11 +229,11 @@ typeConverter.getDialect()->getContext(), typeConverter) {} - PatternMatchResult + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, operands); - return matchSuccess(); + return success(); } }; diff --git a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp --- a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp @@ -26,7 +26,7 @@ public: using SPIRVOpLowering::SPIRVOpLowering; - PatternMatchResult + LogicalResult matchAndRewrite(loop::ForOp forOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; @@ -37,7 +37,7 @@ public: using SPIRVOpLowering::SPIRVOpLowering; - PatternMatchResult + LogicalResult matchAndRewrite(loop::IfOp IfOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; @@ -47,11 +47,11 @@ public: using SPIRVOpLowering::SPIRVOpLowering; - PatternMatchResult + LogicalResult matchAndRewrite(loop::YieldOp terminatorOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { rewriter.eraseOp(terminatorOp); - return matchSuccess(); + return success(); } }; @@ -62,7 +62,7 @@ public: using SPIRVOpLowering::SPIRVOpLowering; - PatternMatchResult + LogicalResult matchAndRewrite(SourceOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; @@ -75,7 +75,7 @@ public: using SPIRVOpLowering::SPIRVOpLowering; - PatternMatchResult + LogicalResult matchAndRewrite(gpu::BlockDimOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; @@ -85,7 +85,7 @@ public: using SPIRVOpLowering::SPIRVOpLowering; - PatternMatchResult + LogicalResult matchAndRewrite(gpu::GPUFuncOp funcOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; @@ -98,7 +98,7 @@ public: using SPIRVOpLowering::SPIRVOpLowering; - PatternMatchResult + LogicalResult matchAndRewrite(gpu::GPUModuleOp moduleOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; @@ -109,7 +109,7 @@ public: using SPIRVOpLowering::SPIRVOpLowering; - PatternMatchResult + LogicalResult matchAndRewrite(gpu::ReturnOp returnOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; @@ -120,7 +120,7 @@ // loop::ForOp. //===----------------------------------------------------------------------===// -PatternMatchResult +LogicalResult ForOpConversion::matchAndRewrite(loop::ForOp forOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { // loop::ForOp can be lowered to the structured control flow represented by @@ -186,14 +186,14 @@ rewriter.create(loc, header, updatedIndVar); rewriter.eraseOp(forOp); - return matchSuccess(); + return success(); } //===----------------------------------------------------------------------===// // loop::IfOp. //===----------------------------------------------------------------------===// -PatternMatchResult +LogicalResult IfOpConversion::matchAndRewrite(loop::IfOp ifOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { // When lowering `loop::IfOp` we explicitly create a selection header block @@ -238,7 +238,7 @@ elseBlock, ArrayRef()); rewriter.eraseOp(ifOp); - return matchSuccess(); + return success(); } //===----------------------------------------------------------------------===// @@ -261,36 +261,36 @@ } template -PatternMatchResult LaunchConfigConversion::matchAndRewrite( +LogicalResult LaunchConfigConversion::matchAndRewrite( SourceOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { auto index = getLaunchConfigIndex(op); if (!index) - return this->matchFailure(); + return failure(); // SPIR-V invocation builtin variables are a vector of type <3xi32> auto spirvBuiltin = spirv::getBuiltinVariableValue(op, builtin, rewriter); rewriter.replaceOpWithNewOp( op, rewriter.getIntegerType(32), spirvBuiltin, rewriter.getI32ArrayAttr({index.getValue()})); - return this->matchSuccess(); + return success(); } -PatternMatchResult WorkGroupSizeConversion::matchAndRewrite( +LogicalResult WorkGroupSizeConversion::matchAndRewrite( gpu::BlockDimOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { auto index = getLaunchConfigIndex(op); if (!index) - return matchFailure(); + return failure(); auto workGroupSizeAttr = spirv::lookupLocalWorkGroupSize(op); auto val = workGroupSizeAttr.getValue(index.getValue()); auto convertedType = typeConverter.convertType(op.getResult().getType()); if (!convertedType) - return matchFailure(); + return failure(); rewriter.replaceOpWithNewOp( op, convertedType, IntegerAttr::get(convertedType, val)); - return matchSuccess(); + return success(); } //===----------------------------------------------------------------------===// @@ -343,11 +343,11 @@ return newFuncOp; } -PatternMatchResult GPUFuncOpConversion::matchAndRewrite( +LogicalResult GPUFuncOpConversion::matchAndRewrite( gpu::GPUFuncOp funcOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { if (!gpu::GPUDialect::isKernel(funcOp)) - return matchFailure(); + return failure(); SmallVector argABI; for (auto argNum : llvm::seq(0, funcOp.getNumArguments())) { @@ -358,22 +358,22 @@ auto entryPointAttr = spirv::lookupEntryPointABI(funcOp); if (!entryPointAttr) { funcOp.emitRemark("match failure: missing 'spv.entry_point_abi' attribute"); - return matchFailure(); + return failure(); } spirv::FuncOp newFuncOp = lowerAsEntryFunction( funcOp, typeConverter, rewriter, entryPointAttr, argABI); if (!newFuncOp) - return matchFailure(); + return failure(); newFuncOp.removeAttr(Identifier::get(gpu::GPUDialect::getKernelFuncAttrName(), rewriter.getContext())); - return matchSuccess(); + return success(); } //===----------------------------------------------------------------------===// // ModuleOp with gpu.module. //===----------------------------------------------------------------------===// -PatternMatchResult GPUModuleConversion::matchAndRewrite( +LogicalResult GPUModuleConversion::matchAndRewrite( gpu::GPUModuleOp moduleOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { auto spvModule = rewriter.create( @@ -389,21 +389,21 @@ // legalized later. spvModuleRegion.back().erase(); rewriter.eraseOp(moduleOp); - return matchSuccess(); + return success(); } //===----------------------------------------------------------------------===// // GPU return inside kernel functions to SPIR-V return. //===----------------------------------------------------------------------===// -PatternMatchResult GPUReturnOpConversion::matchAndRewrite( +LogicalResult GPUReturnOpConversion::matchAndRewrite( gpu::ReturnOp returnOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { if (!operands.empty()) - return matchFailure(); + return failure(); rewriter.replaceOpWithNewOp(returnOp); - return matchSuccess(); + return success(); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp --- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp +++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp @@ -130,7 +130,7 @@ explicit RangeOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_) : ConvertToLLVMPattern(RangeOp::getOperationName(), context, lowering_) {} - PatternMatchResult + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto rangeOp = cast(op); @@ -146,7 +146,7 @@ desc = llvm_insertvalue(desc, adaptor.max(), rewriter.getI64ArrayAttr(1)); desc = llvm_insertvalue(desc, adaptor.step(), rewriter.getI64ArrayAttr(2)); rewriter.replaceOp(op, desc); - return matchSuccess(); + return success(); } }; @@ -160,14 +160,14 @@ : ConvertToLLVMPattern(ReshapeOp::getOperationName(), context, lowering_) {} - PatternMatchResult + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto reshapeOp = cast(op); MemRefType dstType = reshapeOp.getResult().getType().cast(); if (!dstType.hasStaticShape()) - return matchFailure(); + return failure(); int64_t offset; SmallVector strides; @@ -175,7 +175,7 @@ if (failed(res) || llvm::any_of(strides, [](int64_t val) { return ShapedType::isDynamicStrideOrOffset(val); })) - return matchFailure(); + return failure(); edsc::ScopedContext context(rewriter, op->getLoc()); ReshapeOpOperandAdaptor adaptor(operands); @@ -189,7 +189,7 @@ for (auto en : llvm::enumerate(strides)) desc.setConstantStride(en.index(), en.value()); rewriter.replaceOp(op, {desc}); - return matchSuccess(); + return success(); } }; @@ -204,7 +204,7 @@ explicit SliceOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_) : ConvertToLLVMPattern(SliceOp::getOperationName(), context, lowering_) {} - PatternMatchResult + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { edsc::ScopedContext context(rewriter, op->getLoc()); @@ -247,7 +247,7 @@ // Corner case, no sizes or strides: early return the descriptor. if (sliceOp.getShapedType().getRank() == 0) - return rewriter.replaceOp(op, {desc}), matchSuccess(); + return rewriter.replaceOp(op, {desc}), success(); Value zero = llvm_constant( int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 0)); @@ -279,7 +279,7 @@ } rewriter.replaceOp(op, {desc}); - return matchSuccess(); + return success(); } }; @@ -297,7 +297,7 @@ : ConvertToLLVMPattern(TransposeOp::getOperationName(), context, lowering_) {} - PatternMatchResult + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { // Initialize the common boilerplate and alloca at the top of the FuncOp. @@ -308,7 +308,7 @@ auto transposeOp = cast(op); // No permutation, early exit. if (transposeOp.permutation().isIdentity()) - return rewriter.replaceOp(op, {baseDesc}), matchSuccess(); + return rewriter.replaceOp(op, {baseDesc}), success(); BaseViewConversionHelper desc( typeConverter.convertType(transposeOp.getShapedType())); @@ -330,7 +330,7 @@ } rewriter.replaceOp(op, {desc}); - return matchSuccess(); + return success(); } }; @@ -340,11 +340,11 @@ explicit YieldOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_) : ConvertToLLVMPattern(YieldOp::getOperationName(), context, lowering_) {} - PatternMatchResult + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, operands); - return matchSuccess(); + return success(); } }; } // namespace @@ -416,15 +416,15 @@ public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(LinalgOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(LinalgOp op, + PatternRewriter &rewriter) const override { auto libraryCallName = getLibraryCallSymbolRef(op, rewriter); if (!libraryCallName) - return this->matchFailure(); + return failure(); rewriter.replaceOpWithNewOp( op, libraryCallName.getValue(), ArrayRef{}, op.getOperands()); - return this->matchSuccess(); + return success(); } }; @@ -434,22 +434,22 @@ public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(CopyOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(CopyOp op, + PatternRewriter &rewriter) const override { auto inputPerm = op.inputPermutation(); if (inputPerm.hasValue() && !inputPerm->isIdentity()) - return matchFailure(); + return failure(); auto outputPerm = op.outputPermutation(); if (outputPerm.hasValue() && !outputPerm->isIdentity()) - return matchFailure(); + return failure(); auto libraryCallName = getLibraryCallSymbolRef(op, rewriter); if (!libraryCallName) - return matchFailure(); + return failure(); rewriter.replaceOpWithNewOp( op, libraryCallName.getValue(), ArrayRef{}, op.getOperands()); - return matchSuccess(); + return success(); } }; @@ -460,12 +460,12 @@ public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(IndexedGenericOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(IndexedGenericOp op, + PatternRewriter &rewriter) const override { auto libraryCallName = getLibraryCallSymbolRef(op, rewriter); if (!libraryCallName) - return this->matchFailure(); + return failure(); // TODO(pifon, ntv): Use induction variables values instead of zeros, when // IndexedGenericOp is tiled. @@ -483,7 +483,7 @@ } rewriter.replaceOpWithNewOp(op, libraryCallName.getValue(), ArrayRef{}, operands); - return this->matchSuccess(); + return success(); } }; @@ -495,8 +495,8 @@ public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(CopyOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(CopyOp op, + PatternRewriter &rewriter) const override { Value in = op.input(), out = op.output(); // If either inputPerm or outputPerm are non-identities, insert transposes. @@ -511,10 +511,10 @@ // If nothing was transposed, fail and let the conversion kick in. if (in == op.input() && out == op.output()) - return matchFailure(); + return failure(); rewriter.replaceOpWithNewOp(op, in, out); - return matchSuccess(); + return success(); } }; diff --git a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp --- a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp +++ b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp @@ -54,7 +54,7 @@ static Optional matchAsPerformingReduction(linalg::GenericOp genericOp); - PatternMatchResult + LogicalResult matchAndRewrite(linalg::GenericOp genericOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; @@ -109,7 +109,7 @@ return linalg::RegionMatcher::matchAsScalarBinaryOp(genericOp); } -PatternMatchResult SingleWorkgroupReduction::matchAndRewrite( +LogicalResult SingleWorkgroupReduction::matchAndRewrite( linalg::GenericOp genericOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { Operation *op = genericOp.getOperation(); @@ -118,19 +118,19 @@ auto binaryOpKind = matchAsPerformingReduction(genericOp); if (!binaryOpKind) - return matchFailure(); + return failure(); // Query the shader interface for local workgroup size to make sure the // invocation configuration fits with the input memref's shape. DenseIntElementsAttr localSize = spirv::lookupLocalWorkGroupSize(genericOp); if (!localSize) - return matchFailure(); + return failure(); if ((*localSize.begin()).getSExtValue() != originalInputType.getDimSize(0)) - return matchFailure(); + return failure(); if (llvm::any_of(llvm::drop_begin(localSize.getIntValues(), 1), [](const APInt &size) { return !size.isOneValue(); })) - return matchFailure(); + return failure(); // TODO(antiagainst): Query the target environment to make sure the current // workload fits in a local workgroup. @@ -195,7 +195,7 @@ spirv::SelectionOp::createIfThen(loc, condition, createAtomicOp, &rewriter); rewriter.eraseOp(genericOp); - return matchSuccess(); + return success(); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp b/mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp --- a/mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp +++ b/mlir/lib/Conversion/LoopToStandard/ConvertLoopToStandard.cpp @@ -98,8 +98,8 @@ struct ForLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(ForOp forOp, - PatternRewriter &rewriter) const override; + LogicalResult matchAndRewrite(ForOp forOp, + PatternRewriter &rewriter) const override; }; // Create a CFG subgraph for the loop.if operation (including its "then" and @@ -147,20 +147,20 @@ struct IfLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(IfOp ifOp, - PatternRewriter &rewriter) const override; + LogicalResult matchAndRewrite(IfOp ifOp, + PatternRewriter &rewriter) const override; }; struct ParallelLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(mlir::loop::ParallelOp parallelOp, - PatternRewriter &rewriter) const override; + LogicalResult matchAndRewrite(mlir::loop::ParallelOp parallelOp, + PatternRewriter &rewriter) const override; }; } // namespace -PatternMatchResult -ForLowering::matchAndRewrite(ForOp forOp, PatternRewriter &rewriter) const { +LogicalResult ForLowering::matchAndRewrite(ForOp forOp, + PatternRewriter &rewriter) const { Location loc = forOp.getLoc(); // Start by splitting the block containing the 'loop.for' into two parts. @@ -189,7 +189,7 @@ auto step = forOp.step(); auto stepped = rewriter.create(loc, iv, step).getResult(); if (!stepped) - return matchFailure(); + return failure(); SmallVector loopCarried; loopCarried.push_back(stepped); @@ -202,7 +202,7 @@ Value lowerBound = forOp.lowerBound(); Value upperBound = forOp.upperBound(); if (!lowerBound || !upperBound) - return matchFailure(); + return failure(); // The initial values of loop-carried values is obtained from the operands // of the loop operation. @@ -222,11 +222,11 @@ // The result of the loop operation is the values of the condition block // arguments except the induction variable on the last iteration. rewriter.replaceOp(forOp, conditionBlock->getArguments().drop_front()); - return matchSuccess(); + return success(); } -PatternMatchResult -IfLowering::matchAndRewrite(IfOp ifOp, PatternRewriter &rewriter) const { +LogicalResult IfLowering::matchAndRewrite(IfOp ifOp, + PatternRewriter &rewriter) const { auto loc = ifOp.getLoc(); // Start by splitting the block containing the 'loop.if' into two parts. @@ -265,10 +265,10 @@ // Ok, we're done! rewriter.eraseOp(ifOp); - return matchSuccess(); + return success(); } -PatternMatchResult +LogicalResult ParallelLowering::matchAndRewrite(ParallelOp parallelOp, PatternRewriter &rewriter) const { Location loc = parallelOp.getLoc(); @@ -344,7 +344,7 @@ rewriter.replaceOp(parallelOp, loopResults); - return matchSuccess(); + return success(); } void mlir::populateLoopToStdConversionPatterns( diff --git a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp --- a/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp +++ b/mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp @@ -497,8 +497,8 @@ struct ParallelToGpuLaunchLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(ParallelOp parallelOp, - PatternRewriter &rewriter) const override; + LogicalResult matchAndRewrite(ParallelOp parallelOp, + PatternRewriter &rewriter) const override; }; struct MappingAnnotation { @@ -742,7 +742,7 @@ /// the actual loop bound. This only works if an static upper bound for the /// dynamic loop bound can be defived, currently via analyzing `affine.min` /// operations. -PatternMatchResult +LogicalResult ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp, PatternRewriter &rewriter) const { // Create a launch operation. We start with bound one for all grid/block @@ -761,7 +761,7 @@ SmallVector worklist; if (failed(processParallelLoop(parallelOp, launchOp, cloningMap, worklist, launchBounds, rewriter))) - return matchFailure(); + return failure(); // Whether we have seen any side-effects. Reset when leaving an inner scope. bool seenSideeffects = false; @@ -778,13 +778,13 @@ // Before entering a nested scope, make sure there have been no // sideeffects until now. if (seenSideeffects) - return matchFailure(); + return failure(); // A nested loop.parallel needs insertion of code to compute indices. // Insert that now. This will also update the worklist with the loops // body. if (failed(processParallelLoop(nestedParallel, launchOp, cloningMap, worklist, launchBounds, rewriter))) - return matchFailure(); + return failure(); } else if (op == launchOp.getOperation()) { // Found our sentinel value. We have finished the operations from one // nesting level, pop one level back up. @@ -802,7 +802,7 @@ clone->getNumRegions() != 0; // If we are no longer in the innermost scope, sideeffects are disallowed. if (seenSideeffects && leftNestingScope) - return matchFailure(); + return failure(); } } @@ -812,7 +812,7 @@ launchOp.setOperand(std::get<0>(bound), std::get<1>(bound)); rewriter.eraseOp(parallelOp); - return matchSuccess(); + return success(); } void mlir::populateParallelLoopToGPUPatterns(OwningRewritePatternList &patterns, diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -946,7 +946,7 @@ bool emitCWrappers) : FuncOpConversionBase(dialect, converter), emitWrappers(emitCWrappers) {} - PatternMatchResult + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto funcOp = cast(op); @@ -962,7 +962,7 @@ } rewriter.eraseOp(op); - return matchSuccess(); + return success(); } private: @@ -976,7 +976,7 @@ struct BarePtrFuncOpConversion : public FuncOpConversionBase { using FuncOpConversionBase::FuncOpConversionBase; - PatternMatchResult + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto funcOp = cast(op); @@ -990,7 +990,7 @@ auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter); if (newFuncOp.getBody().empty()) { rewriter.eraseOp(op); - return matchSuccess(); + return success(); } // Promote bare pointers from MemRef arguments to a MemRef descriptor struct @@ -1017,7 +1017,7 @@ } rewriter.eraseOp(op); - return matchSuccess(); + return success(); } }; @@ -1109,7 +1109,7 @@ // Convert the type of the result to an LLVM type, pass operands as is, // preserve attributes. - PatternMatchResult + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { unsigned numResults = op->getNumResults(); @@ -1119,7 +1119,7 @@ packedType = this->typeConverter.packFunctionResults(op->getResultTypes()); if (!packedType) - return this->matchFailure(); + return failure(); } auto newOp = rewriter.create(op->getLoc(), packedType, operands, @@ -1127,10 +1127,10 @@ // If the operation produced 0 or 1 result, return them immediately. if (numResults == 0) - return rewriter.eraseOp(op), this->matchSuccess(); + return rewriter.eraseOp(op), success(); if (numResults == 1) return rewriter.replaceOp(op, newOp.getOperation()->getResult(0)), - this->matchSuccess(); + success(); // Otherwise, it had been converted to an operation producing a structure. // Extract individual results from the structure and return them as list. @@ -1143,7 +1143,7 @@ rewriter.getI64ArrayAttr(i))); } rewriter.replaceOp(op, results); - return this->matchSuccess(); + return success(); } }; @@ -1207,7 +1207,7 @@ // Convert the type of the result to an LLVM type, pass operands as is, // preserve attributes. - PatternMatchResult + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { ValidateOpCount(); @@ -1221,7 +1221,7 @@ // Cannot convert ops if their operands are not of LLVM type. for (Value operand : operands) { if (!operand || !operand.getType().isa()) - return this->matchFailure(); + return failure(); } auto llvmArrayTy = operands[0].getType().cast(); @@ -1230,7 +1230,7 @@ auto newOp = rewriter.create( op->getLoc(), operands[0].getType(), operands, op->getAttrs()); rewriter.replaceOp(op, newOp.getResult()); - return this->matchSuccess(); + return success(); } if (succeeded(HandleMultidimensionalVectors( @@ -1240,8 +1240,8 @@ operands, op->getAttrs()); }, rewriter))) - return this->matchSuccess(); - return this->matchFailure(); + return success(); + return failure(); } }; @@ -1381,24 +1381,24 @@ : LLVMLegalizationPattern(dialect_, converter), useAlloca(useAlloca) {} - PatternMatchResult match(Operation *op) const override { + LogicalResult match(Operation *op) const override { MemRefType type = cast(op).getType(); if (isSupportedMemRefType(type)) - return matchSuccess(); + return success(); int64_t offset; SmallVector strides; auto successStrides = getStridesAndOffset(type, strides, offset); if (failed(successStrides)) - return matchFailure(); + return failure(); // Dynamic strides are ok if they can be deduced from dynamic sizes (which // is guaranteed when succeeded(successStrides)). Dynamic offset however can // never be alloc'ed. if (offset == MemRefType::getDynamicStrideOrOffset()) - return matchFailure(); + return failure(); - return matchSuccess(); + return success(); } void rewrite(Operation *op, ArrayRef operands, @@ -1574,7 +1574,7 @@ using Super = CallOpInterfaceLowering; using Base = LLVMLegalizationPattern; - PatternMatchResult + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { OperandAdaptor transformed(operands); @@ -1595,7 +1595,7 @@ if (numResults != 0) { if (!(packedResult = this->typeConverter.packFunctionResults(resultTypes))) - return this->matchFailure(); + return failure(); } auto promoted = this->typeConverter.promoteMemRefDescriptors( @@ -1606,7 +1606,7 @@ // If < 2 results, packing did not do anything and we can just return. if (numResults < 2) { rewriter.replaceOp(op, newOp.getResults()); - return this->matchSuccess(); + return success(); } // Otherwise, it had been converted to an operation producing a structure. @@ -1624,7 +1624,7 @@ } rewriter.replaceOp(op, results); - return this->matchSuccess(); + return success(); } }; @@ -1647,11 +1647,11 @@ : LLVMLegalizationPattern(dialect_, converter), useAlloca(useAlloca) {} - PatternMatchResult + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { if (useAlloca) - return rewriter.eraseOp(op), matchSuccess(); + return rewriter.eraseOp(op), success(); assert(operands.size() == 1 && "dealloc takes one operand"); OperandAdaptor transformed(operands); @@ -1673,7 +1673,7 @@ memref.allocatedPtr(rewriter, op->getLoc())); rewriter.replaceOpWithNewOp( op, ArrayRef(), rewriter.getSymbolRefAttr(freeFunc), casted); - return matchSuccess(); + return success(); } bool useAlloca; @@ -1683,7 +1683,7 @@ struct RsqrtOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; - PatternMatchResult + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { OperandAdaptor transformed(operands); @@ -1691,7 +1691,7 @@ transformed.operand().getType().dyn_cast(); if (!operandType) - return matchFailure(); + return failure(); auto loc = op->getLoc(); auto resultType = *op->result_type_begin(); @@ -1709,12 +1709,12 @@ } auto sqrt = rewriter.create(loc, transformed.operand()); rewriter.replaceOpWithNewOp(op, operandType, one, sqrt); - return this->matchSuccess(); + return success(); } auto vectorType = resultType.dyn_cast(); if (!vectorType) - return this->matchFailure(); + return failure(); if (succeeded(HandleMultidimensionalVectors( op, operands, typeConverter, @@ -1732,8 +1732,8 @@ sqrt); }, rewriter))) - return this->matchSuccess(); - return this->matchFailure(); + return success(); + return failure(); } }; @@ -1741,7 +1741,7 @@ struct TanhOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; - PatternMatchResult + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { @@ -1753,7 +1753,7 @@ transformed.operand().getType().dyn_cast(); if (!operandType) - return matchFailure(); + return failure(); std::string functionName; if (operandType.isFloatTy()) @@ -1761,7 +1761,7 @@ else if (operandType.isDoubleTy()) functionName = "tanh"; else - return matchFailure(); + return failure(); // Get a reference to the tanh function, inserting it if necessary. Operation *tanhFunc = @@ -1783,14 +1783,14 @@ rewriter.replaceOpWithNewOp( op, operandType, rewriter.getSymbolRefAttr(tanhLLVMFunc), transformed.operand()); - return matchSuccess(); + return success(); } }; struct MemRefCastOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; - PatternMatchResult match(Operation *op) const override { + LogicalResult match(Operation *op) const override { auto memRefCastOp = cast(op); Type srcType = memRefCastOp.getOperand().getType(); Type dstType = memRefCastOp.getType(); @@ -1801,8 +1801,8 @@ MemRefType targetType = memRefCastOp.getType().cast(); return (isSupportedMemRefType(targetType) && isSupportedMemRefType(sourceType)) - ? matchSuccess() - : matchFailure(); + ? success() + : failure(); } // At least one of the operands is unranked type @@ -1812,8 +1812,8 @@ // Unranked to unranked cast is disallowed return !(srcType.isa() && dstType.isa()) - ? matchSuccess() - : matchFailure(); + ? success() + : failure(); } void rewrite(Operation *op, ArrayRef operands, @@ -1886,17 +1886,17 @@ : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; - PatternMatchResult + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto castOp = cast(op); OperandAdaptor transformed(operands); if (transformed.in().getType() != typeConverter.convertType(castOp.getType())) { - return matchFailure(); + return failure(); } rewriter.replaceOp(op, transformed.in()); - return matchSuccess(); + return success(); } }; @@ -1905,7 +1905,7 @@ struct DimOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; - PatternMatchResult + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto dimOp = cast(op); @@ -1922,7 +1922,7 @@ // Use constant for static size. rewriter.replaceOp( op, createIndexConstant(rewriter, op->getLoc(), shape[index])); - return matchSuccess(); + return success(); } }; @@ -1934,10 +1934,9 @@ using LLVMLegalizationPattern::LLVMLegalizationPattern; using Base = LoadStoreOpLowering; - PatternMatchResult match(Operation *op) const override { + LogicalResult match(Operation *op) const override { MemRefType type = cast(op).getMemRefType(); - return isSupportedMemRefType(type) ? this->matchSuccess() - : this->matchFailure(); + return isSupportedMemRefType(type) ? success() : failure(); } // Given subscript indices and array sizes in row-major order, @@ -2010,7 +2009,7 @@ struct LoadOpLowering : public LoadStoreOpLowering { using Base::Base; - PatternMatchResult + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loadOp = cast(op); @@ -2020,7 +2019,7 @@ Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(), transformed.indices(), rewriter, getModule()); rewriter.replaceOpWithNewOp(op, dataPtr); - return matchSuccess(); + return success(); } }; @@ -2029,7 +2028,7 @@ struct StoreOpLowering : public LoadStoreOpLowering { using Base::Base; - PatternMatchResult + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto type = cast(op).getMemRefType(); @@ -2039,7 +2038,7 @@ transformed.indices(), rewriter, getModule()); rewriter.replaceOpWithNewOp(op, transformed.value(), dataPtr); - return matchSuccess(); + return success(); } }; @@ -2048,7 +2047,7 @@ struct PrefetchOpLowering : public LoadStoreOpLowering { using Base::Base; - PatternMatchResult + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto prefetchOp = cast(op); @@ -2072,7 +2071,7 @@ rewriter.replaceOpWithNewOp(op, dataPtr, isWrite, localityHint, isData); - return matchSuccess(); + return success(); } }; @@ -2083,7 +2082,7 @@ struct IndexCastOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; - PatternMatchResult + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { IndexCastOpOperandAdaptor transformed(operands); @@ -2104,7 +2103,7 @@ else rewriter.replaceOpWithNewOp(op, targetType, transformed.in()); - return matchSuccess(); + return success(); } }; @@ -2118,7 +2117,7 @@ struct CmpIOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; - PatternMatchResult + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto cmpiOp = cast(op); @@ -2130,14 +2129,14 @@ convertCmpPredicate(cmpiOp.getPredicate()))), transformed.lhs(), transformed.rhs()); - return matchSuccess(); + return success(); } }; struct CmpFOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; - PatternMatchResult + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto cmpfOp = cast(op); @@ -2149,7 +2148,7 @@ convertCmpPredicate(cmpfOp.getPredicate()))), transformed.lhs(), transformed.rhs()); - return matchSuccess(); + return success(); } }; @@ -2189,12 +2188,12 @@ using LLVMLegalizationPattern::LLVMLegalizationPattern; using Super = OneToOneLLVMTerminatorLowering; - PatternMatchResult + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, operands, op->getSuccessors(), op->getAttrs()); - return this->matchSuccess(); + return success(); } }; @@ -2207,7 +2206,7 @@ struct ReturnOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; - PatternMatchResult + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { unsigned numArguments = op->getNumOperands(); @@ -2216,12 +2215,12 @@ if (numArguments == 0) { rewriter.replaceOpWithNewOp( op, ArrayRef(), ArrayRef(), op->getAttrs()); - return matchSuccess(); + return success(); } if (numArguments == 1) { rewriter.replaceOpWithNewOp( op, ArrayRef(), operands.front(), op->getAttrs()); - return matchSuccess(); + return success(); } // Otherwise, we need to pack the arguments into an LLVM struct type before @@ -2237,7 +2236,7 @@ } rewriter.replaceOpWithNewOp(op, ArrayRef(), packed, op->getAttrs()); - return matchSuccess(); + return success(); } }; @@ -2256,13 +2255,13 @@ struct SplatOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; - PatternMatchResult + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto splatOp = cast(op); VectorType resultType = splatOp.getType().dyn_cast(); if (!resultType || resultType.getRank() != 1) - return matchFailure(); + return failure(); // First insert it into an undef vector so we can shuffle it. auto vectorType = typeConverter.convertType(splatOp.getType()); @@ -2280,7 +2279,7 @@ // Shuffle the value across the desired number of elements. ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues); rewriter.replaceOpWithNewOp(op, v, undef, zeroAttrs); - return matchSuccess(); + return success(); } }; @@ -2290,14 +2289,14 @@ struct SplatNdOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; - PatternMatchResult + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto splatOp = cast(op); OperandAdaptor adaptor(operands); VectorType resultType = splatOp.getType().dyn_cast(); if (!resultType || resultType.getRank() == 1) - return matchFailure(); + return failure(); // First insert it into an undef vector so we can shuffle it. auto loc = op->getLoc(); @@ -2305,7 +2304,7 @@ auto llvmArrayTy = vectorTypeInfo.llvmArrayTy; auto llvmVectorTy = vectorTypeInfo.llvmVectorTy; if (!llvmArrayTy || !llvmVectorTy) - return matchFailure(); + return failure(); // Construct returned value. Value desc = rewriter.create(loc, llvmArrayTy); @@ -2332,7 +2331,7 @@ position); }); rewriter.replaceOp(op, desc); - return matchSuccess(); + return success(); } }; @@ -2344,7 +2343,7 @@ struct SubViewOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; - PatternMatchResult + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); @@ -2376,7 +2375,7 @@ auto targetDescTy = typeConverter.convertType(viewMemRefType) .dyn_cast_or_null(); if (!sourceElementTy || !targetDescTy) - return matchFailure(); + return failure(); // Currently, only rank > 0 and full or no operands are supported. Fail to // convert otherwise. @@ -2385,22 +2384,22 @@ (!dynamicOffsets.empty() && rank != dynamicOffsets.size()) || (!dynamicSizes.empty() && rank != dynamicSizes.size()) || (!dynamicStrides.empty() && rank != dynamicStrides.size())) - return matchFailure(); + return failure(); int64_t offset; SmallVector strides; auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset); if (failed(successStrides)) - return matchFailure(); + return failure(); // Fail to convert if neither a dynamic nor static offset is available. if (dynamicOffsets.empty() && offset == MemRefType::getDynamicStrideOrOffset()) - return matchFailure(); + return failure(); // Create the descriptor. if (!operands.front().getType().isa()) - return matchFailure(); + return failure(); MemRefDescriptor sourceMemRef(operands.front()); auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); @@ -2460,7 +2459,7 @@ } rewriter.replaceOp(op, {targetMemRef}); - return matchSuccess(); + return success(); } }; @@ -2505,7 +2504,7 @@ return createIndexConstant(rewriter, loc, 1); } - PatternMatchResult + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); @@ -2520,14 +2519,13 @@ typeConverter.convertType(viewMemRefType).dyn_cast(); if (!targetDescTy) return op->emitWarning("Target descriptor type not converted to LLVM"), - matchFailure(); + failure(); int64_t offset; SmallVector strides; auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset); if (failed(successStrides)) - return op->emitWarning("cannot cast to non-strided shape"), - matchFailure(); + return op->emitWarning("cannot cast to non-strided shape"), failure(); // Create the descriptor. MemRefDescriptor sourceMemRef(adaptor.source()); @@ -2560,12 +2558,11 @@ // Early exit for 0-D corner case. if (viewMemRefType.getRank() == 0) - return rewriter.replaceOp(op, {targetMemRef}), matchSuccess(); + return rewriter.replaceOp(op, {targetMemRef}), success(); // Fields 4 and 5: Update sizes and strides. if (strides.back() != 1) - return op->emitWarning("cannot cast to non-contiguous shape"), - matchFailure(); + return op->emitWarning("cannot cast to non-contiguous shape"), failure(); Value stride = nullptr, nextSize = nullptr; // Drop the dynamic stride from the operand list, if present. ArrayRef sizeOperands(sizeAndOffsetOperands); @@ -2583,7 +2580,7 @@ } rewriter.replaceOp(op, {targetMemRef}); - return matchSuccess(); + return success(); } }; @@ -2591,7 +2588,7 @@ : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; - PatternMatchResult + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { OperandAdaptor transformed(operands); @@ -2622,7 +2619,7 @@ rewriter.create(op->getLoc(), ptrValue, mask), zero)); rewriter.eraseOp(op); - return matchSuccess(); + return success(); } }; @@ -2657,13 +2654,13 @@ struct AtomicRMWOpLowering : public LoadStoreOpLowering { using Base::Base; - PatternMatchResult + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto atomicOp = cast(op); auto maybeKind = matchSimpleAtomicOp(atomicOp); if (!maybeKind) - return matchFailure(); + return failure(); OperandAdaptor adaptor(operands); auto resultType = adaptor.value().getType(); auto memRefType = atomicOp.getMemRefType(); @@ -2672,7 +2669,7 @@ rewriter.replaceOpWithNewOp( op, resultType, *maybeKind, dataPtr, adaptor.value(), LLVM::AtomicOrdering::acq_rel); - return matchSuccess(); + return success(); } }; @@ -2706,13 +2703,13 @@ struct AtomicCmpXchgOpLowering : public LoadStoreOpLowering { using Base::Base; - PatternMatchResult + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto atomicOp = cast(op); auto maybeKind = matchSimpleAtomicOp(atomicOp); if (maybeKind) - return matchFailure(); + return failure(); LLVM::FCmpPredicate predicate; switch (atomicOp.kind()) { @@ -2723,7 +2720,7 @@ predicate = LLVM::FCmpPredicate::olt; break; default: - return matchFailure(); + return failure(); } OperandAdaptor adaptor(operands); @@ -2779,7 +2776,7 @@ // The 'result' of the atomic_rmw op is the newly loaded value. rewriter.replaceOp(op, {newLoaded}); - return matchSuccess(); + return success(); } }; diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp @@ -31,7 +31,7 @@ public: using SPIRVOpLowering::SPIRVOpLowering; - PatternMatchResult + LogicalResult matchAndRewrite(ConstantOp constCompositeOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; @@ -45,7 +45,7 @@ public: using SPIRVOpLowering::SPIRVOpLowering; - PatternMatchResult + LogicalResult matchAndRewrite(ConstantOp constIndexOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; @@ -55,7 +55,7 @@ public: using SPIRVOpLowering::SPIRVOpLowering; - PatternMatchResult + LogicalResult matchAndRewrite(CmpFOp cmpFOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; @@ -65,7 +65,7 @@ public: using SPIRVOpLowering::SPIRVOpLowering; - PatternMatchResult + LogicalResult matchAndRewrite(CmpIOp cmpIOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; @@ -81,14 +81,14 @@ public: using SPIRVOpLowering::SPIRVOpLowering; - PatternMatchResult + LogicalResult matchAndRewrite(StdOp operation, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto resultType = this->typeConverter.convertType(operation.getResult().getType()); rewriter.template replaceOpWithNewOp( operation, resultType, operands, ArrayRef()); - return this->matchSuccess(); + return success(); } }; @@ -100,7 +100,7 @@ public: using SPIRVOpLowering::SPIRVOpLowering; - PatternMatchResult + LogicalResult matchAndRewrite(LoadOp loadOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; @@ -111,7 +111,7 @@ public: using SPIRVOpLowering::SPIRVOpLowering; - PatternMatchResult + LogicalResult matchAndRewrite(ReturnOp returnOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; @@ -121,7 +121,7 @@ class SelectOpConversion final : public SPIRVOpLowering { public: using SPIRVOpLowering::SPIRVOpLowering; - PatternMatchResult + LogicalResult matchAndRewrite(SelectOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; @@ -134,7 +134,7 @@ public: using SPIRVOpLowering::SPIRVOpLowering; - PatternMatchResult + LogicalResult matchAndRewrite(StoreOp storeOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; @@ -145,22 +145,22 @@ // ConstantOp with composite type. //===----------------------------------------------------------------------===// -PatternMatchResult ConstantCompositeOpConversion::matchAndRewrite( +LogicalResult ConstantCompositeOpConversion::matchAndRewrite( ConstantOp constCompositeOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { auto compositeType = constCompositeOp.getResult().getType().dyn_cast(); if (!compositeType) - return matchFailure(); + return failure(); auto spirvCompositeType = typeConverter.convertType(compositeType); if (!spirvCompositeType) - return matchFailure(); + return failure(); auto linearizedElements = constCompositeOp.value().dyn_cast(); if (!linearizedElements) - return matchFailure(); + return failure(); // If composite type has rank greater than one, then perform linearization. if (compositeType.getRank() > 1) { @@ -171,24 +171,24 @@ rewriter.replaceOpWithNewOp( constCompositeOp, spirvCompositeType, linearizedElements); - return matchSuccess(); + return success(); } //===----------------------------------------------------------------------===// // ConstantOp with index type. //===----------------------------------------------------------------------===// -PatternMatchResult ConstantIndexOpConversion::matchAndRewrite( +LogicalResult ConstantIndexOpConversion::matchAndRewrite( ConstantOp constIndexOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { if (!constIndexOp.getResult().getType().isa()) { - return matchFailure(); + return failure(); } // The attribute has index type which is not directly supported in // SPIR-V. Get the integer value and create a new IntegerAttr. auto constAttr = constIndexOp.value().dyn_cast(); if (!constAttr) { - return matchFailure(); + return failure(); } // Use the bitwidth set in the value attribute to decide the result type @@ -197,7 +197,7 @@ auto constVal = constAttr.getValue(); auto constValType = constAttr.getType().dyn_cast(); if (!constValType) { - return matchFailure(); + return failure(); } auto spirvConstType = typeConverter.convertType(constIndexOp.getResult().getType()); @@ -205,14 +205,14 @@ rewriter.getIntegerAttr(spirvConstType, constAttr.getInt()); rewriter.replaceOpWithNewOp(constIndexOp, spirvConstType, spirvConstVal); - return matchSuccess(); + return success(); } //===----------------------------------------------------------------------===// // CmpFOp //===----------------------------------------------------------------------===// -PatternMatchResult +LogicalResult CmpFOpConversion::matchAndRewrite(CmpFOp cmpFOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { CmpFOpOperandAdaptor cmpFOpOperands(operands); @@ -223,7 +223,7 @@ rewriter.replaceOpWithNewOp(cmpFOp, cmpFOp.getResult().getType(), \ cmpFOpOperands.lhs(), \ cmpFOpOperands.rhs()); \ - return matchSuccess(); + return success(); // Ordered. DISPATCH(CmpFPredicate::OEQ, spirv::FOrdEqualOp); @@ -245,14 +245,14 @@ default: break; } - return matchFailure(); + return failure(); } //===----------------------------------------------------------------------===// // CmpIOp //===----------------------------------------------------------------------===// -PatternMatchResult +LogicalResult CmpIOpConversion::matchAndRewrite(CmpIOp cmpIOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { CmpIOpOperandAdaptor cmpIOpOperands(operands); @@ -263,7 +263,7 @@ rewriter.replaceOpWithNewOp(cmpIOp, cmpIOp.getResult().getType(), \ cmpIOpOperands.lhs(), \ cmpIOpOperands.rhs()); \ - return matchSuccess(); + return success(); DISPATCH(CmpIPredicate::eq, spirv::IEqualOp); DISPATCH(CmpIPredicate::ne, spirv::INotEqualOp); @@ -278,14 +278,14 @@ #undef DISPATCH } - return matchFailure(); + return failure(); } //===----------------------------------------------------------------------===// // LoadOp //===----------------------------------------------------------------------===// -PatternMatchResult +LogicalResult LoadOpConversion::matchAndRewrite(LoadOp loadOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { LoadOpOperandAdaptor loadOperands(operands); @@ -293,42 +293,42 @@ typeConverter, loadOp.memref().getType().cast(), loadOperands.memref(), loadOperands.indices(), loadOp.getLoc(), rewriter); rewriter.replaceOpWithNewOp(loadOp, loadPtr); - return matchSuccess(); + return success(); } //===----------------------------------------------------------------------===// // ReturnOp //===----------------------------------------------------------------------===// -PatternMatchResult +LogicalResult ReturnOpConversion::matchAndRewrite(ReturnOp returnOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { if (returnOp.getNumOperands()) { - return matchFailure(); + return failure(); } rewriter.replaceOpWithNewOp(returnOp); - return matchSuccess(); + return success(); } //===----------------------------------------------------------------------===// // SelectOp //===----------------------------------------------------------------------===// -PatternMatchResult +LogicalResult SelectOpConversion::matchAndRewrite(SelectOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { SelectOpOperandAdaptor selectOperands(operands); rewriter.replaceOpWithNewOp(op, selectOperands.condition(), selectOperands.true_value(), selectOperands.false_value()); - return matchSuccess(); + return success(); } //===----------------------------------------------------------------------===// // StoreOp //===----------------------------------------------------------------------===// -PatternMatchResult +LogicalResult StoreOpConversion::matchAndRewrite(StoreOp storeOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { StoreOpOperandAdaptor storeOperands(operands); @@ -338,7 +338,7 @@ rewriter); rewriter.replaceOpWithNewOp(storeOp, storePtr, storeOperands.value()); - return matchSuccess(); + return success(); } namespace { diff --git a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp @@ -26,8 +26,8 @@ public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(LoadOp loadOp, - PatternRewriter &rewriter) const override; + LogicalResult matchAndRewrite(LoadOp loadOp, + PatternRewriter &rewriter) const override; }; /// Merges subview operation with store operation. @@ -35,8 +35,8 @@ public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(StoreOp storeOp, - PatternRewriter &rewriter) const override; + LogicalResult matchAndRewrite(StoreOp storeOp, + PatternRewriter &rewriter) const override; }; } // namespace @@ -107,43 +107,43 @@ // Folding SubViewOp and LoadOp. //===----------------------------------------------------------------------===// -PatternMatchResult +LogicalResult LoadOpOfSubViewFolder::matchAndRewrite(LoadOp loadOp, PatternRewriter &rewriter) const { auto subViewOp = dyn_cast_or_null(loadOp.memref().getDefiningOp()); if (!subViewOp) { - return matchFailure(); + return failure(); } SmallVector sourceIndices; if (failed(resolveSourceIndices(loadOp.getLoc(), rewriter, subViewOp, loadOp.indices(), sourceIndices))) - return matchFailure(); + return failure(); rewriter.replaceOpWithNewOp(loadOp, subViewOp.source(), sourceIndices); - return matchSuccess(); + return success(); } //===----------------------------------------------------------------------===// // Folding SubViewOp and StoreOp. //===----------------------------------------------------------------------===// -PatternMatchResult +LogicalResult StoreOpOfSubViewFolder::matchAndRewrite(StoreOp storeOp, PatternRewriter &rewriter) const { auto subViewOp = dyn_cast_or_null(storeOp.memref().getDefiningOp()); if (!subViewOp) { - return matchFailure(); + return failure(); } SmallVector sourceIndices; if (failed(resolveSourceIndices(storeOp.getLoc(), rewriter, subViewOp, storeOp.indices(), sourceIndices))) - return matchFailure(); + return failure(); rewriter.replaceOpWithNewOp(storeOp, storeOp.value(), subViewOp.source(), sourceIndices); - return matchSuccess(); + return success(); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -133,13 +133,13 @@ : ConvertToLLVMPattern(vector::BroadcastOp::getOperationName(), context, typeConverter) {} - PatternMatchResult + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto broadcastOp = cast(op); VectorType dstVectorType = broadcastOp.getVectorType(); if (typeConverter.convertType(dstVectorType) == nullptr) - return matchFailure(); + return failure(); // Rewrite when the full vector type can be lowered (which // implies all 'reduced' types can be lowered too). auto adaptor = vector::BroadcastOpOperandAdaptor(operands); @@ -149,7 +149,7 @@ op, expandRanks(adaptor.source(), // source value to be expanded op->getLoc(), // location of original broadcast srcVectorType, dstVectorType, rewriter)); - return matchSuccess(); + return success(); } private: @@ -284,7 +284,7 @@ : ConvertToLLVMPattern(vector::MatmulOp::getOperationName(), context, typeConverter) {} - PatternMatchResult + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto matmulOp = cast(op); @@ -293,7 +293,7 @@ op, typeConverter.convertType(matmulOp.res().getType()), adaptor.lhs(), adaptor.rhs(), matmulOp.lhs_rows(), matmulOp.lhs_columns(), matmulOp.rhs_columns()); - return matchSuccess(); + return success(); } }; @@ -304,7 +304,7 @@ : ConvertToLLVMPattern(vector::ReductionOp::getOperationName(), context, typeConverter) {} - PatternMatchResult + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto reductionOp = cast(op); @@ -335,8 +335,8 @@ rewriter.replaceOpWithNewOp( op, llvmType, operands[0]); else - return matchFailure(); - return matchSuccess(); + return failure(); + return success(); } else if (eltType.isF32() || eltType.isF64()) { // Floating-point reductions: add/mul/min/max @@ -364,10 +364,10 @@ rewriter.replaceOpWithNewOp( op, llvmType, operands[0]); else - return matchFailure(); - return matchSuccess(); + return failure(); + return success(); } - return matchFailure(); + return failure(); } }; @@ -378,7 +378,7 @@ : ConvertToLLVMPattern(vector::ShuffleOp::getOperationName(), context, typeConverter) {} - PatternMatchResult + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); @@ -392,7 +392,7 @@ // Bail if result type cannot be lowered. if (!llvmType) - return matchFailure(); + return failure(); // Get rank and dimension sizes. int64_t rank = vectorType.getRank(); @@ -406,7 +406,7 @@ Value shuffle = rewriter.create( loc, adaptor.v1(), adaptor.v2(), maskArrayAttr); rewriter.replaceOp(op, shuffle); - return matchSuccess(); + return success(); } // For all other cases, insert the individual values individually. @@ -425,7 +425,7 @@ llvmType, rank, insPos++); } rewriter.replaceOp(op, insert); - return matchSuccess(); + return success(); } }; @@ -436,7 +436,7 @@ : ConvertToLLVMPattern(vector::ExtractElementOp::getOperationName(), context, typeConverter) {} - PatternMatchResult + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto adaptor = vector::ExtractElementOpOperandAdaptor(operands); @@ -446,11 +446,11 @@ // Bail if result type cannot be lowered. if (!llvmType) - return matchFailure(); + return failure(); rewriter.replaceOpWithNewOp( op, llvmType, adaptor.vector(), adaptor.position()); - return matchSuccess(); + return success(); } }; @@ -461,7 +461,7 @@ : ConvertToLLVMPattern(vector::ExtractOp::getOperationName(), context, typeConverter) {} - PatternMatchResult + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); @@ -474,14 +474,14 @@ // Bail if result type cannot be lowered. if (!llvmResultType) - return matchFailure(); + return failure(); // One-shot extraction of vector from array (only requires extractvalue). if (resultType.isa()) { Value extracted = rewriter.create( loc, llvmResultType, adaptor.vector(), positionArrayAttr); rewriter.replaceOp(op, extracted); - return matchSuccess(); + return success(); } // Potential extraction of 1-D vector from array. @@ -505,7 +505,7 @@ rewriter.create(loc, extracted, constant); rewriter.replaceOp(op, extracted); - return matchSuccess(); + return success(); } }; @@ -530,17 +530,17 @@ : ConvertToLLVMPattern(vector::FMAOp::getOperationName(), context, typeConverter) {} - PatternMatchResult + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto adaptor = vector::FMAOpOperandAdaptor(operands); vector::FMAOp fmaOp = cast(op); VectorType vType = fmaOp.getVectorType(); if (vType.getRank() != 1) - return matchFailure(); + return failure(); rewriter.replaceOpWithNewOp(op, adaptor.lhs(), adaptor.rhs(), adaptor.acc()); - return matchSuccess(); + return success(); } }; @@ -551,7 +551,7 @@ : ConvertToLLVMPattern(vector::InsertElementOp::getOperationName(), context, typeConverter) {} - PatternMatchResult + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto adaptor = vector::InsertElementOpOperandAdaptor(operands); @@ -561,11 +561,11 @@ // Bail if result type cannot be lowered. if (!llvmType) - return matchFailure(); + return failure(); rewriter.replaceOpWithNewOp( op, llvmType, adaptor.dest(), adaptor.source(), adaptor.position()); - return matchSuccess(); + return success(); } }; @@ -576,7 +576,7 @@ : ConvertToLLVMPattern(vector::InsertOp::getOperationName(), context, typeConverter) {} - PatternMatchResult + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); @@ -589,7 +589,7 @@ // Bail if result type cannot be lowered. if (!llvmResultType) - return matchFailure(); + return failure(); // One-shot insertion of a vector into an array (only requires insertvalue). if (sourceType.isa()) { @@ -597,7 +597,7 @@ loc, llvmResultType, adaptor.dest(), adaptor.source(), positionArrayAttr); rewriter.replaceOp(op, inserted); - return matchSuccess(); + return success(); } // Potential extraction of 1-D vector from array. @@ -632,7 +632,7 @@ } rewriter.replaceOp(op, inserted); - return matchSuccess(); + return success(); } }; @@ -661,11 +661,11 @@ public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(FMAOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(FMAOp op, + PatternRewriter &rewriter) const override { auto vType = op.getVectorType(); if (vType.getRank() < 2) - return matchFailure(); + return failure(); auto loc = op.getLoc(); auto elemType = vType.getElementType(); @@ -680,7 +680,7 @@ desc = rewriter.create(loc, fma, desc, i); } rewriter.replaceOp(op, desc); - return matchSuccess(); + return success(); } }; @@ -704,19 +704,19 @@ public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(InsertStridedSliceOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(InsertStridedSliceOp op, + PatternRewriter &rewriter) const override { auto srcType = op.getSourceVectorType(); auto dstType = op.getDestVectorType(); if (op.offsets().getValue().empty()) - return matchFailure(); + return failure(); auto loc = op.getLoc(); int64_t rankDiff = dstType.getRank() - srcType.getRank(); assert(rankDiff >= 0); if (rankDiff == 0) - return matchFailure(); + return failure(); int64_t rankRest = dstType.getRank() - rankDiff; // Extract / insert the subvector of matching rank and InsertStridedSlice @@ -735,7 +735,7 @@ op, stridedSliceInnerOp.getResult(), op.dest(), getI64SubArray(op.offsets(), /*dropFront=*/0, /*dropFront=*/rankRest)); - return matchSuccess(); + return success(); } }; @@ -753,22 +753,22 @@ public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(InsertStridedSliceOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(InsertStridedSliceOp op, + PatternRewriter &rewriter) const override { auto srcType = op.getSourceVectorType(); auto dstType = op.getDestVectorType(); if (op.offsets().getValue().empty()) - return matchFailure(); + return failure(); int64_t rankDiff = dstType.getRank() - srcType.getRank(); assert(rankDiff >= 0); if (rankDiff != 0) - return matchFailure(); + return failure(); if (srcType == dstType) { rewriter.replaceOp(op, op.source()); - return matchSuccess(); + return success(); } int64_t offset = @@ -813,7 +813,7 @@ } rewriter.replaceOp(op, res); - return matchSuccess(); + return success(); } }; @@ -824,7 +824,7 @@ : ConvertToLLVMPattern(vector::TypeCastOp::getOperationName(), context, typeConverter) {} - PatternMatchResult + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); @@ -837,18 +837,18 @@ // Only static shape casts supported atm. if (!sourceMemRefType.hasStaticShape() || !targetMemRefType.hasStaticShape()) - return matchFailure(); + return failure(); auto llvmSourceDescriptorTy = operands[0].getType().dyn_cast(); if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy()) - return matchFailure(); + return failure(); MemRefDescriptor sourceMemRef(operands[0]); auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType) .dyn_cast_or_null(); if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy()) - return matchFailure(); + return failure(); int64_t offset; SmallVector strides; @@ -866,7 +866,7 @@ } // Only contiguous source tensors supported atm. if (failed(successStrides) || !isContiguous) - return matchFailure(); + return failure(); auto int64Ty = LLVM::LLVMType::getInt64Ty(typeConverter.getDialect()); @@ -901,7 +901,7 @@ } rewriter.replaceOp(op, {desc}); - return matchSuccess(); + return success(); } }; @@ -924,7 +924,7 @@ // // TODO(ajcbik): rely solely on libc in future? something else? // - PatternMatchResult + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto printOp = cast(op); @@ -932,7 +932,7 @@ Type printType = printOp.getPrintType(); if (typeConverter.convertType(printType) == nullptr) - return matchFailure(); + return failure(); // Make sure element type has runtime support (currently just Float/Double). VectorType vectorType = printType.dyn_cast(); @@ -948,13 +948,13 @@ else if (eltType.isF64()) printer = getPrintDouble(op); else - return matchFailure(); + return failure(); // Unroll vector into elementary print calls. emitRanks(rewriter, op, adaptor.source(), vectorType, printer, rank); emitCall(rewriter, op->getLoc(), getPrintNewline(op)); rewriter.eraseOp(op); - return matchSuccess(); + return success(); } private: @@ -1047,8 +1047,8 @@ public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(StridedSliceOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(StridedSliceOp op, + PatternRewriter &rewriter) const override { auto dstType = op.getResult().getType().cast(); assert(!op.offsets().getValue().empty() && "Unexpected empty offsets"); @@ -1089,7 +1089,7 @@ res = insertOne(rewriter, loc, extracted, res, idx); } rewriter.replaceOp(op, {res}); - return matchSuccess(); + return success(); } }; diff --git a/mlir/lib/Conversion/VectorToLoops/ConvertVectorToLoops.cpp b/mlir/lib/Conversion/VectorToLoops/ConvertVectorToLoops.cpp --- a/mlir/lib/Conversion/VectorToLoops/ConvertVectorToLoops.cpp +++ b/mlir/lib/Conversion/VectorToLoops/ConvertVectorToLoops.cpp @@ -198,8 +198,8 @@ } /// Performs the rewrite. - PatternMatchResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override; + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override; }; /// Lowers TransferReadOp into a combination of: @@ -246,7 +246,7 @@ /// Performs the rewrite. template <> -PatternMatchResult VectorTransferRewriter::matchAndRewrite( +LogicalResult VectorTransferRewriter::matchAndRewrite( Operation *op, PatternRewriter &rewriter) const { using namespace mlir::edsc::op; @@ -282,7 +282,7 @@ // 3. Propagate. rewriter.replaceOp(op, vectorValue.getValue()); - return matchSuccess(); + return success(); } /// Lowers TransferWriteOp into a combination of: @@ -304,7 +304,7 @@ /// TODO(ntv): implement alternatives to clipping. /// TODO(ntv): support non-data-parallel operations. template <> -PatternMatchResult VectorTransferRewriter::matchAndRewrite( +LogicalResult VectorTransferRewriter::matchAndRewrite( Operation *op, PatternRewriter &rewriter) const { using namespace edsc::op; @@ -340,7 +340,7 @@ (std_dealloc(tmp)); // vexing parse... rewriter.eraseOp(op); - return matchSuccess(); + return success(); } } // namespace diff --git a/mlir/lib/Dialect/AffineOps/AffineOps.cpp b/mlir/lib/Dialect/AffineOps/AffineOps.cpp --- a/mlir/lib/Dialect/AffineOps/AffineOps.cpp +++ b/mlir/lib/Dialect/AffineOps/AffineOps.cpp @@ -727,8 +727,8 @@ void replaceAffineOp(PatternRewriter &rewriter, AffineOpTy affineOp, AffineMap map, ArrayRef mapOperands) const; - PatternMatchResult matchAndRewrite(AffineOpTy affineOp, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(AffineOpTy affineOp, + PatternRewriter &rewriter) const override { static_assert(std::is_same::value || std::is_same::value || std::is_same::value || @@ -743,10 +743,10 @@ composeAffineMapAndOperands(&map, &resultOperands); if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(), resultOperands.begin())) - return this->matchFailure(); + return failure(); replaceAffineOp(rewriter, affineOp, map, resultOperands); - return this->matchSuccess(); + return success(); } }; @@ -1405,13 +1405,13 @@ struct AffineForEmptyLoopFolder : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(AffineForOp forOp, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(AffineForOp forOp, + PatternRewriter &rewriter) const override { // Check that the body only contains a terminator. if (!has_single_element(*forOp.getBody())) - return matchFailure(); + return failure(); rewriter.eraseOp(forOp); - return matchSuccess(); + return success(); } }; } // end anonymous namespace diff --git a/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp b/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp --- a/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp +++ b/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp @@ -111,8 +111,8 @@ struct UniformDequantizePattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(DequantizeCastOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(DequantizeCastOp op, + PatternRewriter &rewriter) const override { Type inputType = op.arg().getType(); Type outputType = op.getResult().getType(); @@ -121,16 +121,16 @@ Type expressedOutputType = inputElementType.castToExpressedType(inputType); if (expressedOutputType != outputType) { // Not a valid uniform cast. - return matchFailure(); + return failure(); } Value dequantizedValue = emitDequantize(op.getLoc(), op.arg(), rewriter); if (!dequantizedValue) { - return matchFailure(); + return failure(); } rewriter.replaceOp(op, dequantizedValue); - return matchSuccess(); + return success(); } }; @@ -313,40 +313,40 @@ struct UniformRealAddEwPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(RealAddEwOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(RealAddEwOp op, + PatternRewriter &rewriter) const override { const UniformBinaryOpInfo info(op, op.lhs(), op.rhs(), op.clamp_min(), op.clamp_max()); if (!info.isValid()) { - return matchFailure(); + return failure(); } // Try all of the permutations we support. if (succeeded(tryRewriteAffineAddEwIsomorphicSigned(info, rewriter))) { - return matchSuccess(); + return success(); } - return matchFailure(); + return failure(); } }; struct UniformRealMulEwPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(RealMulEwOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(RealMulEwOp op, + PatternRewriter &rewriter) const override { const UniformBinaryOpInfo info(op, op.lhs(), op.rhs(), op.clamp_min(), op.clamp_max()); if (!info.isValid()) { - return matchFailure(); + return failure(); } // Try all of the permutations we support. if (succeeded(tryRewriteAffineMulEwSigned(info, rewriter))) { - return matchSuccess(); + return success(); } - return matchFailure(); + return failure(); } }; diff --git a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp --- a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp @@ -380,8 +380,8 @@ explicit GpuAllReduceConversion(MLIRContext *context) : RewritePattern(gpu::GPUFuncOp::getOperationName(), 1, context) {} - PatternMatchResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { auto funcOp = cast(op); auto callback = [&](gpu::AllReduceOp reduceOp) { GpuAllReduceRewriter(funcOp, reduceOp, rewriter).rewrite(); @@ -391,7 +391,7 @@ }; while (funcOp.walk(callback).wasInterrupted()) { } - return matchSuccess(); + return success(); } }; } // namespace diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -534,10 +534,10 @@ struct FuseGenericTensorOps : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(GenericOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(GenericOp op, + PatternRewriter &rewriter) const override { if (!op.hasTensorSemantics()) - return matchFailure(); + return failure(); // Find the first operand that is defined by another generic op on tensors. for (auto operand : llvm::enumerate(op.getOperation()->getOperands())) { @@ -551,9 +551,9 @@ if (!fusedOp) continue; rewriter.replaceOp(op, fusedOp.getValue().getOperation()->getResults()); - return matchSuccess(); + return success(); } - return matchFailure(); + return failure(); } }; diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp @@ -531,13 +531,13 @@ explicit LinalgRewritePattern(MLIRContext *context) : RewritePattern(ConcreteOp::getOperationName(), 1, context) {} - PatternMatchResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { using Impl = LinalgOpToLoopsImpl; if (failed(Impl::doit(op, rewriter))) - return matchFailure(); + return failure(); rewriter.eraseOp(op); - return matchSuccess(); + return success(); } }; @@ -595,26 +595,26 @@ FoldAffineOp(MLIRContext *context) : RewritePattern(AffineApplyOp::getOperationName(), 0, context) {} - PatternMatchResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { AffineApplyOp affineApplyOp = cast(op); auto map = affineApplyOp.getAffineMap(); if (map.getNumResults() != 1 || map.getNumInputs() > 1) - return matchFailure(); + return failure(); AffineExpr expr = map.getResult(0); if (map.getNumInputs() == 0) { if (auto val = expr.dyn_cast()) { rewriter.replaceOpWithNewOp(op, val.getValue()); - return matchSuccess(); + return success(); } - return matchFailure(); + return failure(); } if (expr.dyn_cast() || expr.dyn_cast()) { rewriter.replaceOp(op, op->getOperand(0)); - return matchSuccess(); + return success(); } - return matchFailure(); + return failure(); } }; } // namespace diff --git a/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp b/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp --- a/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp +++ b/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp @@ -30,8 +30,8 @@ struct QuantizedConstRewrite : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(QuantizeCastOp qbarrier, - PatternRewriter &rewriter) const override; + LogicalResult matchAndRewrite(QuantizeCastOp qbarrier, + PatternRewriter &rewriter) const override; }; } // end anonymous namespace @@ -39,14 +39,14 @@ /// Matches a [constant] -> [qbarrier] where the qbarrier results type is /// quantized and the operand type is quantizable. -PatternMatchResult +LogicalResult QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier, PatternRewriter &rewriter) const { Attribute value; // Is the operand a constant? if (!matchPattern(qbarrier.arg(), m_Constant(&value))) { - return matchFailure(); + return failure(); } // Does the qbarrier convert to a quantized type. This will not be true @@ -56,10 +56,10 @@ QuantizedType quantizedElementType = QuantizedType::getQuantizedElementType(qbarrierResultType); if (!quantizedElementType) { - return matchFailure(); + return failure(); } if (!QuantizedType::castToStorageType(qbarrierResultType)) { - return matchFailure(); + return failure(); } // Is the operand type compatible with the expressed type of the quantized @@ -67,20 +67,20 @@ // from and to a quantized type). if (!quantizedElementType.isCompatibleExpressedType( qbarrier.arg().getType())) { - return matchFailure(); + return failure(); } // Is the constant value a type expressed in a way that we support? if (!value.isa() && !value.isa() && !value.isa()) { - return matchFailure(); + return failure(); } Type newConstValueType; auto newConstValue = quantizeAttr(value, quantizedElementType, newConstValueType); if (!newConstValue) { - return matchFailure(); + return failure(); } // When creating the new const op, use a fused location that combines the @@ -92,7 +92,7 @@ rewriter.create(fusedLoc, newConstValueType, newConstValue); rewriter.replaceOpWithNewOp(qbarrier, qbarrier.getType(), newConstOp); - return matchSuccess(); + return success(); } void ConvertConstPass::runOnFunction() { diff --git a/mlir/lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp b/mlir/lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp --- a/mlir/lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp +++ b/mlir/lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp @@ -35,16 +35,16 @@ FakeQuantRewrite(MLIRContext *ctx, bool *hadFailure) : OpRewritePattern(ctx), hadFailure(hadFailure) {} - PatternMatchResult matchAndRewrite(FakeQuantOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(FakeQuantOp op, + PatternRewriter &rewriter) const override { // TODO: If this pattern comes up more frequently, consider adding core // support for failable rewrites. if (failableRewrite(op, rewriter)) { *hadFailure = true; - return Pattern::matchFailure(); + return failure(); } - return Pattern::matchSuccess(); + return success(); } private: diff --git a/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp @@ -88,13 +88,13 @@ : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(spirv::AccessChainOp accessChainOp, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(spirv::AccessChainOp accessChainOp, + PatternRewriter &rewriter) const override { auto parentAccessChainOp = dyn_cast_or_null( accessChainOp.base_ptr().getDefiningOp()); if (!parentAccessChainOp) { - return matchFailure(); + return failure(); } // Combine indices. @@ -105,7 +105,7 @@ rewriter.replaceOpWithNewOp( accessChainOp, parentAccessChainOp.base_ptr(), indices); - return matchSuccess(); + return success(); } }; } // end anonymous namespace @@ -291,24 +291,24 @@ : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(spirv::SelectionOp selectionOp, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(spirv::SelectionOp selectionOp, + PatternRewriter &rewriter) const override { auto *op = selectionOp.getOperation(); auto &body = op->getRegion(0); // Verifier allows an empty region for `spv.selection`. if (body.empty()) { - return matchFailure(); + return failure(); } // Check that region consists of 4 blocks: // header block, `true` block, `false` block and merge block. if (std::distance(body.begin(), body.end()) != 4) { - return matchFailure(); + return failure(); } auto *headerBlock = selectionOp.getHeaderBlock(); if (!onlyContainsBranchConditionalOp(headerBlock)) { - return matchFailure(); + return failure(); } auto brConditionalOp = @@ -319,7 +319,7 @@ auto *mergeBlock = selectionOp.getMergeBlock(); if (failed(canCanonicalizeSelection(trueBlock, falseBlock, mergeBlock))) - return matchFailure(); + return failure(); auto trueValue = getSrcValue(trueBlock); auto falseValue = getSrcValue(falseBlock); @@ -335,7 +335,7 @@ // `spv.selection` is not needed anymore. rewriter.eraseOp(op); - return matchSuccess(); + return success(); } private: @@ -345,9 +345,8 @@ // 2. Each `spv.Store` uses the same pointer and the same memory attributes. // 3. A control flow goes into the given merge block from the given // conditional blocks. - PatternMatchResult canCanonicalizeSelection(Block *trueBlock, - Block *falseBlock, - Block *mergeBlock) const; + LogicalResult canCanonicalizeSelection(Block *trueBlock, Block *falseBlock, + Block *mergeBlock) const; bool onlyContainsBranchConditionalOp(Block *block) const { return std::next(block->begin()) == block->end() && @@ -382,12 +381,12 @@ } }; -PatternMatchResult ConvertSelectionOpToSelect::canCanonicalizeSelection( +LogicalResult ConvertSelectionOpToSelect::canCanonicalizeSelection( Block *trueBlock, Block *falseBlock, Block *mergeBlock) const { // Each block must consists of 2 operations. if ((std::distance(trueBlock->begin(), trueBlock->end()) != 2) || (std::distance(falseBlock->begin(), falseBlock->end()) != 2)) { - return matchFailure(); + return failure(); } auto trueBrStoreOp = dyn_cast(trueBlock->front()); @@ -399,7 +398,7 @@ if (!trueBrStoreOp || !trueBrBranchOp || !falseBrStoreOp || !falseBrBranchOp) { - return matchFailure(); + return failure(); } // Check that each `spv.Store` uses the same pointer, memory access @@ -407,15 +406,15 @@ if ((trueBrStoreOp.ptr() != falseBrStoreOp.ptr()) || !isSameAttrList(trueBrStoreOp, falseBrStoreOp) || !isValidType(trueBrStoreOp.value().getType())) { - return matchFailure(); + return failure(); } if ((trueBrBranchOp.getOperation()->getSuccessor(0) != mergeBlock) || (falseBrBranchOp.getOperation()->getSuccessor(0) != mergeBlock)) { - return matchFailure(); + return failure(); } - return matchSuccess(); + return success(); } } // end anonymous namespace diff --git a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp @@ -177,25 +177,25 @@ public: using SPIRVOpLowering::SPIRVOpLowering; - PatternMatchResult + LogicalResult matchAndRewrite(FuncOp funcOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; } // namespace -PatternMatchResult +LogicalResult FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { auto fnType = funcOp.getType(); // TODO(antiagainst): support converting functions with one result. if (fnType.getNumResults()) - return matchFailure(); + return failure(); TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs()); for (auto argType : enumerate(funcOp.getType().getInputs())) { auto convertedType = typeConverter.convertType(argType.value()); if (!convertedType) - return matchFailure(); + return failure(); signatureConverter.addInputs(argType.index(), convertedType); } @@ -216,7 +216,7 @@ newFuncOp.end()); rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter); rewriter.eraseOp(funcOp); - return matchSuccess(); + return success(); } void mlir::populateBuiltinFuncToSPIRVPatterns( diff --git a/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/DecorateSPIRVCompositeTypeLayoutPass.cpp @@ -27,8 +27,8 @@ public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(spirv::GlobalVariableOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(spirv::GlobalVariableOp op, + PatternRewriter &rewriter) const override { spirv::StructType::LayoutInfo structSize = 0; VulkanLayoutUtils::Size structAlignment = 1; SmallVector globalVarAttrs; @@ -50,7 +50,7 @@ rewriter.replaceOpWithNewOp( op, TypeAttr::get(decoratedType), globalVarAttrs); - return matchSuccess(); + return success(); } }; @@ -59,15 +59,15 @@ public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(spirv::AddressOfOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(spirv::AddressOfOp op, + PatternRewriter &rewriter) const override { auto spirvModule = op.getParentOfType(); auto varName = op.variable(); auto varOp = spirvModule.lookupSymbol(varName); rewriter.replaceOpWithNewOp( op, varOp.type(), rewriter.getSymbolRefAttr(varName)); - return matchSuccess(); + return success(); } }; } // namespace diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp @@ -138,7 +138,7 @@ class ProcessInterfaceVarABI final : public SPIRVOpLowering { public: using SPIRVOpLowering::SPIRVOpLowering; - PatternMatchResult + LogicalResult matchAndRewrite(spirv::FuncOp funcOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override; }; @@ -151,13 +151,13 @@ }; } // namespace -PatternMatchResult ProcessInterfaceVarABI::matchAndRewrite( +LogicalResult ProcessInterfaceVarABI::matchAndRewrite( spirv::FuncOp funcOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { if (!funcOp.getAttrOfType( spirv::getEntryPointABIAttrName())) { // TODO(ravishankarm) : Non-entry point functions are not handled. - return matchFailure(); + return failure(); } TypeConverter::SignatureConversion signatureConverter( funcOp.getType().getNumInputs()); @@ -171,12 +171,12 @@ // to pass around scalar/vector values and return a scalar/vector. For now // non-entry point functions are not handled in this ABI lowering and will // produce an error. - return matchFailure(); + return failure(); } auto var = createGlobalVariableForArg(funcOp, rewriter, argType.index(), abiInfo); if (!var) { - return matchFailure(); + return failure(); } OpBuilder::InsertionGuard funcInsertionGuard(rewriter); @@ -207,7 +207,7 @@ signatureConverter.getConvertedTypes(), llvm::None)); rewriter.applySignatureConversion(&funcOp.getBody(), signatureConverter); }); - return matchSuccess(); + return success(); } void LowerABIAttributesPass::runOnOperation() { diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -313,14 +313,14 @@ struct SimplifyAllocConst : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(AllocOp alloc, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(AllocOp alloc, + PatternRewriter &rewriter) const override { // Check to see if any dimensions operands are constants. If so, we can // substitute and drop them. if (llvm::none_of(alloc.getOperands(), [](Value operand) { return matchPattern(operand, m_ConstantIndex()); })) - return matchFailure(); + return failure(); auto memrefType = alloc.getType(); @@ -364,7 +364,7 @@ alloc.getType()); rewriter.replaceOp(alloc, {resultCast}); - return matchSuccess(); + return success(); } }; @@ -373,13 +373,13 @@ struct SimplifyDeadAlloc : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(AllocOp alloc, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(AllocOp alloc, + PatternRewriter &rewriter) const override { if (alloc.use_empty()) { rewriter.eraseOp(alloc); - return matchSuccess(); + return success(); } - return matchFailure(); + return failure(); } }; } // end anonymous namespace. @@ -461,18 +461,18 @@ struct SimplifyBrToBlockWithSinglePred : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(BranchOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(BranchOp op, + PatternRewriter &rewriter) const override { // Check that the successor block has a single predecessor. Block *succ = op.getDest(); Block *opParent = op.getOperation()->getBlock(); if (succ == opParent || !has_single_element(succ->getPredecessors())) - return matchFailure(); + return failure(); // Merge the successor into the current block and erase the branch. rewriter.mergeBlocks(succ, opParent, op.getOperands()); rewriter.eraseOp(op); - return matchSuccess(); + return success(); } }; } // end anonymous namespace. @@ -545,18 +545,18 @@ : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(CallIndirectOp indirectCall, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(CallIndirectOp indirectCall, + PatternRewriter &rewriter) const override { // Check that the callee is a constant callee. SymbolRefAttr calledFn; if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn))) - return matchFailure(); + return failure(); // Replace with a direct call. rewriter.replaceOpWithNewOp(indirectCall, calledFn, indirectCall.getResultTypes(), indirectCall.getArgOperands()); - return matchSuccess(); + return success(); } }; } // end anonymous namespace. @@ -733,20 +733,20 @@ struct SimplifyConstCondBranchPred : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(CondBranchOp condbr, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(CondBranchOp condbr, + PatternRewriter &rewriter) const override { if (matchPattern(condbr.getCondition(), m_NonZero())) { // True branch taken. rewriter.replaceOpWithNewOp(condbr, condbr.getTrueDest(), condbr.getTrueOperands()); - return matchSuccess(); + return success(); } else if (matchPattern(condbr.getCondition(), m_Zero())) { // False branch taken. rewriter.replaceOpWithNewOp(condbr, condbr.getFalseDest(), condbr.getFalseOperands()); - return matchSuccess(); + return success(); } - return matchFailure(); + return failure(); } }; } // end anonymous namespace. @@ -958,21 +958,21 @@ struct SimplifyDeadDealloc : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(DeallocOp dealloc, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(DeallocOp dealloc, + PatternRewriter &rewriter) const override { // Check that the memref operand's defining operation is an AllocOp. Value memref = dealloc.memref(); if (!isa_and_nonnull(memref.getDefiningOp())) - return matchFailure(); + return failure(); // Check that all of the uses of the AllocOp are other DeallocOps. for (auto *user : memref.getUsers()) if (!isa(user)) - return matchFailure(); + return failure(); // Erase the dealloc operation. rewriter.eraseOp(dealloc); - return matchSuccess(); + return success(); } }; } // end anonymous namespace. @@ -2003,8 +2003,8 @@ public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(SubViewOp subViewOp, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(SubViewOp subViewOp, + PatternRewriter &rewriter) const override { MemRefType subViewType = subViewOp.getType(); // Follow all or nothing approach for shapes for now. If all the operands // for sizes are constants then fold it into the type of the result memref. @@ -2012,7 +2012,7 @@ llvm::any_of(subViewOp.sizes(), [](Value operand) { return !matchPattern(operand, m_ConstantIndex()); })) { - return matchFailure(); + return failure(); } SmallVector staticShape(subViewOp.getNumSizes()); for (auto size : llvm::enumerate(subViewOp.sizes())) { @@ -2028,7 +2028,7 @@ // Insert a memref_cast for compatibility of the uses of the op. rewriter.replaceOpWithNewOp(subViewOp, newSubViewOp, subViewOp.getType()); - return matchSuccess(); + return success(); } }; @@ -2037,10 +2037,10 @@ public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(SubViewOp subViewOp, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(SubViewOp subViewOp, + PatternRewriter &rewriter) const override { if (subViewOp.getNumStrides() == 0) { - return matchFailure(); + return failure(); } // Follow all or nothing approach for strides for now. If all the operands // for strides are constants then fold it into the strides of the result @@ -2056,7 +2056,7 @@ llvm::any_of(subViewOp.strides(), [](Value stride) { return !matchPattern(stride, m_ConstantIndex()); })) { - return matchFailure(); + return failure(); } SmallVector staticStrides(subViewOp.getNumStrides()); @@ -2077,7 +2077,7 @@ // Insert a memref_cast for compatibility of the uses of the op. rewriter.replaceOpWithNewOp(subViewOp, newSubViewOp, subViewOp.getType()); - return matchSuccess(); + return success(); } }; @@ -2086,10 +2086,10 @@ public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(SubViewOp subViewOp, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(SubViewOp subViewOp, + PatternRewriter &rewriter) const override { if (subViewOp.getNumOffsets() == 0) { - return matchFailure(); + return failure(); } // Follow all or nothing approach for offsets for now. If all the operands // for offsets are constants then fold it into the offset of the result @@ -2106,7 +2106,7 @@ llvm::any_of(subViewOp.offsets(), [](Value stride) { return !matchPattern(stride, m_ConstantIndex()); })) { - return matchFailure(); + return failure(); } auto staticOffset = baseOffset; @@ -2128,7 +2128,7 @@ // Insert a memref_cast for compatibility of the uses of the op. rewriter.replaceOpWithNewOp(subViewOp, newSubViewOp, subViewOp.getType()); - return matchSuccess(); + return success(); } }; @@ -2347,18 +2347,18 @@ struct ViewOpShapeFolder : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(ViewOp viewOp, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(ViewOp viewOp, + PatternRewriter &rewriter) const override { // Return if none of the operands are constants. if (llvm::none_of(viewOp.getOperands(), [](Value operand) { return matchPattern(operand, m_ConstantIndex()); })) - return matchFailure(); + return failure(); // Get result memref type. auto memrefType = viewOp.getType(); if (memrefType.getAffineMaps().size() > 1) - return matchFailure(); + return failure(); auto map = memrefType.getAffineMaps().empty() ? AffineMap::getMultiDimIdentityMap(memrefType.getRank(), rewriter.getContext()) @@ -2368,7 +2368,7 @@ int64_t oldOffset; SmallVector oldStrides; if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset))) - return matchFailure(); + return failure(); SmallVector newOperands; @@ -2444,27 +2444,27 @@ // Insert a cast so we have the same type as the old memref type. rewriter.replaceOpWithNewOp(viewOp, newViewOp, viewOp.getType()); - return matchSuccess(); + return success(); } }; struct ViewOpMemrefCastFolder : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(ViewOp viewOp, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(ViewOp viewOp, + PatternRewriter &rewriter) const override { Value memrefOperand = viewOp.getOperand(0); MemRefCastOp memrefCastOp = dyn_cast_or_null(memrefOperand.getDefiningOp()); if (!memrefCastOp) - return matchFailure(); + return failure(); Value allocOperand = memrefCastOp.getOperand(); AllocOp allocOp = dyn_cast_or_null(allocOperand.getDefiningOp()); if (!allocOp) - return matchFailure(); + return failure(); rewriter.replaceOpWithNewOp(viewOp, viewOp.getType(), allocOperand, viewOp.operands()); - return matchSuccess(); + return success(); } }; diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -1145,18 +1145,18 @@ public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(StridedSliceOp stridedSliceOp, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(StridedSliceOp stridedSliceOp, + PatternRewriter &rewriter) const override { // Return if 'stridedSliceOp' operand is not defined by a ConstantMaskOp. auto defOp = stridedSliceOp.vector().getDefiningOp(); auto constantMaskOp = dyn_cast_or_null(defOp); if (!constantMaskOp) - return matchFailure(); + return failure(); // Return if 'stridedSliceOp' has non-unit strides. if (llvm::any_of(stridedSliceOp.strides(), [](Attribute attr) { return attr.cast().getInt() != 1; })) - return matchFailure(); + return failure(); // Gather constant mask dimension sizes. SmallVector maskDimSizes; populateFromInt64AttrArray(constantMaskOp.mask_dim_sizes(), maskDimSizes); @@ -1187,7 +1187,7 @@ rewriter.replaceOpWithNewOp( stridedSliceOp, stridedSliceOp.getResult().getType(), vector::getVectorSubscriptAttr(rewriter, sliceMaskDimSizes)); - return matchSuccess(); + return success(); } }; @@ -1619,14 +1619,14 @@ public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(CreateMaskOp createMaskOp, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(CreateMaskOp createMaskOp, + PatternRewriter &rewriter) const override { // Return if any of 'createMaskOp' operands are not defined by a constant. auto is_not_def_by_constant = [](Value operand) { return !isa_and_nonnull(operand.getDefiningOp()); }; if (llvm::any_of(createMaskOp.operands(), is_not_def_by_constant)) - return matchFailure(); + return failure(); // Gather constant mask dimension sizes. SmallVector maskDimSizes; for (auto operand : createMaskOp.operands()) { @@ -1637,7 +1637,7 @@ rewriter.replaceOpWithNewOp( createMaskOp, createMaskOp.getResult().getType(), vector::getVectorSubscriptAttr(rewriter, maskDimSizes)); - return matchSuccess(); + return success(); } }; diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -545,18 +545,18 @@ struct SplitTransferReadOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(vector::TransferReadOp xferReadOp, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(vector::TransferReadOp xferReadOp, + PatternRewriter &rewriter) const override { // TODO(andydavis, ntv) Support splitting TransferReadOp with non-identity // permutation maps. Repurpose code from MaterializeVectors transformation. if (!isIdentitySuffix(xferReadOp.permutation_map())) - return matchFailure(); + return failure(); // Return unless the unique 'xferReadOp' user is an ExtractSlicesOp. Value xferReadResult = xferReadOp.getResult(); auto extractSlicesOp = dyn_cast(*xferReadResult.getUsers().begin()); if (!xferReadResult.hasOneUse() || !extractSlicesOp) - return matchFailure(); + return failure(); // Get 'sizes' and 'strides' parameters from ExtractSlicesOp user. auto sourceVectorType = extractSlicesOp.getSourceVectorType(); @@ -593,7 +593,7 @@ rewriter.replaceOpWithNewOp( xferReadOp, sourceVectorType, tupleOp, extractSlicesOp.sizes(), extractSlicesOp.strides()); - return matchSuccess(); + return success(); } }; @@ -601,23 +601,23 @@ struct SplitTransferWriteOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(vector::TransferWriteOp xferWriteOp, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(vector::TransferWriteOp xferWriteOp, + PatternRewriter &rewriter) const override { // TODO(andydavis, ntv) Support splitting TransferWriteOp with non-identity // permutation maps. Repurpose code from MaterializeVectors transformation. if (!isIdentitySuffix(xferWriteOp.permutation_map())) - return matchFailure(); + return failure(); // Return unless the 'xferWriteOp' 'vector' operand is an 'InsertSlicesOp'. auto *vectorDefOp = xferWriteOp.vector().getDefiningOp(); auto insertSlicesOp = dyn_cast_or_null(vectorDefOp); if (!insertSlicesOp) - return matchFailure(); + return failure(); // Get TupleOp operand of 'insertSlicesOp'. auto tupleOp = dyn_cast_or_null( insertSlicesOp.vectors().getDefiningOp()); if (!tupleOp) - return matchFailure(); + return failure(); // Get 'sizes' and 'strides' parameters from InsertSlicesOp user. auto sourceTupleType = insertSlicesOp.getSourceTupleType(); @@ -644,7 +644,7 @@ // Erase old 'xferWriteOp'. rewriter.eraseOp(xferWriteOp); - return matchSuccess(); + return success(); } }; @@ -653,15 +653,15 @@ struct ShapeCastOpDecomposer : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(vector::ShapeCastOp shapeCastOp, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp, + PatternRewriter &rewriter) const override { // Check if 'shapeCastOp' has tuple source/result type. auto sourceTupleType = shapeCastOp.source().getType().dyn_cast_or_null(); auto resultTupleType = shapeCastOp.result().getType().dyn_cast_or_null(); if (!sourceTupleType || !resultTupleType) - return matchFailure(); + return failure(); assert(sourceTupleType.size() == resultTupleType.size()); // Create single-vector ShapeCastOp for each source tuple element. @@ -679,7 +679,7 @@ // Replace 'shapeCastOp' with tuple of 'resultElements'. rewriter.replaceOpWithNewOp(shapeCastOp, resultTupleType, resultElements); - return matchSuccess(); + return success(); } }; @@ -702,21 +702,21 @@ struct ShapeCastOpFolder : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(vector::ShapeCastOp shapeCastOp, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp, + PatternRewriter &rewriter) const override { // Check if 'shapeCastOp' has vector source/result type. auto sourceVectorType = shapeCastOp.source().getType().dyn_cast_or_null(); auto resultVectorType = shapeCastOp.result().getType().dyn_cast_or_null(); if (!sourceVectorType || !resultVectorType) - return matchFailure(); + return failure(); // Check if shape cast op source operand is also a shape cast op. auto sourceShapeCastOp = dyn_cast_or_null( shapeCastOp.source().getDefiningOp()); if (!sourceShapeCastOp) - return matchFailure(); + return failure(); auto operandSourceVectorType = sourceShapeCastOp.source().getType().cast(); auto operandResultVectorType = @@ -725,10 +725,10 @@ // Check if shape cast operations invert each other. if (operandSourceVectorType != resultVectorType || operandResultVectorType != sourceVectorType) - return matchFailure(); + return failure(); rewriter.replaceOp(shapeCastOp, sourceShapeCastOp.source()); - return matchSuccess(); + return success(); } }; @@ -738,30 +738,30 @@ struct TupleGetFolderOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(vector::TupleGetOp tupleGetOp, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(vector::TupleGetOp tupleGetOp, + PatternRewriter &rewriter) const override { // Return if 'tupleGetOp.vectors' arg was not defined by ExtractSlicesOp. auto extractSlicesOp = dyn_cast_or_null( tupleGetOp.vectors().getDefiningOp()); if (!extractSlicesOp) - return matchFailure(); + return failure(); // Return if 'extractSlicesOp.vector' arg was not defined by InsertSlicesOp. auto insertSlicesOp = dyn_cast_or_null( extractSlicesOp.vector().getDefiningOp()); if (!insertSlicesOp) - return matchFailure(); + return failure(); // Return if 'insertSlicesOp.vectors' arg was not defined by TupleOp. auto tupleOp = dyn_cast_or_null( insertSlicesOp.vectors().getDefiningOp()); if (!tupleOp) - return matchFailure(); + return failure(); // Forward Value from 'tupleOp' at 'tupleGetOp.index'. Value tupleValue = tupleOp.getOperand(tupleGetOp.getIndex()); rewriter.replaceOp(tupleGetOp, tupleValue); - return matchSuccess(); + return success(); } }; @@ -778,8 +778,8 @@ public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(vector::ExtractSlicesOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(vector::ExtractSlicesOp op, + PatternRewriter &rewriter) const override { auto loc = op.getLoc(); VectorType vectorType = op.getSourceVectorType(); @@ -806,7 +806,7 @@ } rewriter.replaceOpWithNewOp(op, tupleType, tupleValues); - return matchSuccess(); + return success(); } }; @@ -825,8 +825,8 @@ public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(vector::InsertSlicesOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(vector::InsertSlicesOp op, + PatternRewriter &rewriter) const override { auto loc = op.getLoc(); VectorType vectorType = op.getResultVectorType(); @@ -860,7 +860,7 @@ } rewriter.replaceOp(op, result); - return matchSuccess(); + return success(); } }; @@ -881,8 +881,8 @@ public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(vector::OuterProductOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(vector::OuterProductOp op, + PatternRewriter &rewriter) const override { auto loc = op.getLoc(); VectorType rhsType = op.getOperandVectorTypeRHS(); @@ -907,7 +907,7 @@ result = rewriter.create(loc, resType, m, result, pos); } rewriter.replaceOp(op, result); - return matchSuccess(); + return success(); } }; @@ -934,11 +934,11 @@ : OpRewritePattern(context), vectorTransformsOptions(vectorTransformsOptions) {} - PatternMatchResult matchAndRewrite(vector::ContractionOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(vector::ContractionOp op, + PatternRewriter &rewriter) const override { // TODO(ajcbik): implement masks if (llvm::size(op.masks()) != 0) - return matchFailure(); + return failure(); // TODO(ntv, ajcbik): implement benefits, cost models, separate this out in // a new pattern. @@ -977,7 +977,7 @@ rewriter.replaceOpWithNewOp(op, op.acc(), mul); else rewriter.replaceOpWithNewOp(op, op.acc(), mul); - return matchSuccess(); + return success(); } } @@ -987,7 +987,7 @@ int64_t lhsIndex = batchDimMap[0].first; int64_t rhsIndex = batchDimMap[0].second; rewriter.replaceOp(op, lowerParallel(op, lhsIndex, rhsIndex, rewriter)); - return matchSuccess(); + return success(); } // Collect contracting dimensions. @@ -1007,7 +1007,7 @@ if (lhsContractingDimSet.count(lhsIndex) == 0) { rewriter.replaceOp( op, lowerParallel(op, lhsIndex, /*rhsIndex=*/-1, rewriter)); - return matchSuccess(); + return success(); } } @@ -1018,17 +1018,17 @@ if (rhsContractingDimSet.count(rhsIndex) == 0) { rewriter.replaceOp( op, lowerParallel(op, /*lhsIndex=*/-1, rhsIndex, rewriter)); - return matchSuccess(); + return success(); } } // Lower the first remaining reduction dimension. if (!contractingDimMap.empty()) { rewriter.replaceOp(op, lowerReduction(op, rewriter)); - return matchSuccess(); + return success(); } - return matchFailure(); + return failure(); } private: @@ -1275,12 +1275,12 @@ public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(vector::ShapeCastOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(vector::ShapeCastOp op, + PatternRewriter &rewriter) const override { auto sourceVectorType = op.getSourceVectorType(); auto resultVectorType = op.getResultVectorType(); if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1) - return matchFailure(); + return failure(); auto loc = op.getLoc(); auto elemType = sourceVectorType.getElementType(); @@ -1295,7 +1295,7 @@ /*offsets=*/i * mostMinorVectorSize, /*strides=*/1); } rewriter.replaceOp(op, desc); - return matchSuccess(); + return success(); } }; @@ -1309,12 +1309,12 @@ public: using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(vector::ShapeCastOp op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(vector::ShapeCastOp op, + PatternRewriter &rewriter) const override { auto sourceVectorType = op.getSourceVectorType(); auto resultVectorType = op.getResultVectorType(); if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2) - return matchFailure(); + return failure(); auto loc = op.getLoc(); auto elemType = sourceVectorType.getElementType(); @@ -1330,7 +1330,7 @@ desc = rewriter.create(loc, vec, desc, i); } rewriter.replaceOp(op, desc); - return matchSuccess(); + return success(); } }; diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -44,7 +44,7 @@ "rewrite functions!"); } -PatternMatchResult RewritePattern::match(Operation *op) const { +LogicalResult RewritePattern::match(Operation *op) const { llvm_unreachable("need to implement either match or matchAndRewrite!"); } diff --git a/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp b/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp --- a/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp +++ b/mlir/lib/Quantizer/Transforms/RemoveInstrumentationPass.cpp @@ -35,13 +35,13 @@ RemoveIdentityOpRewrite(MLIRContext *context) : RewritePattern(OpTy::getOperationName(), 1, context) {} - PatternMatchResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { assert(op->getNumOperands() == 1); assert(op->getNumResults() == 1); rewriter.replaceOp(op, op->getOperand(0)); - return matchSuccess(); + return success(); } }; diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -1010,7 +1010,7 @@ //===----------------------------------------------------------------------===// /// Attempt to match and rewrite the IR root at the specified operation. -PatternMatchResult +LogicalResult ConversionPattern::matchAndRewrite(Operation *op, PatternRewriter &rewriter) const { SmallVector operands; @@ -1705,7 +1705,7 @@ : OpConversionPattern(ctx), converter(converter) {} /// Hook for derived classes to implement combined matching and rewriting. - PatternMatchResult + LogicalResult matchAndRewrite(FuncOp funcOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { FunctionType type = funcOp.getType(); @@ -1714,12 +1714,12 @@ TypeConverter::SignatureConversion result(type.getNumInputs()); for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i) if (failed(converter.convertSignatureArg(i, type.getInput(i), result))) - return matchFailure(); + return failure(); // Convert the original function results. SmallVector convertedResults; if (failed(converter.convertTypes(type.getResults(), convertedResults))) - return matchFailure(); + return failure(); // Update the function signature in-place. rewriter.updateRootInPlace(funcOp, [&] { @@ -1727,7 +1727,7 @@ convertedResults, funcOp.getContext())); rewriter.applySignatureConversion(&funcOp.getBody(), result); }); - return matchSuccess(); + return success(); } /// The type converter to use when rewriting the signature. diff --git a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp --- a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp +++ b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp @@ -94,32 +94,32 @@ struct ConvertToAtomCmpExchangeWeak : public RewritePattern { ConvertToAtomCmpExchangeWeak(MLIRContext *context); - PatternMatchResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override; + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override; }; struct ConvertToBitReverse : public RewritePattern { ConvertToBitReverse(MLIRContext *context); - PatternMatchResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override; + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override; }; struct ConvertToGroupNonUniformBallot : public RewritePattern { ConvertToGroupNonUniformBallot(MLIRContext *context); - PatternMatchResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override; + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override; }; struct ConvertToModule : public RewritePattern { ConvertToModule(MLIRContext *context); - PatternMatchResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override; + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override; }; struct ConvertToSubgroupBallot : public RewritePattern { ConvertToSubgroupBallot(MLIRContext *context); - PatternMatchResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override; + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override; }; } // end anonymous namespace @@ -145,7 +145,7 @@ : RewritePattern("test.convert_to_atomic_compare_exchange_weak_op", {"spv.AtomicCompareExchangeWeak"}, 1, context) {} -PatternMatchResult +LogicalResult ConvertToAtomCmpExchangeWeak::matchAndRewrite(Operation *op, PatternRewriter &rewriter) const { Value ptr = op->getOperand(0); @@ -159,21 +159,21 @@ spirv::MemorySemantics::AcquireRelease | spirv::MemorySemantics::AtomicCounterMemory, spirv::MemorySemantics::Acquire, value, comparator); - return matchSuccess(); + return success(); } ConvertToBitReverse::ConvertToBitReverse(MLIRContext *context) : RewritePattern("test.convert_to_bit_reverse_op", {"spv.BitReverse"}, 1, context) {} -PatternMatchResult +LogicalResult ConvertToBitReverse::matchAndRewrite(Operation *op, PatternRewriter &rewriter) const { Value predicate = op->getOperand(0); rewriter.replaceOpWithNewOp( op, op->getResult(0).getType(), predicate); - return matchSuccess(); + return success(); } ConvertToGroupNonUniformBallot::ConvertToGroupNonUniformBallot( @@ -181,39 +181,39 @@ : RewritePattern("test.convert_to_group_non_uniform_ballot_op", {"spv.GroupNonUniformBallot"}, 1, context) {} -PatternMatchResult ConvertToGroupNonUniformBallot::matchAndRewrite( +LogicalResult ConvertToGroupNonUniformBallot::matchAndRewrite( Operation *op, PatternRewriter &rewriter) const { Value predicate = op->getOperand(0); rewriter.replaceOpWithNewOp( op, op->getResult(0).getType(), spirv::Scope::Workgroup, predicate); - return matchSuccess(); + return success(); } ConvertToModule::ConvertToModule(MLIRContext *context) : RewritePattern("test.convert_to_module_op", {"spv.module"}, 1, context) {} -PatternMatchResult +LogicalResult ConvertToModule::matchAndRewrite(Operation *op, PatternRewriter &rewriter) const { rewriter.replaceOpWithNewOp( op, spirv::AddressingModel::PhysicalStorageBuffer64, spirv::MemoryModel::Vulkan); - return matchSuccess(); + return success(); } ConvertToSubgroupBallot::ConvertToSubgroupBallot(MLIRContext *context) : RewritePattern("test.convert_to_subgroup_ballot_op", {"spv.SubgroupBallotKHR"}, 1, context) {} -PatternMatchResult +LogicalResult ConvertToSubgroupBallot::matchAndRewrite(Operation *op, PatternRewriter &rewriter) const { Value predicate = op->getOperand(0); rewriter.replaceOpWithNewOp( op, op->getResult(0).getType(), predicate); - return matchSuccess(); + return success(); } namespace mlir { diff --git a/mlir/test/lib/TestDialect/TestDialect.cpp b/mlir/test/lib/TestDialect/TestDialect.cpp --- a/mlir/test/lib/TestDialect/TestDialect.cpp +++ b/mlir/test/lib/TestDialect/TestDialect.cpp @@ -283,10 +283,10 @@ : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(TestOpWithRegionPattern op, - PatternRewriter &rewriter) const override { + LogicalResult matchAndRewrite(TestOpWithRegionPattern op, + PatternRewriter &rewriter) const override { rewriter.eraseOp(op); - return matchSuccess(); + return success(); } }; } // end anonymous namespace diff --git a/mlir/test/lib/TestDialect/TestPatterns.cpp b/mlir/test/lib/TestDialect/TestPatterns.cpp --- a/mlir/test/lib/TestDialect/TestPatterns.cpp +++ b/mlir/test/lib/TestDialect/TestPatterns.cpp @@ -141,7 +141,7 @@ TestRegionRewriteBlockMovement(MLIRContext *ctx) : ConversionPattern("test.region", 1, ctx) {} - PatternMatchResult + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { // Inline this region into the parent region. @@ -155,7 +155,7 @@ // Drop this operation. rewriter.eraseOp(op); - return matchSuccess(); + return success(); } }; /// This pattern is a simple pattern that generates a region containing an @@ -164,8 +164,8 @@ TestRegionRewriteUndo(MLIRContext *ctx) : RewritePattern("test.region_builder", 1, ctx) {} - PatternMatchResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const final { + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const final { // Create the region operation with an entry block containing arguments. OperationState newRegion(op->getLoc(), "test.region"); newRegion.addRegion(); @@ -179,7 +179,7 @@ // Drop this operation. rewriter.eraseOp(op); - return matchSuccess(); + return success(); } }; @@ -191,7 +191,7 @@ TestDropOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter) : ConversionPattern("test.drop_region_op", 1, ctx), converter(converter) { } - PatternMatchResult + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { Region ®ion = op->getRegion(0); @@ -202,12 +202,12 @@ for (unsigned i = 0, e = entry->getNumArguments(); i != e; ++i) if (failed(converter.convertSignatureArg( i, entry->getArgument(i).getType(), result))) - return matchFailure(); + return failure(); // Convert the region signature and just drop the operation. rewriter.applySignatureConversion(®ion, result); rewriter.eraseOp(op); - return matchSuccess(); + return success(); } /// The type converter to use when rewriting the signature. @@ -217,35 +217,35 @@ struct TestPassthroughInvalidOp : public ConversionPattern { TestPassthroughInvalidOp(MLIRContext *ctx) : ConversionPattern("test.invalid", 1, ctx) {} - PatternMatchResult + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { rewriter.replaceOpWithNewOp(op, llvm::None, operands, llvm::None); - return matchSuccess(); + return success(); } }; /// This pattern handles the case of a split return value. struct TestSplitReturnType : public ConversionPattern { TestSplitReturnType(MLIRContext *ctx) : ConversionPattern("test.return", 1, ctx) {} - PatternMatchResult + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { // Check for a return of F32. if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32()) - return matchFailure(); + return failure(); // Check if the first operation is a cast operation, if it is we use the // results directly. auto *defOp = operands[0].getDefiningOp(); if (auto packerOp = llvm::dyn_cast_or_null(defOp)) { rewriter.replaceOpWithNewOp(op, packerOp.getOperands()); - return matchSuccess(); + return success(); } // Otherwise, fail to match. - return matchFailure(); + return failure(); } }; @@ -254,52 +254,52 @@ struct TestChangeProducerTypeI32ToF32 : public ConversionPattern { TestChangeProducerTypeI32ToF32(MLIRContext *ctx) : ConversionPattern("test.type_producer", 1, ctx) {} - PatternMatchResult + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { // If the type is I32, change the type to F32. if (!Type(*op->result_type_begin()).isSignlessInteger(32)) - return matchFailure(); + return failure(); rewriter.replaceOpWithNewOp(op, rewriter.getF32Type()); - return matchSuccess(); + return success(); } }; struct TestChangeProducerTypeF32ToF64 : public ConversionPattern { TestChangeProducerTypeF32ToF64(MLIRContext *ctx) : ConversionPattern("test.type_producer", 1, ctx) {} - PatternMatchResult + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { // If the type is F32, change the type to F64. if (!Type(*op->result_type_begin()).isF32()) return rewriter.notifyMatchFailure(op, "expected single f32 operand"); rewriter.replaceOpWithNewOp(op, rewriter.getF64Type()); - return matchSuccess(); + return success(); } }; struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern { TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx) : ConversionPattern("test.type_producer", 10, ctx) {} - PatternMatchResult + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { // Always convert to B16, even though it is not a legal type. This tests // that values are unmapped correctly. rewriter.replaceOpWithNewOp(op, rewriter.getBF16Type()); - return matchSuccess(); + return success(); } }; struct TestUpdateConsumerType : public ConversionPattern { TestUpdateConsumerType(MLIRContext *ctx) : ConversionPattern("test.type_consumer", 1, ctx) {} - PatternMatchResult + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { // Verify that the incoming operand has been successfully remapped to F64. if (!operands[0].getType().isF64()) - return matchFailure(); + return failure(); rewriter.replaceOpWithNewOp(op, operands[0]); - return matchSuccess(); + return success(); } }; @@ -312,15 +312,15 @@ TestNonRootReplacement(MLIRContext *ctx) : RewritePattern("test.replace_non_root", 1, ctx) {} - PatternMatchResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const final { + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const final { auto resultType = *op->result_type_begin(); auto illegalOp = rewriter.create(op->getLoc(), resultType); auto legalOp = rewriter.create(op->getLoc(), resultType); rewriter.replaceOp(illegalOp, {legalOp}); rewriter.replaceOp(op, {illegalOp}); - return matchSuccess(); + return success(); } }; } // namespace @@ -475,7 +475,7 @@ : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; - PatternMatchResult + LogicalResult matchAndRewrite(OneVResOneVOperandOp1 op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto origOps = op.getOperands(); @@ -490,7 +490,7 @@ rewriter.replaceOpWithNewOp(op, op.getResultTypes(), remappedOperands); - return matchSuccess(); + return success(); } }; diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -215,7 +215,7 @@ // Skip the operand matching at depth 0 as the pattern rewriter already does. if (depth != 0) { // Skip if there is no defining operation (e.g., arguments to function). - os.indent(indent) << formatv("if (!castedOp{0}) return matchFailure();\n", + os.indent(indent) << formatv("if (!castedOp{0}) return failure();\n", depth); } if (tree.getNumArgs() != op.getNumArgs()) { @@ -300,7 +300,7 @@ os.indent(indent) << "if (!(" << std::string(tgfmt(matcher.getConditionTemplate(), &fmtCtx.withSelf(self))) - << ")) return matchFailure();\n"; + << ")) return failure();\n"; } } @@ -344,7 +344,7 @@ // should just capture a mlir::Attribute() to signal the missing state. // That is precisely what getAttr() returns on missing attributes. } else { - os.indent(indent) << "if (!tblgen_attr) return matchFailure();\n"; + os.indent(indent) << "if (!tblgen_attr) return failure();\n"; } auto matcher = tree.getArgAsLeaf(argIndex); @@ -360,7 +360,7 @@ os.indent(indent) << "if (!(" << std::string(tgfmt(matcher.getConditionTemplate(), &fmtCtx.withSelf("tblgen_attr"))) - << ")) return matchFailure();\n"; + << ")) return failure();\n"; } // Capture the value @@ -383,7 +383,7 @@ auto &entities = appliedConstraint.entities; auto condition = constraint.getConditionTemplate(); - auto cmd = "if (!({0})) return matchFailure();\n"; + auto cmd = "if (!({0})) return failure();\n"; if (isa(constraint)) { auto self = formatv("({0}.getType())", @@ -468,7 +468,7 @@ // Emit matchAndRewrite() function. os << R"( - PatternMatchResult matchAndRewrite(Operation *op0, + LogicalResult matchAndRewrite(Operation *op0, PatternRewriter &rewriter) const override { )"; @@ -501,7 +501,7 @@ os.indent(4) << "// Rewrite\n"; emitRewriteLogic(); - os.indent(4) << "return matchSuccess();\n"; + os.indent(4) << "return success();\n"; os << " };\n"; os << "};\n"; }