diff --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt --- a/mlir/include/mlir/Dialect/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/CMakeLists.txt @@ -4,6 +4,7 @@ add_subdirectory(Linalg) add_subdirectory(LLVMIR) add_subdirectory(OpenMP) +add_subdirectory(PDL) add_subdirectory(Quant) add_subdirectory(SCF) add_subdirectory(Shape) diff --git a/mlir/include/mlir/Dialect/PDL/CMakeLists.txt b/mlir/include/mlir/Dialect/PDL/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/PDL/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/mlir/include/mlir/Dialect/PDL/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/PDL/IR/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/PDL/IR/CMakeLists.txt @@ -0,0 +1,2 @@ +add_mlir_dialect(PDLOps pdl) +add_mlir_doc(PDLOps -gen-op-doc PDLOps Dialects/) diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDL.h b/mlir/include/mlir/Dialect/PDL/IR/PDL.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/PDL/IR/PDL.h @@ -0,0 +1,40 @@ +//===- PDL.h - Pattern Descriptor Language Dialect --------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the dialect for the Pattern Descriptor Language. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_PDL_IR_PDL_H_ +#define MLIR_DIALECT_PDL_IR_PDL_H_ + +#include "mlir/IR/Builders.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +namespace mlir { +namespace pdl { +//===----------------------------------------------------------------------===// +// PDL Dialect +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/PDL/IR/PDLOpsDialect.h.inc" + +//===----------------------------------------------------------------------===// +// PDL Dialect Operations +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "mlir/Dialect/PDL/IR/PDLOps.h.inc" + +} // end namespace pdl +} // end namespace mlir + +#endif // MLIR_DIALECT_PDL_IR_PDL_H_ diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDLBase.td b/mlir/include/mlir/Dialect/PDL/IR/PDLBase.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/PDL/IR/PDLBase.td @@ -0,0 +1,96 @@ +//===- PDLBase.td - PDL base definitions -------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines base support for MLIR PDL operations. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_PDL_IR_PDLBASE +#define MLIR_DIALECT_PDL_IR_PDLBASE + +include "mlir/IR/OpBase.td" + +//===----------------------------------------------------------------------===// +// PDL Dialect +//===----------------------------------------------------------------------===// + +def PDL_Dialect : Dialect { + string summary = "High level pattern definition dialect"; + string description = [{ + PDL presents a high level abstraction for the rewrite pattern infrastructure + available in MLIR. This abstraction allows for representing patterns + transforming MLIR, as MLIR. This allows for applying all of the benefits + that the general MLIR infrastructure provides, to the infrastructure itself. + This means that pattern matching can be more easily verified for + correctness, targeted by frontends, and optimized. + + PDL abstracts over various different aspects of patterns and core MLIR data + structures. Patterns are specified via a `pdl.pattern` operation. These + operations contain a region body for the "matcher" code, and terminate with + a `pdl.rewrite` that either dispatches to an external rewriter or contains + a region for the rewrite specified via `pdl`. The types of values in `pdl` + are handle types to MLIR C++ types, with `!pdl.attribute`, `!pdl.operation`, + and `!pdl.type` directly mapping to `mlir::Attribute`, `mlir::Operation*`, + and `mlir::Value` respectively. + + An example pattern is shown below: + + ```mlir + // pdl.pattern contains metadata similarly to a `RewritePattern`. + pdl.pattern : benefit(1) { + // External input operand values are specified via `pdl.input` operations. + // Result types are constrainted via `pdl.type` operations. + + %resultType = pdl.type + %inputOperand = pdl.input + %root, %results = pdl.operation "foo.op"(%inputOperand) -> %resultType + pdl.rewrite(%root) { + pdl.replace %root with (%inputOperand) + } + } + ``` + + The above pattern simply replaces an operation with its first operand. Note + how the input operation is specified structurally, similarly to how it would + look in memory. This is a simple example and pdl provides support for many + other features such as applying external constraints or external generator + methods. These features and more are detailed below. + }]; + + let name = "pdl"; + let cppNamespace = "mlir::pdl"; +} + +//===----------------------------------------------------------------------===// +// PDL Types +//===----------------------------------------------------------------------===// + +class PDL_Handle : + DialectType()">, + underlying>, + BuildableType<"$_builder.getType<" # underlying # ">()">; + +// Handle for `mlir::Attribute`. +def PDL_Attribute : PDL_Handle<"mlir::pdl::AttributeType">; + +// Handle for `mlir::Operation*`. +def PDL_Operation : PDL_Handle<"mlir::pdl::OperationType">; + +// Handle for `mlir::Type`. +def PDL_Type : PDL_Handle<"mlir::pdl::TypeType">; + +// Handle for `mlir::Value`. +def PDL_Value : PDL_Handle<"mlir::pdl::ValueType">; + +// A positional value is a location on a pattern DAG, which may be an operation, +// an attribute, or an operand/result. +def PDL_PositionalValue : + AnyTypeOf<[PDL_Attribute, PDL_Operation, PDL_Type, PDL_Value], + "Positional Value">; + +#endif // MLIR_DIALECT_PDL_IR_PDLBASE diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/PDL/IR/PDLOps.td @@ -0,0 +1,448 @@ +//===- PDLOps.td - Pattern descriptor operations -----------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the Pattern Descriptor Language dialect operations. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_PDL_IR_PDLOPS +#define MLIR_DIALECT_PDL_IR_PDLOPS + +include "mlir/Dialect/PDL/IR/PDLBase.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/SymbolInterfaces.td" + +//===----------------------------------------------------------------------===// +// PDL Ops +//===----------------------------------------------------------------------===// + +class PDL_Op traits = []> + : Op { + let printer = [{ ::print(p, *this); }]; + let parser = [{ return ::parse$cppClass(parser, result); }]; + let verifier = [{ return ::verify(*this); }]; +} + +//===----------------------------------------------------------------------===// +// pdl::ApplyConstraintOp +//===----------------------------------------------------------------------===// + +def PDL_ApplyConstraintOp + : PDL_Op<"apply_constraint", [HasParent<"pdl::PatternOp">]> { + let summary = "Apply a generic constraint to a set of provided entities"; + let description = [{ + `apply_constraint` operations apply a generic constraint, that has been + registered externally with the consumer of PDL, to a given set of entities. + The constraint is permitted to accept any number of constant valued + parameters. + + Example: + + ```mlir + // Apply `myConstraint` to the entities defined by `input`, `attr`, and + // `op`. `42`, `"abc"`, and `i32` are constant parameters passed to the + // constraint. + pdl.apply_constraint "myConstraint"[42, "abc", i32](%input, %attr, %op : !pdl.value, !pdl.attribute, !pdl.operation) + ``` + }]; + + let arguments = (ins Variadic:$args, + ArrayAttr:$params, + StrAttr:$name); + let assemblyFormat = "$name $params `(` $args `:` type($args) `)` attr-dict"; + + let builders = [ + OpBuilder<"OpBuilder &builder, OperationState &state, " + "ValueRange args, ArrayRef params, " + "StringRef name", [{ + build(builder, state, args, builder.getArrayAttr(params), + builder.getStringAttr(name)); + }]>, + ]; +} + +//===----------------------------------------------------------------------===// +// pdl::AttributeOp +//===----------------------------------------------------------------------===// + +def PDL_AttributeOp : PDL_Op<"attribute"> { + let summary = "Define an input attribute in a pattern"; + let description = [{ + `pdl.attribute` operations capture named attribute edges into an operation. + Instances of this operation define, and partially constrain, attributes of a + given operation. A `pdl.attribute` may partially constrain the input by + specifying an expected attribute value type (via a `pdl.type` operation), or + a constant value for the attribute (via `val`). Only one of these may be set + for a given input, as the type of the constant value provides the type. When + defined within a `pdl.rewrite` region, the constant value must be specified. + + Example: + + ```mlir + // Define an attribute: + %attr = pdl.attribute + + // Define an attribute with an expected type: + %type = pdl.type : i32 + %attr = pdl.attribute : %type + + // Define an attribute with a constant value: + %attr = pdl.attribute "hello" + ``` + }]; + + let arguments = (ins Optional:$type, + OptionalAttr:$value); + let results = (outs PDL_Attribute:$attr); + let assemblyFormat = "attr-dict (`:` $type^)? ($value^)?"; + + let builders = [ + OpBuilder<"OpBuilder &builder, OperationState &state, " + "Value type = Value()", [{ + build(builder, state, builder.getType(), type, + Attribute()); + }]>, + OpBuilder<"OpBuilder &builder, OperationState &state, Attribute attr", [{ + build(builder, state, builder.getType(), Value(), attr); + }]>, + ]; +} + +//===----------------------------------------------------------------------===// +// pdl::EraseOp +//===----------------------------------------------------------------------===// + +def PDL_EraseOp : PDL_Op<"erase", [HasParent<"pdl::RewriteOp">]> { + let summary = "Mark an input operation as `erased`"; + let description = [{ + `pdl.erase` operations are used within `pdl.rewrite` regions to specify that + an input operation should be marked as erased. The semantics of this + operation correspond with the `eraseOp` method on a `PatternRewriter`. + + Example: + + ```mlir + pdl.erase %root + ``` + }]; + let arguments = (ins PDL_Operation:$operation); + let assemblyFormat = "$operation attr-dict"; + let verifier = ?; +} + +//===----------------------------------------------------------------------===// +// pdl::InputOp +//===----------------------------------------------------------------------===// + +def PDL_InputOp : PDL_Op<"input", [HasParent<"pdl::PatternOp">]> { + let summary = "Define an input value in a pattern"; + let description = [{ + `pdl.input` operations capture external operand edges into an operation + node that originate from operations or block arguments not otherwise + specified within the pattern (e.g. via `pdl.operation`). These operations + define, and partially constrain, input operands of a given operation. + A `pdl.input` may partially constrain an input operand by specifying an + expected value type (via a `pdl.type` operation). + + Example: + + ```mlir + // Define an input operand: + %operand = pdl.input + + // Define an input operand with an expected type: + %type = pdl.type : i32 + %attr = pdl.input : %type + ``` + }]; + + let arguments = (ins Optional:$type); + let results = (outs PDL_Value:$val); + let assemblyFormat = "(`:` $type^)? attr-dict"; + + let builders = [ + OpBuilder<"OpBuilder &builder, OperationState &state", [{ + build(builder, state, builder.getType(), Value()); + }]>, + ]; +} + +//===----------------------------------------------------------------------===// +// pdl::OperationOp +//===----------------------------------------------------------------------===// + +def PDL_OperationOp + : PDL_Op<"operation", [AttrSizedOperandSegments, NoSideEffect]> { + let summary = "Define an operation within a pattern"; + let description = [{ + `pdl.operation` operations define operation nodes within a pattern. Within + a match sequence, i.e. when directly nested within a `pdl.pattern`, these + operations correspond to input operations, or those that already existing + within the MLIR module. Inside of a `pdl.rewrite`, these operations + correspond to operations that should be created as part of the replacement + sequence. + + `pdl.operation`s are composed of a name, and a set of attribute, operand, + and result type values, that map to what those that would be on a + constructed instance of that operation. The results of a `pdl.operation` are + a handle to the operation itself, and a handle to each of the operation + result values. + + When used within a matching context, the name of the operation may be + omitted. + + When used within a rewriting context, i.e. when defined within a + `pdl.rewrite`, all of the result types must be "inferrable". This means that + the type must be attributable to either a constant type value or the result + type of another entity, such as an attribute, the result of a render, or the + result type of another operation. If the result type value does not meet any + of these criteria, the operation must provide the `InferTypeOpInterface` to + ensure that the result types can be inferred. + + Example: + + ```mlir + // Define an instance of a `foo.op` operation. + %op, %results:4 = pdl.operation "foo.op"(%arg0, %arg1) {"attrA" = %attr0} -> %type, %type, %type, %type + ``` + }]; + + let arguments = (ins OptionalAttr:$name, + Variadic:$operands, + Variadic:$attributes, + StrArrayAttr:$attributeNames, + Variadic:$types); + let results = (outs PDL_Operation:$op, + Variadic:$results); + let builders = [ + OpBuilder<"OpBuilder &builder, OperationState &state, " + "Optional name = llvm::None, " + "ValueRange operandValues = llvm::None, " + "ArrayRef attrNames = llvm::None, " + "ValueRange attrValues = llvm::None, " + "ValueRange resultTypes = llvm::None", [{ + auto nameAttr = name ? StringAttr() : builder.getStringAttr(*name); + build(builder, state, builder.getType(), {}, nameAttr, + operandValues, attrValues, builder.getStrArrayAttr(attrNames), + resultTypes); + state.types.append(resultTypes.size(), builder.getType()); + }]>, + ]; + let extraClassDeclaration = [{ + /// Returns true if the operation type referenced supports result type + /// inference. + bool hasTypeInference(); + }]; +} + +//===----------------------------------------------------------------------===// +// pdl::PatternOp +//===----------------------------------------------------------------------===// + +def PDL_PatternOp : PDL_Op<"pattern", [IsolatedFromAbove, Symbol]> { + let summary = "Define a rewrite pattern"; + let description = [{ + `pdl.pattern` operations provide a transformable representation for a + `RewritePattern`. The attributes on this operation correspond to the various + metadata on a `RewritePattern`, such as the benefit. The match section of + the pattern is specified within the region body, with the rewrite provided + by a terminating `pdl.rewrite`. + + Example: + + ```mlir + // Provide a pattern matching "foo.op" that replaces the root with its + // input. + pdl.pattern : benefit(1) { + %resultType = pdl.type + %inputOperand = pdl.input + %root, %results = pdl.operation "foo.op"(%inputOperand) -> (%resultType) + pdl.rewrite(%root) { + pdl.replace %root with (%inputOperand) + } + } + ``` + }]; + + let arguments = (ins OptionalAttr:$rootKind, + Confined:$benefit, + OptionalAttr:$sym_name); + + let regions = (region SizedRegion<1>:$body); + let builders = [ + OpBuilder<"OpBuilder &builder, OperationState &state, " + "Optional rootKind = llvm::None, " + "Optional benefit = 1, " + "Optional name = llvm::None">, + ]; + let extraClassDeclaration = [{ + //===------------------------------------------------------------------===// + // SymbolOpInterface Methods + //===------------------------------------------------------------------===// + + /// A PatternOp may optionally define a symbol. + bool isOptionalSymbol() { return true; } + + /// Returns the rewrite operation of this pattern. + RewriteOp getRewriter(); + + /// Return the root operation kind that this pattern matches, or None if + /// there isn't a specific root. + Optional getRootKind(); + }]; +} + +//===----------------------------------------------------------------------===// +// pdl::RenderOp +//===----------------------------------------------------------------------===// + +def PDL_RenderOp : PDL_Op<"render", [HasParent<"pdl::RewriteOp">]> { + let summary = "Call a custom renderer to create an `Attribute`, `Operation`, " + "`Type`, or `Value`"; + let description = [{ + `pdl.render` operations invoke a render function, that has been registered + externally with the consumer of PDL, to create an `Attribute`, `Operation`, + `Type`, or `Value`. The renderer function must produce a value of the + specified return type, and may accept any number of arguments and constant + attribute parameters. + + Example: + + ```mlir + %ret = pdl.render "myRenderer"[42, "gt"](%arg0, %arg1) : !pdl.attribute + ``` + }]; + + let arguments = (ins StrAttr:$name, Variadic:$arguments, + ArrayAttr:$constantParams); + let results = (outs PDL_PositionalValue:$result); + let assemblyFormat = [{ + $name $constantParams (`(` $arguments^ `:` type($arguments) `)`)? + `:` type($result) attr-dict + }]; + let verifier = ?; +} + +//===----------------------------------------------------------------------===// +// pdl::ReplaceOp +//===----------------------------------------------------------------------===// + +def PDL_ReplaceOp : PDL_Op<"replace", [ + AttrSizedOperandSegments, HasParent<"pdl::RewriteOp"> + ]> { + let summary = "Mark an input operation as `replaced`"; + let description = [{ + `pdl.replace` operations are used within `pdl.rewrite` regions to specify + that an input operation should be marked as replaced. The semantics of this + operation correspond with the `replaceOp` method on a `PatternRewriter`. The + set of replacement values can be either: + * a single `Operation` (`replOperation` should be populated) + - The operation will be replaced with the results of this operation. + * a set of `Value`s (`replValues` should be populated) + - The operation will be replaced with these values. + + Example: + + ```mlir + // Replace root node with 2 values: + pdl.replace %root with (%val0, %val1) + + // Replace root with another operation: + pdl.replace %root with %otherOp + ``` + }]; + let arguments = (ins PDL_Operation:$operation, + Optional:$replOperation, + Variadic:$replValues); + let assemblyFormat = [{ + $operation `with` (`(` $replValues^ `)`)? ($replOperation^)? attr-dict + }]; +} + +//===----------------------------------------------------------------------===// +// pdl::RewriteOp +//===----------------------------------------------------------------------===// + +def PDL_RewriteOp : PDL_Op<"rewrite", [ + Terminator, HasParent<"pdl::PatternOp">, + SingleBlockImplicitTerminator<"pdl::RewriteEndOp"> + ]> { + let summary = "Specify the rewrite of a matched pattern"; + let description = [{ + `pdl.rewrite` operations terminate the region of a `pdl.pattern` and specify + the rewrite of a `pdl.pattern`, on the specified root operation. The + rewrite is specified either via a string name (`name`) to an external + rewrite function, or via the region body. The rewrite region, if specified, + must contain a single block and terminate via the `pdl.rewrite_end` + operation. + + Example: + + ```mlir + // Specify an external rewrite function: + pdl.rewrite "myExternalRewriter"(%root) + + // Specify the rewrite inline using PDL: + pdl.rewrite(%root) { + %op = pdl.operation "foo.op"(%arg0, %arg1) + pdl.replace %root with %op + } + ``` + }]; + + let arguments = (ins PDL_Operation:$root, + OptionalAttr:$name); + let regions = (region AnyRegion:$body); +} + +def PDL_RewriteEndOp : PDL_Op<"rewrite_end", [Terminator, + HasParent<"pdl::RewriteOp">]> { + let summary = "Implicit terminator of a `pdl.rewrite` region"; + let description = [{ + `pdl.rewrite_end` operations terminate the region of a `pdl.rewrite`. + }]; + let assemblyFormat = "attr-dict"; + let verifier = ?; +} + +//===----------------------------------------------------------------------===// +// pdl::TypeOp +//===----------------------------------------------------------------------===// + +def PDL_TypeOp : PDL_Op<"type"> { + let summary = "Define a type handle within a pattern"; + let description = [{ + `pdl.type` operations capture result type constraints of an `Attributes`, + `Values`, and `Operations`. Instances of this operation define, and + partially constrain, results types of a given entity. A `pdl.type` may + partially constrain the result by specifying a constant `Type`. + + Example: + + ```mlir + // Define a type: + %attr = pdl.type + + // Define a type with a constant value: + %attr = pdl.type : i32 + ``` + }]; + + let arguments = (ins OptionalAttr:$type); + let results = (outs PDL_Type:$result); + let assemblyFormat = "attr-dict (`:` $type^)?"; + + let builders = [ + OpBuilder<"OpBuilder &builder, OperationState &state, Type ty = Type()", [{ + build(builder, state, builder.getType(), + ty ? TypeAttr::get(ty) : TypeAttr()); + }]>, + ]; +} + +#endif // MLIR_DIALECT_PDL_IR_PDLOPS diff --git a/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.h b/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/PDL/IR/PDLTypes.h @@ -0,0 +1,70 @@ +//===- PDL.h - Pattern Descriptor Language Types ----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the types for the Pattern Descriptor Language dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_PDL_IR_PDLTYPES_H_ +#define MLIR_DIALECT_PDL_IR_PDLTYPES_H_ + +#include "mlir/IR/Types.h" + +namespace mlir { +namespace pdl { +//===----------------------------------------------------------------------===// +// PDL Dialect Types +//===----------------------------------------------------------------------===// + +namespace PDLTypes { +enum Kind : unsigned { + Attribute = Type::FIRST_PDL_TYPE, + Operation, + Type, + Value +}; +} // end namespace PDLTypes + +/// A base type used to define the necessary getters/setters for the types that +/// have a single kind. +template +struct PDLTypeBase : public Type::TypeBase { + using BaseT = Type::TypeBase; + using Base = PDLTypeBase; + + using BaseT::BaseT; + + static TypeT get(MLIRContext *ctx) { return BaseT::get(ctx, K); } + static bool kindof(unsigned kind) { return kind == getKind(); } + static constexpr PDLTypes::Kind getKind() { return K; } +}; + +/// This type represents a handle to an `mlir::Attribute`. +struct AttributeType : public PDLTypeBase { + using Base::Base; +}; + +/// This type represents a handle to an `mlir::Operation*`. +struct OperationType : public PDLTypeBase { + using Base::Base; +}; + +/// This type represents a handle to an `mlir::Type`. +struct TypeType : public PDLTypeBase { + using Base::Base; +}; + +/// This type represents a handle to an `mlir::Value`. +struct ValueType : public PDLTypeBase { + using Base::Base; +}; + +} // end namespace pdl +} // end namespace mlir + +#endif // MLIR_DIALECT_PDL_IR_PDLTYPES_H_ diff --git a/mlir/include/mlir/IR/DialectSymbolRegistry.def b/mlir/include/mlir/IR/DialectSymbolRegistry.def --- a/mlir/include/mlir/IR/DialectSymbolRegistry.def +++ b/mlir/include/mlir/IR/DialectSymbolRegistry.def @@ -26,6 +26,7 @@ DEFINE_SYM_KIND_RANGE(XLA_HLO) // XLA HLO dialect DEFINE_SYM_KIND_RANGE(SHAPE) // Shape dialect DEFINE_SYM_KIND_RANGE(TF_FRAMEWORK) // TF Framework dialect +DEFINE_SYM_KIND_RANGE(PDL) // Pattern Descriptor Dialect // The following ranges are reserved for experimenting with MLIR dialects in a // private context without having to register them here. diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -289,6 +289,18 @@ /// Parse a '->' token if present virtual ParseResult parseOptionalArrow() = 0; + /// Parse a `{` token. + virtual ParseResult parseLBrace() = 0; + + /// Parse a `{` token if present. + virtual ParseResult parseOptionalLBrace() = 0; + + /// Parse a `}` token. + virtual ParseResult parseRBrace() = 0; + + /// Parse a `}` token if present. + virtual ParseResult parseOptionalRBrace() = 0; + /// Parse a `:` token. virtual ParseResult parseColon() = 0; @@ -304,6 +316,9 @@ /// Parse a `=` token. virtual ParseResult parseEqual() = 0; + /// Parse a `=` token if present. + virtual ParseResult parseOptionalEqual() = 0; + /// Parse a '<' token. virtual ParseResult parseLess() = 0; @@ -344,6 +359,9 @@ /// Parse a `)` token if present. virtual ParseResult parseOptionalRParen() = 0; + /// Parses a '?' if present. + virtual ParseResult parseOptionalQuestion() = 0; + /// Parse a `[` token. virtual ParseResult parseLSquare() = 0; @@ -363,6 +381,26 @@ // Attribute Parsing //===--------------------------------------------------------------------===// + /// Parse an arbitrary attribute of a given type and return it in result. + virtual ParseResult parseAttribute(Attribute &result, Type type = {}) = 0; + + /// Parse an attribute of a specific kind and type. + template + ParseResult parseAttribute(AttrType &result, Type type = {}) { + llvm::SMLoc loc = getCurrentLocation(); + + // Parse any kind of attribute. + Attribute attr; + if (parseAttribute(attr)) + return failure(); + + // Check for the right kind of attribute. + if (!(result = attr.dyn_cast())) + return emitError(loc, "invalid kind of attribute specified"); + + return success(); + } + /// Parse an arbitrary attribute and return it in result. This also adds the /// attribute to the specified attribute list with the specified name. ParseResult parseAttribute(Attribute &result, StringRef attrName, @@ -377,13 +415,6 @@ return parseAttribute(result, Type(), attrName, attrs); } - /// Parse an arbitrary attribute of a given type and return it in result. This - /// also adds the attribute to the specified attribute list with the specified - /// name. - virtual ParseResult parseAttribute(Attribute &result, Type type, - StringRef attrName, - NamedAttrList &attrs) = 0; - /// Parse an optional attribute. virtual OptionalParseResult parseOptionalAttribute(Attribute &result, Type type, @@ -395,7 +426,9 @@ return parseOptionalAttribute(result, Type(), attrName, attrs); } - /// Parse an attribute of a specific kind and type. + /// Parse an arbitrary attribute of a given type and return it in result. This + /// also adds the attribute to the specified attribute list with the specified + /// name. template ParseResult parseAttribute(AttrType &result, Type type, StringRef attrName, NamedAttrList &attrs) { @@ -403,7 +436,7 @@ // Parse any kind of attribute. Attribute attr; - if (parseAttribute(attr, type, attrName, attrs)) + if (parseAttribute(attr, type)) return failure(); // Check for the right kind of attribute. @@ -411,6 +444,7 @@ if (!result) return emitError(loc, "invalid kind of attribute specified"); + attrs.append(attrName, result); return success(); } diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -23,6 +23,7 @@ #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" +#include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/Quant/QuantOps.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/SDBM/SDBMDialect.h" @@ -44,6 +45,7 @@ registerDialect(); registerDialect(); registerDialect(); + registerDialect(); registerDialect(); registerDialect(); registerDialect(); diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt --- a/mlir/lib/Dialect/CMakeLists.txt +++ b/mlir/lib/Dialect/CMakeLists.txt @@ -4,6 +4,7 @@ add_subdirectory(Linalg) add_subdirectory(LLVMIR) add_subdirectory(OpenMP) +add_subdirectory(PDL) add_subdirectory(Quant) add_subdirectory(SCF) add_subdirectory(SDBM) diff --git a/mlir/lib/Dialect/PDL/CMakeLists.txt b/mlir/lib/Dialect/PDL/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/PDL/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/mlir/lib/Dialect/PDL/IR/CMakeLists.txt b/mlir/lib/Dialect/PDL/IR/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/PDL/IR/CMakeLists.txt @@ -0,0 +1,14 @@ +add_mlir_dialect_library(MLIRPDL + PDL.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/PDL + + DEPENDS + MLIRPDLOpsIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRInferTypeOpInterface + MLIRSideEffectInterfaces + ) diff --git a/mlir/lib/Dialect/PDL/IR/PDL.cpp b/mlir/lib/Dialect/PDL/IR/PDL.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/PDL/IR/PDL.cpp @@ -0,0 +1,514 @@ +//===- PDL.cpp - Pattern Descriptor Language Dialect ----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/PDL/IR/PDL.h" +#include "mlir/Dialect/PDL/IR/PDLTypes.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "llvm/ADT/StringSwitch.h" + +using namespace mlir; +using namespace mlir::pdl; + +//===----------------------------------------------------------------------===// +// PDLDialect +//===----------------------------------------------------------------------===// + +void PDLDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "mlir/Dialect/PDL/IR/PDLOps.cpp.inc" + >(); + addTypes(); +} + +Type PDLDialect::parseType(DialectAsmParser &parser) const { + StringRef keyword; + if (parser.parseKeyword(&keyword)) + return Type(); + + Builder &builder = parser.getBuilder(); + Type result = llvm::StringSwitch(keyword) + .Case("attribute", builder.getType()) + .Case("operation", builder.getType()) + .Case("type", builder.getType()) + .Case("value", builder.getType()) + .Default(Type()); + if (!result) + parser.emitError(parser.getNameLoc(), "invalid 'pdl' type: `") + << keyword << "'"; + return result; +} + +void PDLDialect::printType(Type type, DialectAsmPrinter &printer) const { + if (type.isa()) + printer << "attribute"; + else if (type.isa()) + printer << "operation"; + else if (type.isa()) + printer << "type"; + else if (type.isa()) + printer << "value"; + else + llvm_unreachable("unknown 'pdl' type"); +} + +/// Returns true if the given operation is used by a "binding" pdl operation +/// within the main matcher body of a `pdl.pattern`. +static LogicalResult +verifyHasBindingUseInMatcher(Operation *op, + StringRef bindableContextStr = "`pdl.operation`") { + // If the pattern is not a pattern, there is nothing to do. + if (!isa(op->getParentOp())) + return success(); + Block *matcherBlock = op->getBlock(); + for (Operation *user : op->getUsers()) { + if (user->getBlock() != matcherBlock) + continue; + if (isa(user)) + return success(); + } + return op->emitOpError() + << "expected a bindable (i.e. " << bindableContextStr + << ") user when defined in the matcher body of a `pdl.pattern`"; +} + +//===----------------------------------------------------------------------===// +// pdl::ApplyConstraintOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(ApplyConstraintOp op) { + if (op.getNumOperands() == 0) + return op.emitOpError("expected at least one argument"); + return success(); +} + +//===----------------------------------------------------------------------===// +// pdl::AttributeOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(AttributeOp op) { + Value attrType = op.type(); + Optional attrValue = op.value(); + + if (!attrValue && isa(op.getParentOp())) + return op.emitOpError("expected constant value when specified within a " + "`pdl.rewrite`"); + if (attrValue && attrType) + return op.emitOpError("expected only one of [`type`, `value`] to be set"); + return verifyHasBindingUseInMatcher(op); +} + +//===----------------------------------------------------------------------===// +// pdl::InputOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(InputOp op) { + return verifyHasBindingUseInMatcher(op); +} + +//===----------------------------------------------------------------------===// +// pdl::OperationOp +//===----------------------------------------------------------------------===// + +static ParseResult parseOperationOp(OpAsmParser &p, OperationState &state) { + Builder &builder = p.getBuilder(); + + // Parse the optional operation name. + bool startsWithOperands = succeeded(p.parseOptionalLParen()); + bool startsWithAttributes = + !startsWithOperands && succeeded(p.parseOptionalLBrace()); + bool startsWithOpName = false; + if (!startsWithAttributes && !startsWithOperands) { + StringAttr opName; + OptionalParseResult opNameResult = + p.parseOptionalAttribute(opName, "name", state.attributes); + startsWithOpName = opNameResult.hasValue(); + if (startsWithOpName && failed(*opNameResult)) + return failure(); + } + + // Parse the operands. + SmallVector operands; + if (startsWithOperands || + (!startsWithAttributes && succeeded(p.parseOptionalLParen()))) { + if (p.parseOperandList(operands) || p.parseRParen() || + p.resolveOperands(operands, builder.getType(), + state.operands)) + return failure(); + } + + // Parse the attributes. + SmallVector attrNames; + if (startsWithAttributes || succeeded(p.parseOptionalLBrace())) { + SmallVector attrOps; + do { + StringAttr nameAttr; + OpAsmParser::OperandType operand; + if (p.parseAttribute(nameAttr) || p.parseEqual() || + p.parseOperand(operand)) + return failure(); + attrNames.push_back(nameAttr); + attrOps.push_back(operand); + } while (succeeded(p.parseOptionalComma())); + + if (p.parseRBrace() || + p.resolveOperands(attrOps, builder.getType(), + state.operands)) + return failure(); + } + state.addAttribute("attributeNames", builder.getArrayAttr(attrNames)); + state.addTypes(builder.getType()); + + // Parse the result types. + SmallVector opResultTypes; + if (succeeded(p.parseOptionalArrow())) { + if (p.parseOperandList(opResultTypes) || + p.resolveOperands(opResultTypes, builder.getType(), + state.operands)) + return failure(); + state.types.append(opResultTypes.size(), builder.getType()); + } + + if (p.parseOptionalAttrDict(state.attributes)) + return failure(); + + int32_t operandSegmentSizes[] = {static_cast(operands.size()), + static_cast(attrNames.size()), + static_cast(opResultTypes.size())}; + state.addAttribute("operand_segment_sizes", + builder.getI32VectorAttr(operandSegmentSizes)); + return success(); +} + +static void print(OpAsmPrinter &p, OperationOp op) { + p << "pdl.operation "; + if (Optional name = op.name()) + p << '"' << *name << '"'; + + auto operandValues = op.operands(); + if (!operandValues.empty()) + p << '(' << operandValues << ')'; + + // Emit the optional attributes. + ArrayAttr attrNames = op.attributeNames(); + if (!attrNames.empty()) { + Operation::operand_range attrArgs = op.attributes(); + p << " {"; + interleaveComma(llvm::seq(0, attrNames.size()), p, + [&](int i) { p << attrNames[i] << " = " << attrArgs[i]; }); + p << '}'; + } + + // Print the result type constraints of the operation. + if (!op.results().empty()) + p << " -> " << op.types(); + p.printOptionalAttrDict(op.getAttrs(), + {"attributeNames", "name", "operand_segment_sizes"}); +} + +/// Verifies that the result types of this operation, defined within a +/// `pdl.rewrite`, can be inferred. +static LogicalResult verifyResultTypesAreInferrable(OperationOp op, + ResultRange opResults, + OperandRange resultTypes) { + // Functor that returns if the given use can be used to infer a type. + Block *rewriterBlock = op.getOperation()->getBlock(); + auto canInferTypeFromUse = [&](OpOperand &use) { + // If the use is within a ReplaceOp and isn't the operation being replaced + // (i.e. is not the first operand of the replacement), we can infer a type. + ReplaceOp replOpUser = dyn_cast(use.getOwner()); + if (!replOpUser || use.getOperandNumber() == 0) + return false; + // Make sure the replaced operation was defined before this one. + Operation *replacedOp = replOpUser.operation().getDefiningOp(); + return replacedOp->getBlock() != rewriterBlock || + replacedOp->isBeforeInBlock(op); + }; + + // Check to see if the uses of the operation itself can be used to infer + // types. + if (llvm::any_of(op.op().getUses(), canInferTypeFromUse)) + return success(); + + // Otherwise, make sure each of the types can be inferred. + for (int i : llvm::seq(0, opResults.size())) { + Operation *resultTypeOp = resultTypes[i].getDefiningOp(); + assert(resultTypeOp && "expected valid result type operation"); + + // If the op was defined by a render, it is guaranteed to be usable. + if (isa(resultTypeOp)) + continue; + + // If the type is already constrained, there is nothing to do. + TypeOp typeOp = cast(resultTypeOp); + if (typeOp.type()) + continue; + + // If the type operation was defined in the matcher and constrains the + // result of an input operation, it can be used. + auto constrainsInputOp = [rewriterBlock](Operation *user) { + return user->getBlock() != rewriterBlock && isa(user); + }; + if (llvm::any_of(typeOp.getResult().getUsers(), constrainsInputOp)) + continue; + + // Otherwise, check to see if any uses of the result can infer the type. + if (llvm::any_of(opResults[i].getUses(), canInferTypeFromUse)) + continue; + return op + .emitOpError("must have inferable or constrained result types when " + "nested within `pdl.rewrite`") + .attachNote() + .append("result type #", i, " was not constrained"); + } + return success(); +} + +static LogicalResult verify(OperationOp op) { + bool isWithinRewrite = isa(op.getParentOp()); + if (isWithinRewrite && !op.name()) + return op.emitOpError("must have an operation name when nested within " + "a `pdl.rewrite`"); + ArrayAttr attributeNames = op.attributeNames(); + auto attributeValues = op.attributes(); + if (attributeNames.size() != attributeValues.size()) { + return op.emitOpError() + << "expected the same number of attribute values and attribute " + "names, got " + << attributeNames.size() << " names and " << attributeValues.size() + << " values"; + } + + OperandRange resultTypes = op.types(); + auto opResults = op.results(); + if (resultTypes.size() != opResults.size()) { + return op.emitOpError() << "expected the same number of result values and " + "result type constraints, got " + << opResults.size() << " results and " + << resultTypes.size() << " constraints"; + } + + // If the operation is within a rewrite body and doesn't have type inferrence, + // ensure that the result types can be resolved. + if (isWithinRewrite && !op.hasTypeInference()) { + if (failed(verifyResultTypesAreInferrable(op, opResults, resultTypes))) + return failure(); + } + + return verifyHasBindingUseInMatcher(op, "`pdl.operation` or `pdl.rewrite`"); +} + +bool OperationOp::hasTypeInference() { + Optional opName = name(); + if (!opName) + return false; + + OperationName name(*opName, getContext()); + if (const AbstractOperation *op = name.getAbstractOperation()) + return op->getInterface(); + return false; +} + +//===----------------------------------------------------------------------===// +// pdl::PatternOp +//===----------------------------------------------------------------------===// + +static ParseResult parsePatternOp(OpAsmParser &p, OperationState &state) { + StringAttr name; + p.parseOptionalSymbolName(name, SymbolTable::getSymbolAttrName(), + state.attributes); + + // Parse the benefit. + IntegerAttr benefitAttr; + if (p.parseColon() || p.parseKeyword("benefit") || p.parseLParen() || + p.parseAttribute(benefitAttr, p.getBuilder().getIntegerType(16), + "benefit", state.attributes) || + p.parseRParen()) + return failure(); + + // Parse the pattern body. + if (p.parseOptionalAttrDictWithKeyword(state.attributes) || + p.parseRegion(*state.addRegion(), None, None)) + return failure(); + return success(); +} + +static void print(OpAsmPrinter &p, PatternOp op) { + p << "pdl.pattern"; + if (Optional name = op.sym_name()) { + p << ' '; + p.printSymbolName(*name); + } + p << " : benefit("; + p.printAttributeWithoutType(op.benefitAttr()); + p << ")"; + + p.printOptionalAttrDictWithKeyword( + op.getAttrs(), {"benefit", "rootKind", SymbolTable::getSymbolAttrName()}); + p.printRegion(op.body()); +} + +static LogicalResult verify(PatternOp pattern) { + Region &body = pattern.body(); + auto *term = body.front().getTerminator(); + if (!isa(term)) { + return pattern.emitOpError("expected body to terminate with `pdl.rewrite`") + .attachNote(term->getLoc()) + .append("see terminator defined here"); + } + + // Check that all values defined in the top-level pattern are referenced at + // least once in the source tree. + WalkResult result = body.walk([&](Operation *op) -> WalkResult { + if (!isa_and_nonnull(op->getDialect())) { + pattern + .emitOpError("expected only `pdl` operations within the pattern body") + .attachNote(op->getLoc()) + .append("see non-`pdl` operation defined here"); + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return failure(result.wasInterrupted()); +} + +void PatternOp::build(OpBuilder &builder, OperationState &state, + Optional rootKind, Optional benefit, + Optional name) { + build(builder, state, + rootKind ? builder.getStringAttr(*rootKind) : StringAttr(), + builder.getI16IntegerAttr(benefit ? *benefit : 0), + name ? builder.getStringAttr(*name) : StringAttr()); + builder.createBlock(state.addRegion()); +} + +/// Returns the rewrite operation of this pattern. +RewriteOp PatternOp::getRewriter() { + return cast(body().front().getTerminator()); +} + +/// Return the root operation kind that this pattern matches, or None if +/// there isn't a specific root. +Optional PatternOp::getRootKind() { + OperationOp rootOp = cast(getRewriter().root().getDefiningOp()); + return rootOp.name(); +} + +//===----------------------------------------------------------------------===// +// pdl::ReplaceOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(ReplaceOp op) { + auto sourceOp = cast(op.operation().getDefiningOp()); + auto sourceOpResults = sourceOp.results(); + auto replValues = op.replValues(); + + if (Value replOpVal = op.replOperation()) { + auto replOp = cast(replOpVal.getDefiningOp()); + auto replOpResults = replOp.results(); + if (sourceOpResults.size() != replOpResults.size()) { + return op.emitOpError() + << "expected source operation to have the same number of results " + "as the replacement operation, replacement operation provided " + << replOpResults.size() << " but expected " + << sourceOpResults.size(); + } + + if (!replValues.empty()) { + return op.emitOpError() << "expected no replacement values to be provided" + " when the replacement operation is present"; + } + + return success(); + } + + if (sourceOpResults.size() != replValues.size()) { + return op.emitOpError() + << "expected source operation to have the same number of results " + "as the provided replacement values, found " + << replValues.size() << " replacement values but expected " + << sourceOpResults.size(); + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// pdl::RewriteOp +//===----------------------------------------------------------------------===// + +static ParseResult parseRewriteOp(OpAsmParser &p, OperationState &state) { + // If the first token isn't a '(', this is an external rewrite. + StringAttr nameAttr; + if (failed(p.parseOptionalLParen())) { + if (p.parseAttribute(nameAttr, "name", state.attributes) || p.parseLParen()) + return failure(); + } + + // Parse the root operand. + OpAsmParser::OperandType rootOperand; + if (p.parseOperand(rootOperand) || p.parseRParen() || + p.resolveOperand(rootOperand, p.getBuilder().getType(), + state.operands)) + return failure(); + + // If this isn't an external rewrite, parse the region body. + Region &rewriteRegion = *state.addRegion(); + if (!nameAttr) { + if (p.parseRegion(rewriteRegion, /*arguments=*/llvm::None, + /*argTypes=*/llvm::None)) + return failure(); + RewriteOp::ensureTerminator(rewriteRegion, p.getBuilder(), state.location); + } + return success(); +} + +static void print(OpAsmPrinter &p, RewriteOp op) { + p << "pdl.rewrite"; + if (Optional name = op.name()) { + p << " \"" << *name << "\"(" << op.root() << ")"; + return; + } + + p << "(" << op.root() << ")"; + p.printRegion(op.body(), /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/false); +} + +static LogicalResult verify(RewriteOp op) { + Region &rewriteRegion = op.body(); + if (llvm::hasNItemsOrMore(rewriteRegion, 2)) { + return op.emitOpError() + << "expected rewrite region when specified to have a single block"; + } + return success(); +} + +//===----------------------------------------------------------------------===// +// pdl::TypeOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(TypeOp op) { + return verifyHasBindingUseInMatcher( + op, "`pdl.attribute`, `pdl.input`, or `pdl.operation`"); +} + +//===----------------------------------------------------------------------===// +// TableGen'd op method definitions +//===----------------------------------------------------------------------===// + +namespace mlir { +namespace pdl { + +#define GET_OP_CLASSES +#include "mlir/Dialect/PDL/IR/PDLOps.cpp.inc" + +} // end namespace pdl +} // end namespace mlir diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -914,6 +914,26 @@ return success(parser.consumeIf(Token::arrow)); } + /// Parse a '{' token. + ParseResult parseLBrace() override { + return parser.parseToken(Token::l_brace, "expected '{'"); + } + + /// Parse a '{' token if present + ParseResult parseOptionalLBrace() override { + return success(parser.consumeIf(Token::l_brace)); + } + + /// Parse a `}` token. + ParseResult parseRBrace() override { + return parser.parseToken(Token::r_brace, "expected '}'"); + } + + /// Parse a `}` token if present + ParseResult parseOptionalRBrace() override { + return success(parser.consumeIf(Token::r_brace)); + } + /// Parse a `:` token. ParseResult parseColon() override { return parser.parseToken(Token::colon, "expected ':'"); @@ -944,6 +964,11 @@ return parser.parseToken(Token::equal, "expected '='"); } + /// Parse a `=` token if present. + ParseResult parseOptionalEqual() override { + return success(parser.consumeIf(Token::equal)); + } + /// Parse a '<' token. ParseResult parseLess() override { return parser.parseToken(Token::less, "expected '<'"); @@ -974,6 +999,11 @@ return success(parser.consumeIf(Token::r_paren)); } + /// Parses a '?' if present. + ParseResult parseOptionalQuestion() override { + return success(parser.consumeIf(Token::question)); + } + /// Parse a `[` token. ParseResult parseLSquare() override { return parser.parseToken(Token::l_square, "expected '['"); @@ -998,17 +1028,10 @@ // Attribute Parsing //===--------------------------------------------------------------------===// - /// Parse an arbitrary attribute of a given type and return it in result. This - /// also adds the attribute to the specified attribute list with the specified - /// name. - ParseResult parseAttribute(Attribute &result, Type type, StringRef attrName, - NamedAttrList &attrs) override { + /// Parse an arbitrary attribute of a given type and return it in result. + ParseResult parseAttribute(Attribute &result, Type type) override { result = parser.parseAttribute(type); - if (!result) - return failure(); - - attrs.push_back(parser.builder.getNamedAttr(attrName, result)); - return success(); + return success(static_cast(result)); } /// Parse an optional attribute. diff --git a/mlir/test/Dialect/PDL/invalid.mlir b/mlir/test/Dialect/PDL/invalid.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/PDL/invalid.mlir @@ -0,0 +1,205 @@ +// RUN: mlir-opt %s -split-input-file -verify-diagnostics + +//===----------------------------------------------------------------------===// +// pdl::ApplyConstraintOp +//===----------------------------------------------------------------------===// + +pdl.pattern : benefit(1) { + %op = pdl.operation "foo.op" + + // expected-error@below {{expected at least one argument}} + "pdl.apply_constraint"() {name = "foo", params = []} : () -> () + pdl.rewrite "rewriter"(%op) +} + +// ----- + +//===----------------------------------------------------------------------===// +// pdl::AttributeOp +//===----------------------------------------------------------------------===// + +pdl.pattern : benefit(1) { + %type = pdl.type + + // expected-error@below {{expected only one of [`type`, `value`] to be set}} + %attr = pdl.attribute : %type 10 + + %op, %result = pdl.operation "foo.op" {"attr" = %attr} -> %type + pdl.rewrite "rewriter"(%op) +} + +// ----- + +pdl.pattern : benefit(1) { + %op = pdl.operation "foo.op" + pdl.rewrite(%op) { + %type = pdl.type + + // expected-error@below {{expected constant value when specified within a `pdl.rewrite`}} + %attr = pdl.attribute : %type + } +} + +// ----- + +pdl.pattern : benefit(1) { + %op = pdl.operation "foo.op" + pdl.rewrite(%op) { + // expected-error@below {{expected constant value when specified within a `pdl.rewrite`}} + %attr = pdl.attribute + } +} + +// ----- + +pdl.pattern : benefit(1) { + // expected-error@below {{expected a bindable (i.e. `pdl.operation`) user when defined in the matcher body of a `pdl.pattern`}} + %unused = pdl.attribute + + %op = pdl.operation "foo.op" + pdl.rewrite "rewriter"(%op) +} + +// ----- + +//===----------------------------------------------------------------------===// +// pdl::InputOp +//===----------------------------------------------------------------------===// + +pdl.pattern : benefit(1) { + // expected-error@below {{expected a bindable (i.e. `pdl.operation`) user when defined in the matcher body of a `pdl.pattern`}} + %unused = pdl.input + + %op = pdl.operation "foo.op" + pdl.rewrite "rewriter"(%op) +} + +// ----- + +//===----------------------------------------------------------------------===// +// pdl::OperationOp +//===----------------------------------------------------------------------===// + +pdl.pattern : benefit(1) { + %op = pdl.operation "foo.op" + pdl.rewrite(%op) { + // expected-error@below {{must have an operation name when nested within a `pdl.rewrite`}} + %newOp = pdl.operation + } +} + +// ----- + +pdl.pattern : benefit(1) { + // expected-error@below {{expected the same number of attribute values and attribute names, got 1 names and 0 values}} + %op = "pdl.operation"() { + attributeNames = ["attr"], + operand_segment_sizes = dense<0> : vector<3xi32> + } : () -> (!pdl.operation) + pdl.rewrite "rewriter"(%op) +} + +// ----- + +pdl.pattern : benefit(1) { + %op = pdl.operation "foo.op"() + pdl.rewrite (%op) { + %type = pdl.type + + // expected-error@below {{op must have inferable or constrained result types when nested within `pdl.rewrite`}} + // expected-note@below {{result type #0 was not constrained}} + %newOp, %result = pdl.operation "foo.op" -> %type + } +} + +// ----- + +pdl.pattern : benefit(1) { + // expected-error@below {{expected a bindable (i.e. `pdl.operation` or `pdl.rewrite`) user when defined in the matcher body of a `pdl.pattern`}} + %unused = pdl.operation "foo.op" + + %op = pdl.operation "foo.op" + pdl.rewrite "rewriter"(%op) +} + +// ----- + +//===----------------------------------------------------------------------===// +// pdl::PatternOp +//===----------------------------------------------------------------------===// + +// expected-error@below {{expected body to terminate with `pdl.rewrite`}} +pdl.pattern : benefit(1) { + // expected-note@below {{see terminator defined here}} + return +} + +// ----- + +// expected-error@below {{expected only `pdl` operations within the pattern body}} +pdl.pattern : benefit(1) { + // expected-note@below {{see non-`pdl` operation defined here}} + "foo.other_op"() : () -> () + + %root = pdl.operation "foo.op" + pdl.rewrite "foo"(%root) +} + +// ----- + +//===----------------------------------------------------------------------===// +// pdl::ReplaceOp +//===----------------------------------------------------------------------===// + +pdl.pattern : benefit(1) { + %root = pdl.operation "foo.op" + pdl.rewrite (%root) { + %type = pdl.type : i32 + %newOp, %newResult = pdl.operation "foo.op" -> %type + + // expected-error@below {{to have the same number of results as the replacement operation}} + pdl.replace %root with %newOp + } +} + +// ----- + +pdl.pattern : benefit(1) { + %type = pdl.type : i32 + %root, %oldResult = pdl.operation "foo.op" -> %type + pdl.rewrite (%root) { + %newOp, %newResult = pdl.operation "foo.op" -> %type + + // expected-error@below {{expected no replacement values to be provided when the replacement operation is present}} + "pdl.replace"(%root, %newOp, %newResult) { + operand_segment_sizes = dense<1> : vector<3xi32> + } : (!pdl.operation, !pdl.operation, !pdl.value) -> () + } +} + +// ----- + +pdl.pattern : benefit(1) { + %root = pdl.operation "foo.op" + pdl.rewrite (%root) { + %type = pdl.type : i32 + %newOp, %newResult = pdl.operation "foo.op" -> %type + + // expected-error@below {{to have the same number of results as the provided replacement values}} + pdl.replace %root with (%newResult) + } +} + +// ----- + +//===----------------------------------------------------------------------===// +// pdl::TypeOp +//===----------------------------------------------------------------------===// + +pdl.pattern : benefit(1) { + // expected-error@below {{expected a bindable (i.e. `pdl.attribute`, `pdl.input`, or `pdl.operation`) user when defined in the matcher body of a `pdl.pattern`}} + %unused = pdl.type + + %op = pdl.operation "foo.op" + pdl.rewrite "rewriter"(%op) +} diff --git a/mlir/test/Dialect/PDL/ops.mlir b/mlir/test/Dialect/PDL/ops.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/PDL/ops.mlir @@ -0,0 +1,62 @@ +// RUN: mlir-opt -split-input-file %s | mlir-opt +// Verify the printed output can be parsed. +// RUN: mlir-opt %s | mlir-opt +// Verify the generic form can be parsed. +// RUN: mlir-opt -mlir-print-op-generic %s | mlir-opt + +// ----- + +pdl.pattern @operations : benefit(1) { + // Operation with attributes and results. + %attribute = pdl.attribute + %type = pdl.type + %op0, %op0_result = pdl.operation {"attr" = %attribute} -> %type + + // Operation with input. + %input = pdl.input + %root = pdl.operation(%op0_result, %input) + pdl.rewrite "rewriter"(%root) +} + +// ----- + +// Check that the result type of an operation within a rewrite can be inferred +// from a pdl.replace. +pdl.pattern @infer_type_from_operation_replace : benefit(1) { + %type1 = pdl.type : i32 + %type2 = pdl.type + %root, %results:2 = pdl.operation -> %type1, %type2 + pdl.rewrite(%root) { + %type3 = pdl.type + %newOp, %newResults:2 = pdl.operation "foo.op" -> %type1, %type3 + pdl.replace %root with %newOp + } +} + +// ----- + +// Check that the result type of an operation within a rewrite can be inferred +// from a pdl.replace. +pdl.pattern @infer_type_from_result_replace : benefit(1) { + %type1 = pdl.type : i32 + %type2 = pdl.type + %root, %results:2 = pdl.operation -> %type1, %type2 + pdl.rewrite(%root) { + %type3 = pdl.type + %newOp, %newResults:2 = pdl.operation "foo.op" -> %type1, %type3 + pdl.replace %root with (%newResults#0, %newResults#1) + } +} + +// ----- + +// Check that the result type of an operation within a rewrite can be inferred +// from a pdl.replace. +pdl.pattern @infer_type_from_type_used_in_match : benefit(1) { + %type1 = pdl.type : i32 + %type2 = pdl.type + %root, %results:2 = pdl.operation -> %type1, %type2 + pdl.rewrite(%root) { + %newOp, %newResults:2 = pdl.operation "foo.op" -> %type1, %type2 + } +} diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -226,7 +226,7 @@ // If there is only one character, this must either be punctuation or a // single character bare identifier. if (value.size() == 1) - return isalpha(front) || StringRef("_:,=<>()[]").contains(front); + return isalpha(front) || StringRef("_:,=<>()[]?").contains(front); // Check the punctuation that are larger than a single character. if (value == "->") @@ -586,7 +586,8 @@ .Case("(", "LParen()") .Case(")", "RParen()") .Case("[", "LSquare()") - .Case("]", "RSquare()"); + .Case("]", "RSquare()") + .Case("?", "Question()"); } /// Generate the storage code required for parsing the given element.