diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp --- a/mlir/lib/TableGen/Pattern.cpp +++ b/mlir/lib/TableGen/Pattern.cpp @@ -324,7 +324,7 @@ // means we want to capture the op itself. if (op->getNumResults() == 0) { LLVM_DEBUG(llvm::dbgs() << name << " (Op)\n"); - return std::string(name); + return formatv(fmt, name); } // We are referencing all results of the multi-result op. A specific result diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -690,6 +690,16 @@ let results = (outs I32:$output); } +def TestLocationSrcNoResOp : TEST_Op<"loc_src_no_res"> { + let arguments = (ins I32:$input); + let results = (outs); +} + +def TestLocationDstNoResOp : TEST_Op<"loc_dst_no_res"> { + let arguments = (ins I32:$input); + let results = (outs); +} + //===----------------------------------------------------------------------===// // Test Patterns //===----------------------------------------------------------------------===// @@ -1375,6 +1385,11 @@ (location "named")), (location "fused", $res2, $res3))>; +// Test that we can use the location of an op without results +def : Pat<(TestLocationSrcNoResOp:$loc + (TestLocationSrcOp (TestLocationSrcOp $input))), + (TestLocationDstNoResOp $input, (location $loc))>; + //===----------------------------------------------------------------------===// // Test Patterns (Type Builders) diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir --- a/mlir/test/mlir-tblgen/pattern.mlir +++ b/mlir/test/mlir-tblgen/pattern.mlir @@ -15,10 +15,12 @@ %0 = "test.loc_src"(%arg0) : (i32) -> i32 loc("loc3") %1 = "test.loc_src"(%0) : (i32) -> i32 loc("loc2") %2 = "test.loc_src"(%1) : (i32) -> i32 loc("loc1") + "test.loc_src_no_res"(%2) : (i32) -> () loc("loc4") // CHECK: "test.loc_dst"({{.*}}) : (i32) -> i32 loc("loc1") // CHECK: "test.loc_dst"({{.*}}) : (i32) -> i32 loc("named") // CHECK: "test.loc_dst"({{.*}}) : (i32) -> i32 loc(fused<"fused">["loc2", "loc3"]) + // CHECK: "test.loc_dst_no_res"({{.*}}) : (i32) -> () loc("loc4") return %1 : i32 } diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -1148,8 +1148,8 @@ std::string PatternEmitter::handleLocationDirective(DagNode tree) { assert(tree.isLocationDirective()); auto lookUpArgLoc = [this, &tree](int idx) { - const auto *const lookupFmt = "(*{0}.begin()).getLoc()"; - return symbolInfoMap.getAllRangeUse(tree.getArgName(idx), lookupFmt); + const auto *const lookupFmt = "{0}.getLoc()"; + return symbolInfoMap.getValueAndRangeUse(tree.getArgName(idx), lookupFmt); }; if (tree.getNumArgs() == 0)