diff --git a/mlir/include/mlir/Query/Matcher/ErrorBuilder.h b/mlir/include/mlir/Query/Matcher/ErrorBuilder.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Query/Matcher/ErrorBuilder.h @@ -0,0 +1,63 @@ +//===--- ErrorBuilder.h - Helper for building error messages ----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// ErrorBuilder to manage error messages. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_ERRORBUILDER_H +#define MLIR_TOOLS_MLIRQUERY_MATCHER_ERRORBUILDER_H + +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Twine.h" +#include + +namespace mlir::query::matcher::internal { +class Diagnostics; + +// Represents the line and column numbers in a source query. +struct SourceLocation { + unsigned line{}; + unsigned column{}; +}; + +// Represents a range in a source query, defined by its start and end locations. +struct SourceRange { + SourceLocation start{}; + SourceLocation end{}; +}; + +// All errors from the system. +enum class ErrorType { + None, + + // Parser Errors + ParserFailedToBuildMatcher, + ParserInvalidToken, + ParserNoCloseParen, + ParserNoCode, + ParserNoComma, + ParserNoOpenParen, + ParserNotAMatcher, + ParserOverloadedType, + ParserStringError, + ParserTrailingCode, + + // Registry Errors + RegistryMatcherNotFound, + RegistryValueNotFound, + RegistryWrongArgCount, + RegistryWrongArgType +}; + +void addError(Diagnostics *error, SourceRange range, ErrorType errorType, + std::initializer_list errorTexts); + +} // namespace mlir::query::matcher::internal + +#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_ERRORBUILDER_H diff --git a/mlir/include/mlir/Query/Matcher/Marshallers.h b/mlir/include/mlir/Query/Matcher/Marshallers.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Query/Matcher/Marshallers.h @@ -0,0 +1,199 @@ +//===--- Marshallers.h - Generic matcher function marshallers ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains function templates and classes to wrap matcher construct +// functions. It provides a collection of template function and classes that +// present a generic marshalling layer on top of matcher construct functions. +// The registry uses these to export all marshaller constructors with a uniform +// interface. This mechanism takes inspiration from clang-query. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_MARSHALLERS_H +#define MLIR_TOOLS_MLIRQUERY_MATCHER_MARSHALLERS_H + +#include "ErrorBuilder.h" +#include "VariantValue.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/StringRef.h" + +namespace mlir::query::matcher::internal { + +// Helper template class for jumping from argument type to the correct is/get +// functions in VariantValue. This is used for verifying and extracting the +// matcher arguments. +template +struct ArgTypeTraits; +template +struct ArgTypeTraits : public ArgTypeTraits {}; + +template <> +struct ArgTypeTraits { + + static bool hasCorrectType(const VariantValue &value) { + return value.isString(); + } + + static const llvm::StringRef &get(const VariantValue &value) { + return value.getString(); + } + + static ArgKind getKind() { return ArgKind::String; } + + static std::optional getBestGuess(const VariantValue &) { + return std::nullopt; + } +}; + +template <> +struct ArgTypeTraits { + + static bool hasCorrectType(const VariantValue &value) { + return value.isMatcher(); + } + + static DynMatcher get(const VariantValue &value) { + return *value.getMatcher().getDynMatcher(); + } + + static ArgKind getKind() { return ArgKind::Matcher; } + + static std::optional getBestGuess(const VariantValue &) { + return std::nullopt; + } +}; + +// Interface for generic matcher descriptor. +// Offers a create() method that constructs the matcher from the provided +// arguments. +class MatcherDescriptor { +public: + virtual ~MatcherDescriptor() = default; + virtual VariantMatcher create(SourceRange nameRange, + const llvm::ArrayRef args, + Diagnostics *error) const = 0; + + // Returns the number of arguments accepted by the matcher. + virtual unsigned getNumArgs() const = 0; + + // Append the set of argument types accepted for argument 'argNo' to + // 'argKinds'. + virtual void getArgKinds(unsigned argNo, + std::vector &argKinds) const = 0; +}; + +class FixedArgCountMatcherDescriptor : public MatcherDescriptor { +public: + using MarshallerType = VariantMatcher (*)(void (*matcherFunc)(), + llvm::StringRef matcherName, + SourceRange nameRange, + llvm::ArrayRef args, + Diagnostics *error); + + // Marshaller Function to unpack the arguments and call Func. Func is the + // Matcher construct function. This is the function that the matcher + // expressions would use to create the matcher. + FixedArgCountMatcherDescriptor(MarshallerType marshaller, + void (*matcherFunc)(), + llvm::StringRef matcherName, + llvm::ArrayRef argKinds) + : marshaller(marshaller), matcherFunc(matcherFunc), + matcherName(matcherName), argKinds(argKinds.begin(), argKinds.end()) {} + + VariantMatcher create(SourceRange nameRange, llvm::ArrayRef args, + Diagnostics *error) const override { + return marshaller(matcherFunc, matcherName, nameRange, args, error); + } + + unsigned getNumArgs() const override { return argKinds.size(); } + + void getArgKinds(unsigned argNo, std::vector &kinds) const override { + kinds.push_back(argKinds[argNo]); + } + +private: + const MarshallerType marshaller; + void (*const matcherFunc)(); + const llvm::StringRef matcherName; + const std::vector argKinds; +}; + +// Helper function to check if argument count matches expected count +inline bool checkArgCount(SourceRange nameRange, size_t expectedArgCount, + llvm::ArrayRef args, + Diagnostics *error) { + if (args.size() != expectedArgCount) { + addError(error, nameRange, ErrorType::RegistryWrongArgCount, + {llvm::Twine(expectedArgCount), llvm::Twine(args.size())}); + return false; + } + return true; +} + +// Helper function for checking argument type +template +inline bool checkArgTypeAtIndex(llvm::StringRef matcherName, + llvm::ArrayRef args, + Diagnostics *error) { + if (!ArgTypeTraits::hasCorrectType(args[Index].value)) { + addError(error, args[Index].range, ErrorType::RegistryWrongArgType, + {llvm::Twine(matcherName), llvm::Twine(Index + 1)}); + return false; + } + return true; +} + +// Marshaller function for fixed number of arguments +template +static VariantMatcher +matcherMarshallFixedImpl(void (*matcherFunc)(), llvm::StringRef matcherName, + SourceRange nameRange, + llvm::ArrayRef args, Diagnostics *error, + std::index_sequence) { + using FuncType = ReturnType (*)(ArgTypes...); + + // Check if the argument count matches the expected count + if (!checkArgCount(nameRange, sizeof...(ArgTypes), args, error)) + return VariantMatcher(); + + // Check if each argument at the corresponding index has the correct type + if ((... && checkArgTypeAtIndex(matcherName, args, error))) { + ReturnType fnPointer = reinterpret_cast(matcherFunc)( + ArgTypeTraits::get(args[Is].value)...); + return VariantMatcher::SingleMatcher( + *DynMatcher::constructDynMatcherFromMatcherFn(fnPointer)); + } + + return VariantMatcher(); +} + +template +static VariantMatcher +matcherMarshallFixed(void (*matcherFunc)(), llvm::StringRef matcherName, + SourceRange nameRange, llvm::ArrayRef args, + Diagnostics *error) { + return matcherMarshallFixedImpl( + matcherFunc, matcherName, nameRange, args, error, + std::index_sequence_for{}); +} + +// Fixed number of arguments overload +template +std::unique_ptr +makeMatcherAutoMarshall(ReturnType (*matcherFunc)(ArgTypes...), + llvm::StringRef matcherName) { + // Create a vector of argument kinds + std::vector argKinds = {ArgTypeTraits::getKind()...}; + return std::make_unique( + matcherMarshallFixed, + reinterpret_cast(matcherFunc), matcherName, argKinds); +} + +} // namespace mlir::query::matcher::internal + +#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MARSHALLERS_H diff --git a/mlir/include/mlir/Query/Matcher/MatchFinder.h b/mlir/include/mlir/Query/Matcher/MatchFinder.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Query/Matcher/MatchFinder.h @@ -0,0 +1,41 @@ +//===- MatchFinder.h - ------------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the MatchFinder class, which is used to find operations +// that match a given matcher. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERFINDER_H +#define MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERFINDER_H + +#include "MatchersInternal.h" + +namespace mlir::query::matcher { + +// MatchFinder is used to find all operations that match a given matcher. +class MatchFinder { +public: + // Returns all operations that match the given matcher. + static std::vector getMatches(Operation *root, + DynMatcher matcher) { + std::vector matches; + + // Simple match finding with walk. + root->walk([&](Operation *subOp) { + if (matcher.match(subOp)) + matches.push_back(subOp); + }); + + return matches; + } +}; + +} // namespace mlir::query::matcher + +#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERFINDER_H diff --git a/mlir/include/mlir/Query/Matcher/MatchersInternal.h b/mlir/include/mlir/Query/Matcher/MatchersInternal.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Query/Matcher/MatchersInternal.h @@ -0,0 +1,72 @@ +//===- MatchersInternal.h - Structural query framework ----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Implements the base layer of the matcher framework. +// +// Matchers are methods that return a Matcher which provides a method +// match(Operation *op) +// +// The matcher functions are defined in include/mlir/IR/Matchers.h. +// This file contains the wrapper classes needed to construct matchers for +// mlir-query. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H +#define MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H + +#include "mlir/IR/Matchers.h" +#include "llvm/ADT/IntrusiveRefCntPtr.h" + +namespace mlir::query::matcher { + +// Generic interface for matchers on an MLIR operation. +class MatcherInterface + : public llvm::ThreadSafeRefCountedBase { +public: + virtual ~MatcherInterface() = default; + + virtual bool match(Operation *op) = 0; +}; + +// MatcherFnImpl takes a matcher function object and implements +// MatcherInterface. +template +class MatcherFnImpl : public MatcherInterface { +public: + MatcherFnImpl(MatcherFn &matcherFn) : matcherFn(matcherFn) {} + bool match(Operation *op) override { return matcherFn.match(op); } + +private: + MatcherFn matcherFn; +}; + +// Matcher wraps a MatcherInterface implementation and provides a match() +// method that redirects calls to the underlying implementation. +class DynMatcher { +public: + // Takes ownership of the provided implementation pointer. + DynMatcher(MatcherInterface *implementation) + : implementation(implementation) {} + + template + static std::unique_ptr + constructDynMatcherFromMatcherFn(MatcherFn &matcherFn) { + auto impl = std::make_unique>(matcherFn); + return std::make_unique(impl.release()); + } + + bool match(Operation *op) const { return implementation->match(op); } + +private: + llvm::IntrusiveRefCntPtr implementation; +}; + +} // namespace mlir::query::matcher + +#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H diff --git a/mlir/include/mlir/Query/Matcher/Registry.h b/mlir/include/mlir/Query/Matcher/Registry.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Query/Matcher/Registry.h @@ -0,0 +1,51 @@ +//===--- Registry.h - Matcher Registry --------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Registry class to manage the registry of matchers using a map. +// +// This class provides a convenient interface for registering and accessing +// matcher constructors using a string-based map. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_REGISTRY_H +#define MLIR_TOOLS_MLIRQUERY_MATCHER_REGISTRY_H + +#include "Marshallers.h" +#include "llvm/ADT/StringMap.h" +#include + +namespace mlir::query::matcher { + +using ConstructorMap = + llvm::StringMap>; + +class Registry { +public: + Registry() = default; + ~Registry() = default; + + const ConstructorMap &constructors() const { return constructorMap; } + + template + void registerMatcher(const std::string &name, MatcherType matcher) { + registerMatcherDescriptor(name, + internal::makeMatcherAutoMarshall(matcher, name)); + } + +private: + void registerMatcherDescriptor( + llvm::StringRef matcherName, + std::unique_ptr callback); + + ConstructorMap constructorMap; +}; + +} // namespace mlir::query::matcher + +#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_REGISTRY_H diff --git a/mlir/include/mlir/Query/Matcher/VariantValue.h b/mlir/include/mlir/Query/Matcher/VariantValue.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Query/Matcher/VariantValue.h @@ -0,0 +1,128 @@ +//===--- VariantValue.h -----------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Supports all the types required for dynamic Matcher construction. +// Used by the registry to construct matchers in a generic way. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_VARIANTVALUE_H +#define MLIR_TOOLS_MLIRQUERY_MATCHER_VARIANTVALUE_H + +#include "ErrorBuilder.h" +#include "MatchersInternal.h" +#include "llvm/ADT/StringRef.h" + +namespace mlir::query::matcher { + +// All types that VariantValue can contain. +enum class ArgKind { Matcher, String }; + +// A variant matcher object to abstract simple and complex matchers into a +// single object type. +class VariantMatcher { + class MatcherOps; + + // Payload interface to be specialized by each matcher type. It follows a + // similar interface as VariantMatcher itself. + class Payload { + public: + virtual ~Payload(); + virtual std::optional getDynMatcher() const = 0; + virtual std::string getTypeAsString() const = 0; + }; + +public: + // A null matcher. + VariantMatcher(); + + // Clones the provided matcher. + static VariantMatcher SingleMatcher(DynMatcher matcher); + + // Makes the matcher the "null" matcher. + void reset(); + + // Checks if the matcher is null. + bool isNull() const { return !value; } + + // Returns the matcher + std::optional getDynMatcher() const; + + // String representation of the type of the value. + std::string getTypeAsString() const; + +private: + explicit VariantMatcher(std::shared_ptr value) + : value(std::move(value)) {} + + class SinglePayload; + + std::shared_ptr value; +}; + +// Variant value class with a tagged union with value type semantics. It is used +// by the registry as the return value and argument type for the matcher factory +// methods. It can be constructed from any of the supported types: +// - StringRef +// - VariantMatcher +class VariantValue { +public: + VariantValue() : type(ValueType::Nothing) {} + + VariantValue(const VariantValue &other); + ~VariantValue(); + VariantValue &operator=(const VariantValue &other); + + // Specific constructors for each supported type. + VariantValue(const llvm::StringRef string); + VariantValue(const VariantMatcher &matcher); + + // String value functions. + bool isString() const; + const llvm::StringRef &getString() const; + void setString(const llvm::StringRef &string); + + // Matcher value functions. + bool isMatcher() const; + const VariantMatcher &getMatcher() const; + void setMatcher(const VariantMatcher &matcher); + + // String representation of the type of the value. + std::string getTypeAsString() const; + +private: + void reset(); + + // All supported value types. + enum class ValueType { + Nothing, + String, + Matcher, + }; + + // All supported value types. + union AllValues { + llvm::StringRef *String; + VariantMatcher *Matcher; + }; + + ValueType type; + AllValues value; +}; + +// A VariantValue instance annotated with its parser context. +struct ParserValue { + ParserValue() {} + llvm::StringRef text; + internal::SourceRange range; + VariantValue value; +}; + +} // namespace mlir::query::matcher + +#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_VARIANTVALUE_H diff --git a/mlir/include/mlir/Query/Query.h b/mlir/include/mlir/Query/Query.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Query/Query.h @@ -0,0 +1,109 @@ +//===--- Query.h ------------------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOOLS_MLIRQUERY_QUERY_H +#define MLIR_TOOLS_MLIRQUERY_QUERY_H + +#include "Matcher/VariantValue.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/IntrusiveRefCntPtr.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/LineEditor/LineEditor.h" +#include + +namespace mlir::query { + +enum class QueryKind { Invalid, NoOp, Help, Match, Quit }; + +class QuerySession; + +struct Query : llvm::RefCountedBase { + Query(QueryKind kind) : kind(kind) {} + virtual ~Query(); + + // Perform the query on qs and print output to os. + virtual mlir::LogicalResult run(llvm::raw_ostream &os, + QuerySession &qs) const = 0; + + llvm::StringRef remainingContent; + const QueryKind kind; +}; + +typedef llvm::IntrusiveRefCntPtr QueryRef; + +QueryRef parse(llvm::StringRef line, const QuerySession &qs); + +std::vector +complete(llvm::StringRef line, size_t pos, const QuerySession &qs); + +// Any query which resulted in a parse error. The error message is in ErrStr. +struct InvalidQuery : Query { + InvalidQuery(const llvm::Twine &errStr) + : Query(QueryKind::Invalid), errStr(errStr.str()) {} + mlir::LogicalResult run(llvm::raw_ostream &os, + QuerySession &qs) const override; + + std::string errStr; + + static bool classof(const Query *query) { + return query->kind == QueryKind::Invalid; + } +}; + +// No-op query (i.e. a blank line). +struct NoOpQuery : Query { + NoOpQuery() : Query(QueryKind::NoOp) {} + mlir::LogicalResult run(llvm::raw_ostream &os, + QuerySession &qs) const override; + + static bool classof(const Query *query) { + return query->kind == QueryKind::NoOp; + } +}; + +// Query for "help". +struct HelpQuery : Query { + HelpQuery() : Query(QueryKind::Help) {} + mlir::LogicalResult run(llvm::raw_ostream &os, + QuerySession &qs) const override; + + static bool classof(const Query *query) { + return query->kind == QueryKind::Help; + } +}; + +// Query for "quit". +struct QuitQuery : Query { + QuitQuery() : Query(QueryKind::Quit) {} + mlir::LogicalResult run(llvm::raw_ostream &os, + QuerySession &qs) const override; + + static bool classof(const Query *query) { + return query->kind == QueryKind::Quit; + } +}; + +// Query for "match MATCHER". +struct MatchQuery : Query { + MatchQuery(llvm::StringRef source, const matcher::DynMatcher &matcher) + : Query(QueryKind::Match), matcher(matcher), source(source) {} + mlir::LogicalResult run(llvm::raw_ostream &os, + QuerySession &qs) const override; + + const matcher::DynMatcher matcher; + + llvm::StringRef source; + + static bool classof(const Query *query) { + return query->kind == QueryKind::Match; + } +}; + +} // namespace mlir::query + +#endif diff --git a/mlir/include/mlir/Query/QuerySession.h b/mlir/include/mlir/Query/QuerySession.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Query/QuerySession.h @@ -0,0 +1,42 @@ +//===--- QuerySession.h -----------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOOLS_MLIRQUERY_QUERYSESSION_H +#define MLIR_TOOLS_MLIRQUERY_QUERYSESSION_H + +#include "llvm/ADT/StringMap.h" + +namespace mlir::query { + +class Registry; +// Represents the state for a particular mlir-query session. +class QuerySession { +public: + QuerySession(Operation *rootOp, llvm::SourceMgr &sourceMgr, unsigned bufferId, + const matcher::Registry &matcherRegistry) + : rootOp(rootOp), sourceMgr(sourceMgr), bufferId(bufferId), + matcherRegistry(matcherRegistry) {} + + Operation *getRootOp() { return rootOp; } + llvm::SourceMgr &getSourceManager() const { return sourceMgr; } + unsigned getBufferId() { return bufferId; } + const matcher::Registry &getRegistryData() const { return matcherRegistry; } + + llvm::StringMap namedValues; + bool terminate = false; + +private: + Operation *rootOp; + llvm::SourceMgr &sourceMgr; + unsigned bufferId; + const matcher::Registry &matcherRegistry; +}; + +} // namespace mlir::query + +#endif // MLIR_TOOLS_MLIRQUERY_QUERYSESSION_H diff --git a/mlir/include/mlir/Tools/mlir-query/MlirQueryMain.h b/mlir/include/mlir/Tools/mlir-query/MlirQueryMain.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Tools/mlir-query/MlirQueryMain.h @@ -0,0 +1,30 @@ +//===- MlirQueryMain.h - MLIR Query main ----------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Main entry function for mlir-query for when built as standalone +// binary. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOOLS_MLIRQUERY_MLIRQUERYMAIN_H +#define MLIR_TOOLS_MLIRQUERY_MLIRQUERYMAIN_H + +#include "mlir/Query/Matcher/Registry.h" +#include "mlir/Support/LogicalResult.h" + +namespace mlir { + +class MLIRContext; + +LogicalResult +mlirQueryMain(int argc, char **argv, MLIRContext &context, + const mlir::query::matcher::Registry &matcherRegistry); + +} // namespace mlir + +#endif // MLIR_TOOLS_MLIRQUERY_MLIRQUERYMAIN_H diff --git a/mlir/lib/CMakeLists.txt b/mlir/lib/CMakeLists.txt --- a/mlir/lib/CMakeLists.txt +++ b/mlir/lib/CMakeLists.txt @@ -11,6 +11,7 @@ add_subdirectory(Interfaces) add_subdirectory(Parser) add_subdirectory(Pass) +add_subdirectory(Query) add_subdirectory(Reducer) add_subdirectory(Rewrite) add_subdirectory(Support) diff --git a/mlir/lib/Query/CMakeLists.txt b/mlir/lib/Query/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Query/CMakeLists.txt @@ -0,0 +1,12 @@ +add_mlir_library(MLIRQuery + Query.cpp + QueryParser.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Query + + LINK_LIBS PUBLIC + MLIRQueryMatcher + ) + +add_subdirectory(Matcher) diff --git a/mlir/lib/Query/Matcher/CMakeLists.txt b/mlir/lib/Query/Matcher/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Query/Matcher/CMakeLists.txt @@ -0,0 +1,10 @@ +add_mlir_library(MLIRQueryMatcher + Parser.cpp + RegistryManager.cpp + VariantValue.cpp + Diagnostics.cpp + ErrorBuilder.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Query/Matcher + ) diff --git a/mlir/lib/Query/Matcher/Diagnostics.h b/mlir/lib/Query/Matcher/Diagnostics.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Query/Matcher/Diagnostics.h @@ -0,0 +1,82 @@ +//===--- Diagnostics.h - Helper class for error diagnostics -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Diagnostics class to manage error messages. Implementation shares similarity +// to clang-query Diagnostics. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_DIAGNOSTICS_H +#define MLIR_TOOLS_MLIRQUERY_MATCHER_DIAGNOSTICS_H + +#include "mlir/Query/Matcher/ErrorBuilder.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Support/raw_ostream.h" +#include +#include + +namespace mlir::query::matcher::internal { + +// Diagnostics class to manage error messages. +class Diagnostics { +public: + // Helper stream class for constructing error messages. + class ArgStream { + public: + ArgStream(std::vector *out) : out(out) {} + template + ArgStream &operator<<(const T &arg) { + return operator<<(llvm::Twine(arg)); + } + ArgStream &operator<<(const llvm::Twine &arg); + + private: + std::vector *out; + }; + + // Add an error message with the specified range and error type. + // Returns an ArgStream object to allow constructing the error message using + // the << operator. + ArgStream addError(SourceRange range, ErrorType error); + + // Print all error messages to the specified output stream. + void print(llvm::raw_ostream &os) const; + +private: + // Information stored for one frame of the context. + struct ContextFrame { + SourceRange range; + std::vector args; + }; + + // Information stored for each error found. + struct ErrorContent { + std::vector contextStack; + struct Message { + SourceRange range; + ErrorType type; + std::vector args; + }; + std::vector messages; + }; + + void printMessage(const ErrorContent::Message &message, + const llvm::Twine Prefix, llvm::raw_ostream &os) const; + + void printErrorContent(const ErrorContent &content, + llvm::raw_ostream &os) const; + + std::vector contextStack; + std::vector errorValues; +}; + +} // namespace mlir::query::matcher::internal + +#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_DIAGNOSTICS_H diff --git a/mlir/lib/Query/Matcher/Diagnostics.cpp b/mlir/lib/Query/Matcher/Diagnostics.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Query/Matcher/Diagnostics.cpp @@ -0,0 +1,128 @@ +//===- Diagnostic.cpp -----------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "Diagnostics.h" +#include "mlir/Query/Matcher/ErrorBuilder.h" + +namespace mlir::query::matcher::internal { + +Diagnostics::ArgStream & +Diagnostics::ArgStream::operator<<(const llvm::Twine &arg) { + out->push_back(arg.str()); + return *this; +} + +Diagnostics::ArgStream Diagnostics::addError(SourceRange range, + ErrorType error) { + errorValues.emplace_back(); + ErrorContent &last = errorValues.back(); + last.contextStack = contextStack; + last.messages.emplace_back(); + last.messages.back().range = range; + last.messages.back().type = error; + return ArgStream(&last.messages.back().args); +} + +static llvm::StringRef errorTypeToFormatString(ErrorType type) { + switch (type) { + case ErrorType::RegistryMatcherNotFound: + return "Matcher not found: $0"; + case ErrorType::RegistryWrongArgCount: + return "Incorrect argument count. (Expected = $0) != (Actual = $1)"; + case ErrorType::RegistryWrongArgType: + return "Incorrect type for arg $0. (Expected = $1) != (Actual = $2)"; + case ErrorType::RegistryValueNotFound: + return "Value not found: $0"; + + case ErrorType::ParserStringError: + return "Error parsing string token: <$0>"; + case ErrorType::ParserNoOpenParen: + return "Error parsing matcher. Found token <$0> while looking for '('."; + case ErrorType::ParserNoCloseParen: + return "Error parsing matcher. Found end-of-code while looking for ')'."; + case ErrorType::ParserNoComma: + return "Error parsing matcher. Found token <$0> while looking for ','."; + case ErrorType::ParserNoCode: + return "End of code found while looking for token."; + case ErrorType::ParserNotAMatcher: + return "Input value is not a matcher expression."; + case ErrorType::ParserInvalidToken: + return "Invalid token <$0> found when looking for a value."; + case ErrorType::ParserTrailingCode: + return "Unexpected end of code."; + case ErrorType::ParserOverloadedType: + return "Input value has unresolved overloaded type: $0"; + case ErrorType::ParserFailedToBuildMatcher: + return "Failed to build matcher: $0."; + + case ErrorType::None: + return ""; + } + llvm_unreachable("Unknown ErrorType value."); +} + +static void formatErrorString(llvm::StringRef formatString, + llvm::ArrayRef args, + llvm::raw_ostream &os) { + while (!formatString.empty()) { + std::pair pieces = + formatString.split("$"); + os << pieces.first.str(); + if (pieces.second.empty()) + break; + + const char next = pieces.second.front(); + formatString = pieces.second.drop_front(); + if (next >= '0' && next <= '9') { + const unsigned index = next - '0'; + if (index < args.size()) { + os << args[index]; + } else { + os << ""; + } + } + } +} + +static void maybeAddLineAndColumn(SourceRange range, llvm::raw_ostream &os) { + if (range.start.line > 0 && range.start.column > 0) { + os << range.start.line << ":" << range.start.column << ": "; + } +} + +void Diagnostics::printMessage( + const Diagnostics::ErrorContent::Message &message, const llvm::Twine prefix, + llvm::raw_ostream &os) const { + maybeAddLineAndColumn(message.range, os); + os << prefix; + formatErrorString(errorTypeToFormatString(message.type), message.args, os); +} + +void Diagnostics::printErrorContent(const Diagnostics::ErrorContent &content, + llvm::raw_ostream &os) const { + if (content.messages.size() == 1) { + printMessage(content.messages[0], "", os); + } else { + for (size_t i = 0, e = content.messages.size(); i != e; ++i) { + if (i != 0) + os << "\n"; + printMessage(content.messages[i], + "Candidate " + llvm::Twine(i + 1) + ": ", os); + } + } +} + +void Diagnostics::print(llvm::raw_ostream &os) const { + for (const ErrorContent &error : errorValues) { + if (&error != &errorValues.front()) + os << "\n"; + printErrorContent(error, os); + } +} + +} // namespace mlir::query::matcher::internal diff --git a/mlir/lib/Query/Matcher/ErrorBuilder.cpp b/mlir/lib/Query/Matcher/ErrorBuilder.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Query/Matcher/ErrorBuilder.cpp @@ -0,0 +1,25 @@ +//===--- ErrorBuilder.cpp - Helper for building error messages ------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Query/Matcher/ErrorBuilder.h" +#include "Diagnostics.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Twine.h" +#include + +namespace mlir::query::matcher::internal { + +void addError(Diagnostics *error, SourceRange range, ErrorType errorType, + std::initializer_list errorTexts) { + Diagnostics::ArgStream argStream = error->addError(range, errorType); + for (const llvm::Twine &errorText : errorTexts) { + argStream << errorText; + } +} + +} // namespace mlir::query::matcher::internal diff --git a/mlir/lib/Query/Matcher/Parser.h b/mlir/lib/Query/Matcher/Parser.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Query/Matcher/Parser.h @@ -0,0 +1,188 @@ +//===--- Parser.h - Matcher expression parser -------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Simple matcher expression parser. +// +// This file contains the Parser class, which is responsible for parsing +// expressions in a specific format: matcherName(Arg0, Arg1, ..., ArgN). The +// parser can also interpret simple types, like strings. +// +// The actual processing of the matchers is handled by a Sema object that is +// provided to the parser. +// +// The grammar for the supported expressions is as follows: +// := | +// := "quoted string" +// := () +// := [a-zA-Z]+ +// := | , +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_PARSER_H +#define MLIR_TOOLS_MLIRQUERY_MATCHER_PARSER_H + +#include "Diagnostics.h" +#include "RegistryManager.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" +#include +#include + +namespace mlir::query::matcher::internal { + +// Matcher expression parser. +class Parser { +public: + // Different possible tokens. + enum class TokenKind { + Eof, + NewLine, + OpenParen, + CloseParen, + Comma, + Period, + Literal, + Ident, + InvalidChar, + CodeCompletion, + Error + }; + + // Interface to connect the parser with the registry and more. The parser uses + // the Sema instance passed into parseMatcherExpression() to handle all + // matcher tokens. + class Sema { + public: + virtual ~Sema(); + + // Process a matcher expression. The caller takes ownership of the Matcher + // object returned. + virtual VariantMatcher + actOnMatcherExpression(MatcherCtor ctor, SourceRange nameRange, + llvm::ArrayRef args, + Diagnostics *error) = 0; + + // Look up a matcher by name in the matcher name found by the parser. + virtual std::optional + lookupMatcherCtor(llvm::StringRef matcherName) = 0; + + // Compute the list of completion types for Context. + virtual std::vector getAcceptedCompletionTypes( + llvm::ArrayRef> Context); + + // Compute the list of completions that match any of acceptedTypes. + virtual std::vector + getMatcherCompletions(llvm::ArrayRef acceptedTypes); + }; + + // An implementation of the Sema interface that uses the matcher registry to + // process tokens. + class RegistrySema : public Parser::Sema { + public: + RegistrySema(const Registry &matcherRegistry) + : matcherRegistry(matcherRegistry) {} + ~RegistrySema() override; + + std::optional + lookupMatcherCtor(llvm::StringRef matcherName) override; + + VariantMatcher actOnMatcherExpression(MatcherCtor ctor, + SourceRange nameRange, + llvm::ArrayRef args, + Diagnostics *error) override; + + std::vector getAcceptedCompletionTypes( + llvm::ArrayRef> context) override; + + std::vector + getMatcherCompletions(llvm::ArrayRef acceptedTypes) override; + + private: + const Registry &matcherRegistry; + }; + + using NamedValueMap = llvm::StringMap; + + // Methods to parse a matcher expression and return a DynMatcher object, + // transferring ownership to the caller. + static std::optional + parseMatcherExpression(llvm::StringRef &matcherCode, + const Registry &matcherRegistry, + const NamedValueMap *namedValues, Diagnostics *error); + static std::optional + parseMatcherExpression(llvm::StringRef &matcherCode, + const Registry &matcherRegistry, Diagnostics *error) { + return parseMatcherExpression(matcherCode, matcherRegistry, nullptr, error); + } + + // Methods to parse any expression supported by this parser. + static bool parseExpression(llvm::StringRef &code, + const Registry &matcherRegistry, + const NamedValueMap *namedValues, + VariantValue *value, Diagnostics *error); + + static bool parseExpression(llvm::StringRef &code, + const Registry &matcherRegistry, + VariantValue *value, Diagnostics *error) { + return parseExpression(code, matcherRegistry, nullptr, value, error); + } + + // Methods to complete an expression at a given offset. + static std::vector + completeExpression(llvm::StringRef &code, unsigned completionOffset, + const Registry &matcherRegistry, + const NamedValueMap *namedValues); + static std::vector + completeExpression(llvm::StringRef &code, unsigned completionOffset, + const Registry &matcherRegistry) { + return completeExpression(code, completionOffset, matcherRegistry, nullptr); + } + +private: + class CodeTokenizer; + struct ScopedContextEntry; + struct TokenInfo; + + Parser(CodeTokenizer *tokenizer, const Registry &matcherRegistry, + const NamedValueMap *namedValues, Diagnostics *error); + + bool parseExpressionImpl(VariantValue *value); + + bool parseMatcherArgs(std::vector &args, MatcherCtor ctor, + const TokenInfo &nameToken, TokenInfo &endToken); + + bool parseMatcherExpressionImpl(const TokenInfo &nameToken, + const TokenInfo &openToken, + std::optional ctor, + VariantValue *value); + + bool parseIdentifierPrefixImpl(VariantValue *value); + + void addCompletion(const TokenInfo &compToken, + const MatcherCompletion &completion); + void addExpressionCompletions(); + + std::vector + getNamedValueCompletions(llvm::ArrayRef acceptedTypes); + + CodeTokenizer *const tokenizer; + std::unique_ptr sema; + const NamedValueMap *const namedValues; + Diagnostics *const error; + + using ContextStackTy = std::vector>; + + ContextStackTy contextStack; + std::vector completions; +}; + +} // namespace mlir::query::matcher::internal + +#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_PARSER_H diff --git a/mlir/lib/Query/Matcher/Parser.cpp b/mlir/lib/Query/Matcher/Parser.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Query/Matcher/Parser.cpp @@ -0,0 +1,540 @@ +//===- Parser.cpp - Matcher expression parser -----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Recursive parser implementation for the matcher expression grammar. +// +//===----------------------------------------------------------------------===// + +#include "Parser.h" + +#include + +namespace mlir::query::matcher::internal { + +// Simple structure to hold information for one token from the parser. +struct Parser::TokenInfo { + TokenInfo() = default; + + // Method to set the kind and text of the token + void set(TokenKind newKind, llvm::StringRef newText) { + kind = newKind; + text = newText; + } + + llvm::StringRef text; + TokenKind kind = TokenKind::Eof; + SourceRange range; + VariantValue value; +}; + +class Parser::CodeTokenizer { +public: + // Constructor with matcherCode and error + explicit CodeTokenizer(llvm::StringRef matcherCode, Diagnostics *error) + : code(matcherCode), startOfLine(matcherCode), error(error) { + nextToken = getNextToken(); + } + + // Constructor with matcherCode, error, and codeCompletionOffset + CodeTokenizer(llvm::StringRef matcherCode, Diagnostics *error, + unsigned codeCompletionOffset) + : code(matcherCode), startOfLine(matcherCode), error(error), + codeCompletionLocation(matcherCode.data() + codeCompletionOffset) { + nextToken = getNextToken(); + } + + // Peek at next token without consuming it + const TokenInfo &peekNextToken() const { return nextToken; } + + // Consume and return the next token + TokenInfo consumeNextToken() { + TokenInfo thisToken = nextToken; + nextToken = getNextToken(); + return thisToken; + } + + // Skip any newline tokens + TokenInfo skipNewlines() { + while (nextToken.kind == TokenKind::NewLine) + nextToken = getNextToken(); + return nextToken; + } + + // Consume and return next token, ignoring newlines + TokenInfo consumeNextTokenIgnoreNewlines() { + skipNewlines(); + return nextToken.kind == TokenKind::Eof ? nextToken : consumeNextToken(); + } + + // Return kind of next token + TokenKind nextTokenKind() const { return nextToken.kind; } + +private: + // Helper function to get the first character as a new StringRef and drop it + // from the original string + llvm::StringRef firstCharacterAndDrop(llvm::StringRef &str) { + assert(!str.empty()); + llvm::StringRef firstChar = str.substr(0, 1); + str = str.drop_front(); + return firstChar; + } + + // Get next token, consuming whitespaces and handling different token types + TokenInfo getNextToken() { + consumeWhitespace(); + TokenInfo result; + result.range.start = currentLocation(); + + // Code completion case + if (codeCompletionLocation && codeCompletionLocation <= code.data()) { + result.set(TokenKind::CodeCompletion, + llvm::StringRef(codeCompletionLocation, 0)); + codeCompletionLocation = nullptr; + return result; + } + + // End of file case + if (code.empty()) { + result.set(TokenKind::Eof, ""); + return result; + } + + // Switch to handle specific characters + switch (code[0]) { + case '#': + code = code.drop_until([](char c) { return c == '\n'; }); + return getNextToken(); + case ',': + result.set(TokenKind::Comma, firstCharacterAndDrop(code)); + break; + case '.': + result.set(TokenKind::Period, firstCharacterAndDrop(code)); + break; + case '\n': + ++line; + startOfLine = code.drop_front(); + result.set(TokenKind::NewLine, firstCharacterAndDrop(code)); + break; + case '(': + result.set(TokenKind::OpenParen, firstCharacterAndDrop(code)); + break; + case ')': + result.set(TokenKind::CloseParen, firstCharacterAndDrop(code)); + break; + case '"': + case '\'': + consumeStringLiteral(&result); + break; + default: + parseIdentifierOrInvalid(&result); + break; + } + + result.range.end = currentLocation(); + return result; + } + + // Consume a string literal, handle escape sequences and missing closing + // quote. + void consumeStringLiteral(TokenInfo *result) { + bool inEscape = false; + const char marker = code[0]; + for (size_t length = 1; length < code.size(); ++length) { + if (inEscape) { + inEscape = false; + continue; + } + if (code[length] == '\\') { + inEscape = true; + continue; + } + if (code[length] == marker) { + result->kind = TokenKind::Literal; + result->text = code.substr(0, length + 1); + result->value = code.substr(1, length - 1); + code = code.drop_front(length + 1); + return; + } + } + llvm::StringRef errorText = code; + code = code.drop_front(code.size()); + SourceRange range; + range.start = result->range.start; + range.end = currentLocation(); + error->addError(range, ErrorType::ParserStringError) << errorText; + result->kind = TokenKind::Error; + } + + void parseIdentifierOrInvalid(TokenInfo *result) { + if (isalnum(code[0])) { + // Parse an identifier + size_t tokenLength = 1; + + while (true) { + // A code completion location in/immediately after an identifier will + // cause the portion of the identifier before the code completion + // location to become a code completion token. + if (codeCompletionLocation == code.data() + tokenLength) { + codeCompletionLocation = nullptr; + result->kind = TokenKind::CodeCompletion; + result->text = code.substr(0, tokenLength); + code = code.drop_front(tokenLength); + return; + } + if (tokenLength == code.size() || !(isalnum(code[tokenLength]))) + break; + ++tokenLength; + } + result->kind = TokenKind::Ident; + result->text = code.substr(0, tokenLength); + code = code.drop_front(tokenLength); + } else { + result->kind = TokenKind::InvalidChar; + result->text = code.substr(0, 1); + code = code.drop_front(1); + } + } + + // Consume all leading whitespace from code, except newlines + void consumeWhitespace() { + code = code.drop_while( + [](char c) { return llvm::StringRef(" \t\v\f\r").contains(c); }); + } + + // Returns the current location in the source code + SourceLocation currentLocation() { + SourceLocation location; + location.line = line; + location.column = code.data() - startOfLine.data() + 1; + return location; + } + + llvm::StringRef code; + llvm::StringRef startOfLine; + unsigned line = 1; + Diagnostics *error; + TokenInfo nextToken; + const char *codeCompletionLocation = nullptr; +}; + +Parser::Sema::~Sema() = default; + +std::vector Parser::Sema::getAcceptedCompletionTypes( + llvm::ArrayRef> context) { + return {}; +} + +std::vector +Parser::Sema::getMatcherCompletions(llvm::ArrayRef acceptedTypes) { + return {}; +} + +// Entry for the scope of a parser +struct Parser::ScopedContextEntry { + Parser *parser; + + ScopedContextEntry(Parser *parser, MatcherCtor c) : parser(parser) { + parser->contextStack.emplace_back(c, 0u); + } + + ~ScopedContextEntry() { parser->contextStack.pop_back(); } + + void nextArg() { ++parser->contextStack.back().second; } +}; + +// Parse and validate expressions starting with an identifier. +// This function can parse named values and matchers. In case of failure, it +// will try to determine the user's intent to give an appropriate error message. +bool Parser::parseIdentifierPrefixImpl(VariantValue *value) { + const TokenInfo nameToken = tokenizer->consumeNextToken(); + + if (tokenizer->nextTokenKind() != TokenKind::OpenParen) { + // Parse as a named value. + auto namedValue = + namedValues ? namedValues->lookup(nameToken.text) : VariantValue(); + + if (!namedValue.isMatcher()) { + error->addError(tokenizer->peekNextToken().range, + ErrorType::ParserNotAMatcher); + return false; + } + + if (tokenizer->nextTokenKind() == TokenKind::NewLine) { + error->addError(tokenizer->peekNextToken().range, + ErrorType::ParserNoOpenParen) + << "NewLine"; + return false; + } + + // If the syntax is correct and the name is not a matcher either, report + // an unknown named value. + if ((tokenizer->nextTokenKind() == TokenKind::Comma || + tokenizer->nextTokenKind() == TokenKind::CloseParen || + tokenizer->nextTokenKind() == TokenKind::NewLine || + tokenizer->nextTokenKind() == TokenKind::Eof) && + !sema->lookupMatcherCtor(nameToken.text)) { + error->addError(nameToken.range, ErrorType::RegistryValueNotFound) + << nameToken.text; + return false; + } + // Otherwise, fallback to the matcher parser. + } + + tokenizer->skipNewlines(); + + assert(nameToken.kind == TokenKind::Ident); + TokenInfo openToken = tokenizer->consumeNextToken(); + if (openToken.kind != TokenKind::OpenParen) { + error->addError(openToken.range, ErrorType::ParserNoOpenParen) + << openToken.text; + return false; + } + + std::optional ctor = sema->lookupMatcherCtor(nameToken.text); + + // Parse as a matcher expression. + return parseMatcherExpressionImpl(nameToken, openToken, ctor, value); +} + +// Parse the arguments of a matcher +bool Parser::parseMatcherArgs(std::vector &args, MatcherCtor ctor, + const TokenInfo &nameToken, TokenInfo &endToken) { + ScopedContextEntry sce(this, ctor); + + while (tokenizer->nextTokenKind() != TokenKind::Eof) { + if (tokenizer->nextTokenKind() == TokenKind::CloseParen) { + // end of args. + endToken = tokenizer->consumeNextToken(); + break; + } + + if (!args.empty()) { + // We must find a , token to continue. + TokenInfo commaToken = tokenizer->consumeNextToken(); + if (commaToken.kind != TokenKind::Comma) { + error->addError(commaToken.range, ErrorType::ParserNoComma) + << commaToken.text; + return false; + } + } + + ParserValue argValue; + tokenizer->skipNewlines(); + + argValue.text = tokenizer->peekNextToken().text; + argValue.range = tokenizer->peekNextToken().range; + if (!parseExpressionImpl(&argValue.value)) { + return false; + } + + tokenizer->skipNewlines(); + args.push_back(argValue); + sce.nextArg(); + } + + return true; +} + +// Parse and validate a matcher expression. +bool Parser::parseMatcherExpressionImpl(const TokenInfo &nameToken, + const TokenInfo &openToken, + std::optional ctor, + VariantValue *value) { + if (!ctor) { + error->addError(nameToken.range, ErrorType::RegistryMatcherNotFound) + << nameToken.text; + // Do not return here. We need to continue to give completion suggestions. + } + + std::vector args; + TokenInfo endToken; + + tokenizer->skipNewlines(); + + if (!parseMatcherArgs(args, ctor.value_or(nullptr), nameToken, endToken)) { + return false; + } + + // Check for the missing closing parenthesis + if (endToken.kind != TokenKind::CloseParen) { + error->addError(openToken.range, ErrorType::ParserNoCloseParen) + << nameToken.text; + return false; + } + + if (!ctor) + return false; + // Merge the start and end infos. + SourceRange matcherRange = nameToken.range; + matcherRange.end = endToken.range.end; + VariantMatcher result = + sema->actOnMatcherExpression(*ctor, matcherRange, args, error); + if (result.isNull()) + return false; + *value = result; + return true; +} + +// If the prefix of this completion matches the completion token, add it to +// completions minus the prefix. +void Parser::addCompletion(const TokenInfo &compToken, + const MatcherCompletion &completion) { + if (llvm::StringRef(completion.typedText).startswith(compToken.text)) { + completions.emplace_back(completion.typedText.substr(compToken.text.size()), + completion.matcherDecl); + } +} + +std::vector +Parser::getNamedValueCompletions(llvm::ArrayRef acceptedTypes) { + if (!namedValues) + return {}; + + std::vector result; + for (const auto &entry : *namedValues) { + std::string decl = + (entry.getValue().getTypeAsString() + " " + entry.getKey()).str(); + result.emplace_back(entry.getKey(), decl); + } + return result; +} + +void Parser::addExpressionCompletions() { + const TokenInfo compToken = tokenizer->consumeNextTokenIgnoreNewlines(); + assert(compToken.kind == TokenKind::CodeCompletion); + + // We cannot complete code if there is an invalid element on the context + // stack. + for (const auto &entry : contextStack) { + if (!entry.first) + return; + } + + auto acceptedTypes = sema->getAcceptedCompletionTypes(contextStack); + for (const auto &completion : sema->getMatcherCompletions(acceptedTypes)) { + addCompletion(compToken, completion); + } + + for (const auto &completion : getNamedValueCompletions(acceptedTypes)) { + addCompletion(compToken, completion); + } +} + +// Parse an +bool Parser::parseExpressionImpl(VariantValue *value) { + switch (tokenizer->nextTokenKind()) { + case TokenKind::Literal: + *value = tokenizer->consumeNextToken().value; + return true; + case TokenKind::Ident: + return parseIdentifierPrefixImpl(value); + case TokenKind::CodeCompletion: + addExpressionCompletions(); + return false; + case TokenKind::Eof: + error->addError(tokenizer->consumeNextToken().range, + ErrorType::ParserNoCode); + return false; + + case TokenKind::Error: + // This error was already reported by the tokenizer. + return false; + case TokenKind::NewLine: + case TokenKind::OpenParen: + case TokenKind::CloseParen: + case TokenKind::Comma: + case TokenKind::Period: + case TokenKind::InvalidChar: + const TokenInfo token = tokenizer->consumeNextToken(); + error->addError(token.range, ErrorType::ParserInvalidToken) + << (token.kind == TokenKind::NewLine ? "NewLine" : token.text); + return false; + } + + llvm_unreachable("Unknown token kind."); +} + +Parser::Parser(CodeTokenizer *tokenizer, const Registry &matcherRegistry, + const NamedValueMap *namedValues, Diagnostics *error) + : tokenizer(tokenizer), + sema(std::make_unique(matcherRegistry)), + namedValues(namedValues), error(error) {} + +Parser::RegistrySema::~RegistrySema() = default; + +std::optional +Parser::RegistrySema::lookupMatcherCtor(llvm::StringRef matcherName) { + return RegistryManager::lookupMatcherCtor(matcherName, matcherRegistry); +} + +VariantMatcher Parser::RegistrySema::actOnMatcherExpression( + MatcherCtor ctor, SourceRange nameRange, llvm::ArrayRef args, + Diagnostics *error) { + return RegistryManager::constructMatcher(ctor, nameRange, args, error); +} + +std::vector Parser::RegistrySema::getAcceptedCompletionTypes( + llvm::ArrayRef> context) { + return RegistryManager::getAcceptedCompletionTypes(context); +} + +std::vector Parser::RegistrySema::getMatcherCompletions( + llvm::ArrayRef acceptedTypes) { + return RegistryManager::getMatcherCompletions(acceptedTypes, matcherRegistry); +} + +bool Parser::parseExpression(llvm::StringRef &code, + const Registry &matcherRegistry, + const NamedValueMap *namedValues, + VariantValue *value, Diagnostics *error) { + CodeTokenizer tokenizer(code, error); + Parser parser(&tokenizer, matcherRegistry, namedValues, error); + if (!parser.parseExpressionImpl(value)) + return false; + auto nextToken = tokenizer.peekNextToken(); + if (nextToken.kind != TokenKind::Eof && + nextToken.kind != TokenKind::NewLine) { + error->addError(tokenizer.peekNextToken().range, + ErrorType::ParserTrailingCode); + return false; + } + return true; +} + +std::vector +Parser::completeExpression(llvm::StringRef &code, unsigned completionOffset, + const Registry &matcherRegistry, + const NamedValueMap *namedValues) { + Diagnostics error; + CodeTokenizer tokenizer(code, &error, completionOffset); + Parser parser(&tokenizer, matcherRegistry, namedValues, &error); + VariantValue dummy; + parser.parseExpressionImpl(&dummy); + + return parser.completions; +} + +std::optional Parser::parseMatcherExpression( + llvm::StringRef &code, const Registry &matcherRegistry, + const NamedValueMap *namedValues, Diagnostics *error) { + VariantValue value; + if (!parseExpression(code, matcherRegistry, namedValues, &value, error)) + return std::nullopt; + if (!value.isMatcher()) { + error->addError(SourceRange(), ErrorType::ParserNotAMatcher); + return std::nullopt; + } + std::optional result = value.getMatcher().getDynMatcher(); + if (!result) { + error->addError(SourceRange(), ErrorType::ParserOverloadedType) + << value.getTypeAsString(); + } + return result; +} + +} // namespace mlir::query::matcher::internal diff --git a/mlir/lib/Query/Matcher/RegistryManager.h b/mlir/lib/Query/Matcher/RegistryManager.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Query/Matcher/RegistryManager.h @@ -0,0 +1,70 @@ +//===--- RegistryManager.h - Matcher registry -------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// RegistryManager to manage registry of all known matchers. +// +// The registry provides a generic interface to construct any matcher by name. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_REGISTRYMANAGER_H +#define MLIR_TOOLS_MLIRQUERY_MATCHER_REGISTRYMANAGER_H + +#include "Diagnostics.h" +#include "mlir/Query/Matcher/Marshallers.h" +#include "mlir/Query/Matcher/Registry.h" +#include "mlir/Query/Matcher/VariantValue.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" +#include + +namespace mlir::query::matcher { + +using MatcherCtor = const internal::MatcherDescriptor *; + +struct MatcherCompletion { + MatcherCompletion() = default; + MatcherCompletion(llvm::StringRef typedText, llvm::StringRef matcherDecl) + : typedText(typedText.str()), matcherDecl(matcherDecl.str()) {} + + bool operator==(const MatcherCompletion &other) const { + return typedText == other.typedText && matcherDecl == other.matcherDecl; + } + + // The text to type to select this matcher. + std::string typedText; + + // The "declaration" of the matcher, with type information. + std::string matcherDecl; +}; + +class RegistryManager { +public: + RegistryManager() = delete; + + static std::optional + lookupMatcherCtor(llvm::StringRef matcherName, + const Registry &matcherRegistry); + + static std::vector getAcceptedCompletionTypes( + llvm::ArrayRef> context); + + static std::vector + getMatcherCompletions(ArrayRef acceptedTypes, + const Registry &matcherRegistry); + + static VariantMatcher constructMatcher(MatcherCtor ctor, + internal::SourceRange nameRange, + ArrayRef args, + internal::Diagnostics *error); +}; + +} // namespace mlir::query::matcher + +#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_REGISTRYMANAGER_H diff --git a/mlir/lib/Query/Matcher/RegistryManager.cpp b/mlir/lib/Query/Matcher/RegistryManager.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Query/Matcher/RegistryManager.cpp @@ -0,0 +1,139 @@ +//===- RegistryManager.cpp - Matcher registry -----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Registry map populated at static initialization time. +// +//===----------------------------------------------------------------------===// + +#include "RegistryManager.h" +#include "mlir/Query/Matcher/Registry.h" + +#include +#include + +namespace mlir::query::matcher { +namespace { + +// This is needed because these matchers are defined as overloaded functions. +using IsConstantOp = detail::constant_op_matcher(); +using HasOpAttrName = detail::AttrOpMatcher(llvm::StringRef); +using HasOpName = detail::NameOpMatcher(llvm::StringRef); + +// Enum to string for autocomplete. +static std::string asArgString(ArgKind kind) { + switch (kind) { + case ArgKind::Matcher: + return "Matcher"; + case ArgKind::String: + return "String"; + } + llvm_unreachable("Unhandled ArgKind"); +} + +} // namespace + +void Registry::registerMatcherDescriptor( + llvm::StringRef matcherName, + std::unique_ptr callback) { + assert(!constructorMap.contains(matcherName)); + constructorMap[matcherName] = std::move(callback); +} + +std::optional +RegistryManager::lookupMatcherCtor(llvm::StringRef matcherName, + const Registry &matcherRegistry) { + auto it = matcherRegistry.constructors().find(matcherName); + return it == matcherRegistry.constructors().end() + ? std::optional() + : it->second.get(); +} + +std::vector RegistryManager::getAcceptedCompletionTypes( + llvm::ArrayRef> context) { + // Starting with the above seed of acceptable top-level matcher types, compute + // the acceptable type set for the argument indicated by each context element. + std::set typeSet; + typeSet.insert(ArgKind::Matcher); + + for (const auto &ctxEntry : context) { + MatcherCtor ctor = ctxEntry.first; + unsigned argNumber = ctxEntry.second; + std::vector nextTypeSet; + + if (argNumber < ctor->getNumArgs()) + ctor->getArgKinds(argNumber, nextTypeSet); + + typeSet.insert(nextTypeSet.begin(), nextTypeSet.end()); + } + + return std::vector(typeSet.begin(), typeSet.end()); +} + +std::vector +RegistryManager::getMatcherCompletions(llvm::ArrayRef acceptedTypes, + const Registry &matcherRegistry) { + std::vector completions; + + // Search the registry for acceptable matchers. + for (const auto &m : matcherRegistry.constructors()) { + const internal::MatcherDescriptor &matcher = *m.getValue(); + llvm::StringRef name = m.getKey(); + + unsigned numArgs = matcher.getNumArgs(); + std::vector> argKinds(numArgs); + + for (const ArgKind &kind : acceptedTypes) { + if (kind != ArgKind::Matcher) + continue; + + for (unsigned arg = 0; arg != numArgs; ++arg) + matcher.getArgKinds(arg, argKinds[arg]); + } + + std::string decl; + llvm::raw_string_ostream os(decl); + + std::string typedText = std::string(name); + os << "Matcher: " << name << "("; + + for (const std::vector &arg : argKinds) { + if (&arg != &argKinds[0]) + os << ", "; + + bool firstArgKind = true; + // Two steps. First all non-matchers, then matchers only. + for (const ArgKind &argKind : arg) { + if (!firstArgKind) + os << "|"; + + firstArgKind = false; + os << asArgString(argKind); + } + } + + os << ")"; + typedText += "("; + + if (argKinds.empty()) + typedText += ")"; + else if (argKinds[0][0] == ArgKind::String) + typedText += "\""; + + completions.emplace_back(typedText, os.str()); + } + + return completions; +} + +VariantMatcher RegistryManager::constructMatcher( + MatcherCtor ctor, internal::SourceRange nameRange, + llvm::ArrayRef args, internal::Diagnostics *error) { + return ctor->create(nameRange, args, error); +} + +} // namespace mlir::query::matcher diff --git a/mlir/lib/Query/Matcher/VariantValue.cpp b/mlir/lib/Query/Matcher/VariantValue.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Query/Matcher/VariantValue.cpp @@ -0,0 +1,132 @@ +//===--- Variantvalue.cpp -------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// +//===----------------------------------------------------------------------===// + +#include "mlir/Query/Matcher/VariantValue.h" + +namespace mlir::query::matcher { + +VariantMatcher::Payload::~Payload() = default; + +class VariantMatcher::SinglePayload : public VariantMatcher::Payload { +public: + explicit SinglePayload(DynMatcher matcher) : matcher(std::move(matcher)) {} + + std::optional getDynMatcher() const override { return matcher; } + + std::string getTypeAsString() const override { return "Matcher"; } + +private: + DynMatcher matcher; +}; + +VariantMatcher::VariantMatcher() = default; + +VariantMatcher VariantMatcher::SingleMatcher(DynMatcher matcher) { + return VariantMatcher(std::make_shared(std::move(matcher))); +} + +std::optional VariantMatcher::getDynMatcher() const { + return value ? value->getDynMatcher() : std::nullopt; +} + +void VariantMatcher::reset() { value.reset(); } + +std::string VariantMatcher::getTypeAsString() const { return ""; } + +VariantValue::VariantValue(const VariantValue &other) + : type(ValueType::Nothing) { + *this = other; +} + +VariantValue::VariantValue(const llvm::StringRef string) + : type(ValueType::String) { + value.String = new llvm::StringRef(string); +} + +VariantValue::VariantValue(const VariantMatcher &matcher) + : type(ValueType::Matcher) { + value.Matcher = new VariantMatcher(matcher); +} + +VariantValue::~VariantValue() { reset(); } + +VariantValue &VariantValue::operator=(const VariantValue &other) { + if (this == &other) + return *this; + reset(); + switch (other.type) { + case ValueType::String: + setString(other.getString()); + break; + case ValueType::Matcher: + setMatcher(other.getMatcher()); + break; + case ValueType::Nothing: + type = ValueType::Nothing; + break; + } + return *this; +} + +void VariantValue::reset() { + switch (type) { + case ValueType::String: + delete value.String; + break; + case ValueType::Matcher: + delete value.Matcher; + break; + // Cases that do nothing. + case ValueType::Nothing: + break; + } + type = ValueType::Nothing; +} + +bool VariantValue::isString() const { return type == ValueType::String; } + +const llvm::StringRef &VariantValue::getString() const { + assert(isString()); + return *value.String; +} + +void VariantValue::setString(const llvm::StringRef &newValue) { + reset(); + type = ValueType::String; + value.String = new llvm::StringRef(newValue); +} + +bool VariantValue::isMatcher() const { return type == ValueType::Matcher; } + +const VariantMatcher &VariantValue::getMatcher() const { + assert(isMatcher()); + return *value.Matcher; +} + +void VariantValue::setMatcher(const VariantMatcher &newValue) { + reset(); + type = ValueType::Matcher; + value.Matcher = new VariantMatcher(newValue); +} + +std::string VariantValue::getTypeAsString() const { + switch (type) { + case ValueType::String: + return "String"; + case ValueType::Matcher: + return "Matcher"; + case ValueType::Nothing: + return "Nothing"; + } + llvm_unreachable("Invalid Type"); +} + +} // namespace mlir::query::matcher diff --git a/mlir/lib/Query/Query.cpp b/mlir/lib/Query/Query.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Query/Query.cpp @@ -0,0 +1,82 @@ +//===---- Query.cpp - -----------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Query/Query.h" +#include "QueryParser.h" +#include "mlir/Query/Matcher/MatchFinder.h" +#include "mlir/Query/QuerySession.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/raw_ostream.h" + +namespace mlir::query { + +QueryRef parse(llvm::StringRef line, const QuerySession &qs) { + return QueryParser::parse(line, qs); +} + +std::vector +complete(llvm::StringRef line, size_t pos, const QuerySession &qs) { + return QueryParser::complete(line, pos, qs); +} + +static void printMatch(llvm::raw_ostream &os, QuerySession &qs, Operation *op, + const std::string &binding) { + auto fileLoc = op->getLoc()->findInstanceOf(); + auto smloc = qs.getSourceManager().FindLocForLineAndColumn( + qs.getBufferId(), fileLoc.getLine(), fileLoc.getColumn()); + qs.getSourceManager().PrintMessage(os, smloc, llvm::SourceMgr::DK_Note, + "\"" + binding + "\" binds here"); +} + +Query::~Query() = default; + +mlir::LogicalResult InvalidQuery::run(llvm::raw_ostream &os, + QuerySession &qs) const { + os << errStr << "\n"; + return mlir::failure(); +} + +mlir::LogicalResult NoOpQuery::run(llvm::raw_ostream &os, + QuerySession &qs) const { + return mlir::success(); +} + +mlir::LogicalResult HelpQuery::run(llvm::raw_ostream &os, + QuerySession &qs) const { + os << "Available commands:\n\n" + " match MATCHER, m MATCHER " + "Match the mlir against the given matcher.\n" + " quit " + "Terminates the query session.\n\n"; + return mlir::success(); +} + +mlir::LogicalResult QuitQuery::run(llvm::raw_ostream &os, + QuerySession &qs) const { + qs.terminate = true; + return mlir::success(); +} + +mlir::LogicalResult MatchQuery::run(llvm::raw_ostream &os, + QuerySession &qs) const { + int matchCount = 0; + std::vector matches = + matcher::MatchFinder().getMatches(qs.getRootOp(), matcher); + os << "\n"; + for (Operation *op : matches) { + os << "Match #" << ++matchCount << ":\n\n"; + // Placeholder "root" binding for the initial draft. + printMatch(os, qs, op, "root"); + } + os << matchCount << (matchCount == 1 ? " match.\n\n" : " matches.\n\n"); + + return mlir::success(); +} + +} // namespace mlir::query diff --git a/mlir/lib/Query/QueryParser.h b/mlir/lib/Query/QueryParser.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Query/QueryParser.h @@ -0,0 +1,59 @@ +//===--- QueryParser.h - ----------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOOLS_MLIRQUERY_QUERYPARSER_H +#define MLIR_TOOLS_MLIRQUERY_QUERYPARSER_H + +#include "Matcher/Parser.h" +#include "mlir/Query/Query.h" +#include "mlir/Query/QuerySession.h" + +#include "llvm/ADT/StringRef.h" +#include "llvm/LineEditor/LineEditor.h" + +namespace mlir::query { + +class QuerySession; + +class QueryParser { +public: + // Parse line as a query and return a QueryRef representing the query, which + // may be an InvalidQuery. + static QueryRef parse(llvm::StringRef line, const QuerySession &qs); + + static std::vector + complete(llvm::StringRef line, size_t pos, const QuerySession &qs); + +private: + QueryParser(llvm::StringRef line, const QuerySession &qs) + : line(line), completionPos(nullptr), qs(qs) {} + + llvm::StringRef lexWord(); + + template + struct LexOrCompleteWord; + + QueryRef completeMatcherExpression(); + + QueryRef endQuery(QueryRef queryRef); + + // Parse [begin, end) and returns a reference to the parsed query object, + // which may be an InvalidQuery if a parse error occurs. + QueryRef doParse(); + + llvm::StringRef line; + + const char *completionPos; + std::vector completions; + + const QuerySession &qs; +}; + +} // namespace mlir::query + +#endif // MLIR_TOOLS_MLIRQUERY_QUERYPARSER_H diff --git a/mlir/lib/Query/QueryParser.cpp b/mlir/lib/Query/QueryParser.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Query/QueryParser.cpp @@ -0,0 +1,217 @@ +//===---- QueryParser.cpp - mlir-query command parser ---------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "QueryParser.h" +#include "llvm/ADT/StringSwitch.h" + +namespace mlir::query { + +// Lex any amount of whitespace followed by a "word" (any sequence of +// non-whitespace characters) from the start of region [begin,end). If no word +// is found before end, return StringRef(). begin is adjusted to exclude the +// lexed region. +llvm::StringRef QueryParser::lexWord() { + line = line.drop_while([](char c) { + // Don't trim newlines. + return llvm::StringRef(" \t\v\f\r").contains(c); + }); + + if (line.empty()) + // Even though the line is empty, it contains a pointer and + // a (zero) length. The pointer is used in the LexOrCompleteWord + // code completion. + return line; + + llvm::StringRef word; + if (line.front() == '#') { + word = line.substr(0, 1); + } else { + word = line.take_until([](char c) { + // Don't trim newlines. + return llvm::StringRef(" \t\v\f\r").contains(c); + }); + } + + line = line.drop_front(word.size()); + return word; +} + +// This is the StringSwitch-alike used by LexOrCompleteWord below. See that +// function for details. +template +struct QueryParser::LexOrCompleteWord { + llvm::StringRef word; + llvm::StringSwitch stringSwitch; + + QueryParser *queryParser; + // Set to the completion point offset in word, or StringRef::npos if + // completion point not in word. + size_t wordCompletionPos; + + // Lexes a word and stores it in word. Returns a LexOrCompleteword object + // that can be used like a llvm::StringSwitch, but adds cases as possible + // completions if the lexed word contains the completion point. + LexOrCompleteWord(QueryParser *queryParser, llvm::StringRef &outWord) + : word(queryParser->lexWord()), stringSwitch(word), + queryParser(queryParser), wordCompletionPos(llvm::StringRef::npos) { + outWord = word; + if (queryParser->completionPos && + queryParser->completionPos <= word.data() + word.size()) { + if (queryParser->completionPos < word.data()) + wordCompletionPos = 0; + else + wordCompletionPos = queryParser->completionPos - word.data(); + } + } + + LexOrCompleteWord &Case(llvm::StringLiteral caseStr, const T &value, + bool isCompletion = true) { + + if (wordCompletionPos == llvm::StringRef::npos) + stringSwitch.Case(caseStr, value); + else if (!caseStr.empty() && isCompletion && + wordCompletionPos <= caseStr.size() && + caseStr.substr(0, wordCompletionPos) == + word.substr(0, wordCompletionPos)) { + + queryParser->completions.emplace_back( + (caseStr.substr(wordCompletionPos) + " ").str(), + std::string(caseStr)); + } + return *this; + } + + T Default(T value) { return stringSwitch.Default(value); } +}; + +QueryRef QueryParser::endQuery(QueryRef queryRef) { + llvm::StringRef extra = line; + llvm::StringRef extraTrimmed = extra.drop_while( + [](char c) { return llvm::StringRef(" \t\v\f\r").contains(c); }); + + if ((!extraTrimmed.empty() && extraTrimmed[0] == '\n') || + (extraTrimmed.size() >= 2 && extraTrimmed[0] == '\r' && + extraTrimmed[1] == '\n')) + queryRef->remainingContent = extra; + else { + llvm::StringRef trailingWord = lexWord(); + if (!trailingWord.empty() && trailingWord.front() == '#') { + line = line.drop_until([](char c) { return c == '\n'; }); + line = line.drop_while([](char c) { return c == '\n'; }); + return endQuery(queryRef); + } + if (!trailingWord.empty()) { + return new InvalidQuery("unexpected extra input: '" + extra + "'"); + } + } + return queryRef; +} + +namespace { + +enum class ParsedQueryKind { + Invalid, + Comment, + NoOp, + Help, + Match, + Quit, +}; + +QueryRef +makeInvalidQueryFromDiagnostics(const matcher::internal::Diagnostics &diag) { + std::string errStr; + llvm::raw_string_ostream os(errStr); + diag.print(os); + return new InvalidQuery(os.str()); +} +} // namespace + +QueryRef QueryParser::completeMatcherExpression() { + std::vector comps = + matcher::internal::Parser::completeExpression( + line, completionPos - line.begin(), qs.getRegistryData(), + &qs.namedValues); + for (const auto &comp : comps) { + completions.emplace_back(comp.typedText, comp.matcherDecl); + } + return QueryRef(); +} + +QueryRef QueryParser::doParse() { + + llvm::StringRef commandStr; + ParsedQueryKind qKind = + LexOrCompleteWord(this, commandStr) + .Case("", ParsedQueryKind::NoOp) + .Case("#", ParsedQueryKind::Comment, /*isCompletion=*/false) + .Case("help", ParsedQueryKind::Help) + .Case("m", ParsedQueryKind::Match, /*isCompletion=*/false) + .Case("match", ParsedQueryKind::Match) + .Case("q", ParsedQueryKind::Quit, /*IsCompletion=*/false) + .Case("quit", ParsedQueryKind::Quit) + .Default(ParsedQueryKind::Invalid); + + switch (qKind) { + case ParsedQueryKind::Comment: + case ParsedQueryKind::NoOp: + line = line.drop_until([](char c) { return c == '\n'; }); + line = line.drop_while([](char c) { return c == '\n'; }); + if (line.empty()) + return new NoOpQuery; + return doParse(); + + case ParsedQueryKind::Help: + return endQuery(new HelpQuery); + + case ParsedQueryKind::Quit: + return endQuery(new QuitQuery); + + case ParsedQueryKind::Match: { + if (completionPos) { + return completeMatcherExpression(); + } + + matcher::internal::Diagnostics diag; + auto matcherSource = line.ltrim(); + auto origMatcherSource = matcherSource; + std::optional matcher = + matcher::internal::Parser::parseMatcherExpression( + matcherSource, qs.getRegistryData(), &qs.namedValues, &diag); + if (!matcher) { + return makeInvalidQueryFromDiagnostics(diag); + } + auto actualSource = origMatcherSource.slice(0, origMatcherSource.size() - + matcherSource.size()); + QueryRef query = new MatchQuery(actualSource, *matcher); + query->remainingContent = matcherSource; + return query; + } + + case ParsedQueryKind::Invalid: + return new InvalidQuery("unknown command: " + commandStr); + } + + llvm_unreachable("Invalid query kind"); +} + +QueryRef QueryParser::parse(llvm::StringRef line, const QuerySession &qs) { + return QueryParser(line, qs).doParse(); +} + +std::vector +QueryParser::complete(llvm::StringRef line, size_t pos, + const QuerySession &qs) { + QueryParser queryParser(line, qs); + queryParser.completionPos = line.data() + pos; + + queryParser.doParse(); + return queryParser.completions; +} + +} // namespace mlir::query diff --git a/mlir/lib/Tools/CMakeLists.txt b/mlir/lib/Tools/CMakeLists.txt --- a/mlir/lib/Tools/CMakeLists.txt +++ b/mlir/lib/Tools/CMakeLists.txt @@ -2,6 +2,7 @@ add_subdirectory(mlir-lsp-server) add_subdirectory(mlir-opt) add_subdirectory(mlir-pdll-lsp-server) +add_subdirectory(mlir-query) add_subdirectory(mlir-reduce) add_subdirectory(mlir-tblgen) add_subdirectory(mlir-translate) diff --git a/mlir/lib/Tools/mlir-query/CMakeLists.txt b/mlir/lib/Tools/mlir-query/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Tools/mlir-query/CMakeLists.txt @@ -0,0 +1,13 @@ +set(LLVM_LINK_COMPONENTS + lineeditor + ) + +add_mlir_library(MLIRQueryLib + MlirQueryMain.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Tools/mlir-query + + LINK_LIBS PUBLIC + MLIRQuery + ) diff --git a/mlir/lib/Tools/mlir-query/MlirQueryMain.cpp b/mlir/lib/Tools/mlir-query/MlirQueryMain.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Tools/mlir-query/MlirQueryMain.cpp @@ -0,0 +1,115 @@ +//===- MlirQueryMain.cpp - MLIR Query main --------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the general framework of the MLIR query tool. It +// parses the command line arguments, parses the MLIR file and outputs the query +// results. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Tools/mlir-query/MlirQueryMain.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Parser/Parser.h" +#include "mlir/Query/Query.h" +#include "mlir/Query/QuerySession.h" +#include "mlir/Support/FileUtilities.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/LineEditor/LineEditor.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/InitLLVM.h" +#include "llvm/Support/SourceMgr.h" + +//===----------------------------------------------------------------------===// +// Query Parser +//===----------------------------------------------------------------------===// + +mlir::LogicalResult +mlir::mlirQueryMain(int argc, char **argv, MLIRContext &context, + const mlir::query::matcher::Registry &matcherRegistry) { + + // Override the default '-h' and use the default PrintHelpMessage() which + // won't print options in categories. + static llvm::cl::opt help("h", llvm::cl::desc("Alias for -help"), + llvm::cl::Hidden); + + static llvm::cl::OptionCategory mlirQueryCategory("mlir-query options"); + + static llvm::cl::list commands( + "c", llvm::cl::desc("Specify command to run"), + llvm::cl::value_desc("command"), llvm::cl::cat(mlirQueryCategory)); + + static llvm::cl::opt inputFilename( + llvm::cl::Positional, llvm::cl::desc(""), + llvm::cl::cat(mlirQueryCategory)); + + static llvm::cl::opt noImplicitModule{ + "no-implicit-module", + llvm::cl::desc( + "Disable implicit addition of a top-level module op during parsing"), + llvm::cl::init(false)}; + + static llvm::cl::opt allowUnregisteredDialects( + "allow-unregistered-dialect", + llvm::cl::desc("Allow operation with no registered dialects"), + llvm::cl::init(false)); + + llvm::cl::HideUnrelatedOptions(mlirQueryCategory); + + llvm::InitLLVM y(argc, argv); + + llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR test case query tool.\n"); + + if (help) { + llvm::cl::PrintHelpMessage(); + return mlir::success(); + } + + // Set up the input file. + std::string errorMessage; + auto file = openInputFile(inputFilename, &errorMessage); + if (!file) { + llvm::errs() << errorMessage << "\n"; + return mlir::failure(); + } + + auto sourceMgr = llvm::SourceMgr(); + auto bufferId = sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc()); + + context.allowUnregisteredDialects(allowUnregisteredDialects); + + // Parse the input MLIR file. + OwningOpRef opRef = + noImplicitModule ? parseSourceFile(sourceMgr, &context) + : parseSourceFile(sourceMgr, &context); + if (!opRef) + return mlir::failure(); + + mlir::query::QuerySession qs(opRef.get(), sourceMgr, bufferId, + matcherRegistry); + if (!commands.empty()) { + for (auto &command : commands) { + mlir::query::QueryRef queryRef = mlir::query::parse(command, qs); + if (mlir::failed(queryRef->run(llvm::outs(), qs))) + return mlir::failure(); + } + } else { + llvm::LineEditor le("mlir-query"); + le.setListCompleter([&qs](llvm::StringRef line, size_t pos) { + return mlir::query::complete(line, pos, qs); + }); + while (std::optional line = le.readLine()) { + mlir::query::QueryRef queryRef = mlir::query::parse(*line, qs); + (void)queryRef->run(llvm::outs(), qs); + llvm::outs().flush(); + if (qs.terminate) + break; + } + } + + return mlir::success(); +} diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt --- a/mlir/test/CMakeLists.txt +++ b/mlir/test/CMakeLists.txt @@ -104,6 +104,7 @@ mlir-pdll-lsp-server mlir-opt mlir-pdll + mlir-query mlir-reduce mlir-tblgen mlir-translate diff --git a/mlir/test/mlir-query/simple-test.mlir b/mlir/test/mlir-query/simple-test.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-query/simple-test.mlir @@ -0,0 +1,16 @@ +// RUN: mlir-query %s -c "m isConstantOp()" | FileCheck %s + +// CHECK: {{.*}}.mlir:5:13: note: "root" binds here +func.func @simple1() { + %c1_i32 = arith.constant 1 : i32 + return +} + +// CHECK: {{.*}}.mlir:12:11: note: "root" binds here +// CHECK: {{.*}}.mlir:13:11: note: "root" binds here +func.func @simple2() { + %cst1 = arith.constant 1.0 : f32 + %cst2 = arith.constant 2.0 : f32 + %add = arith.addf %cst1, %cst2 : f32 + return +} diff --git a/mlir/tools/CMakeLists.txt b/mlir/tools/CMakeLists.txt --- a/mlir/tools/CMakeLists.txt +++ b/mlir/tools/CMakeLists.txt @@ -2,6 +2,7 @@ add_subdirectory(mlir-opt) add_subdirectory(mlir-parser-fuzzer) add_subdirectory(mlir-pdll-lsp-server) +add_subdirectory(mlir-query) add_subdirectory(mlir-reduce) add_subdirectory(mlir-shlib) add_subdirectory(mlir-spirv-cpu-runner) diff --git a/mlir/tools/mlir-query/CMakeLists.txt b/mlir/tools/mlir-query/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/tools/mlir-query/CMakeLists.txt @@ -0,0 +1,20 @@ +get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) + +if(MLIR_INCLUDE_TESTS) + set(test_libs + MLIRTestDialect + ) +endif() + +add_mlir_tool(mlir-query + mlir-query.cpp + ) +llvm_update_compile_flags(mlir-query) +target_link_libraries(mlir-query + PRIVATE + ${dialect_libs} + ${test_libs} + MLIRQueryLib + ) + +mlir_check_link_libraries(mlir-query) diff --git a/mlir/tools/mlir-query/mlir-query.cpp b/mlir/tools/mlir-query/mlir-query.cpp new file mode 100644 --- /dev/null +++ b/mlir/tools/mlir-query/mlir-query.cpp @@ -0,0 +1,63 @@ +//===- mlir-query.cpp - MLIR Query Driver ---------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is a command line utility that queries a file from/to MLIR using one +// of the registered queries. +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Matchers.h" +#include "mlir/InitAllDialects.h" +#include "mlir/Query/Matcher/Registry.h" +#include "mlir/Tools/mlir-query/MlirQueryMain.h" + +using namespace mlir; + +// This is needed because these matchers are defined as overloaded functions. +using HasOpAttrName = detail::AttrOpMatcher(StringRef); +using HasOpName = detail::NameOpMatcher(StringRef); +using IsConstantOp = detail::constant_op_matcher(); + +namespace test { +#ifdef MLIR_INCLUDE_TESTS +void registerTestDialect(DialectRegistry &); +#endif +} // namespace test + +int main(int argc, char **argv) { + + DialectRegistry dialectRegistry; + registerAllDialects(dialectRegistry); + + query::matcher::Registry matcherRegistry; + + // Matchers registered in alphabetical order for consistency: + matcherRegistry.registerMatcher("hasOpAttrName", + static_cast(m_Attr)); + matcherRegistry.registerMatcher("hasOpName", static_cast(m_Op)); + matcherRegistry.registerMatcher("isConstantOp", + static_cast(m_Constant)); + matcherRegistry.registerMatcher("isNegInfFloat", m_NegInfFloat); + matcherRegistry.registerMatcher("isNegZeroFloat", m_NegZeroFloat); + matcherRegistry.registerMatcher("isNonZero", m_NonZero); + matcherRegistry.registerMatcher("isOne", m_One); + matcherRegistry.registerMatcher("isOneFloat", m_OneFloat); + matcherRegistry.registerMatcher("isPosInfFloat", m_PosInfFloat); + matcherRegistry.registerMatcher("isPosZeroFloat", m_PosZeroFloat); + matcherRegistry.registerMatcher("isZero", m_Zero); + matcherRegistry.registerMatcher("isZeroFloat", m_AnyZeroFloat); + +#ifdef MLIR_INCLUDE_TESTS + test::registerTestDialect(dialectRegistry); +#endif + MLIRContext context(dialectRegistry); + + return failed(mlirQueryMain(argc, argv, context, matcherRegistry)); +}