diff --git a/mlir/docs/DeclarativeRewrites.md b/mlir/docs/DeclarativeRewrites.md --- a/mlir/docs/DeclarativeRewrites.md +++ b/mlir/docs/DeclarativeRewrites.md @@ -672,7 +672,13 @@ (location $symbol0, $symbol1, ...) ``` -where all `$symbol` should be bound previously in the pattern. +where all `$symbol` should be bound previously in the pattern and one optional +string may be specified as an attribute. The following locations are creted: + +* If only 1 symbol is specified then that symbol's location is used, +* If multiple are specified then a fused location is created; +* If no symbol is specified then string must be specified and a NamedLoc is + created instead; `location` must be used as the last argument to an op creation. For example, diff --git a/mlir/include/mlir/TableGen/Pattern.h b/mlir/include/mlir/TableGen/Pattern.h --- a/mlir/include/mlir/TableGen/Pattern.h +++ b/mlir/include/mlir/TableGen/Pattern.h @@ -77,6 +77,9 @@ // Returns true if this DAG leaf is specifying an enum attribute case. bool isEnumAttrCase() const; + // Returns true if this DAG leaf is specifying a string attribute. + bool isStringAttr() const; + // Returns this DAG leaf as a constraint. Asserts if fails. Constraint getAsConstraint() const; @@ -95,6 +98,10 @@ // Precondition: isNativeCodeCall() StringRef getNativeCodeTemplate() const; + // Returns the string associated with the leaf. + // Precondition: isStringAttr() + std::string getStringAttr() const; + void print(raw_ostream &os) const; private: diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp --- a/mlir/lib/TableGen/Pattern.cpp +++ b/mlir/lib/TableGen/Pattern.cpp @@ -56,6 +56,10 @@ return isSubClassOf("EnumAttrCaseInfo"); } +bool tblgen::DagLeaf::isStringAttr() const { + return isa(def) || isa(def); +} + tblgen::Constraint tblgen::DagLeaf::getAsConstraint() const { assert((isOperandMatcher() || isAttrMatcher()) && "the DAG leaf must be operand or attribute"); @@ -81,6 +85,10 @@ return cast(def)->getDef()->getValueAsString("expression"); } +std::string tblgen::DagLeaf::getStringAttr() const { + assert(isStringAttr() && "the DAG leaf must be string attribute"); + return def->getAsUnquotedString(); +} bool tblgen::DagLeaf::isSubClassOf(StringRef superclass) const { if (auto *defInit = dyn_cast_or_null(def)) return defInit->getDef()->isSubClassOf(superclass); diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1018,8 +1018,9 @@ (TestLocationSrcOp:$res3 $input))), (TestLocationDstOp (TestLocationDstOp - (TestLocationDstOp $input, (location $res1))), - (location $res2, $res3))>; + (TestLocationDstOp $input, (location $res1)), + (location "named")), + (location "fused", $res2, $res3))>; //===----------------------------------------------------------------------===// // Test Legalization diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir --- a/mlir/test/mlir-tblgen/pattern.mlir +++ b/mlir/test/mlir-tblgen/pattern.mlir @@ -17,11 +17,8 @@ %2 = "test.loc_src"(%1) : (i32) -> i32 loc("loc1") // CHECK: "test.loc_dst"({{.*}}) : (i32) -> i32 loc("loc1") - // CHECK: "test.loc_dst"({{.*}}) : (i32) -> i32 loc(fused[ - // CHECK-SAME: "loc1" - // CHECK-SAME: "loc3" - // CHECK-SAME: "loc2" - // CHECK: "test.loc_dst"({{.*}}) : (i32) -> i32 loc(fused["loc2", "loc3"]) + // CHECK: "test.loc_dst"({{.*}}) : (i32) -> i32 loc("named") + // CHECK: "test.loc_dst"({{.*}}) : (i32) -> i32 loc(fused<"fused">["loc2", "loc3"]) return %1 : i32 } diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -112,8 +112,8 @@ // Returns the symbol of the old value serving as the replacement. StringRef handleReplaceWithValue(DagNode tree); - // Returns the symbol of the value whose location to use. - std::string handleUseLocationOf(DagNode tree); + // Returns the location value to use. + std::string handleLocationDirective(DagNode tree); // Emits the C++ statement to build a new op out of the given DAG `tree` and // returns the variable name that this op is assigned to. If the root op in @@ -681,27 +681,54 @@ return tree.getArgName(0); } -std::string PatternEmitter::handleUseLocationOf(DagNode tree) { +std::string PatternEmitter::handleLocationDirective(DagNode tree) { assert(tree.isLocationDirective()); auto lookUpArgLoc = [this, &tree](int idx) { const auto *const lookupFmt = "(*{0}.begin()).getLoc()"; + tree.getArgAsLeaf(idx).print(llvm::errs()); return symbolInfoMap.getAllRangeUse(tree.getArgName(idx), lookupFmt); }; - if (tree.getNumArgs() != 1) { - std::string ret; - llvm::raw_string_ostream os(ret); - os << "rewriter.getFusedLoc({"; - for (int i = 0, e = tree.getNumArgs(); i != e; ++i) - os << (i ? ", " : "") << lookUpArgLoc(i); - os << "})"; - return os.str(); - } + if (tree.getNumArgs() == 0) + llvm::PrintFatalError( + "At least one argument to location directive required"); if (!tree.getSymbol().empty()) PrintFatalError(loc, "cannot bind symbol to location"); - return lookUpArgLoc(0); + if (tree.getNumArgs() == 1) { + DagLeaf leaf = tree.getArgAsLeaf(0); + if (leaf.isStringAttr()) + return formatv("mlir::NameLoc::get(rewriter.getIdentifier(\"{0}\"), " + "rewriter.getContext())", + leaf.getStringAttr()) + .str(); + return lookUpArgLoc(0); + } + + std::string ret; + llvm::raw_string_ostream os(ret); + std::string strAttr; + os << "rewriter.getFusedLoc({"; + bool first = true; + for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { + DagLeaf leaf = tree.getArgAsLeaf(i); + // Handle the optional string value. + if (leaf.isStringAttr()) { + if (!strAttr.empty()) + llvm::PrintFatalError("Only one string attribute may be specified"); + strAttr = leaf.getStringAttr(); + continue; + } + os << (first ? "" : ", ") << lookUpArgLoc(i); + first = false; + } + os << "}"; + if (!strAttr.empty()) { + os << ", rewriter.getStringAttr(\"" << strAttr << "\")"; + } + os << ")"; + return os.str(); } std::string PatternEmitter::handleOpArgument(DagLeaf leaf, @@ -789,7 +816,7 @@ if (numPatArgs != 0) { if (auto lastArg = tree.getArgAsNestedDag(numPatArgs - 1)) if (lastArg.isLocationDirective()) - locToUse = handleUseLocationOf(lastArg); + locToUse = handleLocationDirective(lastArg); } auto inPattern = numPatArgs - !locToUse.empty();