diff --git a/mlir/docs/DeclarativeRewrites.md b/mlir/docs/DeclarativeRewrites.md --- a/mlir/docs/DeclarativeRewrites.md +++ b/mlir/docs/DeclarativeRewrites.md @@ -384,10 +384,12 @@ is called _special placeholder_, while the latter is called _positional placeholder_. -`NativeCodeCall` right now only supports two special placeholders: `$_builder` -and `$_self`: +`NativeCodeCall` right now only supports three special placeholders: +`$_builder`, `$_loc`, and `$_self`: * `$_builder` will be replaced by the current `mlir::PatternRewriter`. +* `$_loc` will be replaced by the fused location or custom location (as + determined by location directive). * `$_self` will be replaced with the entity `NativeCodeCall` is attached to. We have seen how `$_builder` can be used in the above; it allows us to pass a diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -724,7 +724,8 @@ // Test that NativeCodeCall is not ignored if it is not used to directly // replace the matched root op. def : Pattern<(OpNativeCodeCall3 $input), - [(NativeCodeCall<"createOpI($_builder, $0)"> $input), (OpK)]>; + [(NativeCodeCall<"createOpI($_builder, $_loc, $0)"> $input), + (OpK)]>; // Test AllAttrConstraintsOf. def OpAllAttrConstraint1 : TEST_Op<"all_attr_constraint_of1"> { diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -21,8 +21,8 @@ return choice.getValue() ? input1 : input2; } -static void createOpI(PatternRewriter &rewriter, Value input) { - rewriter.create(rewriter.getUnknownLoc(), input); +static void createOpI(PatternRewriter &rewriter, Location loc, Value input) { + rewriter.create(loc, input); } static void handleNoResultOp(PatternRewriter &rewriter, diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -112,6 +112,9 @@ // Returns the symbol of the old value serving as the replacement. StringRef handleReplaceWithValue(DagNode tree); + // Returns the location value to use. + std::pair getLocation(DagNode tree); + // Returns the location value to use. std::string handleLocationDirective(DagNode tree); @@ -779,13 +782,18 @@ PrintFatalError(loc, "unsupported NativeCodeCall argument numbers: " + Twine(tree.getNumArgs())); } - for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { + bool hasLocationDirective; + std::string locToUse; + std::tie(hasLocationDirective, locToUse) = getLocation(tree); + + for (int i = 0, e = tree.getNumArgs() - hasLocationDirective; i != e; ++i) { attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i)); LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argument #" << i << " replacement: " << attrs[i] << "\n"); } - return std::string(tgfmt(fmt, &fmtCtx, attrs[0], attrs[1], attrs[2], attrs[3], - attrs[4], attrs[5], attrs[6], attrs[7])); + return std::string(tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse), attrs[0], + attrs[1], attrs[2], attrs[3], attrs[4], attrs[5], + attrs[6], attrs[7])); } int PatternEmitter::getNodeValueCount(DagNode node) { @@ -804,6 +812,20 @@ return 1; } +std::pair PatternEmitter::getLocation(DagNode tree) { + auto numPatArgs = tree.getNumArgs(); + + if (numPatArgs != 0) { + if (auto lastArg = tree.getArgAsNestedDag(numPatArgs - 1)) + if (lastArg.isLocationDirective()) { + return std::make_pair(true, handleLocationDirective(lastArg)); + } + } + + // If no explicit location is given, use the default, all fused, location. + return std::make_pair(false, "odsLoc"); +} + std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex, int depth) { LLVM_DEBUG(llvm::dbgs() << "create op for pattern: "); @@ -814,15 +836,11 @@ auto numOpArgs = resultOp.getNumArgs(); auto numPatArgs = tree.getNumArgs(); - // Get the location for this operation if explicitly provided. + bool hasLocationDirective; std::string locToUse; - if (numPatArgs != 0) { - if (auto lastArg = tree.getArgAsNestedDag(numPatArgs - 1)) - if (lastArg.isLocationDirective()) - locToUse = handleLocationDirective(lastArg); - } + std::tie(hasLocationDirective, locToUse) = getLocation(tree); - auto inPattern = numPatArgs - !locToUse.empty(); + auto inPattern = numPatArgs - hasLocationDirective; if (numOpArgs != inPattern) { PrintFatalError(loc, formatv("resultant op '{0}' argument number mismatch: " @@ -830,10 +848,6 @@ resultOp.getOperationName(), inPattern, numOpArgs)); } - // If no explicit location is given, use the default, all fused, location. - if (locToUse.empty()) - locToUse = "odsLoc"; - // A map to collect all nested DAG child nodes' names, with operand index as // the key. This includes both bound and unbound child nodes. ChildNodeIndexNameMap childNodeNames;