diff --git a/mlir/include/mlir/Tools/mlir-query/Matcher/Diagnostics.h b/mlir/include/mlir/Tools/mlir-query/Matcher/Diagnostics.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Tools/mlir-query/Matcher/Diagnostics.h @@ -0,0 +1,159 @@ +//===--- Diagnostics.h - Helper class for error diagnostics ---------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Diagnostics class to manage error messages. Implementation shares similarity +// to clang-query Diagnostics. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOOLS_MLIRQUERY_MATCHERDIAGNOSTICS_H +#define MLIR_TOOLS_MLIRQUERY_MATCHERDIAGNOSTICS_H + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Support/raw_ostream.h" +#include +#include + +namespace mlir { +namespace query { +namespace matcher { + +// Represents the line and column numbers in a source file. +struct SourceLocation { + unsigned line{}; + unsigned column{}; +}; + +// Represents a range in a source file, defined by its start and end locations. +struct SourceRange { + SourceLocation start{}; + SourceLocation end{}; +}; + +// Diagnostics class to manage error messages. +class Diagnostics { +public: + // Parser context types. + enum ContextType { CT_MatcherArg, CT_MatcherConstruct }; + + // All errors from the system. + enum ErrorType { + ET_None, + + ET_RegistryMatcherNotFound, + ET_RegistryWrongArgCount, + ET_RegistryWrongArgType, + ET_RegistryValueNotFound, + + ET_ParserStringError, + ET_ParserNoOpenParen, + ET_ParserNoCloseParen, + ET_ParserNoComma, + ET_ParserNoCode, + ET_ParserNotAMatcher, + ET_ParserInvalidToken, + ET_ParserTrailingCode, + ET_ParserOverloadedType, + ET_ParserFailedToBuildMatcher + }; + + // Helper stream class for constructing error messages. + class ArgStream { + public: + ArgStream(std::vector *out) : out(out) {} + template + ArgStream &operator<<(const T &arg) { + return operator<<(llvm::Twine(arg)); + } + ArgStream &operator<<(const llvm::Twine &arg); + + private: + std::vector *out; + }; + + // Context for constructing a matcher or parsing its argument. + struct Context { + enum ConstructMatcherEnum { ConstructMatcher }; + Context(ConstructMatcherEnum, Diagnostics *error, + llvm::StringRef matcherName, SourceRange matcherRange); + enum MatcherArgEnum { MatcherArg }; + Context(MatcherArgEnum, Diagnostics *error, llvm::StringRef matcherName, + SourceRange matcherRange, unsigned argNumber); + ~Context(); + + private: + Diagnostics *const error; + }; + + // Context for managing overloaded matcher construction. + struct OverloadContext { + // Construct an overload context with the given error. + OverloadContext(Diagnostics *error); + ~OverloadContext(); + // Revert all errors that occurred within this context. + void revertErrors(); + + private: + Diagnostics *const error; + unsigned beginIndex{}; + }; + + // Add an error message with the specified range and error type. + // Returns an ArgStream object to allow constructing the error message using + // the << operator. + ArgStream addError(SourceRange range, ErrorType error); + + // Information stored for one frame of the context. + struct ContextFrame { + ContextType type; + SourceRange range; + std::vector args; + }; + + // Information stored for each error found. + struct ErrorContent { + std::vector contextStack; + struct Message { + SourceRange range; + ErrorType type; + std::vector args; + }; + std::vector messages; + }; + + // Get an array reference to the error contents. + llvm::ArrayRef errors() const { return errorValues; } + + // Print all error messages to the specified output stream. + void printToStream(llvm::raw_ostream &OS) const; + // Get a string representation of all error messages. + std::string toString() const; + + // Print the full error messages, including the context information, to the + // specified output stream. + void printToStreamFull(llvm::raw_ostream &OS) const; + // Get the full string representation of all error messages, including the + // context information. + std::string toStringFull() const; + +private: + // Push a new context frame onto the context stack with the specified type and + // range. + ArgStream pushContextFrame(ContextType type, SourceRange range); + + std::vector contextStack; + std::vector errorValues; +}; + +} // namespace matcher +} // namespace query +} // namespace mlir + +#endif // MLIR_TOOLS_MLIRQUERY_MATCHERDIAGNOSTICS_H diff --git a/mlir/include/mlir/Tools/mlir-query/Matcher/Marshallers.h b/mlir/include/mlir/Tools/mlir-query/Matcher/Marshallers.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Tools/mlir-query/Matcher/Marshallers.h @@ -0,0 +1,210 @@ +//===--- Marshallers.h - Generic matcher function marshallers -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains function templates and classes to wrap matcher construct +// functions. It provides a collection of template function and classes that +// present a generic marshalling layer on top of matcher construct functions. +// The registry uses these to export all marshaller constructors with a uniform +// interface. This mechanism takes inspiration from clang-query. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOOLS_MLIRQUERY_MATCHERS_MARSHALLERS_H +#define MLIR_TOOLS_MLIRQUERY_MATCHERS_MARSHALLERS_H + +#include "Diagnostics.h" +#include "VariantValue.h" +#include "mlir/IR/Matchers.h" +#include "llvm/Support/type_traits.h" + +#include "llvm/Support/Debug.h" +using llvm::dbgs; + +#define DEBUG_TYPE "mlir-query" +#define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ") + +namespace mlir { +namespace query { +namespace matcher { + +namespace internal { + +// Helper template class for jumping from argument type to the correct is/get +// functions in VariantValue. This is used for verifying and extracting the +// matcher arguments. +template +struct ArgTypeTraits; +template +struct ArgTypeTraits : public ArgTypeTraits {}; + +template <> +struct ArgTypeTraits { + + static bool hasCorrectType(const VariantValue &value) { + return value.isString(); + } + + static const StringRef &get(const VariantValue &value) { + return value.getString(); + } + + static ArgKind getKind() { return ArgKind(ArgKind::AK_String); } + + static std::optional getBestGuess(const VariantValue &) { + return std::nullopt; + } +}; + +template <> +struct ArgTypeTraits { + + static bool hasCorrectType(const VariantValue &value) { + return value.isMatcher(); + } + + static DynMatcher get(const VariantValue &value) { + return *value.getMatcher().getDynMatcher(); + } + + static ArgKind getKind() { return ArgKind(ArgKind::AK_Matcher); } + + static std::optional getBestGuess(const VariantValue &) { + return std::nullopt; + } +}; + +// Interface for generic matcher descriptor. +// Offers a create() method that constructs the matcher from the provided +// arguments. +class MatcherDescriptor { +public: + virtual ~MatcherDescriptor() = default; + virtual VariantMatcher create(SourceRange nameRange, + const ArrayRef args, + Diagnostics *error) const = 0; + + // Returns the number of arguments accepted by the matcher. + virtual unsigned getNumArgs() const = 0; + + // Append the set of argument types accepted for argument 'ArgNo' to + // 'ArgKinds'. + virtual void getArgKinds(unsigned argNo, + std::vector &argKinds) const = 0; +}; + +class FixedArgCountMatcherDescriptor : public MatcherDescriptor { +public: + using MarshallerType = VariantMatcher (*)(void (*func)(), + StringRef matcherName, + SourceRange nameRange, + ArrayRef args, + Diagnostics *error); + + // Marshaller Function to unpack the arguments and call Func. Func is the + // Matcher construct function. This is the function that the matcher + // expressions would use to create the matcher. + FixedArgCountMatcherDescriptor(MarshallerType marshaller, void (*func)(), + StringRef matcherName, + ArrayRef argKinds) + : marshaller(marshaller), func(func), matcherName(matcherName), + argKinds(argKinds.begin(), argKinds.end()) {} + + VariantMatcher create(SourceRange nameRange, ArrayRef args, + Diagnostics *error) const override { + return marshaller(func, matcherName, nameRange, args, error); + } + + unsigned getNumArgs() const override { return argKinds.size(); } + + void getArgKinds(unsigned argNo, std::vector &kinds) const override { + kinds.push_back(argKinds[argNo]); + } + +private: + const MarshallerType marshaller; + void (*const func)(); + const StringRef matcherName; + const std::vector argKinds; +}; + +// Helper function to check if argument count matches expected count +inline bool checkArgCount(SourceRange nameRange, size_t expectedArgCount, + ArrayRef args, Diagnostics *error) { + if (args.size() != expectedArgCount) { + error->addError(nameRange, error->ET_RegistryWrongArgCount) + << expectedArgCount << args.size(); + return false; + } + return true; +} + +// Helper function for checking argument type +template +inline bool checkArgTypeAtIndex(StringRef matcherName, + ArrayRef args, + Diagnostics *error) { + if (!ArgTypeTraits::hasCorrectType(args[Index].value)) { + error->addError(args[Index].range, error->ET_RegistryWrongArgType) + << matcherName << Index + 1; + return false; + } + return true; +} + +// Marshaller function for fixed number of arguments +template +static VariantMatcher +matcherMarshallFixedImpl(void (*func)(), StringRef matcherName, + SourceRange nameRange, ArrayRef args, + Diagnostics *error, std::index_sequence) { + using FuncType = ReturnType (*)(ArgTypes...); + + // Check if the argument count matches the expected count + if (!checkArgCount(nameRange, sizeof...(ArgTypes), args, error)) { + return VariantMatcher(); + } + + // Check if each argument at the corresponding index has the correct type + if ((... && checkArgTypeAtIndex(matcherName, args, error))) { + ReturnType fnPointer = reinterpret_cast(func)( + ArgTypeTraits::get(args[Is].value)...); + return VariantMatcher::SingleMatcher( + *DynMatcher::constructDynMatcherFromMatcherFn(fnPointer)); + } else { + return VariantMatcher(); + } +} + +template +static VariantMatcher +matcherMarshallFixed(void (*func)(), StringRef matcherName, + SourceRange nameRange, ArrayRef args, + Diagnostics *error) { + return matcherMarshallFixedImpl( + func, matcherName, nameRange, args, error, + std::index_sequence_for{}); +} + +// Fixed number of arguments overload +template +std::unique_ptr +makeMatcherAutoMarshall(ReturnType (*func)(ArgTypes...), + StringRef matcherName) { + // Create a vector of argument kinds + std::vector argKinds = {ArgTypeTraits::getKind()...}; + return std::make_unique( + matcherMarshallFixed, + reinterpret_cast(func), matcherName, argKinds); +} + +} // namespace internal +} // namespace matcher +} // namespace query +} // namespace mlir + +#endif // MLIR_TOOLS_MLIRQUERY_MATCHERS_MARSHALLERS_H diff --git a/mlir/include/mlir/Tools/mlir-query/Matcher/MatchersInternal.h b/mlir/include/mlir/Tools/mlir-query/Matcher/MatchersInternal.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Tools/mlir-query/Matcher/MatchersInternal.h @@ -0,0 +1,97 @@ +//===- MatchersInternal.h - Structural query framework --------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Implements the base layer of the matcher framework. +// +// Matchers are methods that return a Matcher which provides a method +// match(Operation *op) +// +// The matcher functions are defined in include/mlir/IR/Matchers.h. +// This file contains the wrapper classes needed to construct matchers for +// mlir-query. +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_TOOLS_MLIRQUERY_MATCHERS_MATCHERSINTERNAL_H +#define MLIR_TOOLS_MLIRQUERY_MATCHERS_MATCHERSINTERNAL_H + +#include "mlir/IR/Matchers.h" +#include "llvm/ADT/IntrusiveRefCntPtr.h" + +#include "llvm/Support/Debug.h" +using llvm::dbgs; + +#define DEBUG_TYPE "mlir-query" +#define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ") + +namespace mlir { +namespace query { +namespace matcher { + +// Generic interface for matchers on an MLIR operation. +class MatcherInterface + : public llvm::ThreadSafeRefCountedBase { +public: + virtual ~MatcherInterface() = default; + + virtual bool match(Operation *op) = 0; +}; + +// MatcherFnImpl takes a matcher function object and implements +// MatcherInterface. +template +class MatcherFnImpl : public MatcherInterface { +public: + MatcherFnImpl(MatcherFn &matcherFn) : matcherFn(matcherFn) {} + bool match(Operation *op) override { return matcherFn.match(op); } + +private: + MatcherFn matcherFn; +}; + +// Matcher wraps a MatcherInterface implementation and provides a match() +// method that redirects calls to the underlying implementation. +class DynMatcher { +public: + // Takes ownership of the provided implementation pointer. + DynMatcher(MatcherInterface *implementation) + : implementation(implementation) {} + + template + static std::unique_ptr + constructDynMatcherFromMatcherFn(MatcherFn &matcherFn) { + auto impl = std::make_unique>(matcherFn); + return std::make_unique(impl.release()); + } + + bool match(Operation *op) const { return implementation->match(op); } + +private: + llvm::IntrusiveRefCntPtr implementation; +}; + +// MatchFinder is used to find all operations that match a given matcher. +class MatchFinder { +public: + // Returns all operations that match the given matcher. + std::vector getMatches(Operation *root, DynMatcher matcher) { + std::vector matches; + + root->walk([&](Operation *subOp) { + if (matcher.match(subOp)) + matches.push_back(subOp); + }); + + return matches; + } +}; + +} // namespace matcher +} // namespace query +} // namespace mlir + +#endif // MLIR_TOOLS_MLIRQUERY_MATCHERS_MATCHERSINTERNAL_H diff --git a/mlir/include/mlir/Tools/mlir-query/Matcher/Parser.h b/mlir/include/mlir/Tools/mlir-query/Matcher/Parser.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Tools/mlir-query/Matcher/Parser.h @@ -0,0 +1,178 @@ +//===--- Parser.h - Matcher expression parser -----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Simple matcher expression parser. +// +// This file contains the Parser class, which is responsible for parsing +// expressions in a specific format: matcherName(Arg0, Arg1, ..., ArgN). The +// parser can also interpret simple types, like strings. +// +// The actual processing of the matchers is handled by a Sema object that is +// provided to the parser. +// +// The grammar for the supported expressions is as follows: +// := | +// := "quoted string" +// := () +// := [a-zA-Z]+ +// := | , +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOOLS_MLIRQUERY_MATCHERPARSER_H +#define MLIR_TOOLS_MLIRQUERY_MATCHERPARSER_H + +#include "Diagnostics.h" +#include "Registry.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" + +namespace mlir { +namespace query { +namespace matcher { + +// Matcher expression parser. +class Parser { +public: + // Interface to connect the parser with the registry and more. The parser uses + // the Sema instance passed into parseMatcherExpression() to handle all + // matcher tokens. + class Sema { + public: + virtual ~Sema(); + + // Process a matcher expression. The caller takes ownership of the Matcher + // object returned. + virtual VariantMatcher actOnMatcherExpression(MatcherCtor ctor, + SourceRange nameRange, + ArrayRef args, + Diagnostics *error) = 0; + + // Look up a matcher by name in the matcher name found by the parser. + virtual std::optional + lookupMatcherCtor(llvm::StringRef matcherName) = 0; + + // Compute the list of completion types for Context. + virtual std::vector getAcceptedCompletionTypes( + llvm::ArrayRef> Context); + + // Compute the list of completions that match any of acceptedTypes. + virtual std::vector + getMatcherCompletions(llvm::ArrayRef acceptedTypes); + }; + + // An implementation of the Sema interface that uses the matcher registry to + // process tokens. + class RegistrySema : public Parser::Sema { + public: + ~RegistrySema() override; + + std::optional + lookupMatcherCtor(llvm::StringRef matcherName) override; + + VariantMatcher actOnMatcherExpression(MatcherCtor ctor, + SourceRange nameRange, + ArrayRef args, + Diagnostics *error) override; + + std::vector getAcceptedCompletionTypes( + llvm::ArrayRef> context) override; + + std::vector + getMatcherCompletions(llvm::ArrayRef acceptedTypes) override; + }; + + using NamedValueMap = llvm::StringMap; + + // Methods to parse a matcher expression and return a DynMatcher object, + // transferring ownership to the caller. + static std::optional + parseMatcherExpression(llvm::StringRef &matcherCode, Sema *sema, + const NamedValueMap *namedValues, Diagnostics *error); + static std::optional + parseMatcherExpression(llvm::StringRef &matcherCode, Sema *sema, + Diagnostics *error) { + return parseMatcherExpression(matcherCode, sema, nullptr, error); + } + static std::optional + parseMatcherExpression(llvm::StringRef &matcherCode, Diagnostics *error) { + return parseMatcherExpression(matcherCode, nullptr, error); + } + + // Methods to parse any expression supported by this parser. + static bool parseExpression(llvm::StringRef &code, Sema *sema, + const NamedValueMap *namedValues, + VariantValue *value, Diagnostics *error); + + static bool parseExpression(llvm::StringRef &code, Sema *sema, + VariantValue *value, Diagnostics *error) { + return parseExpression(code, sema, nullptr, value, error); + } + static bool parseExpression(llvm::StringRef &code, VariantValue *value, + Diagnostics *error) { + return parseExpression(code, nullptr, value, error); + } + + // Methods to complete an expression at a given offset. + static std::vector + completeExpression(llvm::StringRef &code, unsigned completionOffset, + Sema *sema, const NamedValueMap *namedValues); + static std::vector + completeExpression(llvm::StringRef &code, unsigned completionOffset, + Sema *sema) { + return completeExpression(code, completionOffset, sema, nullptr); + } + static std::vector + completeExpression(llvm::StringRef &code, unsigned completionOffset) { + return completeExpression(code, completionOffset, nullptr); + } + +private: + class CodeTokenizer; + struct ScopedContextEntry; + struct TokenInfo; + + Parser(CodeTokenizer *tokenizer, Sema *sema, const NamedValueMap *namedValues, + Diagnostics *error); + + bool parseExpressionImpl(VariantValue *value); + + bool parseMatcherArgs(std::vector &args, MatcherCtor ctor, + const TokenInfo &nameToken, TokenInfo &endToken); + + bool parseMatcherExpressionImpl(const TokenInfo &nameToken, + const TokenInfo &openToken, + std::optional ctor, + VariantValue *value); + + bool parseIdentifierPrefixImpl(VariantValue *value); + + void addCompletion(const TokenInfo &compToken, + const MatcherCompletion &completion); + void addExpressionCompletions(); + + std::vector + getNamedValueCompletions(ArrayRef acceptedTypes); + + CodeTokenizer *const tokenizer; + Sema *const sema; + const NamedValueMap *const namedValues; + Diagnostics *const error; + + using ContextStackTy = std::vector>; + + ContextStackTy contextStack; + std::vector completions; +}; + +} // namespace matcher +} // namespace query +} // namespace mlir + +#endif // MLIR_TOOLS_MLIRQUERY_MATCHERPARSER_H diff --git a/mlir/include/mlir/Tools/mlir-query/Matcher/Registry.h b/mlir/include/mlir/Tools/mlir-query/Matcher/Registry.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Tools/mlir-query/Matcher/Registry.h @@ -0,0 +1,70 @@ +//===--- Registry.h - Matcher registry ------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Registry of all known matchers. +// +// The registry provides a generic interface to construct any matcher by name. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOOLS_MLIRQUERY_MATCHERREGISTRY_H +#define MLIR_TOOLS_MLIRQUERY_MATCHERREGISTRY_H + +#include "Diagnostics.h" +#include "Marshallers.h" +#include "VariantValue.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/StringRef.h" +#include + +namespace mlir { +namespace query { +namespace matcher { + +using MatcherCtor = const internal::MatcherDescriptor *; + +struct MatcherCompletion { + MatcherCompletion() = default; + MatcherCompletion(llvm::StringRef typedText, llvm::StringRef matcherDecl) + : typedText(typedText.str()), matcherDecl(matcherDecl.str()) {} + + bool operator==(const MatcherCompletion &other) const { + return typedText == other.typedText && matcherDecl == other.matcherDecl; + } + + // The text to type to select this matcher. + std::string typedText; + + // The "declaration" of the matcher, with type information. + std::string matcherDecl; +}; + +class Registry { +public: + Registry() = delete; + + static std::optional + lookupMatcherCtor(llvm::StringRef matcherName); + + static std::vector getAcceptedCompletionTypes( + llvm::ArrayRef> context); + + static std::vector + getMatcherCompletions(ArrayRef acceptedTypes); + + static VariantMatcher constructMatcher(MatcherCtor ctor, + SourceRange nameRange, + ArrayRef args, + Diagnostics *error); +}; + +} // namespace matcher +} // namespace query +} // namespace mlir + +#endif // MLIR_TOOLS_MLIRQUERY_MATCHERREGISTRY_H diff --git a/mlir/include/mlir/Tools/mlir-query/Matcher/VariantValue.h b/mlir/include/mlir/Tools/mlir-query/Matcher/VariantValue.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Tools/mlir-query/Matcher/VariantValue.h @@ -0,0 +1,148 @@ +//===--- VariantValue.h ---------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Supports all the types required for dynamic Matcher construction. +// Used by the registry to construct matchers in a generic way. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOOLS_MLIRQUERY_MATCHERVARIANTVALUE_H +#define MLIR_TOOLS_MLIRQUERY_MATCHERVARIANTVALUE_H + +#include "Diagnostics.h" +#include "MatchersInternal.h" +#include "mlir/IR/Matchers.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Support/type_traits.h" + +namespace mlir { +namespace query { +namespace matcher { + +// Kind identifier that supports all types that VariantValue can contain. +class ArgKind { +public: + enum Kind { AK_Matcher, AK_String }; + ArgKind(Kind k) : k(k) {} + + Kind getArgKind() const { return k; } + + bool operator<(const ArgKind &other) const { return k < other.k; } + + // String representation of the type. + std::string asString() const; + +private: + Kind k; +}; + +// A variant matcher object to abstract simple and complex matchers into a +// single object type. +class VariantMatcher { + class MatcherOps; + + // Payload interface to be specialized by each matcher type. It follows a + // similar interface as VariantMatcher itself. + class Payload { + public: + virtual ~Payload(); + virtual std::optional getDynMatcher() const = 0; + virtual std::string getTypeAsString() const = 0; + }; + +public: + /// A null matcher. + VariantMatcher(); + + // Clones the provided matcher. + static VariantMatcher SingleMatcher(DynMatcher matcher); + + // Makes the matcher the "null" matcher. + void reset(); + + // Checks if the matcher is null. + bool isNull() const { return !value; } + + /// Returns the matcher + std::optional getDynMatcher() const; + + // String representation of the type of the value. + std::string getTypeAsString() const; + +private: + explicit VariantMatcher(std::shared_ptr value) + : value(std::move(value)) {} + + class SinglePayload; + + std::shared_ptr value; +}; + +// Variant value class with a tagged union with value type semantics. It is used +// by the registry as the return value and argument type for the matcher factory +// methods. It can be constructed from any of the supported types: +// - StringRef +// - VariantMatcher +class VariantValue { +public: + VariantValue() : type(VT_Nothing) {} + + VariantValue(const VariantValue &other); + ~VariantValue(); + VariantValue &operator=(const VariantValue &other); + + // Specific constructors for each supported type. + VariantValue(const StringRef String); + VariantValue(const VariantMatcher &Matcher); + + // String value functions. + bool isString() const; + const StringRef &getString() const; + void setString(const StringRef &String); + + // Matcher value functions. + bool isMatcher() const; + const VariantMatcher &getMatcher() const; + void setMatcher(const VariantMatcher &Matcher); + + // String representation of the type of the value. + std::string getTypeAsString() const; + +private: + void reset(); + + // All supported value types. + enum ValueType { + VT_Nothing, + VT_String, + VT_Matcher, + }; + + // All supported value types. + union AllValues { + StringRef *String; + VariantMatcher *Matcher; + }; + + ValueType type; + AllValues value; +}; + +// A VariantValue instance annotated with its parser context. +struct ParserValue { + ParserValue() {} + llvm::StringRef text; + SourceRange range; + VariantValue value; +}; + +} // end namespace matcher +} // end namespace query +} // end namespace mlir + +#endif // MLIR_TOOLS_MLIRQUERY_MATCHERVARIANTVALUE_H diff --git a/mlir/include/mlir/Tools/mlir-query/MlirQueryMain.h b/mlir/include/mlir/Tools/mlir-query/MlirQueryMain.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Tools/mlir-query/MlirQueryMain.h @@ -0,0 +1,28 @@ +//===- MlirQueryMain.h - MLIR Query main ----------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Main entry function for mlir-query for when built as standalone +// binary. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOOLS_MLIRQUERY_MLIRQUERYMAIN_H +#define MLIR_TOOLS_MLIRQUERY_MLIRQUERYMAIN_H + +#include "Query.h" +#include "QueryParser.h" +#include "QuerySession.h" +#include "mlir/Support/LogicalResult.h" +namespace mlir { + +class MLIRContext; +LogicalResult mlirQueryMain(int argc, char **argv, MLIRContext &context); + +} // namespace mlir + +#endif // MLIR_TOOLS_MLIRQUERY_MLIRQUERYMAIN_H diff --git a/mlir/include/mlir/Tools/mlir-query/Query.h b/mlir/include/mlir/Tools/mlir-query/Query.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Tools/mlir-query/Query.h @@ -0,0 +1,82 @@ +//===--- Query.h - mlir-query ---------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOOLS_MLIRQUERY_QUERY_H +#define MLIR_TOOLS_MLIRQUERY_QUERY_H + +#include "Matcher/VariantValue.h" +#include "mlir/IR/Matchers.h" +#include "llvm/ADT/IntrusiveRefCntPtr.h" +#include "llvm/ADT/Twine.h" +#include + +namespace mlir { +namespace query { + +enum QueryKind { QK_Invalid, QK_NoOp, QK_Help, QK_Match }; + +class QuerySession; + +struct Query : llvm::RefCountedBase { + Query(QueryKind kind) : kind(kind) {} + virtual ~Query(); + + // Perform the query on QS and print output to OS. + // Return false if an error occurs, otherwise return true. + virtual bool run(llvm::raw_ostream &OS, QuerySession &QS) const = 0; + + llvm::StringRef remainingContent; + const QueryKind kind; +}; + +typedef llvm::IntrusiveRefCntPtr QueryRef; + +// Any query which resulted in a parse error. The error message is in ErrStr. +struct InvalidQuery : Query { + InvalidQuery(const llvm::Twine &errStr) + : Query(QK_Invalid), errStr(errStr.str()) {} + bool run(llvm::raw_ostream &OS, QuerySession &QS) const override; + + std::string errStr; + + static bool classof(const Query *Q) { return Q->kind == QK_Invalid; } +}; + +// No-op query (i.e. a blank line). +struct NoOpQuery : Query { + NoOpQuery() : Query(QK_NoOp) {} + bool run(llvm::raw_ostream &OS, QuerySession &QS) const override; + + static bool classof(const Query *Q) { return Q->kind == QK_NoOp; } +}; + +// Query for "help". +struct HelpQuery : Query { + HelpQuery() : Query(QK_Help) {} + bool run(llvm::raw_ostream &OS, QuerySession &QS) const override; + + static bool classof(const Query *Q) { return Q->kind == QK_Help; } +}; + +// Query for "match MATCHER". +struct MatchQuery : Query { + MatchQuery(StringRef source, const matcher::DynMatcher &matcher) + : Query(QK_Match), matcher(matcher), source(source) {} + bool run(llvm::raw_ostream &OS, QuerySession &QS) const override; + + const matcher::DynMatcher matcher; + + StringRef source; + + static bool classof(const Query *Q) { return Q->kind == QK_Match; } +}; + +} // namespace query +} // namespace mlir + +#endif diff --git a/mlir/include/mlir/Tools/mlir-query/QueryParser.h b/mlir/include/mlir/Tools/mlir-query/QueryParser.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Tools/mlir-query/QueryParser.h @@ -0,0 +1,66 @@ +//===--- QueryParser.h - mlir-query ---------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOOLS_MLIRQUERY_QUERYPARSER_H +#define MLIR_TOOLS_MLIRQUERY_QUERYPARSER_H + +#include "Matcher/Diagnostics.h" +#include "Matcher/Parser.h" +#include "Query.h" +#include "QuerySession.h" + +#include "mlir/IR/Matchers.h" +#include "llvm/LineEditor/LineEditor.h" +#include + +namespace mlir { +namespace query { + +class QuerySession; + +class QueryParser { +public: + // Parse line as a query and return a QueryRef representing the query, which + // may be an InvalidQuery. + static QueryRef parse(StringRef line, const QuerySession &QS); + + static std::vector + complete(StringRef line, size_t pos, const QuerySession &QS); + +private: + QueryParser(StringRef line, const QuerySession &QS) + : line(line), completionPos(nullptr), QS(QS) {} + + StringRef lexWord(); + + template + struct LexOrCompleteWord; + + QueryRef parseSetBool(bool QuerySession::*Var); + template + QueryRef parseSetOutputKind(); + QueryRef completeMatcherExpression(); + + QueryRef endQuery(QueryRef Q); + + // Parse [Begin, End) and returns a reference to the parsed query object, + // which may be an InvalidQuery if a parse error occurs. + QueryRef doParse(); + + StringRef line; + + const char *completionPos; + std::vector completions; + + const QuerySession &QS; +}; + +} // namespace query +} // namespace mlir + +#endif // MLIR_TOOLS_MLIRQUERY_QUERYPARSER_H diff --git a/mlir/include/mlir/Tools/mlir-query/QuerySession.h b/mlir/include/mlir/Tools/mlir-query/QuerySession.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Tools/mlir-query/QuerySession.h @@ -0,0 +1,39 @@ +//===--- QuerySession.h - mlir-query --------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOOLS_MLIRQUERY_QUERYSESSION_H +#define MLIR_TOOLS_MLIRQUERY_QUERYSESSION_H + +#include "Query.h" +#include "mlir/Tools/ParseUtilities.h" +#include "llvm/ADT/ArrayRef.h" + +namespace mlir { +namespace query { + +// Represents the state for a particular mlir-query session. +class QuerySession { +public: + QuerySession(Operation *rootOp, + const std::shared_ptr &sourceMgr) + : rootOp(rootOp), sourceMgr(sourceMgr), terminate(false) {} + + const std::shared_ptr &getSourceManager() { + return sourceMgr; + } + + Operation *rootOp; + const std::shared_ptr sourceMgr; + bool terminate; + llvm::StringMap namedValues; +}; + +} // namespace query +} // namespace mlir + +#endif // MLIR_TOOLS_MLIRQUERY_QUERYSESSION_H diff --git a/mlir/lib/Tools/CMakeLists.txt b/mlir/lib/Tools/CMakeLists.txt --- a/mlir/lib/Tools/CMakeLists.txt +++ b/mlir/lib/Tools/CMakeLists.txt @@ -2,6 +2,7 @@ add_subdirectory(mlir-lsp-server) add_subdirectory(mlir-opt) add_subdirectory(mlir-pdll-lsp-server) +add_subdirectory(mlir-query) add_subdirectory(mlir-reduce) add_subdirectory(mlir-tblgen) add_subdirectory(mlir-translate) diff --git a/mlir/lib/Tools/mlir-query/CMakeLists.txt b/mlir/lib/Tools/mlir-query/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Tools/mlir-query/CMakeLists.txt @@ -0,0 +1,12 @@ +include_directories(${MLIR_MAIN_INCLUDE_DIR}/mlir/Tools/mlir-query) + +add_mlir_library(MLIRQuery + Query.cpp + QueryParser.cpp + + LINK_LIBS PRIVATE + MLIRQueryMatcher + ) + +add_subdirectory(Matcher) +add_subdirectory(Tool) diff --git a/mlir/lib/Tools/mlir-query/Matcher/CMakeLists.txt b/mlir/lib/Tools/mlir-query/Matcher/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Tools/mlir-query/Matcher/CMakeLists.txt @@ -0,0 +1,8 @@ +include_directories(${MLIR_MAIN_INCLUDE_DIR}/mlir/Tools/mlir-query/Matcher) + +add_mlir_library(MLIRQueryMatcher + Parser.cpp + Registry.cpp + VariantValue.cpp + Diagnostics.cpp + ) diff --git a/mlir/lib/Tools/mlir-query/Matcher/Diagnostics.cpp b/mlir/lib/Tools/mlir-query/Matcher/Diagnostics.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Tools/mlir-query/Matcher/Diagnostics.cpp @@ -0,0 +1,229 @@ +//===- MatcherDiagnostic.cpp ----------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "Diagnostics.h" + +#include "llvm/Support/Debug.h" +using llvm::dbgs; + +#define DEBUG_TYPE "mlir-query" +#define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ") + +namespace mlir { +namespace query { +namespace matcher { + +Diagnostics::ArgStream Diagnostics::pushContextFrame(ContextType type, + SourceRange range) { + contextStack.emplace_back(); + ContextFrame &data = contextStack.back(); + data.type = type; + data.range = range; + return ArgStream(&data.args); +} + +Diagnostics::Context::Context(ConstructMatcherEnum, Diagnostics *error, + llvm::StringRef matcherName, + SourceRange matcherRange) + : error(error) { + error->pushContextFrame(CT_MatcherConstruct, matcherRange) << matcherName; +} + +Diagnostics::Context::Context(MatcherArgEnum, Diagnostics *error, + llvm::StringRef matcherName, + SourceRange matcherRange, unsigned argnumber) + : error(error) { + error->pushContextFrame(CT_MatcherArg, matcherRange) + << argnumber << matcherName; +} + +Diagnostics::Context::~Context() { error->contextStack.pop_back(); } + +Diagnostics::OverloadContext::OverloadContext(Diagnostics *error) + : error(error), beginIndex(error->errorValues.size()) {} + +Diagnostics::OverloadContext::~OverloadContext() { + // Merge all errors that happened while in this context. + if (beginIndex < error->errorValues.size()) { + Diagnostics::ErrorContent &dest = error->errorValues[beginIndex]; + for (size_t i = beginIndex + 1, e = error->errorValues.size(); i < e; ++i) { + dest.messages.push_back(error->errorValues[i].messages[0]); + } + error->errorValues.resize(beginIndex + 1); + } +} + +void Diagnostics::OverloadContext::revertErrors() { + // Revert the errors. + error->errorValues.resize(beginIndex); +} + +Diagnostics::ArgStream & +Diagnostics::ArgStream::operator<<(const llvm::Twine &arg) { + out->push_back(arg.str()); + return *this; +} + +Diagnostics::ArgStream Diagnostics::addError(SourceRange range, + ErrorType error) { + errorValues.emplace_back(); + ErrorContent &last = errorValues.back(); + last.contextStack = contextStack; + last.messages.emplace_back(); + last.messages.back().range = range; + last.messages.back().type = error; + return ArgStream(&last.messages.back().args); +} + +static llvm::StringRef +contextTypeToFormatString(Diagnostics::ContextType type) { + switch (type) { + case Diagnostics::CT_MatcherConstruct: + return "Error building matcher $0."; + case Diagnostics::CT_MatcherArg: + return "Error parsing argument $0 for matcher $1."; + } + llvm_unreachable("Unknown ContextType value."); +} + +static llvm::StringRef errorTypeToFormatString(Diagnostics::ErrorType type) { + switch (type) { + case Diagnostics::ET_RegistryMatcherNotFound: + return "Matcher not found: $0"; + case Diagnostics::ET_RegistryWrongArgCount: + return "Incorrect argument count. (Expected = $0) != (Actual = $1)"; + case Diagnostics::ET_RegistryWrongArgType: + return "Incorrect type for arg $0. (Expected = $1) != (Actual = $2)"; + case Diagnostics::ET_RegistryValueNotFound: + return "Value not found: $0"; + + case Diagnostics::ET_ParserStringError: + return "Error parsing string token: <$0>"; + case Diagnostics::ET_ParserNoOpenParen: + return "Error parsing matcher. Found token <$0> while looking for '('."; + case Diagnostics::ET_ParserNoCloseParen: + return "Error parsing matcher. Found end-of-code while looking for ')'."; + case Diagnostics::ET_ParserNoComma: + return "Error parsing matcher. Found token <$0> while looking for ','."; + case Diagnostics::ET_ParserNoCode: + return "End of code found while looking for token."; + case Diagnostics::ET_ParserNotAMatcher: + return "Input value is not a matcher expression."; + case Diagnostics::ET_ParserInvalidToken: + return "Invalid token <$0> found when looking for a value."; + case Diagnostics::ET_ParserTrailingCode: + return "Unexpected end of code."; + case Diagnostics::ET_ParserOverloadedType: + return "Input value has unresolved overloaded type: $0"; + case Diagnostics::ET_ParserFailedToBuildMatcher: + return "Failed to build matcher: $0."; + + case Diagnostics::ET_None: + return ""; + } + llvm_unreachable("Unknown ErrorType value."); +} + +static void formatErrorString(llvm::StringRef formatString, + llvm::ArrayRef args, + llvm::raw_ostream &OS) { + while (!formatString.empty()) { + std::pair pieces = + formatString.split("$"); + OS << pieces.first.str(); + if (pieces.second.empty()) + break; + + const char next = pieces.second.front(); + formatString = pieces.second.drop_front(); + if (next >= '0' && next <= '9') { + const unsigned index = next - '0'; + if (index < args.size()) { + OS << args[index]; + } else { + OS << ""; + } + } + } +} + +static void maybeAddLineAndColumn(SourceRange range, llvm::raw_ostream &OS) { + if (range.start.line > 0 && range.start.column > 0) { + OS << range.start.line << ":" << range.start.column << ": "; + } +} + +static void printContextFrameToStream(const Diagnostics::ContextFrame &frame, + llvm::raw_ostream &OS) { + maybeAddLineAndColumn(frame.range, OS); + formatErrorString(contextTypeToFormatString(frame.type), frame.args, OS); +} + +static void +printMessageToStream(const Diagnostics::ErrorContent::Message &message, + const llvm::Twine Prefix, llvm::raw_ostream &OS) { + maybeAddLineAndColumn(message.range, OS); + OS << Prefix; + formatErrorString(errorTypeToFormatString(message.type), message.args, OS); +} + +static void printErrorContentToStream(const Diagnostics::ErrorContent &content, + llvm::raw_ostream &OS) { + if (content.messages.size() == 1) { + printMessageToStream(content.messages[0], "", OS); + } else { + for (size_t i = 0, e = content.messages.size(); i != e; ++i) { + if (i != 0) + OS << "\n"; + printMessageToStream(content.messages[i], + "Candidate " + llvm::Twine(i + 1) + ": ", OS); + } + } +} + +void Diagnostics::printToStream(llvm::raw_ostream &OS) const { + for (const ErrorContent &error : errorValues) { + if (&error != &errorValues.front()) + OS << "\n"; + for (const ContextFrame &frame : error.contextStack) { + printContextFrameToStream(frame, OS); + OS << "\n"; + } + printErrorContentToStream(error, OS); + } +} + +std::string Diagnostics::toString() const { + std::string S; + llvm::raw_string_ostream OS(S); + printToStream(OS); + return S; +} + +void Diagnostics::printToStreamFull(llvm::raw_ostream &OS) const { + for (const ErrorContent &error : errorValues) { + if (&error != &errorValues.front()) + OS << "\n"; + for (const ContextFrame &frame : error.contextStack) { + printContextFrameToStream(frame, OS); + OS << "\n"; + } + printErrorContentToStream(error, OS); + } +} + +std::string Diagnostics::toStringFull() const { + std::string S; + llvm::raw_string_ostream OS(S); + printToStreamFull(OS); + return S; +} + +} // namespace matcher +} // namespace query +} // namespace mlir diff --git a/mlir/lib/Tools/mlir-query/Matcher/Parser.cpp b/mlir/lib/Tools/mlir-query/Matcher/Parser.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Tools/mlir-query/Matcher/Parser.cpp @@ -0,0 +1,565 @@ +//===- MatcherParser.cpp - Matcher expression parser ----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Recursive parser implementation for the matcher expression grammar. +// +//===----------------------------------------------------------------------===// + +#include "Parser.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Support/ManagedStatic.h" +#include +#include + +#include "llvm/Support/Debug.h" +using llvm::dbgs; + +#define DEBUG_TYPE "mlir-query" +#define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ") + +namespace mlir { +namespace query { +namespace matcher { + +// Simple structure to hold information for one token from the parser. +struct Parser::TokenInfo { + // Different possible tokens. + enum TokenKind { + TK_Eof, + TK_NewLine, + TK_OpenParen, + TK_CloseParen, + TK_Comma, + TK_Period, + TK_Literal, + TK_Ident, + TK_InvalidChar, + TK_CodeCompletion, + TK_Error + }; + + TokenInfo() = default; + + // Method to set the kind and text of the token + void set(TokenKind newKind, llvm::StringRef newText) { + kind = newKind; + text = newText; + } + + llvm::StringRef text; + TokenKind kind = TK_Eof; + SourceRange range; + VariantValue value; +}; + +class Parser::CodeTokenizer { +public: + // Constructor with matcherCode and error + explicit CodeTokenizer(llvm::StringRef matcherCode, Diagnostics *error) + : code(matcherCode), startOfLine(matcherCode), line(1), error(error) { + nextToken = getNextToken(); + } + + // Constructor with matcherCode, error, and codeCompletionOffset + CodeTokenizer(llvm::StringRef matcherCode, Diagnostics *error, + unsigned codeCompletionOffset) + : code(matcherCode), startOfLine(matcherCode), error(error), + codeCompletionLocation(matcherCode.data() + codeCompletionOffset) { + nextToken = getNextToken(); + } + + // Peek at next token without consuming it + const TokenInfo &peekNextToken() const { return nextToken; } + + // Consume and return the next token + TokenInfo consumeNextToken() { + TokenInfo thisToken = nextToken; + nextToken = getNextToken(); + return thisToken; + } + + // Skip any newline tokens + TokenInfo skipNewlines() { + while (nextToken.kind == TokenInfo::TK_NewLine) + nextToken = getNextToken(); + return nextToken; + } + + // Consume and return next token, ignoring newlines + TokenInfo consumeNextTokenIgnoreNewlines() { + skipNewlines(); + return nextToken.kind == TokenInfo::TK_Eof ? nextToken : consumeNextToken(); + } + + // Return kind of next token + TokenInfo::TokenKind nextTokenKind() const { return nextToken.kind; } + +private: + // Helper function to get the first character as a new StringRef and drop it + // from the original string + llvm::StringRef firstCharacterAndDrop(llvm::StringRef &str) { + assert(!str.empty()); + llvm::StringRef firstChar = str.substr(0, 1); + str = str.drop_front(); + return firstChar; + } + + // Get next token, consuming whitespaces and handling different token types + TokenInfo getNextToken() { + consumeWhitespace(); + TokenInfo result; + result.range.start = currentLocation(); + + // Code completion case + if (codeCompletionLocation && codeCompletionLocation <= code.data()) { + result.set(TokenInfo::TK_CodeCompletion, + llvm::StringRef(codeCompletionLocation, 0)); + codeCompletionLocation = nullptr; + return result; + } + + // End of file case + if (code.empty()) { + result.set(TokenInfo::TK_Eof, ""); + return result; + } + + // Switch to handle specific characters + switch (code[0]) { + case '#': + code = code.drop_until([](char c) { return c == '\n'; }); + return getNextToken(); + case ',': + result.set(TokenInfo::TK_Comma, firstCharacterAndDrop(code)); + break; + case '.': + result.set(TokenInfo::TK_Period, firstCharacterAndDrop(code)); + break; + case '\n': + ++line; + startOfLine = code.drop_front(); + result.set(TokenInfo::TK_NewLine, firstCharacterAndDrop(code)); + break; + case '(': + result.set(TokenInfo::TK_OpenParen, firstCharacterAndDrop(code)); + break; + case ')': + result.set(TokenInfo::TK_CloseParen, firstCharacterAndDrop(code)); + break; + case '"': + case '\'': + consumeStringLiteral(&result); + break; + default: + parseIdentifierOrInvalid(&result); + break; + } + + result.range.end = currentLocation(); + return result; + } + + // Consume a string literal, handle escape sequences and missing closing + // quote. + void consumeStringLiteral(TokenInfo *result) { + bool inEscape = false; + const char marker = code[0]; + for (size_t length = 1; length < code.size(); ++length) { + if (inEscape) { + inEscape = false; + continue; + } + if (code[length] == '\\') { + inEscape = true; + continue; + } + if (code[length] == marker) { + result->kind = TokenInfo::TK_Literal; + result->text = code.substr(0, length + 1); + result->value = code.substr(1, length - 1); + code = code.drop_front(length + 1); + return; + } + } + llvm::StringRef errorText = code; + code = code.drop_front(code.size()); + SourceRange range; + range.start = result->range.start; + range.end = currentLocation(); + error->addError(range, error->ET_ParserStringError) << errorText; + result->kind = TokenInfo::TK_Error; + } + + void parseIdentifierOrInvalid(TokenInfo *result) { + if (isalnum(code[0])) { + // Parse an identifier + size_t tokenLength = 1; + + while (true) { + // A code completion location in/immediately after an identifier will + // cause the portion of the identifier before the code completion + // location to become a code completion token. + if (codeCompletionLocation == code.data() + tokenLength) { + codeCompletionLocation = nullptr; + result->kind = TokenInfo::TK_CodeCompletion; + result->text = code.substr(0, tokenLength); + code = code.drop_front(tokenLength); + return; + } + if (tokenLength == code.size() || !(isalnum(code[tokenLength]))) + break; + ++tokenLength; + } + result->kind = TokenInfo::TK_Ident; + result->text = code.substr(0, tokenLength); + code = code.drop_front(tokenLength); + } else { + result->kind = TokenInfo::TK_InvalidChar; + result->text = code.substr(0, 1); + code = code.drop_front(1); + } + } + + // Consume all leading whitespace from code, except newlines + void consumeWhitespace() { + code = code.drop_while( + [](char c) { return llvm::StringRef(" \t\v\f\r").contains(c); }); + } + + // Returns the current location in the source code + SourceLocation currentLocation() { + SourceLocation location; + location.line = line; + location.column = code.data() - startOfLine.data() + 1; + return location; + } + + llvm::StringRef code; + llvm::StringRef startOfLine; + unsigned line = 1; + Diagnostics *error; + TokenInfo nextToken; + const char *codeCompletionLocation = nullptr; +}; + +Parser::Sema::~Sema() = default; + +std::vector Parser::Sema::getAcceptedCompletionTypes( + llvm::ArrayRef> context) { + return {}; +} + +std::vector +Parser::Sema::getMatcherCompletions(llvm::ArrayRef acceptedTypes) { + return {}; +} + +// Entry for the scope of a parser +struct Parser::ScopedContextEntry { + Parser *parser; + + ScopedContextEntry(Parser *parser, MatcherCtor c) : parser(parser) { + parser->contextStack.push_back({c, 0u}); + } + + ~ScopedContextEntry() { parser->contextStack.pop_back(); } + + void nextArg() { ++parser->contextStack.back().second; } +}; + +// Parse and validate expressions starting with an identifier. +// This function can parse named values and matchers. In case of failure, it +// will try to determine the user's intent to give an appropriate error message. +bool Parser::parseIdentifierPrefixImpl(VariantValue *value) { + const TokenInfo nameToken = tokenizer->consumeNextToken(); + + if (tokenizer->nextTokenKind() != TokenInfo::TK_OpenParen) { + // Parse as a named value. + auto namedValue = + namedValues ? namedValues->lookup(nameToken.text) : VariantValue(); + + if (!namedValue.isMatcher()) { + error->addError(tokenizer->peekNextToken().range, + error->ET_ParserNotAMatcher); + return false; + } + + if (tokenizer->nextTokenKind() == TokenInfo::TK_NewLine) { + error->addError(tokenizer->peekNextToken().range, + error->ET_ParserNoOpenParen) + << "NewLine"; + return false; + } + + // If the syntax is correct and the name is not a matcher either, report + // an unknown named value. + if ((tokenizer->nextTokenKind() == TokenInfo::TK_Comma || + tokenizer->nextTokenKind() == TokenInfo::TK_CloseParen || + tokenizer->nextTokenKind() == TokenInfo::TK_NewLine || + tokenizer->nextTokenKind() == TokenInfo::TK_Eof) && + !sema->lookupMatcherCtor(nameToken.text)) { + error->addError(nameToken.range, error->ET_RegistryValueNotFound) + << nameToken.text; + return false; + } + // Otherwise, fallback to the matcher parser. + } + + tokenizer->skipNewlines(); + + assert(nameToken.kind == TokenInfo::TK_Ident); + TokenInfo openToken = tokenizer->consumeNextToken(); + if (openToken.kind != TokenInfo::TK_OpenParen) { + error->addError(openToken.range, error->ET_ParserNoOpenParen) + << openToken.text; + return false; + } + + std::optional ctor = sema->lookupMatcherCtor(nameToken.text); + + // Parse as a matcher expression. + return parseMatcherExpressionImpl(nameToken, openToken, ctor, value); +} + +// Parse the arguments of a matcher +bool Parser::parseMatcherArgs(std::vector &args, + MatcherCtor ctor, const TokenInfo &nameToken, + TokenInfo &endToken) { + ScopedContextEntry sce(this, ctor); + + while (tokenizer->nextTokenKind() != TokenInfo::TK_Eof) { + if (tokenizer->nextTokenKind() == TokenInfo::TK_CloseParen) { + // end of args. + endToken = tokenizer->consumeNextToken(); + break; + } + + if (!args.empty()) { + // We must find a , token to continue. + TokenInfo commaToken = tokenizer->consumeNextToken(); + if (commaToken.kind != TokenInfo::TK_Comma) { + error->addError(commaToken.range, error->ET_ParserNoComma) + << commaToken.text; + return false; + } + } + + Diagnostics::Context ctx(Diagnostics::Context::MatcherArg, error, + nameToken.text, nameToken.range, args.size() + 1); + ParserValue argValue; + tokenizer->skipNewlines(); + + argValue.text = tokenizer->peekNextToken().text; + argValue.range = tokenizer->peekNextToken().range; + if (!parseExpressionImpl(&argValue.value)) { + return false; + } + + tokenizer->skipNewlines(); + args.push_back(argValue); + sce.nextArg(); + } + + return true; +} + +/// Parse and validate a matcher expression. +bool Parser::parseMatcherExpressionImpl(const TokenInfo &nameToken, + const TokenInfo &openToken, + std::optional ctor, + VariantValue *value) { + if (!ctor) { + error->addError(nameToken.range, error->ET_RegistryMatcherNotFound) + << nameToken.text; + // Do not return here. We need to continue to give completion suggestions. + } + + std::vector args; + TokenInfo endToken; + + tokenizer->skipNewlines(); + + if (!parseMatcherArgs(args, ctor.value_or(nullptr), nameToken, endToken)) { + return false; + } + + if (!ctor) + return false; + // Merge the start and end infos. + Diagnostics::Context ctx(Diagnostics::Context::ConstructMatcher, error, + nameToken.text, nameToken.range); + SourceRange matcherRange = nameToken.range; + matcherRange.end = endToken.range.end; + VariantMatcher result = + sema->actOnMatcherExpression(*ctor, matcherRange, args, error); + if (result.isNull()) + return false; + *value = result; + return true; +} + +// If the prefix of this completion matches the completion token, add it to +// completions minus the prefix. +void Parser::addCompletion(const TokenInfo &compToken, + const MatcherCompletion &completion) { + if (llvm::StringRef(completion.typedText).startswith(compToken.text)) { + completions.emplace_back(completion.typedText.substr(compToken.text.size()), + completion.matcherDecl); + } +} + +std::vector +Parser::getNamedValueCompletions(ArrayRef acceptedTypes) { + if (!namedValues) + return {}; + + std::vector result; + for (const auto &entry : *namedValues) { + std::string decl = + (entry.getValue().getTypeAsString() + " " + entry.getKey()).str(); + result.emplace_back(entry.getKey(), decl); + } + return result; +} + +void Parser::addExpressionCompletions() { + const TokenInfo compToken = tokenizer->consumeNextTokenIgnoreNewlines(); + assert(compToken.kind == TokenInfo::TK_CodeCompletion); + + // We cannot complete code if there is an invalid element on the context + // stack. + for (const auto &entry : contextStack) { + if (!entry.first) + return; + } + + auto acceptedTypes = sema->getAcceptedCompletionTypes(contextStack); + for (const auto &completion : sema->getMatcherCompletions(acceptedTypes)) { + addCompletion(compToken, completion); + } + + for (const auto &completion : getNamedValueCompletions(acceptedTypes)) { + addCompletion(compToken, completion); + } +} + +// Parse an +bool Parser::parseExpressionImpl(VariantValue *value) { + switch (tokenizer->nextTokenKind()) { + case TokenInfo::TK_Literal: + *value = tokenizer->consumeNextToken().value; + return true; + case TokenInfo::TK_Ident: + return parseIdentifierPrefixImpl(value); + case TokenInfo::TK_CodeCompletion: + addExpressionCompletions(); + return false; + case TokenInfo::TK_Eof: + error->addError(tokenizer->consumeNextToken().range, + error->ET_ParserNoCode); + return false; + + case TokenInfo::TK_Error: + // This error was already reported by the tokenizer. + return false; + case TokenInfo::TK_NewLine: + case TokenInfo::TK_OpenParen: + case TokenInfo::TK_CloseParen: + case TokenInfo::TK_Comma: + case TokenInfo::TK_Period: + case TokenInfo::TK_InvalidChar: + const TokenInfo token = tokenizer->consumeNextToken(); + error->addError(token.range, error->ET_ParserInvalidToken) + << (token.kind == TokenInfo::TK_NewLine ? "NewLine" : token.text); + return false; + } + + llvm_unreachable("Unknown token kind."); +} + +static llvm::ManagedStatic defaultRegistrySema; + +Parser::Parser(CodeTokenizer *tokenizer, Sema *sema, + const NamedValueMap *namedValues, Diagnostics *error) + : tokenizer(tokenizer), sema(sema ? sema : &*defaultRegistrySema), + namedValues(namedValues), error(error) {} + +Parser::RegistrySema::~RegistrySema() = default; + +std::optional +Parser::RegistrySema::lookupMatcherCtor(llvm::StringRef matcherName) { + return Registry::lookupMatcherCtor(matcherName); +} + +VariantMatcher Parser::RegistrySema::actOnMatcherExpression( + MatcherCtor ctor, SourceRange nameRange, ArrayRef args, + Diagnostics *error) { + return Registry::constructMatcher(ctor, nameRange, args, error); +} + +std::vector Parser::RegistrySema::getAcceptedCompletionTypes( + ArrayRef> context) { + return Registry::getAcceptedCompletionTypes(context); +} + +std::vector +Parser::RegistrySema::getMatcherCompletions(ArrayRef acceptedTypes) { + return Registry::getMatcherCompletions(acceptedTypes); +} + +bool Parser::parseExpression(llvm::StringRef &code, Sema *sema, + const NamedValueMap *namedValues, + VariantValue *value, Diagnostics *error) { + CodeTokenizer tokenizer(code, error); + Parser parser(&tokenizer, sema, namedValues, error); + if (!parser.parseExpressionImpl(value)) + return false; + auto nextToken = tokenizer.peekNextToken(); + if (nextToken.kind != TokenInfo::TK_Eof && + nextToken.kind != TokenInfo::TK_NewLine) { + error->addError(tokenizer.peekNextToken().range, + error->ET_ParserTrailingCode); + return false; + } + return true; +} + +std::vector +Parser::completeExpression(llvm::StringRef &code, unsigned completionOffset, + Sema *sema, const NamedValueMap *namedValues) { + Diagnostics error; + CodeTokenizer tokenizer(code, &error, completionOffset); + Parser parser(&tokenizer, sema, namedValues, &error); + VariantValue dummy; + parser.parseExpressionImpl(&dummy); + + return parser.completions; +} + +std::optional +Parser::parseMatcherExpression(llvm::StringRef &code, Sema *sema, + const NamedValueMap *namedValues, + Diagnostics *error) { + VariantValue value; + if (!parseExpression(code, sema, namedValues, &value, error)) + return std::nullopt; + if (!value.isMatcher()) { + error->addError(SourceRange(), error->ET_ParserNotAMatcher); + return std::nullopt; + } + std::optional result = value.getMatcher().getDynMatcher(); + if (!result) { + error->addError(SourceRange(), error->ET_ParserOverloadedType) + << value.getTypeAsString(); + } + return result; +} + +} // namespace matcher +} // namespace query +} // namespace mlir diff --git a/mlir/lib/Tools/mlir-query/Matcher/Registry.cpp b/mlir/lib/Tools/mlir-query/Matcher/Registry.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Tools/mlir-query/Matcher/Registry.cpp @@ -0,0 +1,181 @@ +//===- MatcherRegistry.cpp - Matcher registry -----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Registry map populated at static initialization time. +// +//===----------------------------------------------------------------------===// + +#include "Registry.h" +#include "mlir/IR/Matchers.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/ManagedStatic.h" +#include +#include + +#include "llvm/Support/Debug.h" +using llvm::dbgs; + +#define DEBUG_TYPE "mlir-query" +#define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ") + +namespace mlir { +namespace query { +namespace matcher { +namespace { + +using ConstructorMap = + llvm::StringMap>; + +// This is needed because these matchers are defined as overloaded functions. +using IsConstantOp = detail::constant_op_matcher(); +using HasOpAttrName = detail::AttrOpMatcher(StringRef); +using HasOpName = detail::NameOpMatcher(StringRef); + +class RegistryMaps { +public: + RegistryMaps(); + ~RegistryMaps(); + + const ConstructorMap &constructors() const { return constructorMap; } + +private: + void registerMatcher(llvm::StringRef matcherName, + std::unique_ptr callback); + + ConstructorMap constructorMap; +}; + +} // namespace + +void RegistryMaps::registerMatcher( + llvm::StringRef matcherName, + std::unique_ptr callback) { + assert(!constructorMap.contains(matcherName)); + constructorMap[matcherName] = std::move(callback); +} + +// Generate a registry map with all the known matchers. +RegistryMaps::RegistryMaps() { + auto registerOpMatcher = [&](const std::string &name, auto matcher) { + registerMatcher(name, internal::makeMatcherAutoMarshall(matcher, name)); + }; + + // Register matchers using the template function (added in alphabetical order + // for consistency) + registerOpMatcher("hasOpAttrName", static_cast(m_Attr)); + registerOpMatcher("hasOpName", static_cast(m_Op)); + registerOpMatcher("isConstantOp", static_cast(m_Constant)); + registerOpMatcher("isNegInfFloat", m_NegInfFloat); + registerOpMatcher("isNegZeroFloat", m_NegZeroFloat); + registerOpMatcher("isNonZero", m_NonZero); + registerOpMatcher("isOne", m_One); + registerOpMatcher("isOneFloat", m_OneFloat); + registerOpMatcher("isPosInfFloat", m_PosInfFloat); + registerOpMatcher("isPosZeroFloat", m_PosZeroFloat); + registerOpMatcher("isZero", m_Zero); + registerOpMatcher("isZeroFloat", m_AnyZeroFloat); +} + +RegistryMaps::~RegistryMaps() = default; + +static llvm::ManagedStatic registryData; + +std::optional +Registry::lookupMatcherCtor(llvm::StringRef matcherName) { + auto it = registryData->constructors().find(matcherName); + return it == registryData->constructors().end() ? std::optional() + : it->second.get(); +} + +std::vector Registry::getAcceptedCompletionTypes( + ArrayRef> context) { + // Starting with the above seed of acceptable top-level matcher types, compute + // the acceptable type set for the argument indicated by each context element. + std::set typeSet; + typeSet.insert(ArgKind(ArgKind::AK_Matcher)); + + for (const auto &ctxEntry : context) { + MatcherCtor ctor = ctxEntry.first; + unsigned argNumber = ctxEntry.second; + std::vector nextTypeSet; + + if (argNumber < ctor->getNumArgs()) + ctor->getArgKinds(argNumber, nextTypeSet); + + typeSet.insert(nextTypeSet.begin(), nextTypeSet.end()); + } + + return std::vector(typeSet.begin(), typeSet.end()); +} + +std::vector +Registry::getMatcherCompletions(ArrayRef acceptedTypes) { + std::vector completions; + + // Search the registry for acceptable matchers. + for (const auto &m : registryData->constructors()) { + const internal::MatcherDescriptor &matcher = *m.getValue(); + StringRef name = m.getKey(); + + unsigned numArgs = matcher.getNumArgs(); + std::vector> argKinds(numArgs); + + for (const ArgKind &kind : acceptedTypes) { + if (kind.getArgKind() != kind.AK_Matcher) + continue; + + for (unsigned arg = 0; arg != numArgs; ++arg) + matcher.getArgKinds(arg, argKinds[arg]); + } + + std::string decl; + llvm::raw_string_ostream OS(decl); + + std::string typedText = std::string(name); + OS << "Matcher: " << name << "("; + + for (const std::vector &arg : argKinds) { + if (&arg != &argKinds[0]) + OS << ", "; + + bool firstArgKind = true; + // Two steps. First all non-matchers, then matchers only. + for (const ArgKind &argKind : arg) { + if (!firstArgKind) + OS << "|"; + + firstArgKind = false; + OS << argKind.asString(); + } + } + + OS << ")"; + typedText += "("; + + if (argKinds.empty()) + typedText += ")"; + else if (argKinds[0][0].getArgKind() == ArgKind::AK_String) + typedText += "\""; + + completions.emplace_back(typedText, OS.str()); + } + + return completions; +} + +VariantMatcher Registry::constructMatcher(MatcherCtor ctor, + SourceRange nameRange, + ArrayRef args, + Diagnostics *error) { + return ctor->create(nameRange, args, error); +} + +} // namespace matcher +} // namespace query +} // namespace mlir diff --git a/mlir/lib/Tools/mlir-query/Matcher/VariantValue.cpp b/mlir/lib/Tools/mlir-query/Matcher/VariantValue.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Tools/mlir-query/Matcher/VariantValue.cpp @@ -0,0 +1,143 @@ +//===--- MatcherVariantvalue.cpp ------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// +//===----------------------------------------------------------------------===// + +#include "VariantValue.h" + +namespace mlir { +namespace query { +namespace matcher { + +std::string ArgKind::asString() const { + switch (getArgKind()) { + case AK_String: + return "string"; + case AK_Matcher: + return "Matcher"; + } + llvm_unreachable("Unhandled ArgKind"); +} + +VariantMatcher::Payload::~Payload() = default; + +class VariantMatcher::SinglePayload : public VariantMatcher::Payload { +public: + explicit SinglePayload(DynMatcher matcher) : matcher(std::move(matcher)) {} + + std::optional getDynMatcher() const override { return matcher; } + + std::string getTypeAsString() const override { return "Matcher"; } + +private: + DynMatcher matcher; +}; + +VariantMatcher::VariantMatcher() = default; + +VariantMatcher VariantMatcher::SingleMatcher(DynMatcher matcher) { + return VariantMatcher(std::make_shared(std::move(matcher))); +} + +std::optional VariantMatcher::getDynMatcher() const { + return value ? value->getDynMatcher() : std::nullopt; +} + +void VariantMatcher::reset() { value.reset(); } + +std::string VariantMatcher::getTypeAsString() const { return ""; } + +VariantValue::VariantValue(const VariantValue &other) : type(VT_Nothing) { + *this = other; +} + +VariantValue::VariantValue(const StringRef String) : type(VT_String) { + value.String = new StringRef(String); +} + +VariantValue::VariantValue(const VariantMatcher &Matcher) : type(VT_Matcher) { + value.Matcher = new VariantMatcher(Matcher); +} + +VariantValue::~VariantValue() { reset(); } + +VariantValue &VariantValue::operator=(const VariantValue &other) { + if (this == &other) + return *this; + reset(); + switch (other.type) { + case VT_String: + setString(other.getString()); + break; + case VT_Matcher: + setMatcher(other.getMatcher()); + break; + case VT_Nothing: + type = VT_Nothing; + break; + } + return *this; +} + +void VariantValue::reset() { + switch (type) { + case VT_String: + delete value.String; + break; + case VT_Matcher: + delete value.Matcher; + break; + // Cases that do nothing. + case VT_Nothing: + break; + } + type = VT_Nothing; +} + +bool VariantValue::isString() const { return type == VT_String; } + +const StringRef &VariantValue::getString() const { + assert(isString()); + return *value.String; +} + +void VariantValue::setString(const StringRef &newValue) { + reset(); + type = VT_String; + value.String = new StringRef(newValue); +} + +bool VariantValue::isMatcher() const { return type == VT_Matcher; } + +const VariantMatcher &VariantValue::getMatcher() const { + assert(isMatcher()); + return *value.Matcher; +} + +void VariantValue::setMatcher(const VariantMatcher &newValue) { + reset(); + type = VT_Matcher; + value.Matcher = new VariantMatcher(newValue); +} + +std::string VariantValue::getTypeAsString() const { + switch (type) { + case VT_String: + return "String"; + case VT_Matcher: + return "Matcher"; + case VT_Nothing: + return "Nothing"; + } + llvm_unreachable("Invalid Type"); +} + +} // end namespace matcher +} // end namespace query +} // end namespace mlir diff --git a/mlir/lib/Tools/mlir-query/Query.cpp b/mlir/lib/Tools/mlir-query/Query.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Tools/mlir-query/Query.cpp @@ -0,0 +1,73 @@ +//===---- Query.cpp - mlir-query query ------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "Query.h" +#include "QuerySession.h" +#include "llvm/Support/raw_ostream.h" +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/IRMapping.h" + +#include "llvm/Support/Debug.h" +using llvm::dbgs; + +#define DEBUG_TYPE "mlir-query" +#define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ") + +namespace mlir { +namespace query { + +Query::~Query() {} + +bool InvalidQuery::run(llvm::raw_ostream &OS, QuerySession &QS) const { + OS << errStr << "\n"; + return false; +} + +bool NoOpQuery::run(llvm::raw_ostream &OS, QuerySession &QS) const { + return true; +} + +bool HelpQuery::run(llvm::raw_ostream &OS, QuerySession &QS) const { + OS << "Available commands:\n\n" + " match MATCHER, m MATCHER " + "Match the mlir against the given matcher.\n\n"; + return true; +} + +std::vector getMatches(Operation *rootOp, + const matcher::DynMatcher &matcher) { + auto matchFinder = query::matcher::MatchFinder(); + return matchFinder.getMatches(rootOp, matcher); +} + +bool MatchQuery::run(llvm::raw_ostream &OS, QuerySession &QS) const { + Operation *rootOp = QS.rootOp; + auto matches = getMatches(rootOp, matcher); + + MLIRContext *context = rootOp->getContext(); + context->printOpOnDiagnostic(false); + SourceMgrDiagnosticHandler sourceMgrHandler(*QS.sourceMgr, context); + + unsigned matchCount = 0; + OS << "\n"; + for (Operation *op : matches) { + OS << "Match #" << ++matchCount << ":\n\n"; + // Placeholder "root" binding for the initial draft. + op->emitRemark("\"root\" binds here"); + } + OS << matchCount << (matchCount == 1 ? " match.\n\n" : " matches.\n\n"); + + return true; +} + +} // namespace query +} // namespace mlir diff --git a/mlir/lib/Tools/mlir-query/QueryParser.cpp b/mlir/lib/Tools/mlir-query/QueryParser.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Tools/mlir-query/QueryParser.cpp @@ -0,0 +1,219 @@ +//===---- QueryParser.cpp - mlir-query command parser ---------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "QueryParser.h" +#include "Query.h" +#include "QuerySession.h" +#include "mlir/IR/Matchers.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/StringSwitch.h" + +#include "llvm/Support/Debug.h" +using llvm::dbgs; + +#define DEBUG_TYPE "mlir-query" +#define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ") + +namespace mlir { +namespace query { + +// Lex any amount of whitespace followed by a "word" (any sequence of +// non-whitespace characters) from the start of region [Begin,End). If no word +// is found before End, return StringRef(). Begin is adjusted to exclude the +// lexed region. +StringRef QueryParser::lexWord() { + line = line.drop_while([](char c) { + // Don't trim newlines. + return StringRef(" \t\v\f\r").contains(c); + }); + + if (line.empty()) + // Even though the line is empty, it contains a pointer and + // a (zero) length. The pointer is used in the LexOrCompleteWord + // code completion. + return line; + + StringRef word; + if (line.front() == '#') { + word = line.substr(0, 1); + } else { + word = line.take_until([](char c) { + // Don't trim newlines. + return StringRef(" \t\v\f\r").contains(c); + }); + } + + line = line.drop_front(word.size()); + return word; +} + +// This is the StringSwitch-alike used by lexOrCompleteWord below. See that +// function for details. +template +struct QueryParser::LexOrCompleteWord { + StringRef word; + StringSwitch stringSwitch; + + QueryParser *queryParser; + // Set to the completion point offset in word, or StringRef::npos if + // completion point not in word. + size_t wordCompletionPos; + + // Lexes a word and stores it in word. Returns a LexOrCompleteword object + // that can be used like a llvm::StringSwitch, but adds cases as possible + // completions if the lexed word contains the completion point. + LexOrCompleteWord(QueryParser *queryParser, StringRef &outWord) + : word(queryParser->lexWord()), stringSwitch(word), + queryParser(queryParser), wordCompletionPos(StringRef::npos) { + outWord = word; + if (queryParser->completionPos && + queryParser->completionPos <= word.data() + word.size()) { + if (queryParser->completionPos < word.data()) + wordCompletionPos = 0; + else + wordCompletionPos = queryParser->completionPos - word.data(); + } + } + + LexOrCompleteWord &Case(llvm::StringLiteral caseStr, const T &value, + bool isCompletion = true) { + + if (wordCompletionPos == StringRef::npos) + stringSwitch.Case(caseStr, value); + else if (caseStr.size() != 0 && isCompletion && + wordCompletionPos <= caseStr.size() && + caseStr.substr(0, wordCompletionPos) == + word.substr(0, wordCompletionPos)) { + + queryParser->completions.push_back(llvm::LineEditor::Completion( + (caseStr.substr(wordCompletionPos) + " ").str(), + std::string(caseStr))); + } + return *this; + } + + T Default(T value) { return stringSwitch.Default(value); } +}; + +QueryRef QueryParser::endQuery(QueryRef Q) { + StringRef extra = line; + StringRef extraTrimmed = extra.drop_while( + [](char c) { return StringRef(" \t\v\f\r").contains(c); }); + + if ((!extraTrimmed.empty() && extraTrimmed[0] == '\n') || + (extraTrimmed.size() >= 2 && extraTrimmed[0] == '\r' && + extraTrimmed[1] == '\n')) + Q->remainingContent = extra; + else { + StringRef trailingWord = lexWord(); + if (!trailingWord.empty() && trailingWord.front() == '#') { + line = line.drop_until([](char c) { return c == '\n'; }); + line = line.drop_while([](char c) { return c == '\n'; }); + return endQuery(Q); + } + if (!trailingWord.empty()) { + return new InvalidQuery("unexpected extra input: '" + extra + "'"); + } + } + return Q; +} + +namespace { + +enum ParsedQueryKind { + PQK_Invalid, + PQK_Comment, + PQK_NoOp, + PQK_Help, + PQK_Match, +}; + +QueryRef makeInvalidQueryFromDiagnostics(const matcher::Diagnostics &diag) { + std::string ErrStr; + llvm::raw_string_ostream OS(ErrStr); + diag.printToStreamFull(OS); + return new InvalidQuery(OS.str()); +} +} // namespace + +QueryRef QueryParser::completeMatcherExpression() { + std::vector comps = + matcher::Parser::completeExpression(line, completionPos - line.begin(), + nullptr, &QS.namedValues); + for (const auto &comp : comps) { + completions.emplace_back(comp.typedText, comp.matcherDecl); + } + return QueryRef(); +} + +QueryRef QueryParser::doParse() { + + StringRef commandStr; + ParsedQueryKind qKind = LexOrCompleteWord(this, commandStr) + .Case("", PQK_NoOp) + .Case("#", PQK_Comment, /*IsCompletion=*/false) + .Case("help", PQK_Help) + .Case("m", PQK_Match, /*IsCompletion=*/false) + .Case("match", PQK_Match) + .Default(PQK_Invalid); + + switch (qKind) { + case PQK_Comment: + case PQK_NoOp: + line = line.drop_until([](char c) { return c == '\n'; }); + line = line.drop_while([](char c) { return c == '\n'; }); + if (line.empty()) + return new NoOpQuery; + return doParse(); + + case PQK_Help: + return endQuery(new HelpQuery); + + case PQK_Match: { + if (completionPos) { + return completeMatcherExpression(); + } + + matcher::Diagnostics diag; + auto matcherSource = line.ltrim(); + auto origMatcherSource = matcherSource; + std::optional matcher = + matcher::Parser::parseMatcherExpression(matcherSource, nullptr, + &QS.namedValues, &diag); + if (!matcher) { + return makeInvalidQueryFromDiagnostics(diag); + } + auto actualSource = origMatcherSource.slice(0, origMatcherSource.size() - + matcherSource.size()); + auto *Q = new MatchQuery(actualSource, *matcher); + Q->remainingContent = matcherSource; + return Q; + } + + case PQK_Invalid: + return new InvalidQuery("unknown command: " + commandStr); + } + + llvm_unreachable("Invalid query kind"); +} + +QueryRef QueryParser::parse(StringRef line, const QuerySession &QS) { + return QueryParser(line, QS).doParse(); +} + +std::vector +QueryParser::complete(StringRef line, size_t pos, const QuerySession &QS) { + QueryParser queryParser(line, QS); + queryParser.completionPos = line.data() + pos; + + queryParser.doParse(); + return queryParser.completions; +} + +} // namespace query +} // namespace mlir diff --git a/mlir/lib/Tools/mlir-query/Tool/CMakeLists.txt b/mlir/lib/Tools/mlir-query/Tool/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Tools/mlir-query/Tool/CMakeLists.txt @@ -0,0 +1,15 @@ +include_directories( ${MLIR_MAIN_INCLUDE_DIR}/mlir/Tools/mlir-query) + +set(LLVM_LINK_COMPONENTS + lineeditor + Support + ) + +add_mlir_library(MLIRQueryLib + MlirQueryMain.cpp + + LINK_LIBS PRIVATE + MLIRIR + MLIRParser + MLIRQuery + ) diff --git a/mlir/lib/Tools/mlir-query/Tool/MlirQueryMain.cpp b/mlir/lib/Tools/mlir-query/Tool/MlirQueryMain.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Tools/mlir-query/Tool/MlirQueryMain.cpp @@ -0,0 +1,94 @@ +//===- MlirQueryMain.cpp - MLIR Query main --------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the general framework of the MLIR query tool. It +// parses the command line arguments, parses the MLIR file and outputs the query +// results. +// +//===----------------------------------------------------------------------===// + +#include "MlirQueryMain.h" +#include "mlir/Support/FileUtilities.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Tools/ParseUtilities.h" +#include "llvm/LineEditor/LineEditor.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/InitLLVM.h" +#include "llvm/Support/SourceMgr.h" + +#include "llvm/Support/Debug.h" +using llvm::dbgs; + +#define DEBUG_TYPE "mlir-query" +#define DBGS() (dbgs() << '[' << DEBUG_TYPE << "] ") + +//===----------------------------------------------------------------------===// +// Query Parser +//===----------------------------------------------------------------------===// + +mlir::LogicalResult mlir::mlirQueryMain(int argc, char **argv, + MLIRContext &context) { + // Override the default '-h' and use the default PrintHelpMessage() which + // won't print options in categories. + static llvm::cl::opt help("h", llvm::cl::desc("Alias for -help"), + llvm::cl::Hidden); + + static llvm::cl::OptionCategory mlirQueryCategory("mlir-query options"); + + static llvm::cl::opt inputFilename( + llvm::cl::Positional, llvm::cl::desc(""), + llvm::cl::cat(mlirQueryCategory)); + + static llvm::cl::opt noImplicitModule{ + "no-implicit-module", + llvm::cl::desc( + "Disable implicit addition of a top-level module op during parsing"), + llvm::cl::init(false)}; + + llvm::cl::HideUnrelatedOptions(mlirQueryCategory); + + llvm::InitLLVM y(argc, argv); + + llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR test case query tool.\n"); + + if (help) { + llvm::cl::PrintHelpMessage(); + return success(); + } + + // Set up the input file. + std::string errorMessage; + auto file = openInputFile(inputFilename, &errorMessage); + if (!file) { + return failure(); + } + + auto sourceMgr = std::make_shared(); + sourceMgr->AddNewSourceBuffer(std::move(file), SMLoc()); + + // Parse the input MLIR file. + OwningOpRef opRef = + parseSourceFileForTool(sourceMgr, &context, !noImplicitModule); + if (!opRef) + return failure(); + + mlir::query::QuerySession QS(opRef.get(), sourceMgr); + llvm::LineEditor LE("mlir-query"); + LE.setListCompleter([&QS](StringRef line, size_t pos) { + return mlir::query::QueryParser::complete(line, pos, QS); + }); + while (std::optional line = LE.readLine()) { + mlir::query::QueryRef queryRef = mlir::query::QueryParser::parse(*line, QS); + queryRef->run(llvm::outs(), QS); + llvm::outs().flush(); + if (QS.terminate) + break; + } + + return success(); +} diff --git a/mlir/tools/CMakeLists.txt b/mlir/tools/CMakeLists.txt --- a/mlir/tools/CMakeLists.txt +++ b/mlir/tools/CMakeLists.txt @@ -2,6 +2,7 @@ add_subdirectory(mlir-opt) add_subdirectory(mlir-parser-fuzzer) add_subdirectory(mlir-pdll-lsp-server) +add_subdirectory(mlir-query) add_subdirectory(mlir-reduce) add_subdirectory(mlir-shlib) add_subdirectory(mlir-spirv-cpu-runner) diff --git a/mlir/tools/mlir-query/CMakeLists.txt b/mlir/tools/mlir-query/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/tools/mlir-query/CMakeLists.txt @@ -0,0 +1,20 @@ +get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) + +if(MLIR_INCLUDE_TESTS) + set(test_libs + MLIRTestDialect + ) +endif() + +add_mlir_tool(mlir-query + mlir-query.cpp + ) +llvm_update_compile_flags(mlir-query) +target_link_libraries(mlir-query + PRIVATE + ${dialect_libs} + ${test_libs} + MLIRQueryLib + ) + +mlir_check_link_libraries(mlir-query) diff --git a/mlir/tools/mlir-query/mlir-query.cpp b/mlir/tools/mlir-query/mlir-query.cpp new file mode 100644 --- /dev/null +++ b/mlir/tools/mlir-query/mlir-query.cpp @@ -0,0 +1,38 @@ +//===- mlir-query.cpp - MLIR Query Driver -------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is a command line utility that queries a file from/to MLIR using one +// of the registered queries. +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/InitAllDialects.h" +#include "mlir/Tools/mlir-query/MlirQueryMain.h" + +using namespace mlir; + +namespace test { +#ifdef MLIR_INCLUDE_TESTS +void registerTestDialect(DialectRegistry &); +#endif +} // namespace test + +int main(int argc, char **argv) { + + DialectRegistry registry; + registerAllDialects(registry); +#ifdef MLIR_INCLUDE_TESTS + test::registerTestDialect(registry); +#endif + MLIRContext context(registry); + context.allowUnregisteredDialects(true); + + return failed(mlirQueryMain(argc, argv, context)); +}