diff --git a/mlir/include/mlir/Parser/AsmParserState.h b/mlir/include/mlir/Parser/AsmParserState.h --- a/mlir/include/mlir/Parser/AsmParserState.h +++ b/mlir/include/mlir/Parser/AsmParserState.h @@ -104,6 +104,10 @@ /// state. iterator_range getOpDefs() const; + /// Returns (heuristically) the range of an identifier given a SMLoc + /// corresponding to the start of an identifier location. + static llvm::SMRange convertIdLocToRange(llvm::SMLoc loc); + //===--------------------------------------------------------------------===// // Populate State //===--------------------------------------------------------------------===// diff --git a/mlir/lib/Parser/AsmParserState.cpp b/mlir/lib/Parser/AsmParserState.cpp --- a/mlir/lib/Parser/AsmParserState.cpp +++ b/mlir/lib/Parser/AsmParserState.cpp @@ -11,23 +11,6 @@ using namespace mlir; -/// Given a SMLoc corresponding to an identifier location, return a location -/// representing the full range of the identifier. -static llvm::SMRange convertIdLocToRange(llvm::SMLoc loc) { - if (!loc.isValid()) - return llvm::SMRange(); - - // Return if the given character is a valid identifier character. - auto isIdentifierChar = [](char c) { - return isalnum(c) || c == '$' || c == '.' || c == '_' || c == '-'; - }; - - const char *curPtr = loc.getPointer(); - while (isIdentifierChar(*(++curPtr))) - continue; - return llvm::SMRange(loc, llvm::SMLoc::getFromPointer(curPtr)); -} - //===----------------------------------------------------------------------===// // AsmParserState::Impl //===----------------------------------------------------------------------===// @@ -74,6 +57,23 @@ return llvm::make_pointee_range(llvm::makeArrayRef(impl->operations)); } +/// Returns (heuristically) the range of an identifier given a SMLoc +/// corresponding to the start of an identifier location. +llvm::SMRange AsmParserState::convertIdLocToRange(llvm::SMLoc loc) { + if (!loc.isValid()) + return llvm::SMRange(); + + // Return if the given character is a valid identifier character. + auto isIdentifierChar = [](char c) { + return isalnum(c) || c == '$' || c == '.' || c == '_' || c == '-'; + }; + + const char *curPtr = loc.getPointer(); + while (isIdentifierChar(*(++curPtr))) + continue; + return llvm::SMRange(loc, llvm::SMLoc::getFromPointer(curPtr)); +} + //===----------------------------------------------------------------------===// // Populate State diff --git a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp --- a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp +++ b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp @@ -67,7 +67,8 @@ /// one couldn't be created. `uri` is an optional additional filter that, when /// present, is used to filter sub locations that do not share the same uri. static Optional -getLocationFromLoc(Location loc, const lsp::URIForFile *uri = nullptr) { +getLocationFromLoc(llvm::SourceMgr &sourceMgr, Location loc, + const lsp::URIForFile *uri = nullptr) { Optional location; loc->walk([&](Location nestedLoc) { FileLineColLoc fileLoc = nestedLoc.dyn_cast(); @@ -77,6 +78,17 @@ Optional sourceLoc = getLocationFromLoc(fileLoc); if (sourceLoc && (!uri || sourceLoc->uri == *uri)) { location = *sourceLoc; + llvm::SMLoc loc = sourceMgr.FindLocForLineAndColumn( + sourceMgr.getMainFileID(), fileLoc.getLine(), fileLoc.getColumn()); + + // Use range of potential identifier starting at location, else length 1 + // range. + location->range.end.character += 1; + if (Optional range = + AsmParserState::convertIdLocToRange(loc)) { + auto lineCol = sourceMgr.getLineAndColumn(range->End); + location->range.end.character = lineCol.second - 1; + } return WalkResult::interrupt(); } return WalkResult::advance(); @@ -195,7 +207,8 @@ } /// Convert the given MLIR diagnostic to the LSP form. -static lsp::Diagnostic getLspDiagnoticFromDiag(Diagnostic &diag, +static lsp::Diagnostic getLspDiagnoticFromDiag(llvm::SourceMgr &sourceMgr, + Diagnostic &diag, const lsp::URIForFile &uri) { lsp::Diagnostic lspDiag; lspDiag.source = "mlir"; @@ -208,7 +221,7 @@ // TODO: For simplicity, we just grab the first one. It may be likely that we // will need a more interesting heuristic here.' Optional lspLocation = - getLocationFromLoc(diag.getLocation(), &uri); + getLocationFromLoc(sourceMgr, diag.getLocation(), &uri); if (lspLocation) lspDiag.range = lspLocation->range; @@ -232,7 +245,8 @@ std::vector relatedDiags; for (Diagnostic ¬e : diag.getNotes()) { lsp::Location noteLoc; - if (Optional loc = getLocationFromLoc(note.getLocation())) + if (Optional loc = + getLocationFromLoc(sourceMgr, note.getLocation())) noteLoc = *loc; else noteLoc.uri = uri; @@ -306,7 +320,7 @@ : context(registry) { context.allowUnregisteredDialects(); ScopedDiagnosticHandler handler(&context, [&](Diagnostic &diag) { - diagnostics.push_back(getLspDiagnoticFromDiag(diag, uri)); + diagnostics.push_back(getLspDiagnoticFromDiag(sourceMgr, diag, uri)); }); // Try to parsed the given IR string. diff --git a/mlir/test/mlir-lsp-server/diagnostics.test b/mlir/test/mlir-lsp-server/diagnostics.test --- a/mlir/test/mlir-lsp-server/diagnostics.test +++ b/mlir/test/mlir-lsp-server/diagnostics.test @@ -15,7 +15,7 @@ // CHECK-NEXT: "message": "custom op 'func' expected valid '@'-identifier for symbol name", // CHECK-NEXT: "range": { // CHECK-NEXT: "end": { -// CHECK-NEXT: "character": 6, +// CHECK-NEXT: "character": 7, // CHECK-NEXT: "line": 0 // CHECK-NEXT: }, // CHECK-NEXT: "start": {