diff --git a/mlir/docs/Canonicalization.md b/mlir/docs/Canonicalization.md --- a/mlir/docs/Canonicalization.md +++ b/mlir/docs/Canonicalization.md @@ -56,9 +56,9 @@ ## Defining Canonicalizations Two mechanisms are available with which to define canonicalizations; -`getCanonicalizationPatterns` and `fold`. +general `RewritePattern`s and the `fold` method. -### Canonicalizing with `getCanonicalizationPatterns` +### Canonicalizing with `RewritePattern`s This mechanism allows for providing canonicalizations as a set of `RewritePattern`s, either imperatively defined in C++ or declaratively as @@ -67,13 +67,21 @@ These transformations may be as simple as replacing a multiplication with a shift, or even replacing a conditional branch with an unconditional one. -In [ODS](OpDefinitions.md), an operation can set the `hasCanonicalizer` bit to -generate a declaration for the `getCanonicalizationPatterns` method. +In [ODS](OpDefinitions.md), an operation can set the `hasCanonicalizer` bit or +the `hasCanonicalizeMethod` bit to generate a declaration for the +`getCanonicalizationPatterns` method: ```tablegen def MyOp : ... { + // I want to define a fully general set of patterns for this op. let hasCanonicalizer = 1; } + +def OtherOp : ... { + // A single "matchAndRewrite" style RewritePattern implemented as a method + // is good enough for me. + let hasCanonicalizeMethod = 1; +} ``` Canonicalization patterns can then be provided in the source file: @@ -83,12 +91,17 @@ MLIRContext *context) { patterns.add<...>(...); } + +LogicalResult OtherOp::canonicalize(OtherOp op, PatternRewriter &rewriter) { + // patterns and rewrites go here. + return failure(); +} ``` See the [quickstart guide](Tutorials/QuickstartRewrites.md) for information on defining operation rewrites. -### Canonicalizing with `fold` +### Canonicalizing with the `fold` method The `fold` mechanism is an intentionally limited, but powerful mechanism that allows for applying canonicalizations in many places throughout the compiler. diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md --- a/mlir/docs/OpDefinitions.md +++ b/mlir/docs/OpDefinitions.md @@ -919,6 +919,13 @@ for this operation. If it is `1`, then `::getCanonicalizationPatterns()` should be defined. +### `hasCanonicalizeMethod` + +When this boolean field is set to `true`, it indicates that the op implements a +`canonicalize` method for simple "matchAndRewrite" style canonicalization +patterns. If `hasCanonicalizer` is 0, then an implementation of +`::getCanonicalizationPatterns()` is implemented to call this function. + ### `hasFolder` This boolean field indicate whether general folding rules have been defined for diff --git a/mlir/docs/Tutorials/QuickstartRewrites.md b/mlir/docs/Tutorials/QuickstartRewrites.md --- a/mlir/docs/Tutorials/QuickstartRewrites.md +++ b/mlir/docs/Tutorials/QuickstartRewrites.md @@ -159,10 +159,61 @@ use to collect all the generated patterns inside `patterns` and then use `patterns` in any pass you would like. -### C++ rewrite specification +### Simple C++ `matchAndRewrite` style specifications -In case patterns are not sufficient there is also the fully C++ way of -expressing a rewrite: +Many simple rewrites can be expressed with a `matchAndRewrite` style of +pattern, e.g. when converting a multiply by a power of two into a shift. For +these cases, the you can define the pattern as a simple function: + +```c++ +static LogicalResult +convertTFLeakyRelu(TFLeakyReluOp op, PatternRewriter &rewriter) { + rewriter.replaceOpWithNewOp( + op, op->getResult(0).getType(), op->getOperand(0), + /*alpha=*/op->getAttrOfType("alpha")); + return success(); +} + +void populateRewrites(RewritePatternSet &patternSet) { + // Add it to a pattern set. + patternSet.add(convertTFLeakyRelu); +} +``` + +ODS provides a simple way to define a function-style canonicalization for your +operation. In the TableGen definition of the op, specify +`let hasCanonicalizeMethod = 1;` and then implement the `canonicalize` method in +your .cpp file: + +```c++ +// Example from the CIRCT project which has a variadic integer multiply. +LogicalResult circt::MulOp::canonicalize(MulOp op, PatternRewriter &rewriter) { + auto inputs = op.inputs(); + APInt value; + + // mul(x, c) -> shl(x, log2(c)), where c is a power of two. + if (inputs.size() == 2 && matchPattern(inputs.back(), m_RConstant(value)) && + value.isPowerOf2()) { + auto shift = rewriter.create(op.getLoc(), op.getType(), + value.exactLogBase2()); + auto shlOp = + rewriter.create(op.getLoc(), inputs[0], shift); + rewriter.replaceOpWithNewOp(op, op.getType(), + ArrayRef(shlOp)); + return success(); + } + + return failure(); +} +``` + +However, you may want the full generality of canonicalization patterns, for that +you can specify an arbitrary list of `RewritePattern`s. + +### Fully general C++ `RewritePattern` specifications + +In case ODS patterns and `matchAndRewrite`-style functions are not sufficient +you can also specify rewrites as a general set of `RewritePattern`s: ```c++ /// Multi-step rewrite using "match" and "rewrite". This allows for separating @@ -202,19 +253,6 @@ construction. While in the pattern generator a simple heuristic is currently employed based around the number of ops matched and replaced. -In the case where you have a registered op and want to use a benefit of 1, you -can even define the pattern as a C function: - -```c++ -static LogicalResult -convertTFLeakyRelu(TFLeakyReluOp op, PatternRewriter &rewriter) { - rewriter.replaceOpWithNewOp( - op, op->getResult(0).getType(), op->getOperand(0), - /*alpha=*/op->getAttrOfType("alpha")); - return success(); -} -``` - The above rule did not capture the matching operands/attributes, but in general the `match` function in a multi-step rewrite may populate and return a `PatternState` (or class derived from one) to pass information extracted during diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h @@ -33,6 +33,7 @@ class Builder; class FuncOp; class OpBuilder; +class PatternRewriter; /// Return the list of Range (i.e. offset, size, stride). Each Range /// entry contains either the dynamic value or a ConstantIndexOp constructed diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -339,7 +339,7 @@ // AssertOp is fully verified by its traits. let verifier = ?; - let hasCanonicalizer = 1; + let hasCanonicalizeMethod = 1; } //===----------------------------------------------------------------------===// @@ -500,7 +500,7 @@ void eraseOperand(unsigned index); }]; - let hasCanonicalizer = 1; + let hasCanonicalizeMethod = 1; let assemblyFormat = [{ $dest (`(` $destOperands^ `:` type($destOperands) `)`)? attr-dict }]; @@ -629,7 +629,7 @@ }]; let verifier = ?; - let hasCanonicalizer = 1; + let hasCanonicalizeMethod = 1; let assemblyFormat = "$callee `(` $operands `)` attr-dict `:` type($callee)"; } diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -2141,11 +2141,12 @@ code verifier = ?; // Whether this op has associated canonicalization patterns. - // TODO: figure out a better way to write canonicalization patterns in - // TableGen rules directly instead of using this marker and C++ - // implementations. bit hasCanonicalizer = 0; + // Whether this op has a static "canonicalize" method to perform "match and + // rewrite patterns". + bit hasCanonicalizeMethod = 0; + // Whether this op has a folder. bit hasFolder = 0; diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -308,25 +308,13 @@ // AssertOp //===----------------------------------------------------------------------===// -namespace { -struct EraseRedundantAssertions : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(AssertOp op, - PatternRewriter &rewriter) const override { - // Erase assertion if argument is constant true. - if (matchPattern(op.arg(), m_One())) { - rewriter.eraseOp(op); - return success(); - } - return failure(); +LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) { + // Erase assertion if argument is constant true. + if (matchPattern(op.arg(), m_One())) { + rewriter.eraseOp(op); + return success(); } -}; -} // namespace - -void AssertOp::getCanonicalizationPatterns(RewritePatternSet &patterns, - MLIRContext *context) { - patterns.add(context); + return failure(); } //===----------------------------------------------------------------------===// @@ -498,26 +486,21 @@ return success(); } -namespace { /// Simplify a branch to a block that has a single predecessor. This effectively /// merges the two blocks. -struct SimplifyBrToBlockWithSinglePred : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(BranchOp op, - PatternRewriter &rewriter) const override { - // Check that the successor block has a single predecessor. - Block *succ = op.getDest(); - Block *opParent = op->getBlock(); - if (succ == opParent || !llvm::hasSingleElement(succ->getPredecessors())) - return failure(); +static LogicalResult +simplifyBrToBlockWithSinglePred(BranchOp op, PatternRewriter &rewriter) { + // Check that the successor block has a single predecessor. + Block *succ = op.getDest(); + Block *opParent = op->getBlock(); + if (succ == opParent || !llvm::hasSingleElement(succ->getPredecessors())) + return failure(); - // Merge the successor into the current block and erase the branch. - rewriter.mergeBlocks(succ, opParent, op.getOperands()); - rewriter.eraseOp(op); - return success(); - } -}; + // Merge the successor into the current block and erase the branch. + rewriter.mergeBlocks(succ, opParent, op.getOperands()); + rewriter.eraseOp(op); + return success(); +} /// br ^bb1 /// ^bb1 @@ -525,27 +508,27 @@ /// /// -> br ^bbN(...) /// -struct SimplifyPassThroughBr : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +static LogicalResult simplifyPassThroughBr(BranchOp op, + PatternRewriter &rewriter) { + Block *dest = op.getDest(); + ValueRange destOperands = op.getOperands(); + SmallVector destOperandStorage; + + // Try to collapse the successor if it points somewhere other than this + // block. + if (dest == op->getBlock() || + failed(collapseBranch(dest, destOperands, destOperandStorage))) + return failure(); - LogicalResult matchAndRewrite(BranchOp op, - PatternRewriter &rewriter) const override { - Block *dest = op.getDest(); - ValueRange destOperands = op.getOperands(); - SmallVector destOperandStorage; - - // Try to collapse the successor if it points somewhere other than this - // block. - if (dest == op->getBlock() || - failed(collapseBranch(dest, destOperands, destOperandStorage))) - return failure(); + // Create a new branch with the collapsed successor. + rewriter.replaceOpWithNewOp(op, dest, destOperands); + return success(); +} - // Create a new branch with the collapsed successor. - rewriter.replaceOpWithNewOp(op, dest, destOperands); - return success(); - } -}; -} // end anonymous namespace. +LogicalResult BranchOp::canonicalize(BranchOp op, PatternRewriter &rewriter) { + return success(succeeded(simplifyBrToBlockWithSinglePred(op, rewriter)) || + succeeded(simplifyPassThroughBr(op, rewriter))); +} Block *BranchOp::getDest() { return getSuccessor(); } @@ -553,11 +536,6 @@ void BranchOp::eraseOperand(unsigned index) { (*this)->eraseOperand(index); } -void BranchOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add(context); -} - Optional BranchOp::getMutableSuccessorOperands(unsigned index) { assert(index == 0 && "invalid successor index"); @@ -608,31 +586,20 @@ //===----------------------------------------------------------------------===// // CallIndirectOp //===----------------------------------------------------------------------===// -namespace { -/// Fold indirect calls that have a constant function as the callee operand. -struct SimplifyIndirectCallWithKnownCallee - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(CallIndirectOp indirectCall, - PatternRewriter &rewriter) const override { - // Check that the callee is a constant callee. - SymbolRefAttr calledFn; - if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn))) - return failure(); - // Replace with a direct call. - rewriter.replaceOpWithNewOp(indirectCall, calledFn, - indirectCall.getResultTypes(), - indirectCall.getArgOperands()); - return success(); - } -}; -} // end anonymous namespace. +/// Fold indirect calls that have a constant function as the callee operand. +LogicalResult CallIndirectOp::canonicalize(CallIndirectOp indirectCall, + PatternRewriter &rewriter) { + // Check that the callee is a constant callee. + SymbolRefAttr calledFn; + if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn))) + return failure(); -void CallIndirectOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add(context); + // Replace with a direct call. + rewriter.replaceOpWithNewOp(indirectCall, calledFn, + indirectCall.getResultTypes(), + indirectCall.getArgOperands()); + return success(); } //===----------------------------------------------------------------------===// diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -1674,15 +1674,35 @@ } void OpEmitter::genCanonicalizerDecls() { - if (!def.getValueAsBit("hasCanonicalizer")) + bool hasCanonicalizeMethod = false; + if (def.getValueAsBit("hasCanonicalizeMethod")) { + // static LogicResult FooOp:: + // canonicalize(FooOp op, PatternRewriter &rewriter); + hasCanonicalizeMethod = true; + SmallVector paramList; + paramList.emplace_back(op.getCppClassName(), "op"); + paramList.emplace_back("::mlir::PatternRewriter &", "rewriter"); + opClass.addMethodAndPrune("LogicalResult", "canonicalize", + OpMethod::MP_StaticDeclaration, + std::move(paramList)); + } + + if (!hasCanonicalizeMethod && !def.getValueAsBit("hasCanonicalizer")) return; + // Add a signature for getCanonicalizationPatterns if implemented by the + // dialect or if synthesized to call 'canonicalize'. SmallVector paramList; paramList.emplace_back("::mlir::RewritePatternSet &", "results"); paramList.emplace_back("::mlir::MLIRContext *", "context"); - opClass.addMethodAndPrune("void", "getCanonicalizationPatterns", - OpMethod::MP_StaticDeclaration, - std::move(paramList)); + auto kind = hasCanonicalizeMethod ? OpMethod::MP_Static + : OpMethod::MP_StaticDeclaration; + auto *method = opClass.addMethodAndPrune( + "void", "getCanonicalizationPatterns", kind, std::move(paramList)); + + // If synthesizing the method, fill it it. + if (hasCanonicalizeMethod) + method->body() << " results.add(canonicalize);\n"; } void OpEmitter::genFolderDecls() {