diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td --- a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td +++ b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td @@ -52,7 +52,7 @@ }]; let arguments = (ins StrAttr:$name, - Variadic:$args, + Variadic:$args, OptionalAttr:$constParams); let assemblyFormat = [{ $name ($constParams^)? `(` $args `:` type($args) `)` attr-dict @@ -136,9 +136,9 @@ }]; let arguments = (ins StrAttr:$name, - Variadic:$args, + Variadic:$args, OptionalAttr:$constParams); - let results = (outs PDL_PositionalValue:$result); + let results = (outs PDL_AnyType:$result); let assemblyFormat = [{ $name ($constParams^)? (`(` $args^ `:` type($args) `)`)? `:` type($result) attr-dict @@ -403,7 +403,7 @@ let arguments = (ins PDL_Operation:$root, OptionalAttr:$name, - Variadic:$externalArgs, + Variadic:$externalArgs, OptionalAttr:$externalConstParams); let regions = (region AnyRegion:$body); let assemblyFormat = [{ diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.h b/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.h --- a/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.h +++ b/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.h @@ -19,6 +19,18 @@ // PDL Dialect Types //===----------------------------------------------------------------------===// +namespace mlir { +namespace pdl { +/// This class represents the base class of all PDL types. +class PDLType : public Type { +public: + using Type::Type; + + static bool classof(Type type); +}; +} // namespace pdl +} // namespace mlir + #define GET_TYPEDEF_CLASSES #include "mlir/Dialect/PDL/IR/PDLOpsTypes.h.inc" diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.td b/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.td --- a/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.td +++ b/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.td @@ -19,7 +19,8 @@ // PDL Types //===----------------------------------------------------------------------===// -class PDL_Type : TypeDef { +class PDL_Type + : TypeDef { let mnemonic = typeMnemonic; } @@ -47,6 +48,29 @@ }]; } +//===----------------------------------------------------------------------===// +// pdl::RangeType +//===----------------------------------------------------------------------===// + +def PDL_Range : PDL_Type<"Range", "range"> { + let summary = "Range of PDL handles for a given sub-type"; + let description = [{ + This type represents a range of handles to instances of the given PDL + element type, i.e. `Attribute`, `Operation`, `Type`, or `Value`. + }]; + let parameters = (ins "Type":$elementType); + + let builders = [ + TypeBuilderWithInferredContext<(ins "Type":$elementType), [{ + return Base::get(elementType.getContext(), elementType); + }], [{ + return Base::getChecked($_loc, elementType); + }]>, + ]; + let genVerifyInvariantsDecl = 1; + let skipDefaultBuilders = 1; +} + //===----------------------------------------------------------------------===// // pdl::TypeType //===----------------------------------------------------------------------===// @@ -75,10 +99,8 @@ // Additional Type Constraints //===----------------------------------------------------------------------===// -// A positional value is a location on a pattern DAG, which may be an attribute, -// operation, or operand/result. -def PDL_PositionalValue : - AnyTypeOf<[PDL_Attribute, PDL_Operation, PDL_Type, PDL_Value], - "Positional Value">; +def PDL_AnyType : Type< + CPred<"$_self.isa<::mlir::pdl::PDLType>()">, "pdl type", + "::mlir::pdl::PDLType">; #endif // MLIR_DIALECT_PDL_IR_PDLTYPES diff --git a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td --- a/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td +++ b/mlir/include/mlir/Dialect/PDLInterp/IR/PDLInterpOps.td @@ -113,7 +113,7 @@ }]; let arguments = (ins StrAttr:$name, - Variadic:$args, + Variadic:$args, OptionalAttr:$constParams); let assemblyFormat = [{ $name ($constParams^)? `(` $args `:` type($args) `)` attr-dict `->` @@ -151,7 +151,7 @@ }]; let arguments = (ins StrAttr:$name, PDL_Operation:$root, - Variadic:$args, + Variadic:$args, OptionalAttr:$constParams); let assemblyFormat = [{ $name ($constParams^)? (`(` $args^ `:` type($args) `)`)? `on` $root @@ -178,8 +178,7 @@ ``` }]; - let arguments = (ins PDL_PositionalValue:$lhs, - PDL_PositionalValue:$rhs); + let arguments = (ins PDL_AnyType:$lhs, PDL_AnyType:$rhs); let assemblyFormat = "operands `:` type($lhs) attr-dict `->` successors"; } @@ -374,9 +373,9 @@ }]; let arguments = (ins StrAttr:$name, - Variadic:$args, + Variadic:$args, OptionalAttr:$constParams); - let results = (outs PDL_PositionalValue:$result); + let results = (outs PDL_AnyType:$result); let assemblyFormat = [{ $name ($constParams^)? (`(` $args^ `:` type($args) `)`)? `:` type($result) attr-dict @@ -691,7 +690,7 @@ ``` }]; - let arguments = (ins PDL_PositionalValue:$value); + let arguments = (ins PDL_AnyType:$value); let assemblyFormat = "$value `:` type($value) attr-dict `->` successors"; } @@ -716,7 +715,7 @@ ``` }]; - let arguments = (ins Variadic:$inputs, + let arguments = (ins Variadic:$inputs, Variadic:$matchedOps, SymbolRefAttr:$rewriter, OptionalAttr:$rootKind, diff --git a/mlir/lib/Dialect/PDL/IR/CMakeLists.txt b/mlir/lib/Dialect/PDL/IR/CMakeLists.txt --- a/mlir/lib/Dialect/PDL/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/PDL/IR/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRPDL PDL.cpp + PDLTypes.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/PDL diff --git a/mlir/lib/Dialect/PDL/IR/PDL.cpp b/mlir/lib/Dialect/PDL/IR/PDL.cpp --- a/mlir/lib/Dialect/PDL/IR/PDL.cpp +++ b/mlir/lib/Dialect/PDL/IR/PDL.cpp @@ -10,10 +10,8 @@ #include "mlir/Dialect/PDL/IR/PDLOps.h" #include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/DialectImplementation.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "llvm/ADT/StringSwitch.h" -#include "llvm/ADT/TypeSwitch.h" using namespace mlir; using namespace mlir::pdl; @@ -430,27 +428,3 @@ #define GET_OP_CLASSES #include "mlir/Dialect/PDL/IR/PDLOps.cpp.inc" - -//===----------------------------------------------------------------------===// -// TableGen'd type method definitions -//===----------------------------------------------------------------------===// - -#define GET_TYPEDEF_CLASSES -#include "mlir/Dialect/PDL/IR/PDLOpsTypes.cpp.inc" - -Type PDLDialect::parseType(DialectAsmParser &parser) const { - StringRef keyword; - if (parser.parseKeyword(&keyword)) - return Type(); - if (Type type = generatedTypeParser(getContext(), parser, keyword)) - return type; - - parser.emitError(parser.getNameLoc(), "invalid 'pdl' type: `") - << keyword << "'"; - return Type(); -} - -void PDLDialect::printType(Type type, DialectAsmPrinter &printer) const { - if (failed(generatedTypePrinter(type, printer))) - llvm_unreachable("unknown 'pdl' type"); -} diff --git a/mlir/lib/Dialect/PDL/IR/PDLTypes.cpp b/mlir/lib/Dialect/PDL/IR/PDLTypes.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/PDL/IR/PDLTypes.cpp @@ -0,0 +1,100 @@ +//===- PDLTypes.cpp - Pattern Descriptor Language Types -------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/PDL/IR/PDLTypes.h" +#include "mlir/Dialect/PDL/IR/PDL.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/DialectImplementation.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; +using namespace mlir::pdl; + +//===----------------------------------------------------------------------===// +// TableGen'd type method definitions +//===----------------------------------------------------------------------===// + +#define GET_TYPEDEF_CLASSES +#include "mlir/Dialect/PDL/IR/PDLOpsTypes.cpp.inc" + +//===----------------------------------------------------------------------===// +// PDLDialect +//===----------------------------------------------------------------------===// + +static Type parsePDLType(DialectAsmParser &parser) { + StringRef keyword; + if (parser.parseKeyword(&keyword)) + return Type(); + if (Type type = generatedTypeParser(parser.getBuilder().getContext(), parser, + keyword)) + return type; + + // FIXME: This ends up with a double error being emitted if `RangeType` also + // emits an error. We should rework the `generatedTypeParser` to better + // support when the keyword is valid but the individual type parser itself + // emits an error. + parser.emitError(parser.getNameLoc(), "invalid 'pdl' type: `") + << keyword << "'"; + return Type(); +} + +Type PDLDialect::parseType(DialectAsmParser &parser) const { + return parsePDLType(parser); +} + +void PDLDialect::printType(Type type, DialectAsmPrinter &printer) const { + if (failed(generatedTypePrinter(type, printer))) + llvm_unreachable("unknown 'pdl' type"); +} + +//===----------------------------------------------------------------------===// +// PDL Types +//===----------------------------------------------------------------------===// + +bool PDLType::classof(Type type) { + return llvm::isa(type.getDialect()); +} + +//===----------------------------------------------------------------------===// +// RangeType +//===----------------------------------------------------------------------===// + +Type RangeType::parse(MLIRContext *context, DialectAsmParser &parser) { + if (parser.parseLess()) + return Type(); + + llvm::SMLoc elementLoc = parser.getCurrentLocation(); + Type elementType = parsePDLType(parser); + if (!elementType || parser.parseGreater()) + return Type(); + + if (elementType.isa()) { + parser.emitError(elementLoc) + << "element of pdl.range cannot be another range, but got" + << elementType; + return Type(); + } + return RangeType::get(elementType); +} + +void RangeType::print(DialectAsmPrinter &printer) const { + printer << "range<"; + generatedTypePrinter(getElementType(), printer); + printer << ">"; +} + +LogicalResult RangeType::verifyConstructionInvariants(Location loc, + Type elementType) { + if (!elementType.isa() || elementType.isa()) { + return emitError(loc) + << "expected element of pdl.range to be one of [!pdl.attribute, " + "!pdl.operation, !pdl.type, !pdl.value], but got " + << elementType; + } + return success(); +} diff --git a/mlir/test/Dialect/PDL/invalid-types.mlir b/mlir/test/Dialect/PDL/invalid-types.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/PDL/invalid-types.mlir @@ -0,0 +1,9 @@ +// RUN: mlir-opt %s -split-input-file -verify-diagnostics + +//===----------------------------------------------------------------------===// +// pdl::RangeType +//===----------------------------------------------------------------------===// + +// expected-error@+2 {{element of pdl.range cannot be another range, but got'!pdl.range'}} +// expected-error@+1 {{invalid 'pdl' type}} +#invalid_element = !pdl.range>