diff --git a/mlir/docs/DeclarativeRewrites.md b/mlir/docs/DeclarativeRewrites.md --- a/mlir/docs/DeclarativeRewrites.md +++ b/mlir/docs/DeclarativeRewrites.md @@ -657,9 +657,53 @@ The fourth parameter to `Pattern` (and `Pat`) allows to manually tweak a pattern's benefit. Just supply `(addBenefit N)` to add `N` to the benefit value. -## Special directives +## Rewrite directives -[TODO] +### `location` + +By default the C++ pattern expanded from a DRR pattern uses the fused location +of all source ops as the location for all generated ops. This is not always the +best location mapping relationship. For such cases, DRR provides the `location` +directive to provide finer control. + +`location` is of the following syntax: + +```tablgen +(location $symbol0, $symbol1, ...) +``` + +where all `$symbol` should be bound previously in the pattern. + +`location` must be used as the last argument to an op creation. For example, + +```tablegen +def : Pat<(LocSrc1Op:$src1 (LocSrc2Op:$src2 ...), + (LocDst1Op (LocDst2Op ..., (location $src2)))>; +``` + +In the above pattern, the generated `LocDst2Op` will use the matched location +of `LocSrc2Op` while the root `LocDst1Op` node will still se the fused location +of all source Ops. + +### `replaceWithValue` + +The `replaceWithValue` directive is used to eliminate a matched op by replacing +all of it uses with a captured value. It is of the following syntax: + +```tablegen +(replaceWithValue $symbol) +``` + +where `$symbol` should be a symbol bound previously in the pattern. + +For example, + +```tablegen +def : Pat<(Foo $input), (replaceWithValue $input)>; +``` + +The above pattern removes the `Foo` and replaces all uses of `Foo` with +`$input`. ## Debugging Tips diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -2179,9 +2179,14 @@ } //===----------------------------------------------------------------------===// -// Common directives +// Rewrite directives //===----------------------------------------------------------------------===// +// Directive used in result pattern to specify the location of the generated +// op. This directive must be used as the last argument to the op creation +// DAG construct. The arguments to location must be previously captured symbol. +def location; + // Directive used in result pattern to indicate that no new op are generated, // so to replace the matched DAG with an existing SSA value. def replaceWithValue; diff --git a/mlir/include/mlir/TableGen/Pattern.h b/mlir/include/mlir/TableGen/Pattern.h --- a/mlir/include/mlir/TableGen/Pattern.h +++ b/mlir/include/mlir/TableGen/Pattern.h @@ -159,6 +159,9 @@ // value. bool isReplaceWithValue() const; + // Returns whether this DAG represents the location of an op creation. + bool isLocationDirective() const; + // Returns true if this DAG node is wrapping native code call. bool isNativeCodeCall() const; diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp --- a/mlir/lib/TableGen/Pattern.cpp +++ b/mlir/lib/TableGen/Pattern.cpp @@ -103,7 +103,7 @@ } bool tblgen::DagNode::isOperation() const { - return !(isNativeCodeCall() || isReplaceWithValue()); + return !isNativeCodeCall() && !isReplaceWithValue() && !isLocationDirective(); } llvm::StringRef tblgen::DagNode::getNativeCodeTemplate() const { @@ -159,6 +159,11 @@ return dagOpDef->getName() == "replaceWithValue"; } +bool tblgen::DagNode::isLocationDirective() const { + auto *dagOpDef = cast(node->getOperator())->getDef(); + return dagOpDef->getName() == "location"; +} + void tblgen::DagNode::print(raw_ostream &os) const { if (node) node->print(os); @@ -533,7 +538,14 @@ auto numOpArgs = op.getNumArgs(); auto numTreeArgs = tree.getNumArgs(); - if (numOpArgs != numTreeArgs) { + // The pattern might have the last argument specifying the location. + bool hasLocDirective = false; + if (numTreeArgs != 0) { + if (auto lastArg = tree.getArgAsNestedDag(numTreeArgs - 1)) + hasLocDirective = lastArg.isLocationDirective(); + } + + if (numOpArgs != numTreeArgs - hasLocDirective) { auto err = formatv("op '{0}' argument number mismatch: " "{1} in pattern vs. {2} in definition", op.getOperationName(), numTreeArgs, numOpArgs); diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -502,6 +502,20 @@ } //===----------------------------------------------------------------------===// +// Test Locations +//===----------------------------------------------------------------------===// + +def TestLocationSrcOp : TEST_Op<"loc_src"> { + let arguments = (ins I32:$input); + let results = (outs I32:$output); +} + +def TestLocationDstOp : TEST_Op<"loc_dst", [SameOperandsAndResultType]> { + let arguments = (ins I32:$input); + let results = (outs I32:$output); +} + +//===----------------------------------------------------------------------===// // Test Patterns //===----------------------------------------------------------------------===// @@ -996,6 +1010,18 @@ ConstantAttr)>; //===----------------------------------------------------------------------===// +// Test Patterns (Location) + +// Test that we can specify locations for generated ops. +def : Pat<(TestLocationSrcOp:$res1 + (TestLocationSrcOp:$res2 + (TestLocationSrcOp:$res3 $input))), + (TestLocationDstOp + (TestLocationDstOp + (TestLocationDstOp $input, (location $res1))), + (location $res2, $res3))>; + +//===----------------------------------------------------------------------===// // Test Legalization //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir --- a/mlir/test/mlir-tblgen/pattern.mlir +++ b/mlir/test/mlir-tblgen/pattern.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -test-patterns -mlir-print-debuginfo %s | FileCheck %s +// RUN: mlir-opt -test-patterns -mlir-print-debuginfo %s | FileCheck %s --dump-input-on-failure // CHECK-LABEL: verifyFusedLocs func @verifyFusedLocs(%arg0 : i32) -> i32 { @@ -10,6 +10,21 @@ return %result : i32 } +// CHECK-LABEL: verifyDesignatedLoc +func @verifyDesignatedLoc(%arg0 : i32) -> i32 { + %0 = "test.loc_src"(%arg0) : (i32) -> i32 loc("loc3") + %1 = "test.loc_src"(%0) : (i32) -> i32 loc("loc2") + %2 = "test.loc_src"(%1) : (i32) -> i32 loc("loc1") + + // CHECK: "test.loc_dst"({{.*}}) : (i32) -> i32 loc("loc1") + // CHECK: "test.loc_dst"({{.*}}) : (i32) -> i32 loc(fused[ + // CHECK-SAME: "loc1" + // CHECK-SAME: "loc3" + // CHECK-SAME: "loc2" + // CHECK: "test.loc_dst"({{.*}}) : (i32) -> i32 loc(fused["loc2", "loc3"]) + return %1 : i32 +} + // CHECK-LABEL: verifyZeroResult func @verifyZeroResult(%arg0 : i32) { // CHECK: "test.op_i"(%arg0) : (i32) -> () diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -109,9 +109,11 @@ // calling native C++ code. std::string handleReplaceWithNativeCodeCall(DagNode resultTree); - // Returns the C++ expression referencing the old value serving as the - // replacement. - std::string handleReplaceWithValue(DagNode tree); + // Returns the symbol of the old value serving as the replacement. + StringRef handleReplaceWithValue(DagNode tree); + + // Returns the symbol of the value whose location to use. + std::string handleUseLocationOf(DagNode tree); // Emits the C++ statement to build a new op out of the given DAG `tree` and // returns the variable name that this op is assigned to. If the root op in @@ -580,11 +582,11 @@ PrintFatalError(loc, error); } - os.indent(4) << "auto loc = rewriter.getFusedLoc({"; + os.indent(4) << "auto odsLoc = rewriter.getFusedLoc({"; for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) { os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()"; } - os << "}); (void)loc;\n"; + os << "}); (void)odsLoc;\n"; // Process auxiliary result patterns. for (int i = 0; i < replStartIndex; ++i) { @@ -640,15 +642,19 @@ LLVM_DEBUG(resultTree.print(llvm::dbgs())); LLVM_DEBUG(llvm::dbgs() << '\n'); + if (resultTree.isLocationDirective()) { + PrintFatalError(loc, + "location directive can only be used with op creation"); + } + if (resultTree.isNativeCodeCall()) { auto symbol = handleReplaceWithNativeCodeCall(resultTree); symbolInfoMap.bindValue(symbol); return symbol; } - if (resultTree.isReplaceWithValue()) { - return handleReplaceWithValue(resultTree); - } + if (resultTree.isReplaceWithValue()) + return handleReplaceWithValue(resultTree).str(); // Normal op creation. auto symbol = handleOpCreation(resultTree, resultIndex, depth); @@ -660,7 +666,7 @@ return symbol; } -std::string PatternEmitter::handleReplaceWithValue(DagNode tree) { +StringRef PatternEmitter::handleReplaceWithValue(DagNode tree) { assert(tree.isReplaceWithValue()); if (tree.getNumArgs() != 1) { @@ -672,7 +678,30 @@ PrintFatalError(loc, "cannot bind symbol to replaceWithValue"); } - return std::string(tree.getArgName(0)); + return tree.getArgName(0); +} + +std::string PatternEmitter::handleUseLocationOf(DagNode tree) { + assert(tree.isLocationDirective()); + auto lookUpArgLoc = [this, &tree](int idx) { + const auto *const lookupFmt = "(*{0}.begin()).getLoc()"; + return symbolInfoMap.getAllRangeUse(tree.getArgName(idx), lookupFmt); + }; + + if (tree.getNumArgs() != 1) { + std::string ret; + llvm::raw_string_ostream os(ret); + os << "rewriter.getFusedLoc({"; + for (int i = 0, e = tree.getNumArgs(); i != e; ++i) + os << (i ? ", " : "") << lookUpArgLoc(i); + os << "})"; + return os.str(); + } + + if (!tree.getSymbol().empty()) + PrintFatalError(loc, "cannot bind symbol to location"); + + return lookUpArgLoc(0); } std::string PatternEmitter::handleOpArgument(DagLeaf leaf, @@ -753,14 +782,28 @@ Operator &resultOp = tree.getDialectOp(opMap); auto numOpArgs = resultOp.getNumArgs(); + auto numPatArgs = tree.getNumArgs(); + + // Get the location for this operation if explicitly provided. + std::string locToUse; + if (numPatArgs != 0) { + if (auto lastArg = tree.getArgAsNestedDag(numPatArgs - 1)) + if (lastArg.isLocationDirective()) + locToUse = handleUseLocationOf(lastArg); + } - if (numOpArgs != tree.getNumArgs()) { - PrintFatalError(loc, formatv("resultant op '{0}' argument number mismatch: " - "{1} in pattern vs. {2} in definition", - resultOp.getOperationName(), tree.getNumArgs(), - numOpArgs)); + auto inPattern = numPatArgs - !locToUse.empty(); + if (numOpArgs != inPattern) { + PrintFatalError(loc, + formatv("resultant op '{0}' argument number mismatch: " + "{1} in pattern vs. {2} in definition", + resultOp.getOperationName(), inPattern, numOpArgs)); } + // If no explicit location is given, use the default, all fused, location. + if (locToUse.empty()) + locToUse = "odsLoc"; + // A map to collect all nested DAG child nodes' names, with operand index as // the key. This includes both bound and unbound child nodes. ChildNodeIndexNameMap childNodeNames; @@ -769,9 +812,8 @@ // create ops for them and remember the symbol names for them, so that we can // use the results in the current node. This happens in a recursive manner. for (int i = 0, e = resultOp.getNumOperands(); i != e; ++i) { - if (auto child = tree.getArgAsNestedDag(i)) { + if (auto child = tree.getArgAsNestedDag(i)) childNodeNames[i] = handleResultPattern(child, i, depth + 1); - } } // The name of the local variable holding this op. @@ -811,10 +853,11 @@ // First prepare local variables for op arguments used in builder call. createAggregateLocalVarsForOpArgs(tree, childNodeNames); + // Then create the op. os.indent(6) << formatv( - "{0} = rewriter.create<{1}>(loc, tblgen_values, tblgen_attrs);\n", - valuePackName, resultOp.getQualCppClassName()); + "{0} = rewriter.create<{1}>({2}, tblgen_values, tblgen_attrs);\n", + valuePackName, resultOp.getQualCppClassName(), locToUse); os.indent(4) << "}\n"; return resultValue; } @@ -831,8 +874,9 @@ // here given that it's easier for developers to write compared to // aggregate-parameter builders. createSeparateLocalVarsForOpArgs(tree, childNodeNames); - os.indent(6) << formatv("{0} = rewriter.create<{1}>(loc", valuePackName, - resultOp.getQualCppClassName()); + + os.indent(6) << formatv("{0} = rewriter.create<{1}>({2}", valuePackName, + resultOp.getQualCppClassName(), locToUse); supplyValuesForOpArgs(tree, childNodeNames); os << "\n );\n"; os.indent(4) << "}\n"; @@ -858,9 +902,10 @@ "tblgen_types.push_back(v.getType()); }\n", resultIndex + i); } - os.indent(6) << formatv("{0} = rewriter.create<{1}>(loc, tblgen_types, " + os.indent(6) << formatv("{0} = rewriter.create<{1}>({2}, tblgen_types, " "tblgen_values, tblgen_attrs);\n", - valuePackName, resultOp.getQualCppClassName()); + valuePackName, resultOp.getQualCppClassName(), + locToUse); os.indent(4) << "}\n"; return resultValue; }