diff --git a/llvm/include/llvm/Support/SourceMgr.h b/llvm/include/llvm/Support/SourceMgr.h --- a/llvm/include/llvm/Support/SourceMgr.h +++ b/llvm/include/llvm/Support/SourceMgr.h @@ -100,6 +100,9 @@ SourceMgr &operator=(SourceMgr &&) = default; ~SourceMgr() = default; + /// Return the include directories of this source manager. + ArrayRef getIncludeDirs() const { return IncludeDirectories; } + void setIncludeDirs(const std::vector &Dirs) { IncludeDirectories = Dirs; } @@ -147,6 +150,14 @@ return Buffers.size(); } + /// Takes the source buffers from the given source manager and append them to + /// the current manager. + void takeSourceBuffersFrom(SourceMgr &SrcMgr) { + std::move(SrcMgr.Buffers.begin(), SrcMgr.Buffers.end(), + std::back_inserter(Buffers)); + SrcMgr.Buffers.clear(); + } + /// Search for a file with the specified name in the current directory or in /// one of the IncludeDirs. /// @@ -156,6 +167,17 @@ unsigned AddIncludeFile(const std::string &Filename, SMLoc IncludeLoc, std::string &IncludedFile); + /// Search for a file with the specified name in the current directory or in + /// one of the IncludeDirs, and try to open it **without** adding to the + /// SourceMgr. If the opened file is intended to be added to the source + /// manager, prefer `AddIncludeFile` instead. + /// + /// If no file is found, this returns an Error, otherwise it returns the + /// buffer of the stacked file. The full path to the included file can be + /// found in \p IncludedFile. + ErrorOr> + OpenIncludeFile(const std::string &Filename, std::string &IncludedFile); + /// Return the ID of the buffer containing the specified location. /// /// 0 is returned if the buffer is not found. diff --git a/llvm/lib/Support/SourceMgr.cpp b/llvm/lib/Support/SourceMgr.cpp --- a/llvm/lib/Support/SourceMgr.cpp +++ b/llvm/lib/Support/SourceMgr.cpp @@ -40,6 +40,17 @@ unsigned SourceMgr::AddIncludeFile(const std::string &Filename, SMLoc IncludeLoc, std::string &IncludedFile) { + ErrorOr> NewBufOrErr = + OpenIncludeFile(Filename, IncludedFile); + if (!NewBufOrErr) + return 0; + + return AddNewSourceBuffer(std::move(*NewBufOrErr), IncludeLoc); +} + +ErrorOr> +SourceMgr::OpenIncludeFile(const std::string &Filename, + std::string &IncludedFile) { IncludedFile = Filename; ErrorOr> NewBufOrErr = MemoryBuffer::getFile(IncludedFile); @@ -52,10 +63,7 @@ NewBufOrErr = MemoryBuffer::getFile(IncludedFile); } - if (!NewBufOrErr) - return 0; - - return AddNewSourceBuffer(std::move(*NewBufOrErr), IncludeLoc); + return NewBufOrErr; } unsigned SourceMgr::FindBufferContainingLoc(SMLoc Loc) const { diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -363,7 +363,8 @@ // A variadic type constraint. It expands to zero or more of the base type. This // class is used for supporting variadic operands/results. -class Variadic : TypeConstraint { +class Variadic : TypeConstraint { Type baseType = type; } @@ -379,7 +380,8 @@ // An optional type constraint. It expands to either zero or one of the base // type. This class is used for supporting optional operands/results. -class Optional : TypeConstraint { +class Optional : TypeConstraint { Type baseType = type; } diff --git a/mlir/include/mlir/TableGen/Constraint.h b/mlir/include/mlir/TableGen/Constraint.h --- a/mlir/include/mlir/TableGen/Constraint.h +++ b/mlir/include/mlir/TableGen/Constraint.h @@ -54,6 +54,11 @@ // description is not provided, returns the TableGen def name. StringRef getSummary() const; + /// Returns the name of the TablGen def of this constraint. In some cases + /// where the current def is anonymous, the name of the base def is used (e.g. + /// `Optional<>`/`Variadic<>` type constraints). + StringRef getDefName() const; + Kind getKind() const { return kind; } protected: diff --git a/mlir/include/mlir/Tools/PDLL/AST/Context.h b/mlir/include/mlir/Tools/PDLL/AST/Context.h --- a/mlir/include/mlir/Tools/PDLL/AST/Context.h +++ b/mlir/include/mlir/Tools/PDLL/AST/Context.h @@ -14,13 +14,17 @@ namespace mlir { namespace pdll { +namespace ods { +class Context; +} // namespace ods + namespace ast { /// This class represents the main context of the PDLL AST. It handles /// allocating all of the AST constructs, and manages all state necessary for /// the AST. class Context { public: - Context(); + explicit Context(ods::Context &odsContext); Context(const Context &) = delete; Context &operator=(const Context &) = delete; @@ -30,6 +34,10 @@ /// Return the storage uniquer used for AST types. StorageUniquer &getTypeUniquer() { return typeUniquer; } + /// Return the ODS context used by the AST. + ods::Context &getODSContext() { return odsContext; } + const ods::Context &getODSContext() const { return odsContext; } + /// Return the diagnostic engine of this context. DiagnosticEngine &getDiagEngine() { return diagEngine; } @@ -37,6 +45,9 @@ /// The diagnostic engine of this AST context. DiagnosticEngine diagEngine; + /// The ODS context used by the AST. + ods::Context &odsContext; + /// The allocator used for AST nodes, and other entities allocated within the /// context. llvm::BumpPtrAllocator allocator; diff --git a/mlir/include/mlir/Tools/PDLL/ODS/Constraint.h b/mlir/include/mlir/Tools/PDLL/ODS/Constraint.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Tools/PDLL/ODS/Constraint.h @@ -0,0 +1,98 @@ +//===- Constraint.h - MLIR PDLL ODS Constraints -----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains a PDLL description of ODS constraints. These are used to +// support the import of constraints defined outside of PDLL. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOOLS_PDLL_ODS_CONSTRAINT_H_ +#define MLIR_TOOLS_PDLL_ODS_CONSTRAINT_H_ + +#include + +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringMap.h" + +namespace mlir { +namespace pdll { +namespace ods { + +//===----------------------------------------------------------------------===// +// Constraint +//===----------------------------------------------------------------------===// + +/// This class represents a generic ODS constraint. +class Constraint { +public: + /// Return the name of this constraint. + StringRef getName() const { return name; } + + /// Return the summary of this constraint. + StringRef getSummary() const { return summary; } + +protected: + Constraint(StringRef name, StringRef summary) + : name(name.str()), summary(summary.str()) {} + Constraint(const Constraint &) = delete; + +private: + /// The name of the constraint. + std::string name; + /// A summary of the constraint. + std::string summary; +}; + +//===----------------------------------------------------------------------===// +// AttributeConstraint +//===----------------------------------------------------------------------===// + +/// This class represents a generic ODS Attribute constraint. +class AttributeConstraint : public Constraint { +public: + /// Return the name of the underlying c++ class of this constraint. + StringRef getCppClass() const { return cppClassName; } + +private: + AttributeConstraint(StringRef name, StringRef summary, StringRef cppClassName) + : Constraint(name, summary), cppClassName(cppClassName.str()) {} + + /// The c++ class of the constraint. + std::string cppClassName; + + /// Allow access to the constructor. + friend class Context; +}; + +//===----------------------------------------------------------------------===// +// TypeConstraint +//===----------------------------------------------------------------------===// + +/// This class represents a generic ODS Type constraint. +class TypeConstraint : public Constraint { +public: + /// Return the name of the underlying c++ class of this constraint. + StringRef getCppClass() const { return cppClassName; } + +private: + TypeConstraint(StringRef name, StringRef summary, StringRef cppClassName) + : Constraint(name, summary), cppClassName(cppClassName.str()) {} + + /// The c++ class of the constraint. + std::string cppClassName; + + /// Allow access to the constructor. + friend class Context; +}; + +} // namespace ods +} // namespace pdll +} // namespace mlir + +#endif // MLIR_TOOLS_PDLL_ODS_CONSTRAINT_H_ diff --git a/mlir/include/mlir/Tools/PDLL/ODS/Context.h b/mlir/include/mlir/Tools/PDLL/ODS/Context.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Tools/PDLL/ODS/Context.h @@ -0,0 +1,78 @@ +//===- Context.h - MLIR PDLL ODS Context ------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOOLS_PDLL_ODS_CONTEXT_H_ +#define MLIR_TOOLS_PDLL_ODS_CONTEXT_H_ + +#include + +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringMap.h" + +namespace llvm { +class SMLoc; +} // namespace llvm + +namespace mlir { +namespace pdll { +namespace ods { +class AttributeConstraint; +class Dialect; +class Operation; +class TypeConstraint; + +/// This class contains all of the registered ODS operation classes. +class Context { +public: + Context(); + ~Context(); + + /// Insert a new attribute constraint with the context. Returns the inserted + /// constraint, or a previously inserted constraint with the same name. + const AttributeConstraint &insertAttributeConstraint(StringRef name, + StringRef summary, + StringRef cppClass); + + /// Insert a new type constraint with the context. Returns the inserted + /// constraint, or a previously inserted constraint with the same name. + const TypeConstraint &insertTypeConstraint(StringRef name, StringRef summary, + StringRef cppClass); + + /// Insert a new dialect with the context. Returns the inserted dialect, or a + /// previously inserted dialect with the same name. + Dialect &insertDialect(StringRef name); + + /// Lookup a dialect registered with the given name, or null if no dialect + /// with that name was inserted. + const Dialect *lookupDialect(StringRef name) const; + + /// Insert a new operation with the context. Returns the inserted operation, + /// and a boolean indicating if the operation newly inserted (false if the + /// operation already existed). + std::pair + insertOperation(StringRef name, StringRef summary, StringRef desc, SMLoc loc); + + /// Lookup an operation registered with the given name, or null if no + /// operation with that name is registered. + const Operation *lookupOperation(StringRef name) const; + + /// Print the contents of this context to the provided stream. + void print(raw_ostream &os) const; + +private: + llvm::StringMap> attributeConstraints; + llvm::StringMap> dialects; + llvm::StringMap> typeConstraints; +}; +} // namespace ods +} // namespace pdll +} // namespace mlir + +#endif // MLIR_PDL_pdll_ODS_CONTEXT_H_ diff --git a/mlir/include/mlir/Tools/PDLL/ODS/Dialect.h b/mlir/include/mlir/Tools/PDLL/ODS/Dialect.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Tools/PDLL/ODS/Dialect.h @@ -0,0 +1,64 @@ +//===- Dialect.h - PDLL ODS Dialect -----------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOOLS_PDLL_ODS_DIALECT_H_ +#define MLIR_TOOLS_PDLL_ODS_DIALECT_H_ + +#include + +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringMap.h" + +namespace mlir { +namespace pdll { +namespace ods { +class Operation; + +/// This class represents an ODS dialect, and contains information on the +/// constructs held within the dialect. +class Dialect { +public: + ~Dialect(); + + /// Return the name of this dialect. + StringRef getName() const { return name; } + + /// Insert a new operation with the dialect. Returns the inserted operation, + /// and a boolean indicating if the operation newly inserted (false if the + /// operation already existed). + std::pair + insertOperation(StringRef name, StringRef summary, StringRef desc, SMLoc loc); + + /// Lookup an operation registered with the given name, or null if no + /// operation with that name is registered. + Operation *lookupOperation(StringRef name) const; + + /// Return a map of all of the operations registered to this dialect. + const llvm::StringMap> &getOperations() const { + return operations; + } + +private: + explicit Dialect(StringRef name); + + /// The name of the dialect. + std::string name; + + /// The operations defined by the dialect. + llvm::StringMap> operations; + + /// Allow access to the constructor. + friend class Context; +}; +} // namespace ods +} // namespace pdll +} // namespace mlir + +#endif // MLIR_TOOLS_PDLL_ODS_DIALECT_H_ diff --git a/mlir/include/mlir/Tools/PDLL/ODS/Operation.h b/mlir/include/mlir/Tools/PDLL/ODS/Operation.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Tools/PDLL/ODS/Operation.h @@ -0,0 +1,189 @@ +//===- Operation.h - MLIR PDLL ODS Operation --------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOOLS_PDLL_ODS_OPERATION_H_ +#define MLIR_TOOLS_PDLL_ODS_OPERATION_H_ + +#include + +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/SMLoc.h" + +namespace mlir { +namespace pdll { +namespace ods { +class AttributeConstraint; +class TypeConstraint; + +//===----------------------------------------------------------------------===// +// VariableLengthKind +//===----------------------------------------------------------------------===// + +enum VariableLengthKind { Single, Optional, Variadic }; + +//===----------------------------------------------------------------------===// +// Attribute +//===----------------------------------------------------------------------===// + +/// This class provides an ODS representation of a specific operation attribute. +/// This includes the name, optionality, and more. +class Attribute { +public: + /// Return the name of this operand. + StringRef getName() const { return name; } + + /// Return true if this attribute is optional. + bool isOptional() const { return optional; } + + /// Return the constraint of this attribute. + const AttributeConstraint &getConstraint() const { return constraint; } + +private: + Attribute(StringRef name, bool optional, + const AttributeConstraint &constraint) + : name(name.str()), optional(optional), constraint(constraint) {} + + /// The ODS name of the attribute. + std::string name; + + /// A flag indicating if the attribute is optional. + bool optional; + + /// The ODS constraint of this attribute. + const AttributeConstraint &constraint; + + /// Allow access to the private constructor. + friend class Operation; +}; + +//===----------------------------------------------------------------------===// +// OperandOrResult +//===----------------------------------------------------------------------===// + +/// This class provides an ODS representation of a specific operation operand or +/// result. This includes the name, variable length flags, and more. +class OperandOrResult { +public: + /// Return the name of this value. + StringRef getName() const { return name; } + + /// Returns true if this value is variadic (Note this is false if the value is + /// Optional). + bool isVariadic() const { + return variableLengthKind == VariableLengthKind::Variadic; + } + + /// Returns the variable length kind of this value. + VariableLengthKind getVariableLengthKind() const { + return variableLengthKind; + } + + /// Return the constraint of this value. + const TypeConstraint &getConstraint() const { return constraint; } + +private: + OperandOrResult(StringRef name, VariableLengthKind variableLengthKind, + const TypeConstraint &constraint) + : name(name.str()), variableLengthKind(variableLengthKind), + constraint(constraint) {} + + /// The ODS name of this value. + std::string name; + + /// The variable length kind of this value. + VariableLengthKind variableLengthKind; + + /// The ODS constraint of this value. + const TypeConstraint &constraint; + + /// Allow access to the private constructor. + friend class Operation; +}; + +//===----------------------------------------------------------------------===// +// Operation +//===----------------------------------------------------------------------===// + +/// This class provides an ODS representation of a specific operation. This +/// includes all of the information necessary for use by the PDL frontend for +/// generating code for a pattern rewrite. +class Operation { +public: + /// Return the source location of this operation. + SMRange getLoc() const { return location; } + + /// Append an attribute to this operation. + void appendAttribute(StringRef name, bool optional, + const AttributeConstraint &constraint) { + attributes.emplace_back(Attribute(name, optional, constraint)); + } + + /// Append an operand to this operation. + void appendOperand(StringRef name, VariableLengthKind variableLengthKind, + const TypeConstraint &constraint) { + operands.emplace_back( + OperandOrResult(name, variableLengthKind, constraint)); + } + + /// Append a result to this operation. + void appendResult(StringRef name, VariableLengthKind variableLengthKind, + const TypeConstraint &constraint) { + results.emplace_back(OperandOrResult(name, variableLengthKind, constraint)); + } + + /// Returns the name of the operation. + StringRef getName() const { return name; } + + /// Returns the summary of the operation. + StringRef getSummary() const { return summary; } + + /// Returns the description of the operation. + StringRef getDescription() const { return description; } + + /// Returns the attributes of this operation. + ArrayRef getAttributes() const { return attributes; } + + /// Returns the operands of this operation. + ArrayRef getOperands() const { return operands; } + + /// Returns the results of this operation. + ArrayRef getResults() const { return results; } + +private: + Operation(StringRef name, StringRef summary, StringRef desc, SMLoc loc); + + /// The name of the operation. + std::string name; + + /// The documentation of the operation. + std::string summary; + std::string description; + + /// The source location of this operation. + SMRange location; + + /// The operands of the operation. + SmallVector operands; + + /// The results of the operation. + SmallVector results; + + /// The attributes of the operation. + SmallVector attributes; + + /// Allow access to the private constructor. + friend class Dialect; +}; +} // namespace ods +} // namespace pdll +} // namespace mlir + +#endif // MLIR_TOOLS_PDLL_ODS_OPERATION_H_ diff --git a/mlir/lib/TableGen/Constraint.cpp b/mlir/lib/TableGen/Constraint.cpp --- a/mlir/lib/TableGen/Constraint.cpp +++ b/mlir/lib/TableGen/Constraint.cpp @@ -57,6 +57,29 @@ return def->getName(); } +StringRef Constraint::getDefName() const { + // Functor used to check a base def in the case where the current def is + // anonymous. + auto checkBaseDefFn = [&](StringRef baseName) { + if (const auto *init = dyn_cast(def->getValueInit(baseName))) + return Constraint(init->getDef(), kind).getDefName(); + return def->getName(); + }; + + switch (kind) { + case CK_Attr: + if (def->isAnonymous()) + return checkBaseDefFn("baseAttr"); + return def->getName(); + case CK_Type: + if (def->isAnonymous()) + return checkBaseDefFn("baseType"); + return def->getName(); + default: + return def->getName(); + } +} + AppliedConstraint::AppliedConstraint(Constraint &&constraint, llvm::StringRef self, std::vector &&entities) diff --git a/mlir/lib/Tools/PDLL/AST/CMakeLists.txt b/mlir/lib/Tools/PDLL/AST/CMakeLists.txt --- a/mlir/lib/Tools/PDLL/AST/CMakeLists.txt +++ b/mlir/lib/Tools/PDLL/AST/CMakeLists.txt @@ -6,5 +6,6 @@ Types.cpp LINK_LIBS PUBLIC + MLIRPDLLODS MLIRSupport ) diff --git a/mlir/lib/Tools/PDLL/AST/Context.cpp b/mlir/lib/Tools/PDLL/AST/Context.cpp --- a/mlir/lib/Tools/PDLL/AST/Context.cpp +++ b/mlir/lib/Tools/PDLL/AST/Context.cpp @@ -12,7 +12,7 @@ using namespace mlir; using namespace mlir::pdll::ast; -Context::Context() { +Context::Context(ods::Context &odsContext) : odsContext(odsContext) { typeUniquer.registerSingletonStorageType(); typeUniquer.registerSingletonStorageType(); typeUniquer.registerSingletonStorageType(); diff --git a/mlir/lib/Tools/PDLL/CMakeLists.txt b/mlir/lib/Tools/PDLL/CMakeLists.txt --- a/mlir/lib/Tools/PDLL/CMakeLists.txt +++ b/mlir/lib/Tools/PDLL/CMakeLists.txt @@ -1,3 +1,4 @@ add_subdirectory(AST) add_subdirectory(CodeGen) +add_subdirectory(ODS) add_subdirectory(Parser) diff --git a/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp b/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp --- a/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp +++ b/mlir/lib/Tools/PDLL/CodeGen/MLIRGen.cpp @@ -17,6 +17,8 @@ #include "mlir/Tools/PDLL/AST/Context.h" #include "mlir/Tools/PDLL/AST/Nodes.h" #include "mlir/Tools/PDLL/AST/Types.h" +#include "mlir/Tools/PDLL/ODS/Context.h" +#include "mlir/Tools/PDLL/ODS/Operation.h" #include "llvm/ADT/ScopedHashTable.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/TypeSwitch.h" @@ -33,7 +35,8 @@ public: CodeGen(MLIRContext *mlirContext, const ast::Context &context, const llvm::SourceMgr &sourceMgr) - : builder(mlirContext), sourceMgr(sourceMgr) { + : builder(mlirContext), odsContext(context.getODSContext()), + sourceMgr(sourceMgr) { // Make sure that the PDL dialect is loaded. mlirContext->loadDialect(); } @@ -117,6 +120,9 @@ llvm::ScopedHashTable>; VariableMapTy variables; + /// A reference to the ODS context. + const ods::Context &odsContext; + /// The source manager of the PDLL ast. const llvm::SourceMgr &sourceMgr; }; @@ -435,7 +441,28 @@ builder.getI32IntegerAttr(0)); return builder.create(loc, mlirType, parentExprs[0]); } - llvm_unreachable("unhandled operation member access expression"); + + assert(opType.getName() && "expected valid operation name"); + const ods::Operation *odsOp = odsContext.lookupOperation(*opType.getName()); + assert(odsOp && "expected valid ODS operation information"); + + // Find the result with the member name or by index. + ArrayRef results = odsOp->getResults(); + unsigned resultIndex = results.size(); + if (llvm::isDigit(name[0])) { + name.getAsInteger(/*Radix=*/10, resultIndex); + } else { + auto findFn = [&](const ods::OperandOrResult &result) { + return result.getName() == name; + }; + resultIndex = llvm::find_if(results, findFn) - results.begin(); + } + assert(resultIndex < results.size() && "invalid result index"); + + // Generate the result access. + IntegerAttr index = builder.getI32IntegerAttr(resultIndex); + return builder.create(loc, genType(expr->getType()), + parentExprs[0], index); } // Handle tuple based member access. diff --git a/mlir/lib/Tools/PDLL/ODS/CMakeLists.txt b/mlir/lib/Tools/PDLL/ODS/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Tools/PDLL/ODS/CMakeLists.txt @@ -0,0 +1,8 @@ +add_mlir_library(MLIRPDLLODS + Context.cpp + Dialect.cpp + Operation.cpp + + LINK_LIBS PUBLIC + MLIRSupport + ) diff --git a/mlir/lib/Tools/PDLL/ODS/Context.cpp b/mlir/lib/Tools/PDLL/ODS/Context.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Tools/PDLL/ODS/Context.cpp @@ -0,0 +1,174 @@ +//===- Context.cpp --------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Tools/PDLL/ODS/Context.h" +#include "mlir/Tools/PDLL/ODS/Constraint.h" +#include "mlir/Tools/PDLL/ODS/Dialect.h" +#include "mlir/Tools/PDLL/ODS/Operation.h" +#include "llvm/Support/ScopedPrinter.h" +#include "llvm/Support/raw_ostream.h" + +using namespace mlir; +using namespace mlir::pdll::ods; + +//===----------------------------------------------------------------------===// +// Context +//===----------------------------------------------------------------------===// + +Context::Context() = default; +Context::~Context() = default; + +const AttributeConstraint & +Context::insertAttributeConstraint(StringRef name, StringRef summary, + StringRef cppClass) { + std::unique_ptr &constraint = attributeConstraints[name]; + if (!constraint) { + constraint.reset(new AttributeConstraint(name, summary, cppClass)); + } else { + assert(constraint->getCppClass() == cppClass && + constraint->getSummary() == summary && + "constraint with the same name was already registered with a " + "different class"); + } + return *constraint; +} + +const TypeConstraint &Context::insertTypeConstraint(StringRef name, + StringRef summary, + StringRef cppClass) { + std::unique_ptr &constraint = typeConstraints[name]; + if (!constraint) + constraint.reset(new TypeConstraint(name, summary, cppClass)); + return *constraint; +} + +Dialect &Context::insertDialect(StringRef name) { + std::unique_ptr &dialect = dialects[name]; + if (!dialect) + dialect.reset(new Dialect(name)); + return *dialect; +} + +const Dialect *Context::lookupDialect(StringRef name) const { + auto it = dialects.find(name); + return it == dialects.end() ? nullptr : &*it->second; +} + +std::pair Context::insertOperation(StringRef name, + StringRef summary, + StringRef desc, + SMLoc loc) { + std::pair dialectAndName = name.split('.'); + return insertDialect(dialectAndName.first) + .insertOperation(name, summary, desc, loc); +} + +const Operation *Context::lookupOperation(StringRef name) const { + std::pair dialectAndName = name.split('.'); + if (const Dialect *dialect = lookupDialect(dialectAndName.first)) + return dialect->lookupOperation(name); + return nullptr; +} + +template +SmallVector sortMapByName(const llvm::StringMap> &map) { + SmallVector storage; + for (auto &entry : map) + storage.push_back(entry.second.get()); + llvm::sort(storage, [](const auto &lhs, const auto &rhs) { + return lhs->getName() < rhs->getName(); + }); + return storage; +} + +void Context::print(raw_ostream &os) const { + auto printVariableLengthCst = [&](StringRef cst, VariableLengthKind kind) { + switch (kind) { + case VariableLengthKind::Optional: + os << "Optional<" << cst << ">"; + break; + case VariableLengthKind::Single: + os << cst; + break; + case VariableLengthKind::Variadic: + os << "Variadic<" << cst << ">"; + break; + } + }; + + llvm::ScopedPrinter printer(os); + llvm::DictScope odsScope(printer, "ODSContext"); + for (const Dialect *dialect : sortMapByName(dialects)) { + printer.startLine() << "Dialect `" << dialect->getName() << "` {\n"; + printer.indent(); + + for (const Operation *op : sortMapByName(dialect->getOperations())) { + printer.startLine() << "Operation `" << op->getName() << "` {\n"; + printer.indent(); + + // Attributes. + ArrayRef attributes = op->getAttributes(); + if (!attributes.empty()) { + printer.startLine() << "Attributes { "; + llvm::interleaveComma(attributes, os, [&](const Attribute &attr) { + os << attr.getName() << " : "; + + auto kind = attr.isOptional() ? VariableLengthKind::Optional + : VariableLengthKind::Single; + printVariableLengthCst(attr.getConstraint().getName(), kind); + }); + os << " }\n"; + } + + // Operands. + ArrayRef operands = op->getOperands(); + if (!operands.empty()) { + printer.startLine() << "Operands { "; + llvm::interleaveComma( + operands, os, [&](const OperandOrResult &operand) { + os << operand.getName() << " : "; + printVariableLengthCst(operand.getConstraint().getName(), + operand.getVariableLengthKind()); + }); + os << " }\n"; + } + + // Results. + ArrayRef results = op->getResults(); + if (!results.empty()) { + printer.startLine() << "Results { "; + llvm::interleaveComma(results, os, [&](const OperandOrResult &result) { + os << result.getName() << " : "; + printVariableLengthCst(result.getConstraint().getName(), + result.getVariableLengthKind()); + }); + os << " }\n"; + } + + printer.objectEnd(); + } + printer.objectEnd(); + } + for (const AttributeConstraint *cst : sortMapByName(attributeConstraints)) { + printer.startLine() << "AttributeConstraint `" << cst->getName() << "` {\n"; + printer.indent(); + + printer.startLine() << "Summary: " << cst->getSummary() << "\n"; + printer.startLine() << "CppClass: " << cst->getCppClass() << "\n"; + printer.objectEnd(); + } + for (const TypeConstraint *cst : sortMapByName(typeConstraints)) { + printer.startLine() << "TypeConstraint `" << cst->getName() << "` {\n"; + printer.indent(); + + printer.startLine() << "Summary: " << cst->getSummary() << "\n"; + printer.startLine() << "CppClass: " << cst->getCppClass() << "\n"; + printer.objectEnd(); + } + printer.objectEnd(); +} diff --git a/mlir/lib/Tools/PDLL/ODS/Dialect.cpp b/mlir/lib/Tools/PDLL/ODS/Dialect.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Tools/PDLL/ODS/Dialect.cpp @@ -0,0 +1,39 @@ +//===- Dialect.cpp --------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Tools/PDLL/ODS/Dialect.h" +#include "mlir/Tools/PDLL/ODS/Constraint.h" +#include "mlir/Tools/PDLL/ODS/Operation.h" +#include "llvm/Support/raw_ostream.h" + +using namespace mlir; +using namespace mlir::pdll::ods; + +//===----------------------------------------------------------------------===// +// Dialect +//===----------------------------------------------------------------------===// + +Dialect::Dialect(StringRef name) : name(name.str()) {} +Dialect::~Dialect() = default; + +std::pair Dialect::insertOperation(StringRef name, + StringRef summary, + StringRef desc, + llvm::SMLoc loc) { + std::unique_ptr &operation = operations[name]; + if (operation) + return std::make_pair(&*operation, /*wasInserted*/ false); + + operation.reset(new Operation(name, summary, desc, loc)); + return std::make_pair(&*operation, /*wasInserted*/ true); +} + +Operation *Dialect::lookupOperation(StringRef name) const { + auto it = operations.find(name); + return it != operations.end() ? it->second.get() : nullptr; +} diff --git a/mlir/lib/Tools/PDLL/ODS/Operation.cpp b/mlir/lib/Tools/PDLL/ODS/Operation.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Tools/PDLL/ODS/Operation.cpp @@ -0,0 +1,26 @@ +//===- Operation.cpp ------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Tools/PDLL/ODS/Operation.h" +#include "mlir/Support/IndentedOstream.h" +#include "llvm/Support/raw_ostream.h" + +using namespace mlir; +using namespace mlir::pdll::ods; + +//===----------------------------------------------------------------------===// +// Operation +//===----------------------------------------------------------------------===// + +Operation::Operation(StringRef name, StringRef summary, StringRef desc, + llvm::SMLoc loc) + : name(name.str()), summary(summary.str()), + location(loc, llvm::SMLoc::getFromPointer(loc.getPointer() + 1)) { + llvm::raw_string_ostream descOS(description); + raw_indented_ostream(descOS).printReindented(desc.rtrim(" \t")); +} diff --git a/mlir/lib/Tools/PDLL/Parser/CMakeLists.txt b/mlir/lib/Tools/PDLL/Parser/CMakeLists.txt --- a/mlir/lib/Tools/PDLL/Parser/CMakeLists.txt +++ b/mlir/lib/Tools/PDLL/Parser/CMakeLists.txt @@ -1,3 +1,8 @@ +set(LLVM_LINK_COMPONENTS + Support + TableGen +) + add_mlir_library(MLIRPDLLParser Lexer.cpp Parser.cpp @@ -5,4 +10,5 @@ LINK_LIBS PUBLIC MLIRPDLLAST MLIRSupport + MLIRTableGen ) diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp --- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp +++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp @@ -9,15 +9,26 @@ #include "mlir/Tools/PDLL/Parser/Parser.h" #include "Lexer.h" #include "mlir/Support/LogicalResult.h" +#include "mlir/TableGen/Argument.h" +#include "mlir/TableGen/Attribute.h" +#include "mlir/TableGen/Constraint.h" +#include "mlir/TableGen/Format.h" +#include "mlir/TableGen/Operator.h" #include "mlir/Tools/PDLL/AST/Context.h" #include "mlir/Tools/PDLL/AST/Diagnostic.h" #include "mlir/Tools/PDLL/AST/Nodes.h" #include "mlir/Tools/PDLL/AST/Types.h" +#include "mlir/Tools/PDLL/ODS/Constraint.h" +#include "mlir/Tools/PDLL/ODS/Context.h" +#include "mlir/Tools/PDLL/ODS/Operation.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/ManagedStatic.h" #include "llvm/Support/SaveAndRestore.h" #include "llvm/Support/ScopedPrinter.h" +#include "llvm/TableGen/Error.h" +#include "llvm/TableGen/Parser.h" #include using namespace mlir; @@ -36,7 +47,8 @@ valueTy(ast::ValueType::get(ctx)), valueRangeTy(ast::ValueRangeType::get(ctx)), typeTy(ast::TypeType::get(ctx)), - typeRangeTy(ast::TypeRangeType::get(ctx)) {} + typeRangeTy(ast::TypeRangeType::get(ctx)), + attrTy(ast::AttributeType::get(ctx)) {} /// Try to parse a new module. Returns nullptr in the case of failure. FailureOr parseModule(); @@ -78,7 +90,7 @@ void popDeclScope() { curDeclScope = curDeclScope->getParentScope(); } /// Parse the body of an AST module. - LogicalResult parseModuleBody(SmallVector &decls); + LogicalResult parseModuleBody(SmallVectorImpl &decls); /// Try to convert the given expression to `type`. Returns failure and emits /// an error if a conversion is not viable. On failure, `noteAttachFn` is @@ -92,11 +104,34 @@ /// typed expression. ast::Expr *convertOpToValue(const ast::Expr *opExpr); + /// Lookup ODS information for the given operation, returns nullptr if no + /// information is found. + const ods::Operation *lookupODSOperation(Optional opName) { + return opName ? ctx.getODSContext().lookupOperation(*opName) : nullptr; + } + //===--------------------------------------------------------------------===// // Directives - LogicalResult parseDirective(SmallVector &decls); - LogicalResult parseInclude(SmallVector &decls); + LogicalResult parseDirective(SmallVectorImpl &decls); + LogicalResult parseInclude(SmallVectorImpl &decls); + LogicalResult parseTdInclude(StringRef filename, SMRange fileLoc, + SmallVectorImpl &decls); + + /// Process the records of a parsed tablegen include file. + void processTdIncludeRecords(llvm::RecordKeeper &tdRecords, + SmallVectorImpl &decls); + + /// Create a user defined native constraint for a constraint imported from + /// ODS. + template + ast::Decl *createODSNativePDLLConstraintDecl(StringRef name, + StringRef codeBlock, SMRange loc, + ast::Type type); + template + ast::Decl * + createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint, + SMRange loc, ast::Type type); //===--------------------------------------------------------------------===// // Decls @@ -340,13 +375,16 @@ MutableArrayRef results); LogicalResult validateOperationOperands(SMRange loc, Optional name, + const ods::Operation *odsOp, MutableArrayRef operands); LogicalResult validateOperationResults(SMRange loc, Optional name, + const ods::Operation *odsOp, MutableArrayRef results); - LogicalResult - validateOperationOperandsOrResults(SMRange loc, Optional name, - MutableArrayRef values, - ast::Type singleTy, ast::Type rangeTy); + LogicalResult validateOperationOperandsOrResults( + StringRef groupName, SMRange loc, Optional odsOpLoc, + Optional name, MutableArrayRef values, + ArrayRef odsValues, ast::Type singleTy, + ast::Type rangeTy); FailureOr createTupleExpr(SMRange loc, ArrayRef elements, ArrayRef elementNames); @@ -440,6 +478,7 @@ /// Cached types to simplify verification and expression creation. ast::Type valueTy, valueRangeTy; ast::Type typeTy, typeRangeTy; + ast::Type attrTy; /// A counter used when naming anonymous constraints and rewrites. unsigned anonymousDeclNameCounter = 0; @@ -459,7 +498,7 @@ return ast::Module::create(ctx, moduleLoc, decls); } -LogicalResult Parser::parseModuleBody(SmallVector &decls) { +LogicalResult Parser::parseModuleBody(SmallVectorImpl &decls) { while (curToken.isNot(Token::eof)) { if (curToken.is(Token::directive)) { if (failed(parseDirective(decls))) @@ -516,6 +555,32 @@ // Allow conversion to a single value by constraining the result range. if (type == valueTy) { + // If the operation is registered, we can verify if it can ever have a + // single result. + Optional opName = exprOpType.getName(); + if (const ods::Operation *odsOp = lookupODSOperation(opName)) { + if (odsOp->getResults().empty()) { + return emitConvertError()->attachNote( + llvm::formatv("see the definition of `{0}`, which was defined " + "with zero results", + odsOp->getName()), + odsOp->getLoc()); + } + + unsigned numSingleResults = llvm::count_if( + odsOp->getResults(), [](const ods::OperandOrResult &result) { + return result.getVariableLengthKind() == + ods::VariableLengthKind::Single; + }); + if (numSingleResults > 1) { + return emitConvertError()->attachNote( + llvm::formatv("see the definition of `{0}`, which was defined " + "with at least {1} results", + odsOp->getName(), numSingleResults), + odsOp->getLoc()); + } + } + expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr, valueTy); return success(); @@ -569,7 +634,7 @@ //===----------------------------------------------------------------------===// // Directives -LogicalResult Parser::parseDirective(SmallVector &decls) { +LogicalResult Parser::parseDirective(SmallVectorImpl &decls) { StringRef directive = curToken.getSpelling(); if (directive == "#include") return parseInclude(decls); @@ -577,7 +642,7 @@ return emitError("unknown directive `" + directive + "`"); } -LogicalResult Parser::parseInclude(SmallVector &decls) { +LogicalResult Parser::parseInclude(SmallVectorImpl &decls) { SMRange loc = curToken.getLoc(); consumeToken(Token::directive); @@ -607,7 +672,193 @@ return result; } - return emitError(fileLoc, "expected include filename to end with `.pdll`"); + // Otherwise, this must be a `.td` include. + if (filename.endswith(".td")) + return parseTdInclude(filename, fileLoc, decls); + + return emitError(fileLoc, + "expected include filename to end with `.pdll` or `.td`"); +} + +LogicalResult Parser::parseTdInclude(StringRef filename, llvm::SMRange fileLoc, + SmallVectorImpl &decls) { + llvm::SourceMgr &parserSrcMgr = lexer.getSourceMgr(); + + // This class provides a context argument for the llvm::SourceMgr diagnostic + // handler. + struct DiagHandlerContext { + Parser &parser; + StringRef filename; + llvm::SMRange loc; + } handlerContext{*this, filename, fileLoc}; + + // Set the diagnostic handler for the tablegen source manager. + llvm::SrcMgr.setDiagHandler( + [](const llvm::SMDiagnostic &diag, void *rawHandlerContext) { + auto *ctx = reinterpret_cast(rawHandlerContext); + (void)ctx->parser.emitError( + ctx->loc, + llvm::formatv("error while processing include file `{0}`: {1}", + ctx->filename, diag.getMessage())); + }, + &handlerContext); + + // Use the source manager to open the file, but don't yet add it. + std::string includedFile; + llvm::ErrorOr> includeBuffer = + parserSrcMgr.OpenIncludeFile(filename.str(), includedFile); + if (!includeBuffer) + return emitError(fileLoc, "unable to open include file `" + filename + "`"); + + auto processFn = [&](llvm::RecordKeeper &records) { + processTdIncludeRecords(records, decls); + + // After we are done processing, move all of the tablegen source buffers to + // the main parser source mgr. This allows for directly using source + // locations from the .td files without needing to remap them. + parserSrcMgr.takeSourceBuffersFrom(llvm::SrcMgr); + return false; + }; + if (llvm::TableGenParseFile(std::move(*includeBuffer), + parserSrcMgr.getIncludeDirs(), processFn)) + return failure(); + + return success(); +} + +void Parser::processTdIncludeRecords(llvm::RecordKeeper &tdRecords, + SmallVectorImpl &decls) { + // Return the length kind of the given value. + auto getLengthKind = [](const auto &value) { + if (value.isOptional()) + return ods::VariableLengthKind::Optional; + return value.isVariadic() ? ods::VariableLengthKind::Variadic + : ods::VariableLengthKind::Single; + }; + + // Insert a type constraint into the ODS context. + ods::Context &odsContext = ctx.getODSContext(); + auto addTypeConstraint = [&](const tblgen::NamedTypeConstraint &cst) + -> const ods::TypeConstraint & { + return odsContext.insertTypeConstraint(cst.constraint.getDefName(), + cst.constraint.getSummary(), + cst.constraint.getCPPClassName()); + }; + auto convertLocToRange = [&](llvm::SMLoc loc) -> llvm::SMRange { + return {loc, llvm::SMLoc::getFromPointer(loc.getPointer() + 1)}; + }; + + // Process the parsed tablegen records to build ODS information. + /// Operations. + for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Op")) { + tblgen::Operator op(def); + + bool inserted = false; + ods::Operation *odsOp = nullptr; + std::tie(odsOp, inserted) = + odsContext.insertOperation(op.getOperationName(), op.getSummary(), + op.getDescription(), op.getLoc().front()); + + // Ignore operations that have already been added. + if (!inserted) + continue; + + for (const tblgen::NamedAttribute &attr : op.getAttributes()) { + odsOp->appendAttribute( + attr.name, attr.attr.isOptional(), + odsContext.insertAttributeConstraint(attr.attr.getAttrDefName(), + attr.attr.getSummary(), + attr.attr.getStorageType())); + } + for (const tblgen::NamedTypeConstraint &operand : op.getOperands()) { + odsOp->appendOperand(operand.name, getLengthKind(operand), + addTypeConstraint(operand)); + } + for (const tblgen::NamedTypeConstraint &result : op.getResults()) { + odsOp->appendResult(result.name, getLengthKind(result), + addTypeConstraint(result)); + } + } + /// Attr constraints. + for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Attr")) { + if (!def->isAnonymous() && !curDeclScope->lookup(def->getName())) { + decls.push_back( + createODSNativePDLLConstraintDecl( + tblgen::AttrConstraint(def), + convertLocToRange(def->getLoc().front()), attrTy)); + } + } + /// Type constraints. + for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Type")) { + if (!def->isAnonymous() && !curDeclScope->lookup(def->getName())) { + decls.push_back( + createODSNativePDLLConstraintDecl( + tblgen::TypeConstraint(def), + convertLocToRange(def->getLoc().front()), typeTy)); + } + } + /// Interfaces. + ast::Type opTy = ast::OperationType::get(ctx); + for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Interface")) { + StringRef name = def->getName(); + if (def->isAnonymous() || curDeclScope->lookup(name) || + def->isSubClassOf("DeclareInterfaceMethods")) + continue; + SMRange loc = convertLocToRange(def->getLoc().front()); + + StringRef className = def->getValueAsString("cppClassName"); + StringRef cppNamespace = def->getValueAsString("cppNamespace"); + std::string codeBlock = + llvm::formatv("llvm::isa<{0}::{1}>(self)", cppNamespace, className) + .str(); + + if (def->isSubClassOf("OpInterface")) { + decls.push_back(createODSNativePDLLConstraintDecl( + name, codeBlock, loc, opTy)); + } else if (def->isSubClassOf("AttrInterface")) { + decls.push_back( + createODSNativePDLLConstraintDecl( + name, codeBlock, loc, attrTy)); + } else if (def->isSubClassOf("TypeInterface")) { + decls.push_back( + createODSNativePDLLConstraintDecl( + name, codeBlock, loc, typeTy)); + } + } +} + +template +ast::Decl * +Parser::createODSNativePDLLConstraintDecl(StringRef name, StringRef codeBlock, + SMRange loc, ast::Type type) { + // Build the single input parameter. + ast::DeclScope *argScope = pushDeclScope(); + auto *paramVar = ast::VariableDecl::create( + ctx, ast::Name::create(ctx, "self", loc), type, + /*initExpr=*/nullptr, ast::ConstraintRef(ConstraintT::create(ctx, loc))); + argScope->add(paramVar); + popDeclScope(); + + // Build the native constraint. + auto *constraintDecl = ast::UserConstraintDecl::createNative( + ctx, ast::Name::create(ctx, name, loc), paramVar, + /*results=*/llvm::None, codeBlock, ast::TupleType::get(ctx)); + curDeclScope->add(constraintDecl); + return constraintDecl; +} + +template +ast::Decl * +Parser::createODSNativePDLLConstraintDecl(const tblgen::Constraint &constraint, + SMRange loc, ast::Type type) { + // Format the condition template. + tblgen::FmtContext fmtContext; + fmtContext.withSelf("self"); + std::string codeBlock = + tblgen::tgfmt(constraint.getConditionTemplate(), &fmtContext); + + return createODSNativePDLLConstraintDecl(constraint.getDefName(), + codeBlock, loc, type); } //===----------------------------------------------------------------------===// @@ -2302,9 +2553,29 @@ FailureOr Parser::validateMemberAccess(ast::Expr *parentExpr, StringRef name, SMRange loc) { ast::Type parentType = parentExpr->getType(); - if (parentType.isa()) { + if (ast::OperationType opType = parentType.dyn_cast()) { if (name == ast::AllResultsMemberAccessExpr::getMemberName()) return valueRangeTy; + + // Verify member access based on the operation type. + if (const ods::Operation *odsOp = lookupODSOperation(opType.getName())) { + auto results = odsOp->getResults(); + + // Handle indexed results. + unsigned index = 0; + if (llvm::isDigit(name[0]) && !name.getAsInteger(/*Radix=*/10, index) && + index < results.size()) { + return results[index].isVariadic() ? valueRangeTy : valueTy; + } + + // Handle named results. + const auto *it = llvm::find_if(results, [&](const auto &result) { + return result.getName() == name; + }); + if (it != results.end()) + return it->isVariadic() ? valueRangeTy : valueTy; + } + } else if (auto tupleType = parentType.dyn_cast()) { // Handle indexed results. unsigned index = 0; @@ -2331,9 +2602,10 @@ MutableArrayRef attributes, MutableArrayRef results) { Optional opNameRef = name->getName(); + const ods::Operation *odsOp = lookupODSOperation(opNameRef); // Verify the inputs operands. - if (failed(validateOperationOperands(loc, opNameRef, operands))) + if (failed(validateOperationOperands(loc, opNameRef, odsOp, operands))) return failure(); // Verify the attribute list. @@ -2348,7 +2620,7 @@ } // Verify the result types. - if (failed(validateOperationResults(loc, opNameRef, results))) + if (failed(validateOperationResults(loc, opNameRef, odsOp, results))) return failure(); return ast::OperationExpr::create(ctx, loc, name, operands, results, @@ -2357,21 +2629,28 @@ LogicalResult Parser::validateOperationOperands(SMRange loc, Optional name, + const ods::Operation *odsOp, MutableArrayRef operands) { - return validateOperationOperandsOrResults(loc, name, operands, valueTy, - valueRangeTy); + return validateOperationOperandsOrResults( + "operand", loc, odsOp ? odsOp->getLoc() : Optional(), name, + operands, odsOp ? odsOp->getOperands() : llvm::None, valueTy, + valueRangeTy); } LogicalResult Parser::validateOperationResults(SMRange loc, Optional name, + const ods::Operation *odsOp, MutableArrayRef results) { - return validateOperationOperandsOrResults(loc, name, results, typeTy, - typeRangeTy); + return validateOperationOperandsOrResults( + "result", loc, odsOp ? odsOp->getLoc() : Optional(), name, + results, odsOp ? odsOp->getResults() : llvm::None, typeTy, typeRangeTy); } LogicalResult Parser::validateOperationOperandsOrResults( - SMRange loc, Optional name, MutableArrayRef values, - ast::Type singleTy, ast::Type rangeTy) { + StringRef groupName, SMRange loc, Optional odsOpLoc, + Optional name, MutableArrayRef values, + ArrayRef odsValues, ast::Type singleTy, + ast::Type rangeTy) { // All operation types accept a single range parameter. if (values.size() == 1) { if (failed(convertExpressionTo(values[0], rangeTy))) @@ -2379,6 +2658,29 @@ return success(); } + /// If the operation has ODS information, we can more accurately verify the + /// values. + if (odsOpLoc) { + if (odsValues.size() != values.size()) { + return emitErrorAndNote( + loc, + llvm::formatv("invalid number of {0} groups for `{1}`; expected " + "{2}, but got {3}", + groupName, *name, odsValues.size(), values.size()), + *odsOpLoc, llvm::formatv("see the definition of `{0}` here", *name)); + } + auto diagFn = [&](ast::Diagnostic &diag) { + diag.attachNote(llvm::formatv("see the definition of `{0}` here", *name), + *odsOpLoc); + }; + for (unsigned i = 0, e = values.size(); i < e; ++i) { + ast::Type expectedType = odsValues[i].isVariadic() ? rangeTy : singleTy; + if (failed(convertExpressionTo(values[i], expectedType, diagFn))) + return failure(); + } + return success(); + } + // Otherwise, accept the value groups as they have been defined and just // ensure they are one of the expected types. for (ast::Expr *&valueExpr : values) { diff --git a/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll b/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll --- a/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll +++ b/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll @@ -1,4 +1,4 @@ -// RUN: mlir-pdll %s -I %S -split-input-file -x mlir | FileCheck %s +// RUN: mlir-pdll %s -I %S -I %S/../../../../include -split-input-file -x mlir | FileCheck %s //===----------------------------------------------------------------------===// // AttributeExpr @@ -55,6 +55,24 @@ // ----- +// Handle implicit "named" operation results access. + +#include "include/ops.td" + +// CHECK: pdl.pattern @OpResultMemberAccess +// CHECK: %[[OP0:.*]] = operation +// CHECK: %[[RES:.*]] = results 0 of %[[OP0]] -> !pdl.value +// CHECK: %[[RES1:.*]] = results 0 of %[[OP0]] -> !pdl.value +// CHECK: %[[RES2:.*]] = results 1 of %[[OP0]] -> !pdl.range +// CHECK: %[[RES3:.*]] = results 1 of %[[OP0]] -> !pdl.range +// CHECK: operation(%[[RES]], %[[RES1]], %[[RES2]], %[[RES3]] : !pdl.value, !pdl.value, !pdl.range, !pdl.range) +Pattern OpResultMemberAccess { + let op: Op; + erase op<>(op.0, op.result, op.1, op.var_result); +} + +// ----- + // CHECK: pdl.pattern @TupleMemberAccessNumber // CHECK: %[[FIRST:.*]] = operation "test.first" // CHECK: %[[SECOND:.*]] = operation "test.second" diff --git a/mlir/test/mlir-pdll/CodeGen/MLIR/include/ops.td b/mlir/test/mlir-pdll/CodeGen/MLIR/include/ops.td new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-pdll/CodeGen/MLIR/include/ops.td @@ -0,0 +1,9 @@ +include "mlir/IR/OpBase.td" + +def Test_Dialect : Dialect { + let name = "test"; +} + +def OpWithResults : Op { + let results = (outs I64:$result, Variadic:$var_result); +} diff --git a/mlir/test/mlir-pdll/Parser/directive-failure.pdll b/mlir/test/mlir-pdll/Parser/directive-failure.pdll --- a/mlir/test/mlir-pdll/Parser/directive-failure.pdll +++ b/mlir/test/mlir-pdll/Parser/directive-failure.pdll @@ -19,5 +19,5 @@ // ----- -// CHECK: expected include filename to end with `.pdll` +// CHECK: expected include filename to end with `.pdll` or `.td` #include "unknown_file.foo" diff --git a/mlir/test/mlir-pdll/Parser/expr-failure.pdll b/mlir/test/mlir-pdll/Parser/expr-failure.pdll --- a/mlir/test/mlir-pdll/Parser/expr-failure.pdll +++ b/mlir/test/mlir-pdll/Parser/expr-failure.pdll @@ -1,4 +1,4 @@ -// RUN: not mlir-pdll %s -split-input-file 2>&1 | FileCheck %s +// RUN: not mlir-pdll %s -I %S -I %S/../../../include -split-input-file 2>&1 | FileCheck %s //===----------------------------------------------------------------------===// // Reference Expr @@ -276,6 +276,26 @@ // ----- +#include "include/ops.td" + +Pattern { + // CHECK: invalid number of operand groups for `test.all_empty`; expected 0, but got 2 + // CHECK: see the definition of `test.all_empty` here + let foo = op(operand1: Value, operand2: Value); +} + +// ----- + +#include "include/ops.td" + +Pattern { + // CHECK: invalid number of result groups for `test.all_empty`; expected 0, but got 2 + // CHECK: see the definition of `test.all_empty` here + let foo = op -> (result1: Type, result2: Type); +} + +// ----- + //===----------------------------------------------------------------------===// // `type` Expr //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-pdll/Parser/expr.pdll b/mlir/test/mlir-pdll/Parser/expr.pdll --- a/mlir/test/mlir-pdll/Parser/expr.pdll +++ b/mlir/test/mlir-pdll/Parser/expr.pdll @@ -1,4 +1,4 @@ -// RUN: mlir-pdll %s -I %S -split-input-file | FileCheck %s +// RUN: mlir-pdll %s -I %S -I %S/../../../include -split-input-file | FileCheck %s //===----------------------------------------------------------------------===// // AttrExpr @@ -71,6 +71,25 @@ // ----- +#include "include/ops.td" + +// CHECK: Module +// CHECK: `-VariableDecl {{.*}} Name Type +// CHECK: `-MemberAccessExpr {{.*}} Member<0> Type +// CHECK: `-DeclRefExpr {{.*}} Type> +// CHECK: `-VariableDecl {{.*}} Name Type +// CHECK: `-MemberAccessExpr {{.*}} Member Type +// CHECK: `-DeclRefExpr {{.*}} Type> +Pattern { + let op: Op; + let firstEltIndex = op.0; + let firstEltName = op.result; + + erase op; +} + +// ----- + //===----------------------------------------------------------------------===// // OperationExpr //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-pdll/Parser/include/interfaces.td b/mlir/test/mlir-pdll/Parser/include/interfaces.td new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-pdll/Parser/include/interfaces.td @@ -0,0 +1,5 @@ +include "mlir/IR/OpBase.td" + +def TestAttrInterface : AttrInterface<"TestAttrInterface">; +def TestOpInterface : OpInterface<"TestOpInterface">; +def TestTypeInterface : TypeInterface<"TestTypeInterface">; diff --git a/mlir/test/mlir-pdll/Parser/include/ops.td b/mlir/test/mlir-pdll/Parser/include/ops.td new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-pdll/Parser/include/ops.td @@ -0,0 +1,26 @@ +include "include/interfaces.td" + +def Test_Dialect : Dialect { + let name = "test"; +} + +def OpAllEmpty : Op; + +def OpAllSingle : Op { + let arguments = (ins I64:$operand, I64Attr:$attr); + let results = (outs I64:$result); +} + +def OpAllOptional : Op { + let arguments = (ins Optional:$operand, OptionalAttr:$attr); + let results = (outs Optional:$result); +} + +def OpAllVariadic : Op { + let arguments = (ins Variadic:$operands); + let results = (outs Variadic:$results); +} + +def OpMultipleSingleResult : Op { + let results = (outs I64:$result, I64:$result2); +} diff --git a/mlir/test/mlir-pdll/Parser/include_td.pdll b/mlir/test/mlir-pdll/Parser/include_td.pdll new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-pdll/Parser/include_td.pdll @@ -0,0 +1,52 @@ +// RUN: mlir-pdll %s -I %S -I %S/../../../include -dump-ods 2>&1 | FileCheck %s + +#include "include/ops.td" + +// CHECK: Operation `test.all_empty` { +// CHECK-NEXT: } + +// CHECK: Operation `test.all_optional` { +// CHECK-NEXT: Attributes { attr : Optional } +// CHECK-NEXT: Operands { operand : Optional } +// CHECK-NEXT: Results { result : Optional } +// CHECK-NEXT: } + +// CHECK: Operation `test.all_single` { +// CHECK-NEXT: Attributes { attr : I64Attr } +// CHECK-NEXT: Operands { operand : I64 } +// CHECK-NEXT: Results { result : I64 } +// CHECK-NEXT: } + +// CHECK: Operation `test.all_variadic` { +// CHECK-NEXT: Operands { operands : Variadic } +// CHECK-NEXT: Results { results : Variadic } +// CHECK-NEXT: } + +// CHECK: AttributeConstraint `I64Attr` { +// CHECK-NEXT: Summary: 64-bit signless integer attribute +// CHECK-NEXT: CppClass: ::mlir::IntegerAttr +// CHECK-NEXT: } + +// CHECK: TypeConstraint `I64` { +// CHECK-NEXT: Summary: 64-bit signless integer +// CHECK-NEXT: CppClass: ::mlir::IntegerType +// CHECK-NEXT: } + +// CHECK: UserConstraintDecl {{.*}} Name ResultType> Code(self)> +// CHECK: `Inputs` +// CHECK: `-VariableDecl {{.*}} Name Type +// CHECK: `Constraints` +// CHECK: `-AttrConstraintDecl + +// CHECK: UserConstraintDecl {{.*}} Name ResultType> Code(self)> +// CHECK: `Inputs` +// CHECK: `-VariableDecl {{.*}} Name Type +// CHECK: `Constraints` +// CHECK: `-OpConstraintDecl +// CHECK: `-OpNameDecl + +// CHECK: UserConstraintDecl {{.*}} Name ResultType> Code(self)> +// CHECK: `Inputs` +// CHECK: `-VariableDecl {{.*}} Name Type +// CHECK: `Constraints` +// CHECK: `-TypeConstraintDecl {{.*}} diff --git a/mlir/test/mlir-pdll/Parser/stmt-failure.pdll b/mlir/test/mlir-pdll/Parser/stmt-failure.pdll --- a/mlir/test/mlir-pdll/Parser/stmt-failure.pdll +++ b/mlir/test/mlir-pdll/Parser/stmt-failure.pdll @@ -1,4 +1,4 @@ -// RUN: not mlir-pdll %s -split-input-file 2>&1 | FileCheck %s +// RUN: not mlir-pdll %s -I %S -I %S/../../../include -split-input-file 2>&1 | FileCheck %s // CHECK: expected top-level declaration, such as a `Pattern` 10 @@ -250,6 +250,28 @@ // ----- +#include "include/ops.td" + +Pattern { + // CHECK: unable to convert expression of type `Op` to the expected type of `Value` + // CHECK: see the definition of `test.all_empty`, which was defined with zero results + let value: Value = op; + erase _: Op; +} + +// ----- + +#include "include/ops.td" + +Pattern { + // CHECK: unable to convert expression of type `Op` to the expected type of `Value` + // CHECK: see the definition of `test.multiple_single_result`, which was defined with at least 2 results + let value: Value = op; + erase _: Op; +} + +// ----- + //===----------------------------------------------------------------------===// // `replace` //===----------------------------------------------------------------------===// diff --git a/mlir/tools/mlir-pdll/mlir-pdll.cpp b/mlir/tools/mlir-pdll/mlir-pdll.cpp --- a/mlir/tools/mlir-pdll/mlir-pdll.cpp +++ b/mlir/tools/mlir-pdll/mlir-pdll.cpp @@ -13,6 +13,7 @@ #include "mlir/Tools/PDLL/AST/Nodes.h" #include "mlir/Tools/PDLL/CodeGen/CPPGen.h" #include "mlir/Tools/PDLL/CodeGen/MLIRGen.h" +#include "mlir/Tools/PDLL/ODS/Context.h" #include "mlir/Tools/PDLL/Parser/Parser.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/InitLLVM.h" @@ -35,16 +36,23 @@ static LogicalResult processBuffer(raw_ostream &os, std::unique_ptr chunkBuffer, - OutputType outputType, std::vector &includeDirs) { + OutputType outputType, std::vector &includeDirs, + bool dumpODS) { llvm::SourceMgr sourceMgr; sourceMgr.setIncludeDirs(includeDirs); sourceMgr.AddNewSourceBuffer(std::move(chunkBuffer), SMLoc()); - ast::Context astContext; + ods::Context odsContext; + ast::Context astContext(odsContext); FailureOr module = parsePDLAST(astContext, sourceMgr); if (failed(module)) return failure(); + // Print out the ODS information if requested. + if (dumpODS) + odsContext.print(llvm::errs()); + + // Generate the output. if (outputType == OutputType::AST) { (*module)->print(os); return success(); @@ -66,6 +74,10 @@ } int main(int argc, char **argv) { + // FIXME: This is necessary because we link in TableGen, which defines its + // options as static variables.. some of which overlap with our options. + llvm::cl::ResetCommandLineParser(); + llvm::cl::opt inputFilename( llvm::cl::Positional, llvm::cl::desc(""), llvm::cl::init("-"), llvm::cl::value_desc("filename")); @@ -78,6 +90,11 @@ "I", llvm::cl::desc("Directory of include files"), llvm::cl::value_desc("directory"), llvm::cl::Prefix); + llvm::cl::opt dumpODS( + "dump-ods", + llvm::cl::desc( + "Print out the parsed ODS information from the input file"), + llvm::cl::init(false)); llvm::cl::opt splitInputFile( "split-input-file", llvm::cl::desc("Split the input file into pieces and process each " @@ -118,7 +135,8 @@ // up into small pieces and checks each independently. auto processFn = [&](std::unique_ptr chunkBuffer, raw_ostream &os) { - return processBuffer(os, std::move(chunkBuffer), outputType, includeDirs); + return processBuffer(os, std::move(chunkBuffer), outputType, includeDirs, + dumpODS); }; if (splitInputFile) { if (failed(splitAndProcessBuffer(std::move(inputFile), processFn,