diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDL.h b/mlir/include/mlir/Dialect/PDL/IR/PDL.h --- a/mlir/include/mlir/Dialect/PDL/IR/PDL.h +++ b/mlir/include/mlir/Dialect/PDL/IR/PDL.h @@ -13,11 +13,7 @@ #ifndef MLIR_DIALECT_PDL_IR_PDL_H_ #define MLIR_DIALECT_PDL_IR_PDL_H_ -#include "mlir/IR/Builders.h" #include "mlir/IR/Dialect.h" -#include "mlir/IR/OpImplementation.h" -#include "mlir/IR/SymbolTable.h" -#include "mlir/Interfaces/SideEffectInterfaces.h" //===----------------------------------------------------------------------===// // PDL Dialect @@ -25,12 +21,4 @@ #include "mlir/Dialect/PDL/IR/PDLOpsDialect.h.inc" -//===----------------------------------------------------------------------===// -// PDL Dialect Operations -//===----------------------------------------------------------------------===// - -#define GET_OP_CLASSES -#include "mlir/Dialect/PDL/IR/PDLOps.h.inc" - - #endif // MLIR_DIALECT_PDL_IR_PDL_H_ diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDLBase.td b/mlir/include/mlir/Dialect/PDL/IR/PDLDialect.td rename from mlir/include/mlir/Dialect/PDL/IR/PDLBase.td rename to mlir/include/mlir/Dialect/PDL/IR/PDLDialect.td --- a/mlir/include/mlir/Dialect/PDL/IR/PDLBase.td +++ b/mlir/include/mlir/Dialect/PDL/IR/PDLDialect.td @@ -1,4 +1,4 @@ -//===- PDLBase.td - PDL base definitions -------------------*- tablegen -*-===// +//===- PDLDialect.td - PDL dialect definition --------------*- tablegen -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,12 +6,12 @@ // //===----------------------------------------------------------------------===// // -// Defines base support for MLIR PDL operations. +// Defines the MLIR PDL dialect. // //===----------------------------------------------------------------------===// -#ifndef MLIR_DIALECT_PDL_IR_PDLBASE -#define MLIR_DIALECT_PDL_IR_PDLBASE +#ifndef MLIR_DIALECT_PDL_IR_PDLDIALECT +#define MLIR_DIALECT_PDL_IR_PDLDIALECT include "mlir/IR/OpBase.td" @@ -66,31 +66,4 @@ let cppNamespace = "::mlir::pdl"; } -//===----------------------------------------------------------------------===// -// PDL Types -//===----------------------------------------------------------------------===// - -class PDL_Handle : - DialectType()">, - underlying>, - BuildableType<"$_builder.getType<" # underlying # ">()">; - -// Handle for `mlir::Attribute`. -def PDL_Attribute : PDL_Handle<"mlir::pdl::AttributeType">; - -// Handle for `mlir::Operation*`. -def PDL_Operation : PDL_Handle<"mlir::pdl::OperationType">; - -// Handle for `mlir::Type`. -def PDL_Type : PDL_Handle<"mlir::pdl::TypeType">; - -// Handle for `mlir::Value`. -def PDL_Value : PDL_Handle<"mlir::pdl::ValueType">; - -// A positional value is a location on a pattern DAG, which may be an operation, -// an attribute, or an operand/result. -def PDL_PositionalValue : - AnyTypeOf<[PDL_Attribute, PDL_Operation, PDL_Type, PDL_Value], - "Positional Value">; - -#endif // MLIR_DIALECT_PDL_IR_PDLBASE +#endif // MLIR_DIALECT_PDL_IR_PDLDIALECT diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDL.h b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.h copy from mlir/include/mlir/Dialect/PDL/IR/PDL.h copy to mlir/include/mlir/Dialect/PDL/IR/PDLOps.h --- a/mlir/include/mlir/Dialect/PDL/IR/PDL.h +++ b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.h @@ -1,4 +1,4 @@ -//===- PDL.h - Pattern Descriptor Language Dialect --------------*- C++ -*-===// +//===- PDLOps.h - Pattern Descriptor Language Operations --------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,25 +6,19 @@ // //===----------------------------------------------------------------------===// // -// This file defines the dialect for the Pattern Descriptor Language. +// This file defines the operations for the Pattern Descriptor Language dialect. // //===----------------------------------------------------------------------===// -#ifndef MLIR_DIALECT_PDL_IR_PDL_H_ -#define MLIR_DIALECT_PDL_IR_PDL_H_ +#ifndef MLIR_DIALECT_PDL_IR_PDLOPS_H_ +#define MLIR_DIALECT_PDL_IR_PDLOPS_H_ +#include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/Dialect.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/SymbolTable.h" #include "mlir/Interfaces/SideEffectInterfaces.h" -//===----------------------------------------------------------------------===// -// PDL Dialect -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/PDL/IR/PDLOpsDialect.h.inc" - //===----------------------------------------------------------------------===// // PDL Dialect Operations //===----------------------------------------------------------------------===// @@ -32,5 +26,4 @@ #define GET_OP_CLASSES #include "mlir/Dialect/PDL/IR/PDLOps.h.inc" - -#endif // MLIR_DIALECT_PDL_IR_PDL_H_ +#endif // MLIR_DIALECT_PDL_IR_PDLOPS_H_ diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td --- a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td +++ b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td @@ -13,7 +13,7 @@ #ifndef MLIR_DIALECT_PDL_IR_PDLOPS #define MLIR_DIALECT_PDL_IR_PDLOPS -include "mlir/Dialect/PDL/IR/PDLBase.td" +include "mlir/Dialect/PDL/IR/PDLTypes.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/SymbolInterfaces.td" diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.h b/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.h --- a/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.h +++ b/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.h @@ -1,4 +1,4 @@ -//===- PDL.h - Pattern Descriptor Language Types ----------------*- C++ -*-===// +//===- PDLTypes.h - Pattern Descriptor Language Types -----------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -15,33 +15,11 @@ #include "mlir/IR/Types.h" -namespace mlir { -namespace pdl { //===----------------------------------------------------------------------===// // PDL Dialect Types //===----------------------------------------------------------------------===// -/// This type represents a handle to an `mlir::Attribute`. -struct AttributeType : public Type::TypeBase { - using Base::Base; -}; - -/// This type represents a handle to an `mlir::Operation*`. -struct OperationType : public Type::TypeBase { - using Base::Base; -}; - -/// This type represents a handle to an `mlir::Type`. -struct TypeType : public Type::TypeBase { - using Base::Base; -}; - -/// This type represents a handle to an `mlir::Value`. -struct ValueType : public Type::TypeBase { - using Base::Base; -}; - -} // end namespace pdl -} // end namespace mlir +#define GET_TYPEDEF_CLASSES +#include "mlir/Dialect/PDL/IR/PDLOpsTypes.h.inc" #endif // MLIR_DIALECT_PDL_IR_PDLTYPES_H_ diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.td b/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.td @@ -0,0 +1,84 @@ +//===- PDLTypes.td - Pattern descriptor types --------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the Pattern Descriptor Language dialect types. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_PDL_IR_PDLTYPES +#define MLIR_DIALECT_PDL_IR_PDLTYPES + +include "mlir/Dialect/PDL/IR/PDLDialect.td" + +//===----------------------------------------------------------------------===// +// PDL Types +//===----------------------------------------------------------------------===// + +class PDL_Type : TypeDef { + let mnemonic = typeMnemonic; +} + +//===----------------------------------------------------------------------===// +// pdl::AttributeType +//===----------------------------------------------------------------------===// + +def PDL_Attribute : PDL_Type<"Attribute", "attribute"> { + let summary = "PDL handle to an `mlir::Attribute`"; + let description = [{ + This type represents a handle to an instance of an `mlir::Attribute`, bound + to a value that is usable within a PDL pattern or rewrite. + }]; +} + +//===----------------------------------------------------------------------===// +// pdl::OperationType +//===----------------------------------------------------------------------===// + +def PDL_Operation : PDL_Type<"Operation", "operation"> { + let summary = "PDL handle to an `mlir::Operation *`"; + let description = [{ + This type represents a handle to an instance of an `mlir::Operation *`, + bound to a value that is usable within a PDL pattern or rewrite. + }]; +} + +//===----------------------------------------------------------------------===// +// pdl::TypeType +//===----------------------------------------------------------------------===// + +def PDL_Type : PDL_Type<"Type", "type"> { + let summary = "PDL handle to an `mlir::Type`"; + let description = [{ + This type represents a handle to an instance of an `mlir::Type`, bound to a + value that is usable within a PDL pattern or rewrite. + }]; +} + +//===----------------------------------------------------------------------===// +// pdl::ValueType +//===----------------------------------------------------------------------===// + +def PDL_Value : PDL_Type<"Value", "value"> { + let summary = "PDL handle for an `mlir::Value`"; + let description = [{ + This type represents a handle to an instance of an `mlir::Value`, bound to a + value that is usable within a PDL pattern or rewrite. + }]; +} + +//===----------------------------------------------------------------------===// +// Additional Type Constraints +//===----------------------------------------------------------------------===// + +// A positional value is a location on a pattern DAG, which may be an attribute, +// operation, or operand/result. +def PDL_PositionalValue : + AnyTypeOf<[PDL_Attribute, PDL_Operation, PDL_Type, PDL_Value], + "Positional Value">; + +#endif // MLIR_DIALECT_PDL_IR_PDLTYPES diff --git a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterp.h b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterp.h --- a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterp.h +++ b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterp.h @@ -15,6 +15,7 @@ #define MLIR_DIALECT_PDLINTERP_IR_PDLINTERP_H_ #include "mlir/Dialect/PDL/IR/PDL.h" +#include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" diff --git a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td --- a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td +++ b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td @@ -13,7 +13,7 @@ #ifndef MLIR_DIALECT_PDLINTERP_IR_PDLINTERPOPS #define MLIR_DIALECT_PDLINTERP_IR_PDLINTERPOPS -include "mlir/Dialect/PDL/IR/PDLBase.td" +include "mlir/Dialect/PDL/IR/PDLTypes.td" include "mlir/Interfaces/SideEffectInterfaces.td" //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.h b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.h --- a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.h +++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.h @@ -15,7 +15,7 @@ #define MLIR_LIB_CONVERSION_PDLTOPDLINTERP_PREDICATETREE_H_ #include "Predicate.h" -#include "mlir/Dialect/PDL/IR/PDL.h" +#include "mlir/Dialect/PDL/IR/PDLOps.h" #include "llvm/ADT/MapVector.h" namespace mlir { diff --git a/mlir/lib/Dialect/PDL/IR/PDL.cpp b/mlir/lib/Dialect/PDL/IR/PDL.cpp --- a/mlir/lib/Dialect/PDL/IR/PDL.cpp +++ b/mlir/lib/Dialect/PDL/IR/PDL.cpp @@ -7,11 +7,13 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/PDL/IR/PDL.h" +#include "mlir/Dialect/PDL/IR/PDLOps.h" #include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "llvm/ADT/StringSwitch.h" +#include "llvm/ADT/TypeSwitch.h" using namespace mlir; using namespace mlir::pdl; @@ -25,38 +27,10 @@ #define GET_OP_LIST #include "mlir/Dialect/PDL/IR/PDLOps.cpp.inc" >(); - addTypes(); -} - -Type PDLDialect::parseType(DialectAsmParser &parser) const { - StringRef keyword; - if (parser.parseKeyword(&keyword)) - return Type(); - - Builder &builder = parser.getBuilder(); - Type result = StringSwitch(keyword) - .Case("attribute", builder.getType()) - .Case("operation", builder.getType()) - .Case("type", builder.getType()) - .Case("value", builder.getType()) - .Default(Type()); - if (!result) - parser.emitError(parser.getNameLoc(), "invalid 'pdl' type: `") - << keyword << "'"; - return result; -} - -void PDLDialect::printType(Type type, DialectAsmPrinter &printer) const { - if (type.isa()) - printer << "attribute"; - else if (type.isa()) - printer << "operation"; - else if (type.isa()) - printer << "type"; - else if (type.isa()) - printer << "value"; - else - llvm_unreachable("unknown 'pdl' type"); + addTypes< +#define GET_TYPEDEF_LIST +#include "mlir/Dialect/PDL/IR/PDLOpsTypes.cpp.inc" + >(); } /// Returns true if the given operation is used by a "binding" pdl operation @@ -456,3 +430,27 @@ #define GET_OP_CLASSES #include "mlir/Dialect/PDL/IR/PDLOps.cpp.inc" + +//===----------------------------------------------------------------------===// +// TableGen'd type method definitions +//===----------------------------------------------------------------------===// + +#define GET_TYPEDEF_CLASSES +#include "mlir/Dialect/PDL/IR/PDLOpsTypes.cpp.inc" + +Type PDLDialect::parseType(DialectAsmParser &parser) const { + StringRef keyword; + if (parser.parseKeyword(&keyword)) + return Type(); + if (Type type = generatedTypeParser(getContext(), parser, keyword)) + return type; + + parser.emitError(parser.getNameLoc(), "invalid 'pdl' type: `") + << keyword << "'"; + return Type(); +} + +void PDLDialect::printType(Type type, DialectAsmPrinter &printer) const { + if (failed(generatedTypePrinter(type, printer))) + llvm_unreachable("unknown 'pdl' type"); +} diff --git a/mlir/lib/Rewrite/FrozenRewritePatternList.cpp b/mlir/lib/Rewrite/FrozenRewritePatternList.cpp --- a/mlir/lib/Rewrite/FrozenRewritePatternList.cpp +++ b/mlir/lib/Rewrite/FrozenRewritePatternList.cpp @@ -9,7 +9,7 @@ #include "mlir/Rewrite/FrozenRewritePatternList.h" #include "ByteCode.h" #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h" -#include "mlir/Dialect/PDL/IR/PDL.h" +#include "mlir/Dialect/PDL/IR/PDLOps.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" diff --git a/mlir/tools/mlir-tblgen/TypeDefGen.cpp b/mlir/tools/mlir-tblgen/TypeDefGen.cpp --- a/mlir/tools/mlir-tblgen/TypeDefGen.cpp +++ b/mlir/tools/mlir-tblgen/TypeDefGen.cpp @@ -537,12 +537,21 @@ os << "static ::mlir::Type generatedTypeParser(::mlir::MLIRContext* " "ctxt, " "::mlir::DialectAsmParser& parser, ::llvm::StringRef mnemonic) {\n"; - for (const TypeDef &type : types) - if (type.getMnemonic()) + for (const TypeDef &type : types) { + if (type.getMnemonic()) { os << formatv(" if (mnemonic == {0}::{1}::getMnemonic()) return " - "{0}::{1}::parse(ctxt, parser);\n", + "{0}::{1}::", type.getDialect().getCppNamespace(), type.getCppClassName()); + + // If the type has no parameters and no parser code, just invoke a normal + // `get`. + if (type.getNumParameters() == 0 && !type.getParserCode()) + os << "get(ctxt);\n"; + else + os << "parse(ctxt, parser);\n"; + } + } os << " return ::mlir::Type();\n"; os << "}\n\n"; @@ -551,17 +560,27 @@ os << "static ::mlir::LogicalResult generatedTypePrinter(::mlir::Type " "type, " "::mlir::DialectAsmPrinter& printer) {\n" - << " ::mlir::LogicalResult found = ::mlir::success();\n" - << " ::llvm::TypeSwitch<::mlir::Type>(type)\n"; - for (const TypeDef &type : types) - if (type.getMnemonic()) - os << formatv(" .Case<{0}::{1}>([&](::mlir::Type t) {{ " - "t.dyn_cast<{0}::{1}>().print(printer); })\n", - type.getDialect().getCppNamespace(), - type.getCppClassName()); - os << " .Default([&found](::mlir::Type) { found = ::mlir::failure(); " - "});\n" - << " return found;\n" + << " return ::llvm::TypeSwitch<::mlir::Type, " + "::mlir::LogicalResult>(type)\n"; + for (const TypeDef &type : types) { + if (Optional mnemonic = type.getMnemonic()) { + StringRef cppNamespace = type.getDialect().getCppNamespace(); + StringRef cppClassName = type.getCppClassName(); + + os << formatv(" .Case<{0}::{1}>([&]({0}::{1} t) {{\n ", + cppNamespace, cppClassName); + + // If the type has no parameters and no printer code, just print the + // mnemonic. + if (type.getNumParameters() == 0 && !type.getParserCode()) + os << formatv("printer << {0}::{1}::getMnemonic();", cppNamespace, + cppClassName); + else + os << "t.print(printer);"; + os << "\n return ::mlir::success();\n })\n"; + } + } + os << " .Default([](::mlir::Type) { return ::mlir::failure(); });\n" << "}\n\n"; }