diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index 82dc6a456f29..72b3b1ab41f5 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1,2479 +1,2481 @@ //===-- OpBase.td - Base op definition file ----------------*- tablegen -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This is the base operation definition file. // //===----------------------------------------------------------------------===// #ifndef OP_BASE #define OP_BASE //===----------------------------------------------------------------------===// // Common utilities for defining TableGen mechanisms //===----------------------------------------------------------------------===// // A workaround for the inability to define functions in Tablegen. // // The template parameter defines a string that can be extracted from an // instance of this class by accessing the "result" member. Subclasses can take // their own template parameters as function "arguments" and use them to // populate result. // For example, if it didn't already exist, a concat function could be defined // like: // // class StrConcat strings> : // StrFunc // // and then called like // // StrConcat<["a", "b", "c"]>.result // // to get the string "abc" class StrFunc { string result = r; } // Concatenates a list of strings with a separator (default ", ") class StrJoin strings, string sep = ", "> : StrFunc; // Concatenates a list of integers into a string with a separator (default ", ") class StrJoinInt integers, string sep = ", "> : StrJoin(i)), sep>; //===----------------------------------------------------------------------===// // Predicate definitions //===----------------------------------------------------------------------===// // Base class for logical predicates. // // Predicates are used to compose constraints (see next section for details). // There are two categories of predicates: // // 1. CPred: the primitive leaf predicate. // 2. Compound predicate: a predicate composed from child predicates using // predicate combiners ("conjunction", "disjunction", "negation" or // "substitution"). class Pred; // A logical predicate wrapping any C expression. // // This is the basis for composing more complex predicates. It is the "atom" // predicate from the perspective of TableGen and the "interface" between // TableGen and C++. What is inside is already C++ code, which will be treated // as opaque strings with special placeholders to be substituted. // // ## Special placeholders // // Special placeholders can be used to refer to entities in the context where // this predicate is used. They serve as "hooks" to the enclosing environment. // The following special placeholders are supported in constraints for an op: // // * `$_builder` will be replaced by a mlir::Builder instance. // * `$_op` will be replaced by the current operation. // * `$_self` will be replaced with the entity this predicate is attached to. // E.g., `BoolAttr` is an attribute constraint that wraps a // `CPred<"$_self.isa()">` (see the following sections for details). // Then for `F32:$attr`,`$_self` will be replaced by `$attr`. // For type constraints, it's a little bit special since we want the // constraints on each type definition reads naturally and we want to attach // type constraints directly to an operand/result, $_self will be replaced // by the operand/result's type. E.g., for `F32` in `F32:$operand`, its // `$_self` will be expanded as `getOperand(...).getType()`. class CPred : Pred { code predExpr = "(" # pred # ")"; } // Kinds of predicate combiners. These must closely match the predicates // implemented by the C++ backend (tblgen::PredCombinerKind). class PredCombinerKind; def PredCombinerAnd : PredCombinerKind; def PredCombinerOr : PredCombinerKind; def PredCombinerNot : PredCombinerKind; def PredCombinerSubstLeaves : PredCombinerKind; def PredCombinerConcat : PredCombinerKind; // A predicate that combines other predicates as defined by PredCombinerKind. // Instantiated below. class CombinedPred c> : Pred { PredCombinerKind kind = k; list children = c; } // Predicate combiners // A predicate that holds if all of its children hold. Always holds for zero // children. class And children> : CombinedPred; // A predicate that holds if any of its children hold. Never holds for zero // children. class Or children> : CombinedPred; // A predicate that holds if its child does not. class Neg : CombinedPred; // A predicate that substitutes "pat" with "repl" in predicate calls of the // leaves of the predicate tree (i.e., not CombinedPred). // // This is plain string substitution without regular expressions or captures. // New predicates with more complex logical can be introduced should the need // arise. class SubstLeaves : CombinedPred { string pattern = pat; string replacement = repl; } // A predicate that prepends `pre` and appends `suf` to the final predicate // string composed from `child`. This is plain string concatenation and there // will be no substitution happening for `pre` and `suf`. class Concat : CombinedPred { string prefix = pre; string suffix = suf; } //===----------------------------------------------------------------------===// // Constraint definitions //===----------------------------------------------------------------------===// // TODO: Merge Constraints into Pred. // Base class for named constraints. // // An op's operands/attributes/results can have various requirements, e.g., // having certain types, having values inside a certain range, and so on. // Besides, for a graph rewrite rule, the source pattern used to match against // the existing graph has conditions, like the op's operand must be of a more // constrained subtype, the attribute must have a certain value, and so on. // // These requirements and conditions are modeled using this class. Records of // this class are used to generate verification code in op verifier, and // matching code in pattern matcher. // // Constraints are predicates with descriptive names, to facilitate inspection, // provide nice error messages, etc. class Constraint { // The predicates that this constraint requires. Pred predicate = pred; // User-readable description used in error reporting messages. If empty, a // generic message will be used. string description = desc; } // Subclasses used to differentiate different constraint kinds. These are used // as markers for the TableGen backend to handle different constraint kinds // differently if needed. Constraints not deriving from the following subclasses // are considered as uncategorized constraints. // Subclass for constraints on a type. class TypeConstraint : Constraint; // Subclass for constraints on an attribute. class AttrConstraint : Constraint; // Subclass for constraints on a region. class RegionConstraint : Constraint; // Subclass for constraints on a successor. class SuccessorConstraint : Constraint; // How to use these constraint categories: // // * Use TypeConstraint to specify // * Constraints on an op's operand/result definition // * Further constraints to match an op's operand/result in source pattern // // * Use Attr (a subclass for AttrConstraint) for // * Constraints on an op's attribute definition // * Use AttrConstraint to specify // * Further constraints to match an op's attribute in source pattern // // * Use uncategorized constraint to specify // * Multi-entity constraints in rewrite rules //===----------------------------------------------------------------------===// // Common predicates //===----------------------------------------------------------------------===// // Whether a type is a VectorType. def IsVectorTypePred : CPred<"$_self.isa<::mlir::VectorType>()">; // Whether a type is a TensorType. def IsTensorTypePred : CPred<"$_self.isa<::mlir::TensorType>()">; // Whether a type is a MemRefType. def IsMemRefTypePred : CPred<"$_self.isa<::mlir::MemRefType>()">; // Whether a type is an IsUnrankedMemRefType def IsUnrankedMemRefTypePred : CPred<"$_self.isa<::mlir::UnrankedMemRefType>()">; // Whether a type is a ShapedType. def IsShapedTypePred : CPred<"$_self.isa<::mlir::ShapedType>()">; // For a ShapedType, verify that it has a static shape. def HasStaticShapePred : CPred<"$_self.cast<::mlir::ShapedType>().hasStaticShape()">; // Whether a type is a TupleType. def IsTupleTypePred : CPred<"$_self.isa<::mlir::TupleType>()">; //===----------------------------------------------------------------------===// // Dialect definitions //===----------------------------------------------------------------------===// class Dialect { // The name of the dialect. string name = ?; // Short summary of the dialect. string summary = ?; // The description of the dialect. string description = ?; // A list of dialects this dialect will load on construction as dependencies. // These are dialects that this dialect may involved in canonicalization // pattern or interfaces. list dependentDialects = []; // The C++ namespace that ops of this dialect should be placed into. // // By default, uses the name of the dialect as the only namespace. To avoid // placing in any namespace, use "". To specify nested namespaces, use "::" // as the delimiter, e.g., given "A::B", ops will be placed in // `namespace A { namespace B { } }`. // // Note that this works in conjunction with dialect C++ code. Depending on how // the generated files are included into the dialect, you may want to specify // a full namespace path or a partial one. string cppNamespace = name; // An optional code block containing extra declarations to place in the // dialect declaration. code extraClassDeclaration = ""; // If this dialect overrides the hook for materializing constants. bit hasConstantMaterializer = 0; // If this dialect overrides the hook for verifying operation attributes. bit hasOperationAttrVerify = 0; // If this dialect overrides the hook for verifying region argument // attributes. bit hasRegionArgAttrVerify = 0; // If this dialect overrides the hook for verifying region result attributes. bit hasRegionResultAttrVerify = 0; } //===----------------------------------------------------------------------===// // Type definitions //===----------------------------------------------------------------------===// // A type, carries type constraints. class Type : TypeConstraint { string typeDescription = ""; string builderCall = ""; } // Allows providing an alternative name and description to an existing type def. class TypeAlias : Type { let typeDescription = t.typeDescription; let builderCall = t.builderCall; } // A type of a specific dialect. class DialectType : Type { Dialect dialect = d; } // A variadic type constraint. It expands to zero or more of the base type. This // class is used for supporting variadic operands/results. class Variadic : TypeConstraint { Type baseType = type; } // An optional type constraint. It expands to either zero or one of the base // type. This class is used for supporting optional operands/results. class Optional : TypeConstraint { Type baseType = type; } // A type that can be constructed using MLIR::Builder. // Note that this does not "inherit" from Type because it would require // duplicating Type subclasses for buildable and non-buildable cases to avoid // diamond "inheritance". // TODO: we may extend this to a more general 'Buildable' trait, making some // Types and some Attrs buildable. class BuildableType { // The builder call to invoke (if specified) to construct the BuildableType. code builderCall = builder; } // Any type at all. def AnyType : Type, "any type">; // None type def NoneType : Type()">, "none type">, BuildableType<"$_builder.getType<::mlir::NoneType>()">; // Any type from the given list class AnyTypeOf allowedTypes, string description = ""> : Type< // Satisfy any of the allowed type's condition Or, !if(!eq(description, ""), StrJoin.result, description)>; // Integer types. // Any integer type irrespective of its width and signedness semantics. def AnyInteger : Type()">, "integer">; // Any integer type (regardless of signedness semantics) of a specific width. class AnyI : Type, width # "-bit integer"> { int bitwidth = width; } class AnyIntOfWidths widths> : AnyTypeOf), StrJoinInt.result # "-bit integer">; def AnyI1 : AnyI<1>; def AnyI8 : AnyI<8>; def AnyI16 : AnyI<16>; def AnyI32 : AnyI<32>; def AnyI64 : AnyI<64>; // Any signless integer type irrespective of its width. def AnySignlessInteger : Type< CPred<"$_self.isSignlessInteger()">, "signless integer">; // Signless integer type of a specific width. class I : Type, width # "-bit signless integer">, BuildableType<"$_builder.getIntegerType(" # width # ")"> { int bitwidth = width; } class SignlessIntOfWidths widths> : AnyTypeOf), StrJoinInt.result # "-bit signless integer">; def I1 : I<1>; def I8 : I<8>; def I16 : I<16>; def I32 : I<32>; def I64 : I<64>; // Any signed integer type irrespective of its width. def AnySignedInteger : Type< CPred<"$_self.isSignedInteger()">, "signed integer">; // Signed integer type of a specific width. class SI : Type, width # "-bit signed integer">, BuildableType< "$_builder.getIntegerType(" # width # ", /*isSigned=*/true)"> { int bitwidth = width; } class SignedIntOfWidths widths> : AnyTypeOf), StrJoinInt.result # "-bit signed integer">; def SI1 : SI<1>; def SI8 : SI<8>; def SI16 : SI<16>; def SI32 : SI<32>; def SI64 : SI<64>; // Any unsigned integer type irrespective of its width. def AnyUnsignedInteger : Type< CPred<"$_self.isUnsignedInteger()">, "unsigned integer">; // Unsigned integer type of a specific width. class UI : Type, width # "-bit unsigned integer">, BuildableType< "$_builder.getIntegerType(" # width # ", /*isSigned=*/false)"> { int bitwidth = width; } class UnsignedIntOfWidths widths> : AnyTypeOf), StrJoinInt.result # "-bit unsigned integer">; def UI1 : UI<1>; def UI8 : UI<8>; def UI16 : UI<16>; def UI32 : UI<32>; def UI64 : UI<64>; // Index type. def Index : Type()">, "index">, BuildableType<"$_builder.getIndexType()">; // Floating point types. // Any float type irrespective of its width. def AnyFloat : Type()">, "floating-point">; // Float type of a specific width. class F : Type, width # "-bit float">, BuildableType<"$_builder.getF" # width # "Type()"> { int bitwidth = width; } class FloatOfWidths widths> : AnyTypeOf), StrJoinInt.result # "-bit float">; def F16 : F<16>; def F32 : F<32>; def F64 : F<64>; def BF16 : Type, "bfloat16 type">, BuildableType<"$_builder.getBF16Type()">; class Complex : Type()">, SubstLeaves<"$_self", "$_self.cast<::mlir::ComplexType>().getElementType()", type.predicate>]>, "complex type with " # type.description # " elements"> { Type elementType = type; } def AnyComplex : Type()">, "complex-type">; class OpaqueType : Type, description>, BuildableType<"::mlir::OpaqueType::get($_builder.getIdentifier(\"" # dialect # "\"), \"" # name # "\", $_builder.getContext())">; // Function Type // Any function type. def FunctionType : Type()">, "function type">; // A container type is a type that has another type embedded within it. class ContainerType : // First, check the container predicate. Then, substitute the extracted // element into the element type checker. Type(elementTypeCall), etype.predicate>]>, descr # " of " # etype.description # " values"> { // The type of elements in the container. Type elementType = etype; // Call to retrieve. code getElementTypeCall = elementTypeCall; } class ShapedContainerType allowedTypes, Pred containerPred, string descr> : ContainerType, containerPred, "$_self.cast<::mlir::ShapedType>().getElementType()", descr>; // Whether a shaped type is ranked. def HasRankPred : CPred<"$_self.cast<::mlir::ShapedType>().hasRank()">; // Whether a shaped type has one of the specified ranks. class HasAnyRankOfPred ranks> : And<[ HasRankPred, Or().getRank() == }] # rank>)>]>; // Vector types. class VectorOf allowedTypes> : ShapedContainerType; // Whether the number of elements of a vector is from the given // `allowedRanks` list class IsVectorOfRankPred allowedRanks> : And<[IsVectorTypePred, Or().getRank() == }] # allowedlength>)>]>; // Any vector where the rank is from the given `allowedRanks` list class VectorOfRank allowedRanks> : Type< IsVectorOfRankPred, " of ranks " # StrJoinInt.result>; // Any vector where the rank is from the given `allowedRanks` list and the type // is from the given `allowedTypes` list class VectorOfRankAndType allowedRanks, list allowedTypes> : Type< And<[VectorOf.predicate, VectorOfRank.predicate]>, VectorOf.description # VectorOfRank.description>; // Whether the number of elements of a vector is from the given // `allowedLengths` list class IsVectorOfLengthPred allowedLengths> : And<[IsVectorTypePred, Or().getNumElements() == }] # allowedlength>)>]>; // Any vector where the number of elements is from the given // `allowedLengths` list class VectorOfLength allowedLengths> : Type< IsVectorOfLengthPred, " of length " # StrJoinInt.result>; // Any vector where the number of elements is from the given // `allowedLengths` list and the type is from the given `allowedTypes` // list class VectorOfLengthAndType allowedLengths, list allowedTypes> : Type< And<[VectorOf.predicate, VectorOfLength.predicate]>, VectorOf.description # VectorOfLength.description>; def AnyVector : VectorOf<[AnyType]>; // Shaped types. def AnyShaped: ShapedContainerType<[AnyType], IsShapedTypePred, "shaped">; // Tensor types. // Any tensor type whose element type is from the given `allowedTypes` list class TensorOf allowedTypes> : ShapedContainerType; def AnyTensor : TensorOf<[AnyType]>; def AnyRankedTensor : ShapedContainerType<[AnyType], And<[IsTensorTypePred, HasRankPred]>, "ranked tensor">; // TODO: Have an easy way to add another constraint to a type. class StaticShapeTensorOf allowedTypes> : Type.predicate, HasStaticShapePred]>, "statically shaped " # TensorOf.description>; def AnyStaticShapeTensor : StaticShapeTensorOf<[AnyType]>; def I1Tensor : TensorOf<[I1]>; def I8Tensor : TensorOf<[I8]>; def I16Tensor : TensorOf<[I16]>; def I32Tensor : TensorOf<[I32]>; def I64Tensor : TensorOf<[I64]>; def IndexTensor: TensorOf<[Index]>; def BF16Tensor : TensorOf<[BF16]>; def F16Tensor : TensorOf<[F16]>; def F32Tensor : TensorOf<[F32]>; def F64Tensor : TensorOf<[F64]>; // Ranked tensor type with one of the specified types and ranks. class TensorRankOf allowedTypes, list ranks> : Type.predicate, HasAnyRankOfPred]>, StrJoin.result # " " # TensorOf.description>; class 0DTensorOf allowedTypes> : TensorRankOf; class 1DTensorOf allowedTypes> : TensorRankOf; class 2DTensorOf allowedTypes> : TensorRankOf; class 3DTensorOf allowedTypes> : TensorRankOf; class 4DTensorOf allowedTypes> : TensorRankOf; // Unranked Memref type def AnyUnrankedMemRef : ShapedContainerType<[AnyType], IsUnrankedMemRefTypePred, "unranked.memref">; // Memref type. // Memrefs are blocks of data with fixed type and rank. class MemRefOf allowedTypes> : ShapedContainerType; def AnyMemRef : MemRefOf<[AnyType]>; def AnyRankedOrUnrankedMemRef: AnyTypeOf<[AnyUnrankedMemRef, AnyMemRef]>; // Memref declarations handle any memref, independent of rank, size, (static or // dynamic), layout, or memory space. def I1MemRef : MemRefOf<[I1]>; def I8MemRef : MemRefOf<[I8]>; def I16MemRef : MemRefOf<[I16]>; def I32MemRef : MemRefOf<[I32]>; def I64MemRef : MemRefOf<[I64]>; def BF16MemRef : MemRefOf<[BF16]>; def F16MemRef : MemRefOf<[F16]>; def F32MemRef : MemRefOf<[F32]>; def F64MemRef : MemRefOf<[F64]>; // TODO: Have an easy way to add another constraint to a type. class MemRefRankOf allowedTypes, list ranks> : Type.predicate, HasAnyRankOfPred]>, StrJoin.result # " " # MemRefOf.description>; class StaticShapeMemRefOf allowedTypes> : Type.predicate, HasStaticShapePred]>, "statically shaped " # MemRefOf.description>; def AnyStaticShapeMemRef : StaticShapeMemRefOf<[AnyType]>; // For a MemRefType, verify that it has strides. def HasStridesPred : CPred<[{ isStrided($_self.cast<::mlir::MemRefType>()) }]>; class StridedMemRefOf allowedTypes> : Type.predicate, HasStridesPred]>, "strided " # MemRefOf.description>; def AnyStridedMemRef : StridedMemRefOf<[AnyType]>; class AnyStridedMemRefOfRank : Type.predicate]>, AnyStridedMemRef.description # " of rank " # rank>; // This represents a generic tuple without any constraints on element type. def AnyTuple : Type; // A container type that has other types embedded in it, but (unlike // ContainerType) can hold elements with a mix of types. Requires a call that // produces a list of all elements' types. class MixedContainerType : Type< And<[ containerPred, Concat< "::llvm::all_of(" # elementTypesCall # ", [](Type t) { return ", SubstLeaves<"$_self", "t", etype.predicate>, "; })" > ]>, descr # " with any combination of " # etype.description # " values"> { // The type of elements in the container. Type elementType = etype; // Call to retrieve. code getElementTypesCall = elementTypesCall; } // A Tuple that holds a mix of elements of the allowed types. class TupleOf allowedTypes> : MixedContainerType, IsTupleTypePred, "$_self.cast<::mlir::TupleType>().getTypes()", "tuple">; // A Tuple with arbitrary nesting, where all elements are a mix of the allowed // types. class NestedTupleOf allowedTypes> : MixedContainerType, IsTupleTypePred, "getFlattenedTypes($_self.cast<::mlir::TupleType>())", "nested tuple">; //===----------------------------------------------------------------------===// // Common type constraints //===----------------------------------------------------------------------===// // Type constraint for bool-like types: bools, vectors of bools, tensors of // bools. def BoolLike : TypeConstraint.predicate, TensorOf<[I1]>.predicate]>, "bool-like">; // Type constraint for signless-integer-like types: signless integers, indices, // vectors of signless integers, tensors of signless integers. def SignlessIntegerLike : TypeConstraint.predicate, TensorOf<[AnySignlessInteger]>.predicate]>, "signless-integer-like">; // Type constraint for float-like types: floats, vectors or tensors thereof. def FloatLike : TypeConstraint.predicate, TensorOf<[AnyFloat]>.predicate]>, "floating-point-like">; // Type constraint for signless-integer-like or float-like types. def SignlessIntegerOrFloatLike : TypeConstraint, "signless-integer-like or floating-point-like">; //===----------------------------------------------------------------------===// // Attribute definitions //===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===// // Base attribute definition // Base class for all attributes. class Attr : AttrConstraint { code storageType = ?; // The backing mlir::Attribute type code returnType = ?; // The underlying C++ value type // The call expression to convert from the storage type to the return // type. For example, an enum can be stored as an int but returned as an // enum class. // // Format: $_self will be expanded to the attribute. // // For example, `$_self.getValue().getSExtValue()` for `IntegerAttr val` will // expand to `getAttrOfType("val").getValue().getSExtValue()`. code convertFromStorage = "$_self.getValue()"; // The call expression to build an attribute from a constant value. // // Format: $0 will be expanded to the constant value of the attribute. // // For example, `$_builder.getStringAttr("$0")` for `StringAttr:"foo"` will // expand to `builder.getStringAttr("foo")`. string constBuilderCall = ?; // Default value for attribute. // Requires a constBuilderCall defined. string defaultValue = ?; // The value type of this attribute. This corresponds to the mlir::Type that // this attribute returns via `getType()`. Type valueType = ?; // Whether the attribute is optional. Typically requires a custom // convertFromStorage method to handle the case where the attribute is // not present. bit isOptional = 0; // What is the base-level Attr instantiation that this Attr is built upon. // Unset means this is a base-level Attr. // // This field is used by attribute wrapper classes (DefaultValuedAttr, // OptionalAttr, etc.) to retrieve the base-level attribute definition. // This can be used for getting its name; otherwise, we will see // "anonymous_" as the attribute def name because of template // instantiation. // TOOD(b/132458159): deduplicate the fields in attribute wrapper classes. Attr baseAttr = ?; // The fully-qualified C++ namespace where the generated class lives. string cppNamespace = ""; } // An attribute of a specific dialect. class DialectAttr : Attr { Dialect dialect = d; let cppNamespace = d.cppNamespace; } //===----------------------------------------------------------------------===// // Attribute modifier definition // Decorates an attribute to have an (unvalidated) default value if not present. class DefaultValuedAttr : Attr { // Construct this attribute with the input attribute and change only // the default value. // Note: this has to be kept up to date with Attr above. let storageType = attr.storageType; let returnType = attr.returnType; let convertFromStorage = attr.convertFromStorage; let constBuilderCall = attr.constBuilderCall; let defaultValue = val; let valueType = attr.valueType; let baseAttr = attr; } // Decorates an attribute as optional. The return type of the generated // attribute accessor method will be Optional<>. class OptionalAttr : Attr { // Rewrite the attribute to be optional. // Note: this has to be kept up to date with Attr above. let storageType = attr.storageType; let returnType = "::llvm::Optional<" # attr.returnType #">"; let convertFromStorage = "$_self ? " # returnType # "(" # attr.convertFromStorage # ") : (::llvm::None)"; let valueType = attr.valueType; let isOptional = 1; let baseAttr = attr; } //===----------------------------------------------------------------------===// // Primitive attribute kinds // A generic attribute that must be constructed around a specific buildable type // `attrValType`. Backed by MLIR attribute kind `attrKind`. class TypedAttrBase : Attr { let constBuilderCall = "$_builder.get" # attrKind # "(" # attrValType.builderCall # ", $0)"; let storageType = "::mlir::" # attrKind; let valueType = attrValType; } // Any attribute. def AnyAttr : Attr, "any attribute"> { let storageType = "::mlir::Attribute"; let returnType = "::mlir::Attribute"; let convertFromStorage = "$_self"; let constBuilderCall = "$0"; } def BoolAttr : Attr()">, "bool attribute"> { let storageType = [{ ::mlir::BoolAttr }]; let returnType = [{ bool }]; let valueType = I1; let constBuilderCall = "$_builder.getBoolAttr($0)"; } // Index attribute. def IndexAttr : TypedAttrBase< Index, "IntegerAttr", And<[CPred<"$_self.isa<::mlir::IntegerAttr>()">, CPred<"$_self.cast<::mlir::IntegerAttr>().getType()" ".isa<::mlir::IndexType>()">]>, "index attribute"> { let returnType = [{ ::llvm::APInt }]; } // Base class for any integer (regardless of signedness semantics) attributes // of fixed width. class AnyIntegerAttrBase : TypedAttrBase< attrValType, "IntegerAttr", And<[CPred<"$_self.isa<::mlir::IntegerAttr>()">, CPred<"$_self.cast<::mlir::IntegerAttr>().getType()." "isInteger(" # attrValType.bitwidth # ")">]>, descr> { let returnType = [{ ::llvm::APInt }]; let constBuilderCall = ?; } def AnyI1Attr : AnyIntegerAttrBase; def AnyI8Attr : AnyIntegerAttrBase; def AnyI16Attr : AnyIntegerAttrBase; def AnyI32Attr : AnyIntegerAttrBase; def AnyI64Attr : AnyIntegerAttrBase; def APIntAttr : Attr()">, "arbitrary integer attribute"> { let storageType = [{ ::mlir::IntegerAttr }]; let returnType = [{ ::mlir::APInt }]; } // Base class for signless integer attributes of fixed width. class SignlessIntegerAttrBase : TypedAttrBase< attrValType, "IntegerAttr", And<[CPred<"$_self.isa<::mlir::IntegerAttr>()">, CPred<"$_self.cast<::mlir::IntegerAttr>().getType()." "isSignlessInteger(" # attrValType.bitwidth # ")">]>, descr> { let returnType = [{ ::llvm::APInt }]; } // Base class for signless integer attributes of fixed width that have a // correpsonding C++ type. class TypedSignlessIntegerAttrBase : SignlessIntegerAttrBase { let returnType = retType; let convertFromStorage = "$_self.getValue().getZExtValue()"; } def I1Attr : TypedSignlessIntegerAttrBase< I1, "bool", "1-bit signless integer attribute">; def I8Attr : TypedSignlessIntegerAttrBase< I8, "uint8_t", "8-bit signless integer attribute">; def I16Attr : TypedSignlessIntegerAttrBase< I16, "uint16_t", "16-bit signless integer attribute">; def I32Attr : TypedSignlessIntegerAttrBase< I32, "uint32_t", "32-bit signless integer attribute">; def I64Attr : TypedSignlessIntegerAttrBase< I64, "uint64_t", "64-bit signless integer attribute">; // Base class for signed integer attributes of fixed width. class SignedIntegerAttrBase : TypedAttrBase< attrValType, "IntegerAttr", And<[CPred<"$_self.isa<::mlir::IntegerAttr>()">, CPred<"$_self.cast<::mlir::IntegerAttr>().getType()." "isSignedInteger(" # attrValType.bitwidth # ")">]>, descr> { let returnType = [{ ::llvm::APInt }]; } // Base class for signed integer attributes of fixed width that have a // correpsonding C++ type. class TypedSignedIntegerAttrBase : SignedIntegerAttrBase { let returnType = retType; let convertFromStorage = "$_self.getValue().getSExtValue()"; } def SI1Attr : TypedSignedIntegerAttrBase< SI1, "bool", "1-bit signed integer attribute">; def SI8Attr : TypedSignedIntegerAttrBase< SI8, "int8_t", "8-bit signed integer attribute">; def SI16Attr : TypedSignedIntegerAttrBase< SI16, "int16_t", "16-bit signed integer attribute">; def SI32Attr : TypedSignedIntegerAttrBase< SI32, "int32_t", "32-bit signed integer attribute">; def SI64Attr : TypedSignedIntegerAttrBase< SI64, "int64_t", "64-bit signed integer attribute">; // Base class for unsigned integer attributes of fixed width. class UnsignedIntegerAttrBase : TypedAttrBase< attrValType, "IntegerAttr", And<[CPred<"$_self.isa<::mlir::IntegerAttr>()">, CPred<"$_self.cast<::mlir::IntegerAttr>().getType()." "isUnsignedInteger(" # attrValType.bitwidth # ")">]>, descr> { let returnType = [{ ::llvm::APInt }]; } // Base class for unsigned integer attributes of fixed width that have a // correpsonding C++ type. class TypedUnsignedIntegerAttrBase : UnsignedIntegerAttrBase { let returnType = retType; let convertFromStorage = "$_self.getValue().getZExtValue()"; } def UI1Attr : TypedUnsignedIntegerAttrBase< UI1, "bool", "1-bit unsigned integer attribute">; def UI8Attr : TypedUnsignedIntegerAttrBase< UI8, "uint8_t", "8-bit unsigned integer attribute">; def UI16Attr : TypedUnsignedIntegerAttrBase< UI16, "uint16_t", "16-bit unsigned integer attribute">; def UI32Attr : TypedUnsignedIntegerAttrBase< UI32, "uint32_t", "32-bit unsigned integer attribute">; def UI64Attr : TypedUnsignedIntegerAttrBase< UI64, "uint64_t", "64-bit unsigned integer attribute">; // Base class for float attributes of fixed width. class FloatAttrBase : TypedAttrBase()">, CPred<"$_self.cast<::mlir::FloatAttr>().getType().isF" # attrValType.bitwidth # "()">]>, descr> { let returnType = [{ ::llvm::APFloat }]; } def F32Attr : FloatAttrBase; def F64Attr : FloatAttrBase; // An attribute backed by a string type. class StringBasedAttr : Attr { let constBuilderCall = "$_builder.getStringAttr(\"$0\")"; let storageType = [{ ::mlir::StringAttr }]; let returnType = [{ ::llvm::StringRef }]; let valueType = NoneType; } def StrAttr : StringBasedAttr()">, "string attribute">; // A string attribute that represents the name of a symbol. def SymbolNameAttr : StringBasedAttr()">, "string attribute">; // String attribute that has a specific value type. class TypedStrAttr : StringBasedAttr()">, "string attribute"> { let valueType = ty; } // Base class for attributes containing types. Example: // def IntTypeAttr : TypeAttrBase<"IntegerType", "integer type attribute"> // defines a type attribute containing an integer type. class TypeAttrBase : Attr()">, CPred<"$_self.cast<::mlir::TypeAttr>().getValue().isa<" # retType # ">()">]>, description> { let storageType = [{ ::mlir::TypeAttr }]; let returnType = retType; let valueType = NoneType; let convertFromStorage = "$_self.getValue().cast<" # retType # ">()"; } def TypeAttr : TypeAttrBase<"::mlir::Type", "any type attribute">; // The mere presence of unit attributes has a meaning. Therefore, unit // attributes are always treated as optional and accessors to them return // "true" if the attribute is present and "false" otherwise. def UnitAttr : Attr()">, "unit attribute"> { let storageType = [{ ::mlir::UnitAttr }]; let constBuilderCall = "$_builder.getUnitAttr()"; let convertFromStorage = "$_self != nullptr"; let returnType = "bool"; let valueType = NoneType; let isOptional = 1; } //===----------------------------------------------------------------------===// // Enum attribute kinds // Additional information for an enum attribute case. class EnumAttrCaseInfo { // The C++ enumerant symbol. string symbol = sym; // The C++ enumerant value. // If less than zero, there will be no explicit discriminator values assigned // to enumerators in the generated enum class. int value = intVal; // The string representation of the enumerant. May be the same as symbol. string str = strVal; } // An enum attribute case stored with StringAttr. class StrEnumAttrCase : EnumAttrCaseInfo, StringBasedAttr< CPred<"$_self.cast<::mlir::StringAttr>().getValue() == \"" # sym # "\"">, "case " # sym>; // An enum attribute case stored with IntegerAttr, which has an integer value, // its representation as a string and a C++ symbol name which may be different. class IntEnumAttrCaseBase : EnumAttrCaseInfo, SignlessIntegerAttrBase { let predicate = CPred<"$_self.cast<::mlir::IntegerAttr>().getInt() == " # intVal>; } // Cases of integer enum attributes with a specific type. By default, the string // representation is the same as the C++ symbol name. class I32EnumAttrCase : IntEnumAttrCaseBase; class I64EnumAttrCase : IntEnumAttrCaseBase; // A bit enum case stored with 32-bit IntegerAttr. `val` here is *not* the // ordinal number of the bit that is set. It is the 32-bit integer with only // one bit set. class BitEnumAttrCase : EnumAttrCaseInfo, SignlessIntegerAttrBase { let predicate = CPred< "$_self.cast<::mlir::IntegerAttr>().getValue().getZExtValue() & " # val # "u">; } // Additional information for an enum attribute. class EnumAttrInfo cases> { // The C++ enum class name string className = name; // List of all accepted cases list enumerants = cases; // The following fields are only used by the EnumsGen backend to generate // an enum class definition and conversion utility functions. // The underlying type for the C++ enum class. An empty string mean the // underlying type is not explicitly specified. string underlyingType = ""; // The name of the utility function that converts a value of the underlying // type to the corresponding symbol. It will have the following signature: // // ```c++ // llvm::Optional<> (); // ``` string underlyingToSymbolFnName = "symbolize" # name; // The name of the utility function that converts a string to the // corresponding symbol. It will have the following signature: // // ```c++ // llvm::Optional<> (llvm::StringRef); // ``` string stringToSymbolFnName = "symbolize" # name; // The name of the utility function that converts a symbol to the // corresponding string. It will have the following signature: // // ```c++ // (); // ``` string symbolToStringFnName = "stringify" # name; string symbolToStringFnRetType = "::llvm::StringRef"; // The name of the utility function that returns the max enum value used // within the enum class. It will have the following signature: // // ```c++ // static constexpr unsigned (); // ``` string maxEnumValFnName = "getMaxEnumValFor" # name; } // An enum attribute backed by StringAttr. // // Op attributes of this kind are stored as StringAttr. Extra verification will // be generated on the string though: only the symbols of the allowed cases are // permitted as the string value. class StrEnumAttr cases> : EnumAttrInfo, StringBasedAttr< And<[StrAttr.predicate, Or]>, !if(!empty(description), "allowed string cases: " # StrJoin.result, description)>; // An enum attribute backed by IntegerAttr. // // Op attributes of this kind are stored as IntegerAttr. Extra verification will // be generated on the integer though: only the values of the allowed cases are // permitted as the integer value. class IntEnumAttr cases> : EnumAttrInfo, SignlessIntegerAttrBase.result, description)> { let predicate = And<[ SignlessIntegerAttrBase.predicate, Or]>; } class I32EnumAttr cases> : IntEnumAttr { let returnType = cppNamespace # "::" # name; let underlyingType = "uint32_t"; let convertFromStorage = "static_cast<" # returnType # ">($_self.getInt())"; let constBuilderCall = "$_builder.getI32IntegerAttr(static_cast($0))"; } class I64EnumAttr cases> : IntEnumAttr { let returnType = cppNamespace # "::" # name; let underlyingType = "uint64_t"; let convertFromStorage = "static_cast<" # returnType # ">($_self.getInt())"; let constBuilderCall = "$_builder.getI64IntegerAttr(static_cast($0))"; } // A bit enum stored with 32-bit IntegerAttr. // // Op attributes of this kind are stored as IntegerAttr. Extra verification will // be generated on the integer to make sure only allowed bit are set. Besides, // helper methods are generated to parse a string separated with a specified // delimiter to a symbol and vice versa. class BitEnumAttr cases> : EnumAttrInfo, SignlessIntegerAttrBase { let predicate = And<[ I32Attr.predicate, // Make sure we don't have unknown bit set. CPred<"!($_self.cast<::mlir::IntegerAttr>().getValue().getZExtValue() & (~(" # StrJoin.result # ")))"> ]>; let returnType = cppNamespace # "::" # name; let underlyingType = "uint32_t"; let convertFromStorage = "static_cast<" # returnType # ">($_self.getInt())"; let constBuilderCall = "$_builder.getI32IntegerAttr(static_cast($0))"; // We need to return a string because we may concatenate symbols for multiple // bits together. let symbolToStringFnRetType = "std::string"; // The delimiter used to separate bit enum cases in strings. string separator = "|"; } //===----------------------------------------------------------------------===// // Composite attribute kinds class DictionaryAttrBase : Attr { let storageType = [{ ::mlir::DictionaryAttr }]; let returnType = [{ ::mlir::DictionaryAttr }]; let valueType = NoneType; let convertFromStorage = "$_self"; } def DictionaryAttr : DictionaryAttrBase()">, "dictionary of named attribute values">; class ElementsAttrBase : Attr { let storageType = [{ ::mlir::ElementsAttr }]; let returnType = [{ ::mlir::ElementsAttr }]; let convertFromStorage = "$_self"; } def ElementsAttr : ElementsAttrBase()">, "constant vector/tensor attribute">; class IntElementsAttrBase : ElementsAttrBase()">, condition]>, description> { let storageType = [{ ::mlir::DenseIntElementsAttr }]; let returnType = [{ ::mlir::DenseIntElementsAttr }]; let convertFromStorage = "$_self"; } def IndexElementsAttr : IntElementsAttrBase() .getType() .getElementType() .isIndex()}]>, "index elements attribute">; class AnyIntElementsAttr : IntElementsAttrBase< CPred<"$_self.cast<::mlir::DenseIntElementsAttr>().getType()." "getElementType().isInteger(" # width # ")">, width # "-bit integer elements attribute">; def AnyI32ElementsAttr : AnyIntElementsAttr<32>; def AnyI64ElementsAttr : AnyIntElementsAttr<64>; class SignlessIntElementsAttr : IntElementsAttrBase< CPred<"$_self.cast<::mlir::DenseIntElementsAttr>().getType()." "getElementType().isSignlessInteger(" # width # ")">, width # "-bit signless integer elements attribute"> { // Note that this is only constructing scalar elements attribute. let constBuilderCall = "::mlir::DenseElementsAttr::get(" "::mlir::RankedTensorType::get({}, " "$_builder.getIntegerType(" # width # ")), " "::llvm::makeArrayRef($0)).cast<::mlir::DenseIntElementsAttr>()"; } def I32ElementsAttr : SignlessIntElementsAttr<32>; def I64ElementsAttr : SignlessIntElementsAttr<64>; // A `width`-bit signless integer elements attribute. The attribute should be // ranked and has a shape as specified in `dims`. class RankedSignlessIntElementsAttr dims> : SignlessIntElementsAttr { // Check that this has the specified shape. let predicate = And<[ SignlessIntElementsAttr.predicate, CPred<"$_self.cast<::mlir::DenseIntElementsAttr>().getType().getShape() == " "::mlir::ArrayRef({" # StrJoinInt.result # "})">]>; let description = width # "-bit signless int elements attribute of shape [" # StrJoinInt.result # "]"; let constBuilderCall = "::mlir::DenseIntElementsAttr::get(" "::mlir::RankedTensorType::get({" # StrJoinInt.result # "}, $_builder.getIntegerType(" # width # ")), ::llvm::makeArrayRef($0))"; } class RankedI32ElementsAttr dims> : RankedSignlessIntElementsAttr<32, dims>; class RankedI64ElementsAttr dims> : RankedSignlessIntElementsAttr<64, dims>; class FloatElementsAttr : ElementsAttrBase< CPred<"$_self.isa<::mlir::DenseFPElementsAttr>() &&" "$_self.cast<::mlir::DenseElementsAttr>().getType()." "getElementType().isF" # width # "()">, width # "-bit float elements attribute"> { let storageType = [{ ::mlir::DenseElementsAttr }]; let returnType = [{ ::mlir::DenseElementsAttr }]; // Note that this is only constructing scalar elements attribute. let constBuilderCall = "::mlir::DenseElementsAttr::get(" "::mlir::RankedTensorType::get({}, $_builder.getF" # width # "Type())," "::llvm::makeArrayRef($0))"; let convertFromStorage = "$_self"; } def F64ElementsAttr : FloatElementsAttr<64>; // A `width`-bit floating point elements attribute. The attribute should be // ranked and has a shape as specified in `dims`. class RankedFloatElementsAttr dims> : ElementsAttrBase< CPred<"$_self.isa<::mlir::DenseFPElementsAttr>() &&" "$_self.cast<::mlir::DenseFPElementsAttr>().getType()." "getElementType().isF" # width # "() && " // Check that this is ranked and has the specified shape. "$_self.cast<::mlir::DenseFPElementsAttr>().getType().hasRank() && " "$_self.cast<::mlir::DenseFPElementsAttr>().getType().getShape() == " "::mlir::ArrayRef({" # StrJoinInt.result # "})">, width # "-bit float elements attribute of shape [" # StrJoinInt.result # "]"> { let storageType = [{ ::mlir::DenseFPElementsAttr }]; let returnType = [{ ::mlir::DenseFPElementsAttr }]; let constBuilderCall = "::mlir::DenseElementsAttr::get(" "::mlir::RankedTensorType::get({" # StrJoinInt.result # "}, $_builder.getF" # width # "Type()), " "::llvm::makeArrayRef($0)).cast<::mlir::DenseFPElementsAttr>()"; let convertFromStorage = "$_self"; } class RankedF32ElementsAttr dims> : RankedFloatElementsAttr<32, dims>; class RankedF64ElementsAttr dims> : RankedFloatElementsAttr<64, dims>; def StringElementsAttr : ElementsAttrBase< CPred<"$_self.isa<::mlir::DenseStringElementsAttr>()" >, "string elements attribute"> { let storageType = [{ ::mlir::DenseElementsAttr }]; let returnType = [{ ::mlir::DenseElementsAttr }]; let convertFromStorage = "$_self"; } // Attributes containing affine maps. def AffineMapAttr : Attr< CPred<"$_self.isa<::mlir::AffineMapAttr>()">, "AffineMap attribute"> { let storageType = [{::mlir::AffineMapAttr }]; let returnType = [{ ::mlir::AffineMap }]; let valueType = Index; let constBuilderCall = "::mlir::AffineMapAttr::get($0)"; } // Base class for array attributes. class ArrayAttrBase : Attr { let storageType = [{ ::mlir::ArrayAttr }]; let returnType = [{ ::mlir::ArrayAttr }]; let valueType = NoneType; let convertFromStorage = "$_self"; } def ArrayAttr : ArrayAttrBase()">, "array attribute">; // Base class for array attributes whose elements are of the same kind. // `element` specifies the element attribute kind stored in this array. class TypedArrayAttrBase: ArrayAttrBase< And<[ // Guarantee this is an ArrayAttr first CPred<"$_self.isa<::mlir::ArrayAttr>()">, // Guarantee all elements satisfy the constraints from `element` Concat<"::llvm::all_of($_self.cast<::mlir::ArrayAttr>(), " "[](::mlir::Attribute attr) { return ", SubstLeaves<"$_self", "attr", element.predicate>, "; })">]>, description> { let constBuilderCall = "$_builder.getArrayAttr($0)"; Attr elementAttr = element; } def AffineMapArrayAttr : TypedArrayAttrBase { let constBuilderCall = "$_builder.getAffineMapArrayAttr($0)"; } def BoolArrayAttr : TypedArrayAttrBase { let constBuilderCall = "$_builder.getBoolArrayAttr($0)"; } def I32ArrayAttr : TypedArrayAttrBase { let constBuilderCall = "$_builder.getI32ArrayAttr($0)"; } def I64ArrayAttr : TypedArrayAttrBase { let constBuilderCall = "$_builder.getI64ArrayAttr($0)"; } def F32ArrayAttr : TypedArrayAttrBase { let constBuilderCall = "$_builder.getF32ArrayAttr($0)"; } def F64ArrayAttr : TypedArrayAttrBase { let constBuilderCall = "$_builder.getF64ArrayAttr($0)"; } def StrArrayAttr : TypedArrayAttrBase { let constBuilderCall = "$_builder.getStrArrayAttr($0)"; } def TypeArrayAttr : TypedArrayAttrBase { let constBuilderCall = "$_builder.getTypeArrayAttr($0)"; } // Attribute information for an Attribute field within a StructAttr. class StructFieldAttr { // Name of this field in the StructAttr. string name = thisName; // Attribute type wrapped by the struct attr. Attr type = thisType; } // Structured attribute that wraps a DictionaryAttr and provides both a // validation method and set of accessors for a fixed set of fields. This is // useful when representing data that would normally be in a structure. class StructAttr attributes> : DictionaryAttrBase()">, "DictionaryAttr with field(s): " # StrJoin.result # " (each field having its own constraints)"> { // Name for this StructAttr. string className = name; // Return type should match the name of the structure. let returnType = d.cppNamespace # "::" # name; // Storage type should match the name of the structure. let storageType = d.cppNamespace # "::" # name; // The dialect this StructAttr belongs to. Dialect dialect = d; let cppNamespace = d.cppNamespace; // List of fields that the StructAttr contains. list fields = attributes; } // Attributes containing symbol references. def SymbolRefAttr : Attr()">, "symbol reference attribute"> { let storageType = [{ ::mlir::SymbolRefAttr }]; let returnType = [{ ::mlir::SymbolRefAttr }]; let valueType = NoneType; let constBuilderCall = "$_builder.getSymbolRefAttr($0)"; let convertFromStorage = "$_self"; } def FlatSymbolRefAttr : Attr()">, "flat symbol reference attribute"> { let storageType = [{ ::mlir::FlatSymbolRefAttr }]; let returnType = [{ ::llvm::StringRef }]; let valueType = NoneType; let constBuilderCall = "$_builder.getSymbolRefAttr($0)"; let convertFromStorage = "$_self.getValue()"; } def SymbolRefArrayAttr : TypedArrayAttrBase { let constBuilderCall = ?; } //===----------------------------------------------------------------------===// // Derive attribute kinds // DerivedAttr are attributes whose value is computed from properties // of the operation. They do not require additional storage and are // materialized as needed. // Note: All derived attributes should be materializable as an Attribute. E.g., // do not use DerivedAttr for things that could not have been stored as // Attribute. // class DerivedAttr : Attr, "derived attribute"> { let returnType = ret; code body = b; // Specify how to convert from the derived attribute to an attribute. // // ## Special placeholders // // Special placeholders can be used to refer to entities during conversion: // // * `$_builder` will be replaced by a mlir::Builder instance. // * `$_ctx` will be replaced by the MLIRContext* instance. // * `$_self` will be replaced with the derived attribute (value produces // `returnType`). let convertFromStorage = convert; } // Derived attribute that returns a mlir::Type. class DerivedTypeAttr : DerivedAttr<"Type", body> { let convertFromStorage = "::mlir::TypeAttr::get($_self)"; } //===----------------------------------------------------------------------===// // Constant attribute kinds // Represents a constant attribute of specific Attr type. A constant // attribute can be specified only of attributes that have a constant // builder call defined. The constant value is specified as a string. // // If used as a constraint, it generates a matcher on a constant attribute by // using the constant value builder of the attribute and the value. class ConstantAttr : AttrConstraint< CPred<"$_self == " # !subst("$0", val, attribute.constBuilderCall)>, "constant attribute " # val> { Attr attr = attribute; string value = val; } class ConstF32Attr : ConstantAttr; def ConstBoolAttrFalse : ConstantAttr; def ConstBoolAttrTrue : ConstantAttr; def ConstUnitAttr : ConstantAttr; //===----------------------------------------------------------------------===// // Common attribute constraints //===----------------------------------------------------------------------===// // A general mechanism to further confine the given `attr` with all the // `constraints`. This allows to compose complex constraints out of a series // of more primitive ones. class Confined constraints> : Attr< And, !foldl(/*init*/attr.description, /*list*/constraints, prev, cur, prev # " " # cur.description)> { let storageType = attr.storageType; let returnType = attr.returnType; let convertFromStorage = attr.convertFromStorage; let constBuilderCall = attr.constBuilderCall; let defaultValue = attr.defaultValue; let valueType = attr.valueType; let isOptional = attr.isOptional; let baseAttr = attr; } // An AttrConstraint that holds if all attr constraints specified in // 'constraints' hold. class AllAttrConstraintsOf constraints> : AttrConstraint< And, !foldl(/*init*/!head(constraints).description, /*list*/!tail(constraints), prev, cur, prev # " and " # cur.description)> { } class IntMinValue : AttrConstraint< CPred<"$_self.cast<::mlir::IntegerAttr>().getInt() >= " # n>, "whose minimum value is " # n>; class IntMaxValue : AttrConstraint< CPred<"$_self.cast<::mlir::IntegerAttr>().getInt() <= " # n>, "whose maximum value is " # n>; def IntNonNegative : AttrConstraint< CPred<"!$_self.cast<::mlir::IntegerAttr>().getValue().isNegative()">, "whose value is non-negative">; def IntPositive : AttrConstraint< CPred<"$_self.cast().getValue().isStrictlyPositive()">, "whose value is positive">; class ArrayMinCount : AttrConstraint< CPred<"$_self.cast<::mlir::ArrayAttr>().size() >= " # n>, "with at least " # n # " elements">; class ArrayCount : AttrConstraint< CPred<"$_self.cast<::mlir::ArrayAttr>().size() == " #n>, "with exactly " # n # " elements">; class IntArrayNthElemEq : AttrConstraint< And<[ CPred<"$_self.cast<::mlir::ArrayAttr>().size() > " # index>, CPred<"$_self.cast<::mlir::ArrayAttr>()[" # index # "]" ".cast<::mlir::IntegerAttr>().getInt() == " # value> ]>, "whose " # index # "-th element must be " # value>; class IntArrayNthElemMinValue : AttrConstraint< And<[ CPred<"$_self.cast<::mlir::ArrayAttr>().size() > " # index>, CPred<"$_self.cast<::mlir::ArrayAttr>()[" # index # "]" ".cast<::mlir::IntegerAttr>().getInt() >= " # min> ]>, "whose " # index # "-th element must be at least " # min>; def IsNullAttr : AttrConstraint< CPred<"!$_self">, "empty attribute (for optional attributes)">; // An attribute constraint on FlatSymbolRefAttr that requires that the // reference point to an op of `opClass` within the closest parent with a symbol // table. // TODO: Add support for nested symbol references. class ReferToOp : AttrConstraint< CPred<"isa_and_nonnull<" # opClass # ">(" "::mlir::SymbolTable::lookupNearestSymbolFrom(" "&$_op, $_self.cast<::mlir::FlatSymbolRefAttr>().getValue()))">, "referencing to a '" # opClass # "' symbol">; //===----------------------------------------------------------------------===// // Region definitions //===----------------------------------------------------------------------===// class Region : RegionConstraint; // Any region. def AnyRegion : Region, "any region">; // A region with the given number of blocks. class SizedRegion : Region< CPred<"::llvm::hasNItems($_self, " # numBlocks # ")">, "region with " # numBlocks # " blocks">; // A variadic region constraint. It expands to zero or more of the base region. class VariadicRegion : Region; //===----------------------------------------------------------------------===// // Successor definitions //===----------------------------------------------------------------------===// class Successor : SuccessorConstraint; // Any successor. def AnySuccessor : Successor; // A variadic successor constraint. It expands to zero or more of the base // successor. class VariadicSuccessor : Successor; //===----------------------------------------------------------------------===// // OpTrait definitions //===----------------------------------------------------------------------===// // OpTrait represents a trait regarding an op. class OpTrait; // NativeOpTrait corresponds to the MLIR C++ OpTrait mechanism. The // purpose to wrap around C++ symbol string with this class is to make // traits specified for ops in TableGen less alien and more integrated. class NativeOpTrait : OpTrait { string trait = "::mlir::OpTrait::" # prop; } // ParamNativeOpTrait corresponds to the template-parameterized traits in the // C++ implementation. MLIR uses nested class templates to implement such // traits leading to constructs of the form "TraitName::Impl". Use // the value in `prop` as the trait name and the value in `params` as // parameters to construct the native trait class name. class ParamNativeOpTrait : NativeOpTrait::Impl">; // GenInternalOpTrait is an op trait that does not have direct C++ mapping but // affects op definition generator internals, like how op builders and // operand/attribute/result getters are generated. class GenInternalOpTrait : OpTrait { string trait = "::mlir::OpTrait::" # prop; } // PredOpTrait is an op trait implemented by way of a predicate on the op. class PredOpTrait : OpTrait { string description = descr; Pred predicate = pred; } // Op defines an affine scope. def AffineScope : NativeOpTrait<"AffineScope">; // Op defines an automatic allocation scope. def AutomaticAllocationScope : NativeOpTrait<"AutomaticAllocationScope">; // Op supports operand broadcast behavior. def ResultsBroadcastableShape : NativeOpTrait<"ResultsBroadcastableShape">; // X op Y == Y op X def Commutative : NativeOpTrait<"IsCommutative">; // op op X == X def Involution : NativeOpTrait<"IsInvolution">; // Op behaves like a constant. def ConstantLike : NativeOpTrait<"ConstantLike">; // Op behaves like a function. def FunctionLike : NativeOpTrait<"FunctionLike">; // Op is isolated from above. def IsolatedFromAbove : NativeOpTrait<"IsIsolatedFromAbove">; // Op results are float or vectors/tensors thereof. def ResultsAreFloatLike : NativeOpTrait<"ResultsAreFloatLike">; // Op has the same operand type. def SameTypeOperands : NativeOpTrait<"SameTypeOperands">; // Op has same shape for all operands. def SameOperandsShape : NativeOpTrait<"SameOperandsShape">; // Op has same operand and result shape. def SameOperandsAndResultShape : NativeOpTrait<"SameOperandsAndResultShape">; // Op has the same operand and result type. def SameOperandsAndResultType : NativeOpTrait<"SameOperandsAndResultType">; // Op has the same element type (or type itself, if scalar) for all operands. def SameOperandsElementType : NativeOpTrait<"SameOperandsElementType">; // Op has the same operand and result element type (or type itself, if scalar). def SameOperandsAndResultElementType : NativeOpTrait<"SameOperandsAndResultElementType">; // Op is a terminator. def Terminator : NativeOpTrait<"IsTerminator">; // Op can be safely normalized in the presence of MemRefs with // non-identity maps. def MemRefsNormalizable : NativeOpTrait<"MemRefsNormalizable">; // Op's regions have a single block with the specified terminator. class SingleBlockImplicitTerminator : ParamNativeOpTrait<"SingleBlockImplicitTerminator", op>; // Op's parent operation is the provided one. class HasParent : ParamNativeOpTrait<"HasParent", op>; class ParentOneOf ops> : ParamNativeOpTrait<"HasParent", StrJoin.result>; // Op result type is derived from the first attribute. If the attribute is an // subclass of `TypeAttrBase`, its value is used, otherwise, the type of the // attribute content is used. def FirstAttrDerivedResultType : GenInternalOpTrait<"FirstAttrDerivedResultType">; // TODO: Turn the following into normal traits and generate verification for // them. // All variadic operands of the op have the same number of values. // A variadic operand contains an array of values whose array size is only // known at runtime. This trait requires all variadic operands of an op // to have the same array size. def SameVariadicOperandSize : GenInternalOpTrait<"SameVariadicOperandSize">; // All variadic results of the op have the same number of values. // A variadic result contains an array of values whose array size is only // known at runtime. This trait requires all variadic results of an op // to have the same array size. def SameVariadicResultSize : GenInternalOpTrait<"SameVariadicResultSize">; // Uses an attribute named `operand_segment_sizes` to specify how many actual // operand each ODS-declared operand (variadic or not) corresponds to. // This trait is used for ops that have multiple variadic operands but do // not know statically their size relationship. The attribute must be a 1D // vector that has the same number of elements as the number of ODS declared // operands. That means even if some operands are non-variadic, the attribute // still need to have an element for its size, which is always 1. def AttrSizedOperandSegments : NativeOpTrait<"AttrSizedOperandSegments">; // Similar to AttrSizedOperandSegments, but used for results. The attribute // should be named as `result_segment_sizes`. def AttrSizedResultSegments : NativeOpTrait<"AttrSizedResultSegments">; // Op attached regions have no arguments def NoRegionArguments : NativeOpTrait<"NoRegionArguments">; //===----------------------------------------------------------------------===// // OpInterface definitions //===----------------------------------------------------------------------===// // Marker used to identify the argument list for an op or interface method. def ins; // OpInterfaceTrait corresponds to a specific 'OpInterface' class defined in // C++. The purpose to wrap around C++ symbol string with this class is to make // interfaces specified for ops in TableGen less alien and more integrated. class OpInterfaceTrait : NativeOpTrait<""> { let trait = name # "::Trait"; // Specify the body of the verification function. `$_op` will be replaced with // the operation being verified. code verify = verifyBody; // An optional code block containing extra declarations to place in the // interface trait declaration. code extraTraitClassDeclaration = ""; } // This class represents a single, optionally static, interface method. // Note: non-static interface methods have an implicit parameter, either // $_op/$_attr/$_type corresponding to an instance of the derived value. class InterfaceMethod { // A human-readable description of what this method does. string description = desc; // The name of the interface method. string name = methodName; // The c++ type-name of the return type. string returnType = retTy; // A dag of string that correspond to the arguments of the method. dag arguments = args; // An optional body to the method. code body = methodBody; // An optional default implementation of the method. code defaultBody = defaultImplementation; } // This class represents a single static interface method. class StaticInterfaceMethod : InterfaceMethod; // Interface represents a base interface. class Interface { // A human-readable description of what this interface does. string description = ""; // The name given to the c++ interface class. string cppClassName = name; // The C++ namespace that this interface should be placed into. // // To specify nested namespaces, use "::" as the delimiter, e.g., given // "A::B", ops will be placed in `namespace A { namespace B { } }`. string cppNamespace = ""; // The list of methods defined by this interface. list methods = []; // An optional code block containing extra declarations to place in the // interface declaration. code extraClassDeclaration = ""; } // AttrInterface represents an interface registered to an attribute. class AttrInterface : Interface { // An optional code block containing extra declarations to place in the // interface trait declaration. code extraTraitClassDeclaration = ""; } // OpInterface represents an interface registered to an operation. class OpInterface : Interface, OpInterfaceTrait; // TypeInterface represents an interface registered to a type. class TypeInterface : Interface { // An optional code block containing extra declarations to place in the // interface trait declaration. code extraTraitClassDeclaration = ""; } // Whether to declare the op interface methods in the op's header. This class // simply wraps an OpInterface but is used to indicate that the method // declarations should be generated. This class takes an optional set of methods // that should have declarations generated even if the method has a default // implementation. class DeclareOpInterfaceMethods overridenMethods = []> : OpInterface { let description = interface.description; let cppClassName = interface.cppClassName; let cppNamespace = interface.cppNamespace; let methods = interface.methods; // This field contains a set of method names that should always have their // declarations generated. This allows for generating declarations for // methods with default implementations that need to be overridden. list alwaysOverriddenMethods = overridenMethods; } //===----------------------------------------------------------------------===// // Op definitions //===----------------------------------------------------------------------===// // Marker used to identify the result list for an op. def outs; // Marker used to identify the region list for an op. def region; // Marker used to identify the successor list for an op. def successor; // Class for defining a custom builder. // // TableGen generates several generic builders for each op by default (see // comment in the `Op` class). If the default generated ones cannot cover // some use case, custom builders can be defined using instances of this class. // // The signature of the builder is always // // ```c++ // static void build(OpBuilder &builder, OperationState &state, // ...) { // ... // } // ``` // // To define a custom builder, the parameter list (*excluding* the `Builder // *builder, OperationState &state` part) and body should be passed in // as separate template arguments to this class. This is because we generate // op declaration and definition into separate files. If an empty string is // passed in for `body`, then *only* the builder declaration will be // generated; this provides a way to define complicated builders entirely // in C++. class OpBuilder { string params = p; code body = b; } // A base decorator class that may optionally be added to OpVariables. class OpVariableDecorator; // Class for providing additional information on the variables, i.e. arguments // and results, of an operation. class OpVariable varDecorators = []> { // The constraint, either attribute or type, of the argument. Constraint constraint = varConstraint; // A description for the argument. string description = desc; // The list of decorators for this variable, e.g. side effects. list decorators = varDecorators; } class Arg decorators = []> : OpVariable; class Res decorators = []> : OpVariable; // Base class for all ops. class Op props = []> { // The dialect of the op. Dialect opDialect = dialect; // The mnemonic of the op. string opName = mnemonic; // One-line human-readable description of what the op does. string summary = ""; // Additional, longer human-readable description of what the op does. string description = ""; // Dag containing the arguments of the op. Default to 0 arguments. dag arguments = (ins); // The list of results of the op. Default to 0 results. dag results = (outs); // The list of regions of the op. Default to 0 regions. dag regions = (region); // The list of successors of the op. Default to 0 successors. dag successors = (successor); // Attribute getters can be added to the op by adding an Attr member // with the name and type of the attribute. E.g., adding int attribute // with name "value" and type "i32": // I32Attr value; // Define the hooks used for building, parsing, printing, verification. // Custom builder. // In addition to the custom builder provided here, and unless // skipDefaultBuilders is set, two default builders are generated, with the // following signatures: // // ```c++ // static void build(OpBuilder &, OperationState &odsState, // Type , Type , ..., // Value , Value , ..., // Attribute , Attribute , ...); // ``` // * where the attributes follow the same declaration order as in the op. // // ```c++ // static void build(OpBuilder &, OperationState &odsState, // TypeRange resultTypes, // ValueRange operands, // ArrayRef attributes); // ``` list builders = ?; // Avoid generating default build functions. Custom builders must be // provided. bit skipDefaultBuilders = 0; // Custom parser. code parser = ?; // Custom printer. code printer = ?; // Custom assembly format. string assemblyFormat = ?; // Custom verifier. code verifier = ?; // Whether this op has associated canonicalization patterns. // TODO: figure out a better way to write canonicalization patterns in // TableGen rules directly instead of using this marker and C++ // implementations. bit hasCanonicalizer = 0; // Whether this op has a folder. bit hasFolder = 0; // Op traits. // Note: The list of traits will be uniqued by ODS. list traits = props; // Additional code that will be added to the public part of the generated // C++ code of the op declaration. code extraClassDeclaration = ?; } // The arguments of an op. class Arguments { dag arguments = args; } // The results of an op. class Results { dag results = rets; } //===----------------------------------------------------------------------===// // Common value constraints //===----------------------------------------------------------------------===// def HasNoUseOf: Constraint< CPred<"$_self.use_empty()">, "has no use">; //===----------------------------------------------------------------------===// // Common op type constraints //===----------------------------------------------------------------------===// // These traits are for verifying properties of an op that require knowledge of // multiple arguments or results. For verifying properties of a single argument // or result, prefer operand type constraints. // These traits often require including "mlir/IR/TypeUtilities.h". // TODO: Improve the autogenerated error messages. class Rank : StrFunc<"$" # name # ".getType().cast<::mlir::ShapedType>().getRank()">; class Shape : StrFunc<"$" # name # ".getType().cast<::mlir::ShapedType>().getShape()">; class ElementCount : StrFunc<"$" # name # ".getType().cast<::mlir::ShapedType>()" ".getNumElements()">; class ElementType : StrFunc<"getElementTypeOrSelf($" # name # ")">; class AllMatchPred values> : CPred<"::llvm::is_splat(::llvm::makeArrayRef({" # StrJoin.result #"}))">; class AllMatch values, string description> : PredOpTrait>; // TODO: Only works for non-variadic. class AllMatchSameOperatorPred names, string operator> : AllMatchPred; class AllMatchSameOperatorTrait names, string operator, string description> : PredOpTrait< "all of {" # StrJoin.result # "} have same " # description, AllMatchSameOperatorPred> { list values = names; } class AllElementCountsMatch names> : AllMatchSameOperatorTrait.result, "element count">; class AllElementTypesMatch names> : AllMatchSameOperatorTrait.result, "element type">; class AllRanksMatch names> : AllMatchSameOperatorTrait.result, "rank">; class AllShapesMatch names> : AllMatchSameOperatorTrait.result, "shape">; class AllTypesMatch names> : AllMatchSameOperatorTrait; // A type constraint that denotes `transform(lhs.getType()) == rhs.getType()`. class TypesMatchWith : PredOpTrait> { string lhs = lhsArg; string rhs = rhsArg; string transformer = transform; } // Type Constraint operand `idx`'s Element type is `type`. class TCopVTEtIs : And<[ CPred<"$_op.getNumOperands() > " # idx>, SubstLeaves<"$_self", "$_op.getOperand(" # idx # ").getType()", IsShapedTypePred>, SubstLeaves<"$_self", "getElementTypeOrSelf($_op.getOperand(" # idx # "))", type.predicate>]>; // Predicate to verify that a named argument or result's element type matches a // given type. class TypeIsPred : SubstLeaves<"$_self", "$" # name # ".getType()", type.predicate>; class TypeIs : PredOpTrait< "'" # name # "' is " # type.description, TypeIsPred>; // Predicate to verify that a named argument or result's element type matches a // given type. class ElementTypeIsPred : And<[ SubstLeaves<"$_self", "$" # name # ".getType()", IsShapedTypePred>, SubstLeaves<"$_self", "getElementTypeOrSelf($" # name # ")", type.predicate>]>; class ElementTypeIs : PredOpTrait< "'" # name # "' is " # type.description, ElementTypeIsPred>; // Predicate to verify that the i'th operand and the j'th operand have the same // elemental type. // Type Constraint operand `i`'s Element type is Same As operand `j`'s Element // type. class TCopVTEtIsSameAs : And<[ CPred<"$_op.getNumOperands() > " # !if(!gt(i,j),i,j)>, SubstLeaves<"$_self", "$_op.getOperand(" # i # ").getType()", IsShapedTypePred>, SubstLeaves<"$_self", "$_op.getOperand(" # j # ").getType()", IsShapedTypePred>, CPred<"::mlir::getElementTypeOrSelf($_op.getOperand(" # i # ")) == " "::mlir::getElementTypeOrSelf($_op.getOperand(" # j # "))">]>; // Predicate to verify that the i'th result and the j'th operand exist and has // shaped types. class TCOpResIsShapedTypePred : And<[ CPred<"$_op.getNumResults() > " # i>, CPred<"$_op.getNumOperands() > " # j>, SubstLeaves<"$_self", "$_op.getResult(" # i # ").getType()", IsShapedTypePred>, SubstLeaves<"$_self", "$_op.getOperand(" # j # ").getType()", IsShapedTypePred>]>; // Predicate to verify that the i'th result and the j'th operand have the same // type. class TCresIsSameAsOpBase : CPred<"$_op.getResult(" # i # ").getType() == " "$_op.getOperand(" # j # ").getType()">; // Basic Predicate to verify that the i'th result and the j'th operand have the // same elemental type. class TCresVTEtIsSameAsOpBase : CPred<"getElementTypeOrSelf($_op.getResult(" # i # ")) == " "getElementTypeOrSelf($_op.getOperand(" # j # "))">; // Predicate to verify that the i'th result and the j'th operand have the same // elemental type. // Type Constraint result`i`'s Element type is Same As Operand `j`'s Element // type. class TCresVTEtIsSameAsOp : And<[ TCOpResIsShapedTypePred, TCresVTEtIsSameAsOpBase]>; // Predicate to verify that the opId'th operand can be broadcasted to the type // of the resId'th result. class TCOpIsBroadcastableToRes : And<[ TCOpResIsShapedTypePred, CPred<"::mlir::OpTrait::util::getBroadcastedType(" "$_op.getOperand(" # opId # ").getType(), " "$_op.getResult(" # resId # ").getType())">]>; // Predicate to verify that all the operands at the given `indices` // have the same element type. // Type Constraint operands' Element type are all Same At the given `indices`. // We query the operands' types into a list and check they are all the same. // Precondition: // 1) all operands involved are of shaped type and // 2) the indices are not out of range. class TCopVTEtAreSameAt indices> : CPred< "::llvm::is_splat(::llvm::map_range(" "::mlir::ArrayRef({" # StrJoinInt.result # "}), " "[this](unsigned i) { return getElementTypeOrSelf(this->getOperand(i)); " "}))">; //===----------------------------------------------------------------------===// // Pattern definitions //===----------------------------------------------------------------------===// // Marker used to identify the delta value added to the default benefit value. def addBenefit; // Base class for op+ -> op+ rewrite rules. These allow declaratively // specifying rewrite rules. // // A rewrite rule contains two components: a source pattern and one or more // result patterns. Each pattern is specified as a (recursive) DAG node (tree) // in the form of `(node arg0, arg1, ...)`. // // The `node` are normally MLIR ops, but it can also be one of the directives // listed later in this section. // // ## Symbol binding // // In the source pattern, `argN` can be used to specify matchers (e.g., using // type/attribute type constraints, etc.) and bound to a name for later use. // We can also bound names to op instances to reference them later in // multi-entity constraints. // // In the result pattern, `argN` can be used to refer to a previously bound // name, with potential transformations (e.g., using tAttr, etc.). `argN` can // itself be nested DAG node. We can also bound names to ops to reference // them later in other result patterns. // // For example, // // ``` // def : Pattern<(OneResultOp1:$op1 $arg0, $arg1), // [(OneResultOp2:$op2 $arg0, $arg1), // (OneResultOp3 $op2 (OneResultOp4))], // [(HasStaticShapePred $op1)]>; // ``` // // `$argN` is bound to the `OneResultOp1`'s N-th argument and used later to // build `OneResultOp2`. `$op1` is bound to `OneResultOp1` and used to // check whether the result's shape is static. `$op2` is bound to // `OneResultOp2` and used to build `OneResultOp3`. // // ## Multi-result op // // To create multi-result ops in result pattern, you can use a syntax similar // to uni-result op, and it will act as a value pack for all results: // // ``` // def : Pattern<(ThreeResultOp ...), // [(TwoResultOp ...), (OneResultOp ...)]>; // ``` // // Then `TwoResultOp` will replace the first two values of `ThreeResultOp`. // // You can also use `$__N` to explicitly access the N-th result. // ``` // def : Pattern<(FiveResultOp ...), // [(TwoResultOp1:$res1__1 ...), (replaceWithValue $res1__0), // (TwoResultOp2:$res2 ...), (replaceWithValue $res2__1)]>; // ``` // // Then the values generated by `FiveResultOp` will be replaced by // // * `FiveResultOp`#0: `TwoResultOp1`#1 // * `FiveResultOp`#1: `TwoResultOp1`#0 // * `FiveResultOp`#2: `TwoResultOp2`#0 // * `FiveResultOp`#3: `TwoResultOp2`#1 // * `FiveResultOp`#4: `TwoResultOp2`#1 class Pattern results, list preds = [], dag benefitAdded = (addBenefit 0)> { dag sourcePattern = source; // Result patterns. Each result pattern is expected to replace one result // of the root op in the source pattern. In the case of more result patterns // than needed to replace the source op, only the last N results generated // by the last N result pattern is used to replace a N-result source op. // So that the beginning result patterns can be used to generate additional // ops to aid building the results used for replacement. list resultPatterns = results; // Multi-entity constraints. Each constraint here involves multiple entities // matched in source pattern and places further constraints on them as a // whole. list constraints = preds; // The delta value added to the default benefit value. The default value is // the number of ops in the source pattern. The rule with the highest final // benefit value will be applied first if there are multiple rules matches. // This delta value can be either positive or negative. dag benefitDelta = benefitAdded; } // Form of a pattern which produces a single result. class Pat preds = [], dag benefitAdded = (addBenefit 0)> : Pattern; // Native code call wrapper. This allows invoking an arbitrary C++ expression // to create an op operand/attribute or replace an op result. // // ## Placeholders // // If used as a DAG leaf, i.e., `(... NativeCodeCall<"...">:$arg, ...)`, // the wrapped expression can take special placeholders listed below: // // * `$_builder` will be replaced by the current `mlir::PatternRewriter`. // * `$_self` will be replaced with the entity this transformer is attached to. // E.g., with the definition `def transform : NativeCodeCall<"$_self...">`, // `$_self` in `transform:$attr` will be replaced by the value for `$attr`. // // If used as a DAG node, i.e., `(NativeCodeCall<"..."> , ..., )`, // then positional placeholders are also supported; placeholder `$N` in the // wrapped C++ expression will be replaced by ``. class NativeCodeCall { string expression = expr; } +def ConstantLikeMatcher : NativeCodeCall<"success(matchPattern($0->getResult(0), m_Constant(&$1)))">; + //===----------------------------------------------------------------------===// // Rewrite directives //===----------------------------------------------------------------------===// // Directive used in result pattern to specify the location of the generated // op. This directive must be used as the last argument to the op creation // DAG construct. The arguments to location must be previously captured symbol. def location; // Directive used in result pattern to indicate that no new op are generated, // so to replace the matched DAG with an existing SSA value. def replaceWithValue; //===----------------------------------------------------------------------===// // Data type generation //===----------------------------------------------------------------------===// // Define a new type belonging to a dialect and called 'name'. class TypeDef { Dialect dialect = owningdialect; string cppClassName = name # "Type"; // Short summary of the type. string summary = ?; // The longer description of this type. string description = ?; // Name of storage class to generate or use. string storageClass = name # "TypeStorage"; // Namespace (withing dialect c++ namespace) in which the storage class // resides. string storageNamespace = "detail"; // Specify if the storage class is to be generated. bit genStorageClass = 1; // Specify that the generated storage class has a constructor which is written // in C++. bit hasStorageCustomConstructor = 0; // The list of parameters for this type. Parameters will become both // parameters to the get() method and storage class member variables. // // The format of this dag is: // (ins // "":$param1Name, // "":$param2Name, // TypeParameter<"c++ type", "param description">:$param3Name) // TypeParameters (or more likely one of their subclasses) are required to add // more information about the parameter, specifically: // - Documentation // - Code to allocate the parameter (if allocation is needed in the storage // class constructor) // // For example: // (ins // "int":$width, // ArrayRefParameter<"bool", "list of bools">:$yesNoArray) // // (ArrayRefParameter is a subclass of TypeParameter which has allocation code // for re-allocating ArrayRefs. It is defined below.) dag parameters = (ins); // Use the lowercased name as the keyword for parsing/printing. Specify only // if you want tblgen to generate declarations and/or definitions of // printer/parser for this type. string mnemonic = ?; // If 'mnemonic' specified, // If null, generate just the declarations. // If a non-empty code block, just use that code as the definition code. // Error if an empty code block. code printer = ?; code parser = ?; // If set, generate accessors for each Type parameter. bit genAccessors = 1; // Generate the verifyConstructionInvariants declaration and getChecked // method. bit genVerifyInvariantsDecl = 0; // Extra code to include in the class declaration. code extraClassDeclaration = [{}]; } // 'Parameters' should be subclasses of this or simple strings (which is a // shorthand for TypeParameter<"C++Type">). class TypeParameter { // Custom memory allocation code for storage constructor. code allocator = ?; // The C++ type of this parameter. string cppType = type; // A description of this parameter. string description = desc; // The format string for the asm syntax (documentation only). string syntax = ?; } // For StringRefs, which require allocation. class StringRefParameter : TypeParameter<"::llvm::StringRef", desc> { let allocator = [{$_dst = $_allocator.copyInto($_self);}]; } // For standard ArrayRefs, which require allocation. class ArrayRefParameter : TypeParameter<"::llvm::ArrayRef<" # arrayOf # ">", desc> { let allocator = [{$_dst = $_allocator.copyInto($_self);}]; } // For classes which require allocation and have their own allocateInto method. class SelfAllocationParameter : TypeParameter { let allocator = [{$_dst = $_self.allocateInto($_allocator);}]; } // For ArrayRefs which contain things which allocate themselves. class ArrayRefOfSelfAllocationParameter : TypeParameter<"::llvm::ArrayRef<" # arrayOf # ">", desc> { let allocator = [{ llvm::SmallVector<}] # arrayOf # [{, 4> tmpFields; for (size_t i = 0, e = $_self.size(); i < e; ++i) tmpFields.push_back($_self[i].allocateInto($_allocator)); $_dst = $_allocator.copyInto(ArrayRef<}] # arrayOf # [{>(tmpFields)); }]; } #endif // OP_BASE diff --git a/mlir/include/mlir/TableGen/Pattern.h b/mlir/include/mlir/TableGen/Pattern.h index 4fc2ae762a66..98c5d9b18f5d 100644 --- a/mlir/include/mlir/TableGen/Pattern.h +++ b/mlir/include/mlir/TableGen/Pattern.h @@ -1,441 +1,451 @@ //===- Pattern.h - Pattern wrapper class ------------------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // Pattern wrapper class to simplify using TableGen Record defining a MLIR // Pattern. // //===----------------------------------------------------------------------===// #ifndef MLIR_TABLEGEN_PATTERN_H_ #define MLIR_TABLEGEN_PATTERN_H_ #include "mlir/Support/LLVM.h" #include "mlir/TableGen/Argument.h" #include "mlir/TableGen/Operator.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringSet.h" #include namespace llvm { class DagInit; class Init; class Record; } // end namespace llvm namespace mlir { namespace tblgen { // Mapping from TableGen Record to Operator wrapper object. // // We allocate each wrapper object in heap to make sure the pointer to it is // valid throughout the lifetime of this map. This is important because this map // is shared among multiple patterns to avoid creating the wrapper object for // the same op again and again. But this map will continuously grow. using RecordOperatorMap = DenseMap>; class Pattern; // Wrapper class providing helper methods for accessing TableGen DAG leaves // used inside Patterns. This class is lightweight and designed to be used like // values. // // A TableGen DAG construct is of the syntax // `(operator, arg0, arg1, ...)`. // // This class provides getters to retrieve `arg*` as tblgen:: wrapper objects // for handy helper methods. It only works on `arg*`s that are not nested DAG // constructs. class DagLeaf { public: explicit DagLeaf(const llvm::Init *def) : def(def) {} // Returns true if this DAG leaf is not specified in the pattern. That is, it // places no further constraints/transforms and just carries over the original // value. bool isUnspecified() const; // Returns true if this DAG leaf is matching an operand. That is, it specifies // a type constraint. bool isOperandMatcher() const; // Returns true if this DAG leaf is matching an attribute. That is, it // specifies an attribute constraint. bool isAttrMatcher() const; // Returns true if this DAG leaf is wrapping native code call. bool isNativeCodeCall() const; // Returns true if this DAG leaf is specifying a constant attribute. bool isConstantAttr() const; // Returns true if this DAG leaf is specifying an enum attribute case. bool isEnumAttrCase() const; // Returns true if this DAG leaf is specifying a string attribute. bool isStringAttr() const; // Returns this DAG leaf as a constraint. Asserts if fails. Constraint getAsConstraint() const; // Returns this DAG leaf as an constant attribute. Asserts if fails. ConstantAttr getAsConstantAttr() const; // Returns this DAG leaf as an enum attribute case. // Precondition: isEnumAttrCase() EnumAttrCase getAsEnumAttrCase() const; // Returns the matching condition template inside this DAG leaf. Assumes the // leaf is an operand/attribute matcher and asserts otherwise. std::string getConditionTemplate() const; // Returns the native code call template inside this DAG leaf. // Precondition: isNativeCodeCall() StringRef getNativeCodeTemplate() const; // Returns the string associated with the leaf. // Precondition: isStringAttr() std::string getStringAttr() const; void print(raw_ostream &os) const; private: // Returns true if the TableGen Init `def` in this DagLeaf is a DefInit and // also a subclass of the given `superclass`. bool isSubClassOf(StringRef superclass) const; const llvm::Init *def; }; // Wrapper class providing helper methods for accessing TableGen DAG constructs // used inside Patterns. This class is lightweight and designed to be used like // values. // // A TableGen DAG construct is of the syntax // `(operator, arg0, arg1, ...)`. // // When used inside Patterns, `operator` corresponds to some dialect op, or // a known list of verbs that defines special transformation actions. This // `arg*` can be a nested DAG construct. This class provides getters to // retrieve `operator` and `arg*` as tblgen:: wrapper objects for handy helper // methods. // // A null DagNode contains a nullptr and converts to false implicitly. class DagNode { public: explicit DagNode(const llvm::DagInit *node) : node(node) {} // Implicit bool converter that returns true if this DagNode is not a null // DagNode. operator bool() const { return node != nullptr; } // Returns the symbol bound to this DAG node. StringRef getSymbol() const; // Returns the operator wrapper object corresponding to the dialect op matched // by this DAG. The operator wrapper will be queried from the given `mapper` // and created in it if not existing. Operator &getDialectOp(RecordOperatorMap *mapper) const; // Returns the number of operations recursively involved in the DAG tree // rooted from this node. int getNumOps() const; // Returns the number of immediate arguments to this DAG node. int getNumArgs() const; // Returns true if the `index`-th argument is a nested DAG construct. bool isNestedDagArg(unsigned index) const; // Gets the `index`-th argument as a nested DAG construct if possible. Returns // null DagNode otherwise. DagNode getArgAsNestedDag(unsigned index) const; // Gets the `index`-th argument as a DAG leaf. DagLeaf getArgAsLeaf(unsigned index) const; // Returns the specified name of the `index`-th argument. StringRef getArgName(unsigned index) const; // Returns true if this DAG construct means to replace with an existing SSA // value. bool isReplaceWithValue() const; // Returns whether this DAG represents the location of an op creation. bool isLocationDirective() const; // Returns true if this DAG node is wrapping native code call. bool isNativeCodeCall() const; // Returns true if this DAG node is an operation. bool isOperation() const; // Returns the native code call template inside this DAG node. // Precondition: isNativeCodeCall() StringRef getNativeCodeTemplate() const; void print(raw_ostream &os) const; private: const llvm::DagInit *node; // nullptr means null DagNode }; // A class for maintaining information for symbols bound in patterns and // provides methods for resolving them according to specific use cases. // // Symbols can be bound to // // * Op arguments and op results in the source pattern and // * Op results in result patterns. // // Symbols can be referenced in result patterns and additional constraints to // the pattern. // // For example, in // // ``` // def : Pattern< // (SrcOp:$results1 $arg0, %arg1), // [(ResOp1:$results2), (ResOp2 $results2 (ResOp3 $arg0, $arg1))]>; // ``` // // `$argN` is bound to the `SrcOp`'s N-th argument. `$results1` is bound to // `SrcOp`. `$results2` is bound to `ResOp1`. $result2 is referenced to build // `ResOp2`. `$arg0` and `$arg1` are referenced to build `ResOp3`. // // If a symbol binds to a multi-result op and it does not have the `__N` // suffix, the symbol is expanded to represent all results generated by the // multi-result op. If the symbol has a `__N` suffix, then it will expand to // only the N-th *static* result as declared in ODS, and that can still // corresponds to multiple *dynamic* values if the N-th *static* result is // variadic. // // This class keeps track of such symbols and resolves them into their bound // values in a suitable way. class SymbolInfoMap { public: explicit SymbolInfoMap(ArrayRef loc) : loc(loc) {} // Class for information regarding a symbol. class SymbolInfo { public: // Returns a string for defining a variable named as `name` to store the // value bound by this symbol. std::string getVarDecl(StringRef name) const; // Returns a variable name for the symbol named as `name`. std::string getVarName(StringRef name) const; private: // Allow SymbolInfoMap to access private methods. friend class SymbolInfoMap; // What kind of entity this symbol represents: // * Attr: op attribute // * Operand: op operand // * Result: op result // * Value: a value not attached to an op (e.g., from NativeCodeCall) enum class Kind : uint8_t { Attr, Operand, Result, Value }; // Creates a SymbolInfo instance. `index` is only used for `Attr` and // `Operand` so should be negative for `Result` and `Value` kind. SymbolInfo(const Operator *op, Kind kind, Optional index); // Static methods for creating SymbolInfo. static SymbolInfo getAttr(const Operator *op, int index) { return SymbolInfo(op, Kind::Attr, index); } + static SymbolInfo getAttr() { + return SymbolInfo(nullptr, Kind::Attr, llvm::None); + } static SymbolInfo getOperand(const Operator *op, int index) { return SymbolInfo(op, Kind::Operand, index); } static SymbolInfo getResult(const Operator *op) { return SymbolInfo(op, Kind::Result, llvm::None); } static SymbolInfo getValue() { return SymbolInfo(nullptr, Kind::Value, llvm::None); } // Returns the number of static values this symbol corresponds to. // A static value is an operand/result declared in ODS. Normally a symbol // only represents one static value, but symbols bound to op results can // represent more than one if the op is a multi-result op. int getStaticValueCount() const; // Returns a string containing the C++ expression for referencing this // symbol as a value (if this symbol represents one static value) or a value // range (if this symbol represents multiple static values). `name` is the // name of the C++ variable that this symbol bounds to. `index` should only // be used for indexing results. `fmt` is used to format each value. // `separator` is used to separate values if this is a value range. std::string getValueAndRangeUse(StringRef name, int index, const char *fmt, const char *separator) const; // Returns a string containing the C++ expression for referencing this // symbol as a value range regardless of how many static values this symbol // represents. `name` is the name of the C++ variable that this symbol // bounds to. `index` should only be used for indexing results. `fmt` is // used to format each value. `separator` is used to separate values in the // range. std::string getAllRangeUse(StringRef name, int index, const char *fmt, const char *separator) const; const Operator *op; // The op where the bound entity belongs Kind kind; // The kind of the bound entity // The argument index (for `Attr` and `Operand` only) Optional argIndex; // Alternative name for the symbol. It is used in case the name // is not unique. Applicable for `Operand` only. Optional alternativeName; }; using BaseT = std::unordered_multimap; // Iterators for accessing all symbols. using iterator = BaseT::iterator; iterator begin() { return symbolInfoMap.begin(); } iterator end() { return symbolInfoMap.end(); } // Const iterators for accessing all symbols. using const_iterator = BaseT::const_iterator; const_iterator begin() const { return symbolInfoMap.begin(); } const_iterator end() const { return symbolInfoMap.end(); } // Binds the given `symbol` to the `argIndex`-th argument to the given `op`. // Returns false if `symbol` is already bound and symbols are not operands. bool bindOpArgument(StringRef symbol, const Operator &op, int argIndex); // Binds the given `symbol` to the results the given `op`. Returns false if // `symbol` is already bound. bool bindOpResult(StringRef symbol, const Operator &op); // Registers the given `symbol` as bound to a value. Returns false if `symbol` // is already bound. bool bindValue(StringRef symbol); + // Registers the given `symbol` as bound to an attr. Returns false if `symbol` + // is already bound. + bool bindAttr(StringRef symbol); + // Returns true if the given `symbol` is bound. bool contains(StringRef symbol) const; // Returns an iterator to the information of the given symbol named as `key`. const_iterator find(StringRef key) const; // Returns an iterator to the information of the given symbol named as `key`, // with index `argIndex` for operator `op`. const_iterator findBoundSymbol(StringRef key, const Operator &op, int argIndex) const; // Returns the bounds of a range that includes all the elements which // bind to the `key`. std::pair getRangeOfEqualElements(StringRef key); // Returns number of times symbol named as `key` was used. int count(StringRef key) const; // Returns the number of static values of the given `symbol` corresponds to. // A static value is an operand/result declared in ODS. Normally a symbol only // represents one static value, but symbols bound to op results can represent // more than one if the op is a multi-result op. int getStaticValueCount(StringRef symbol) const; // Returns a string containing the C++ expression for referencing this // symbol as a value (if this symbol represents one static value) or a value // range (if this symbol represents multiple static values). `fmt` is used to // format each value. `separator` is used to separate values if `symbol` // represents a value range. std::string getValueAndRangeUse(StringRef symbol, const char *fmt = "{0}", const char *separator = ", ") const; // Returns a string containing the C++ expression for referencing this // symbol as a value range regardless of how many static values this symbol // represents. `fmt` is used to format each value. `separator` is used to // separate values in the range. std::string getAllRangeUse(StringRef symbol, const char *fmt = "{0}", const char *separator = ", ") const; // Assign alternative unique names to Operands that have equal names. void assignUniqueAlternativeNames(); // Splits the given `symbol` into a value pack name and an index. Returns the // value pack name and writes the index to `index` on success. Returns // `symbol` itself if it does not contain an index. // // We can use `name__N` to access the `N`-th value in the value pack bound to // `name`. `name` is typically the results of an multi-result op. static StringRef getValuePackName(StringRef symbol, int *index = nullptr); private: BaseT symbolInfoMap; // Pattern instantiation location. This is intended to be used as parameter // to PrintFatalError() to report errors. ArrayRef loc; }; // Wrapper class providing helper methods for accessing MLIR Pattern defined // in TableGen. This class should closely reflect what is defined as class // `Pattern` in TableGen. This class contains maps so it is not intended to be // used as values. class Pattern { public: explicit Pattern(const llvm::Record *def, RecordOperatorMap *mapper); // Returns the source pattern to match. DagNode getSourcePattern() const; // Returns the number of result patterns generated by applying this rewrite // rule. int getNumResultPatterns() const; // Returns the DAG tree root node of the `index`-th result pattern. DagNode getResultPattern(unsigned index) const; // Collects all symbols bound in the source pattern into `infoMap`. void collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap); // Collects all symbols bound in result patterns into `infoMap`. void collectResultPatternBoundSymbols(SymbolInfoMap &infoMap); // Returns the op that the root node of the source pattern matches. const Operator &getSourceRootOp(); // Returns the operator wrapper object corresponding to the given `node`'s DAG // operator. Operator &getDialectOp(DagNode node); // Returns the constraints. std::vector getConstraints() const; // Returns the benefit score of the pattern. int getBenefit() const; using IdentifierLine = std::pair; // Returns the file location of the pattern (buffer identifier + line number // pair). std::vector getLocation() const; private: + // Helper function to verify variabld binding. + void verifyBind(bool result, StringRef symbolName); + // Recursively collects all bound symbols inside the DAG tree rooted // at `tree` and updates the given `infoMap`. void collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap, bool isSrcPattern); // The TableGen definition of this pattern. const llvm::Record &def; // All operators. // TODO: we need a proper context manager, like MLIRContext, for managing the // lifetime of shared entities. RecordOperatorMap *recordOpMap; }; } // end namespace tblgen } // end namespace mlir #endif // MLIR_TABLEGEN_PATTERN_H_ diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp index 448f70359bd0..7044677fad36 100644 --- a/mlir/lib/TableGen/Pattern.cpp +++ b/mlir/lib/TableGen/Pattern.cpp @@ -1,666 +1,731 @@ //===- Pattern.cpp - Pattern wrapper class --------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // Pattern wrapper class to simplify using TableGen Record defining a MLIR // Pattern. // //===----------------------------------------------------------------------===// #include "mlir/TableGen/Pattern.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/TableGen/Error.h" #include "llvm/TableGen/Record.h" #define DEBUG_TYPE "mlir-tblgen-pattern" using namespace mlir; using namespace tblgen; using llvm::formatv; //===----------------------------------------------------------------------===// // DagLeaf //===----------------------------------------------------------------------===// bool DagLeaf::isUnspecified() const { return dyn_cast_or_null(def); } bool DagLeaf::isOperandMatcher() const { // Operand matchers specify a type constraint. return isSubClassOf("TypeConstraint"); } bool DagLeaf::isAttrMatcher() const { // Attribute matchers specify an attribute constraint. return isSubClassOf("AttrConstraint"); } bool DagLeaf::isNativeCodeCall() const { return isSubClassOf("NativeCodeCall"); } bool DagLeaf::isConstantAttr() const { return isSubClassOf("ConstantAttr"); } bool DagLeaf::isEnumAttrCase() const { return isSubClassOf("EnumAttrCaseInfo"); } bool DagLeaf::isStringAttr() const { return isa(def); } Constraint DagLeaf::getAsConstraint() const { assert((isOperandMatcher() || isAttrMatcher()) && "the DAG leaf must be operand or attribute"); return Constraint(cast(def)->getDef()); } ConstantAttr DagLeaf::getAsConstantAttr() const { assert(isConstantAttr() && "the DAG leaf must be constant attribute"); return ConstantAttr(cast(def)); } EnumAttrCase DagLeaf::getAsEnumAttrCase() const { assert(isEnumAttrCase() && "the DAG leaf must be an enum attribute case"); return EnumAttrCase(cast(def)); } std::string DagLeaf::getConditionTemplate() const { return getAsConstraint().getConditionTemplate(); } llvm::StringRef DagLeaf::getNativeCodeTemplate() const { assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall"); return cast(def)->getDef()->getValueAsString("expression"); } std::string DagLeaf::getStringAttr() const { assert(isStringAttr() && "the DAG leaf must be string attribute"); return def->getAsUnquotedString(); } bool DagLeaf::isSubClassOf(StringRef superclass) const { if (auto *defInit = dyn_cast_or_null(def)) return defInit->getDef()->isSubClassOf(superclass); return false; } void DagLeaf::print(raw_ostream &os) const { if (def) def->print(os); } //===----------------------------------------------------------------------===// // DagNode //===----------------------------------------------------------------------===// bool DagNode::isNativeCodeCall() const { if (auto *defInit = dyn_cast_or_null(node->getOperator())) return defInit->getDef()->isSubClassOf("NativeCodeCall"); return false; } bool DagNode::isOperation() const { return !isNativeCodeCall() && !isReplaceWithValue() && !isLocationDirective(); } llvm::StringRef DagNode::getNativeCodeTemplate() const { assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall"); return cast(node->getOperator()) ->getDef() ->getValueAsString("expression"); } llvm::StringRef DagNode::getSymbol() const { return node->getNameStr(); } Operator &DagNode::getDialectOp(RecordOperatorMap *mapper) const { llvm::Record *opDef = cast(node->getOperator())->getDef(); auto it = mapper->find(opDef); if (it != mapper->end()) return *it->second; return *mapper->try_emplace(opDef, std::make_unique(opDef)) .first->second; } int DagNode::getNumOps() const { int count = isReplaceWithValue() ? 0 : 1; for (int i = 0, e = getNumArgs(); i != e; ++i) { if (auto child = getArgAsNestedDag(i)) count += child.getNumOps(); } return count; } int DagNode::getNumArgs() const { return node->getNumArgs(); } bool DagNode::isNestedDagArg(unsigned index) const { return isa(node->getArg(index)); } DagNode DagNode::getArgAsNestedDag(unsigned index) const { return DagNode(dyn_cast_or_null(node->getArg(index))); } DagLeaf DagNode::getArgAsLeaf(unsigned index) const { assert(!isNestedDagArg(index)); return DagLeaf(node->getArg(index)); } StringRef DagNode::getArgName(unsigned index) const { return node->getArgNameStr(index); } bool DagNode::isReplaceWithValue() const { auto *dagOpDef = cast(node->getOperator())->getDef(); return dagOpDef->getName() == "replaceWithValue"; } bool DagNode::isLocationDirective() const { auto *dagOpDef = cast(node->getOperator())->getDef(); return dagOpDef->getName() == "location"; } void DagNode::print(raw_ostream &os) const { if (node) node->print(os); } //===----------------------------------------------------------------------===// // SymbolInfoMap //===----------------------------------------------------------------------===// StringRef SymbolInfoMap::getValuePackName(StringRef symbol, int *index) { StringRef name, indexStr; int idx = -1; std::tie(name, indexStr) = symbol.rsplit("__"); if (indexStr.consumeInteger(10, idx)) { // The second part is not an index; we return the whole symbol as-is. return symbol; } if (index) { *index = idx; } return name; } SymbolInfoMap::SymbolInfo::SymbolInfo(const Operator *op, SymbolInfo::Kind kind, Optional index) : op(op), kind(kind), argIndex(index) {} int SymbolInfoMap::SymbolInfo::getStaticValueCount() const { switch (kind) { case Kind::Attr: case Kind::Operand: case Kind::Value: return 1; case Kind::Result: return op->getNumResults(); } llvm_unreachable("unknown kind"); } std::string SymbolInfoMap::SymbolInfo::getVarName(StringRef name) const { return alternativeName.hasValue() ? alternativeName.getValue() : name.str(); } std::string SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const { LLVM_DEBUG(llvm::dbgs() << "getVarDecl for '" << name << "': "); switch (kind) { case Kind::Attr: { - auto type = - op->getArg(*argIndex).get()->attr.getStorageType(); - return std::string(formatv("{0} {1};\n", type, name)); + if (op) { + auto type = + op->getArg(*argIndex).get()->attr.getStorageType(); + return std::string(formatv("{0} {1};\n", type, name)); + } + // TODO(suderman): Use a more exact type when available. + return std::string(formatv("Attribute {0};\n", name)); } case Kind::Operand: { // Use operand range for captured operands (to support potential variadic // operands). return std::string( formatv("::mlir::Operation::operand_range {0}(op0->getOperands());\n", getVarName(name))); } case Kind::Value: { return std::string(formatv("::llvm::ArrayRef<::mlir::Value> {0};\n", name)); } case Kind::Result: { // Use the op itself for captured results. return std::string(formatv("{0} {1};\n", op->getQualCppClassName(), name)); } } llvm_unreachable("unknown kind"); } std::string SymbolInfoMap::SymbolInfo::getValueAndRangeUse( StringRef name, int index, const char *fmt, const char *separator) const { LLVM_DEBUG(llvm::dbgs() << "getValueAndRangeUse for '" << name << "': "); switch (kind) { case Kind::Attr: { assert(index < 0); auto repl = formatv(fmt, name); LLVM_DEBUG(llvm::dbgs() << repl << " (Attr)\n"); return std::string(repl); } case Kind::Operand: { assert(index < 0); auto *operand = op->getArg(*argIndex).get(); // If this operand is variadic, then return a range. Otherwise, return the // value itself. if (operand->isVariableLength()) { auto repl = formatv(fmt, name); LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicOperand)\n"); return std::string(repl); } auto repl = formatv(fmt, formatv("(*{0}.begin())", name)); LLVM_DEBUG(llvm::dbgs() << repl << " (SingleOperand)\n"); return std::string(repl); } case Kind::Result: { // If `index` is greater than zero, then we are referencing a specific // result of a multi-result op. The result can still be variadic. if (index >= 0) { std::string v = std::string(formatv("{0}.getODSResults({1})", name, index)); if (!op->getResult(index).isVariadic()) v = std::string(formatv("(*{0}.begin())", v)); auto repl = formatv(fmt, v); LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n"); return std::string(repl); } // If this op has no result at all but still we bind a symbol to it, it // means we want to capture the op itself. if (op->getNumResults() == 0) { LLVM_DEBUG(llvm::dbgs() << name << " (Op)\n"); return std::string(name); } // We are referencing all results of the multi-result op. A specific result // can either be a value or a range. Then join them with `separator`. SmallVector values; values.reserve(op->getNumResults()); for (int i = 0, e = op->getNumResults(); i < e; ++i) { std::string v = std::string(formatv("{0}.getODSResults({1})", name, i)); if (!op->getResult(i).isVariadic()) { v = std::string(formatv("(*{0}.begin())", v)); } values.push_back(std::string(formatv(fmt, v))); } auto repl = llvm::join(values, separator); LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n"); return repl; } case Kind::Value: { assert(index < 0); assert(op == nullptr); auto repl = formatv(fmt, name); LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n"); return std::string(repl); } } llvm_unreachable("unknown kind"); } std::string SymbolInfoMap::SymbolInfo::getAllRangeUse( StringRef name, int index, const char *fmt, const char *separator) const { LLVM_DEBUG(llvm::dbgs() << "getAllRangeUse for '" << name << "': "); switch (kind) { case Kind::Attr: case Kind::Operand: { assert(index < 0 && "only allowed for symbol bound to result"); auto repl = formatv(fmt, name); LLVM_DEBUG(llvm::dbgs() << repl << " (Operand/Attr)\n"); return std::string(repl); } case Kind::Result: { if (index >= 0) { auto repl = formatv(fmt, formatv("{0}.getODSResults({1})", name, index)); LLVM_DEBUG(llvm::dbgs() << repl << " (SingleResult)\n"); return std::string(repl); } // We are referencing all results of the multi-result op. Each result should // have a value range, and then join them with `separator`. SmallVector values; values.reserve(op->getNumResults()); for (int i = 0, e = op->getNumResults(); i < e; ++i) { values.push_back(std::string( formatv(fmt, formatv("{0}.getODSResults({1})", name, i)))); } auto repl = llvm::join(values, separator); LLVM_DEBUG(llvm::dbgs() << repl << " (VariadicResult)\n"); return repl; } case Kind::Value: { assert(index < 0 && "only allowed for symbol bound to result"); assert(op == nullptr); auto repl = formatv(fmt, formatv("{{{0}}", name)); LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n"); return std::string(repl); } } llvm_unreachable("unknown kind"); } bool SymbolInfoMap::bindOpArgument(StringRef symbol, const Operator &op, int argIndex) { StringRef name = getValuePackName(symbol); if (name != symbol) { auto error = formatv( "symbol '{0}' with trailing index cannot bind to op argument", symbol); PrintFatalError(loc, error); } auto symInfo = op.getArg(argIndex).is() ? SymbolInfo::getAttr(&op, argIndex) : SymbolInfo::getOperand(&op, argIndex); std::string key = symbol.str(); if (symbolInfoMap.count(key)) { // Only non unique name for the operand is supported. if (symInfo.kind != SymbolInfo::Kind::Operand) { return false; } // Cannot add new operand if there is already non operand with the same // name. if (symbolInfoMap.find(key)->second.kind != SymbolInfo::Kind::Operand) { return false; } } symbolInfoMap.emplace(key, symInfo); return true; } bool SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) { std::string name = getValuePackName(symbol).str(); auto inserted = symbolInfoMap.emplace(name, SymbolInfo::getResult(&op)); return symbolInfoMap.count(inserted->first) == 1; } bool SymbolInfoMap::bindValue(StringRef symbol) { auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getValue()); return symbolInfoMap.count(inserted->first) == 1; } +bool SymbolInfoMap::bindAttr(StringRef symbol) { + auto inserted = symbolInfoMap.emplace(symbol, SymbolInfo::getAttr()); + return symbolInfoMap.count(inserted->first) == 1; +} + bool SymbolInfoMap::contains(StringRef symbol) const { return find(symbol) != symbolInfoMap.end(); } SymbolInfoMap::const_iterator SymbolInfoMap::find(StringRef key) const { std::string name = getValuePackName(key).str(); return symbolInfoMap.find(name); } SymbolInfoMap::const_iterator SymbolInfoMap::findBoundSymbol(StringRef key, const Operator &op, int argIndex) const { std::string name = getValuePackName(key).str(); auto range = symbolInfoMap.equal_range(name); for (auto it = range.first; it != range.second; ++it) { if (it->second.op == &op && it->second.argIndex == argIndex) { return it; } } return symbolInfoMap.end(); } std::pair SymbolInfoMap::getRangeOfEqualElements(StringRef key) { std::string name = getValuePackName(key).str(); return symbolInfoMap.equal_range(name); } int SymbolInfoMap::count(StringRef key) const { std::string name = getValuePackName(key).str(); return symbolInfoMap.count(name); } int SymbolInfoMap::getStaticValueCount(StringRef symbol) const { StringRef name = getValuePackName(symbol); if (name != symbol) { // If there is a trailing index inside symbol, it references just one // static value. return 1; } // Otherwise, find how many it represents by querying the symbol's info. return find(name)->second.getStaticValueCount(); } std::string SymbolInfoMap::getValueAndRangeUse(StringRef symbol, const char *fmt, const char *separator) const { int index = -1; StringRef name = getValuePackName(symbol, &index); auto it = symbolInfoMap.find(name.str()); if (it == symbolInfoMap.end()) { auto error = formatv("referencing unbound symbol '{0}'", symbol); PrintFatalError(loc, error); } return it->second.getValueAndRangeUse(name, index, fmt, separator); } std::string SymbolInfoMap::getAllRangeUse(StringRef symbol, const char *fmt, const char *separator) const { int index = -1; StringRef name = getValuePackName(symbol, &index); auto it = symbolInfoMap.find(name.str()); if (it == symbolInfoMap.end()) { auto error = formatv("referencing unbound symbol '{0}'", symbol); PrintFatalError(loc, error); } return it->second.getAllRangeUse(name, index, fmt, separator); } void SymbolInfoMap::assignUniqueAlternativeNames() { llvm::StringSet<> usedNames; for (auto symbolInfoIt = symbolInfoMap.begin(); symbolInfoIt != symbolInfoMap.end();) { auto range = symbolInfoMap.equal_range(symbolInfoIt->first); auto startRange = range.first; auto endRange = range.second; auto operandName = symbolInfoIt->first; int startSearchIndex = 0; for (++startRange; startRange != endRange; ++startRange) { // Current operand name is not unique, find a unique one // and set the alternative name. for (int i = startSearchIndex;; ++i) { std::string alternativeName = operandName + std::to_string(i); if (!usedNames.contains(alternativeName) && symbolInfoMap.count(alternativeName) == 0) { usedNames.insert(alternativeName); startRange->second.alternativeName = alternativeName; startSearchIndex = i + 1; break; } } } symbolInfoIt = endRange; } } //===----------------------------------------------------------------------===// // Pattern //==----------------------------------------------------------------------===// Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper) : def(*def), recordOpMap(mapper) {} DagNode Pattern::getSourcePattern() const { return DagNode(def.getValueAsDag("sourcePattern")); } int Pattern::getNumResultPatterns() const { auto *results = def.getValueAsListInit("resultPatterns"); return results->size(); } DagNode Pattern::getResultPattern(unsigned index) const { auto *results = def.getValueAsListInit("resultPatterns"); return DagNode(cast(results->getElement(index))); } void Pattern::collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap) { LLVM_DEBUG(llvm::dbgs() << "start collecting source pattern bound symbols\n"); collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true); LLVM_DEBUG(llvm::dbgs() << "done collecting source pattern bound symbols\n"); LLVM_DEBUG(llvm::dbgs() << "start assigning alternative names for symbols\n"); infoMap.assignUniqueAlternativeNames(); LLVM_DEBUG(llvm::dbgs() << "done assigning alternative names for symbols\n"); } void Pattern::collectResultPatternBoundSymbols(SymbolInfoMap &infoMap) { LLVM_DEBUG(llvm::dbgs() << "start collecting result pattern bound symbols\n"); for (int i = 0, e = getNumResultPatterns(); i < e; ++i) { auto pattern = getResultPattern(i); collectBoundSymbols(pattern, infoMap, /*isSrcPattern=*/false); } LLVM_DEBUG(llvm::dbgs() << "done collecting result pattern bound symbols\n"); } const Operator &Pattern::getSourceRootOp() { return getSourcePattern().getDialectOp(recordOpMap); } Operator &Pattern::getDialectOp(DagNode node) { return node.getDialectOp(recordOpMap); } std::vector Pattern::getConstraints() const { auto *listInit = def.getValueAsListInit("constraints"); std::vector ret; ret.reserve(listInit->size()); for (auto it : *listInit) { auto *dagInit = dyn_cast(it); if (!dagInit) - PrintFatalError(def.getLoc(), "all elements in Pattern multi-entity " - "constraints should be DAG nodes"); + PrintFatalError(&def, "all elements in Pattern multi-entity " + "constraints should be DAG nodes"); std::vector entities; entities.reserve(dagInit->arg_size()); for (auto *argName : dagInit->getArgNames()) { if (!argName) { PrintFatalError( - def.getLoc(), + &def, "operands to additional constraints can only be symbol references"); } entities.push_back(std::string(argName->getValue())); } ret.emplace_back(cast(dagInit->getOperator())->getDef(), dagInit->getNameStr(), std::move(entities)); } return ret; } int Pattern::getBenefit() const { // The initial benefit value is a heuristic with number of ops in the source // pattern. int initBenefit = getSourcePattern().getNumOps(); llvm::DagInit *delta = def.getValueAsDag("benefitDelta"); if (delta->getNumArgs() != 1 || !isa(delta->getArg(0))) { - PrintFatalError(def.getLoc(), + PrintFatalError(&def, "The 'addBenefit' takes and only takes one integer value"); } return initBenefit + dyn_cast(delta->getArg(0))->getValue(); } std::vector Pattern::getLocation() const { std::vector> result; result.reserve(def.getLoc().size()); for (auto loc : def.getLoc()) { unsigned buf = llvm::SrcMgr.FindBufferContainingLoc(loc); assert(buf && "invalid source location"); result.emplace_back( llvm::SrcMgr.getBufferInfo(buf).Buffer->getBufferIdentifier(), llvm::SrcMgr.getLineAndColumn(loc, buf).first); } return result; } +void Pattern::verifyBind(bool result, StringRef symbolName) { + if (!result) { + auto err = formatv("symbol '{0}' bound more than once", symbolName); + PrintFatalError(&def, err); + } +} + void Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap, bool isSrcPattern) { auto treeName = tree.getSymbol(); - if (!tree.isOperation()) { + auto numTreeArgs = tree.getNumArgs(); + + if (tree.isNativeCodeCall()) { if (!treeName.empty()) { PrintFatalError( - def.getLoc(), - formatv("binding symbol '{0}' to non-operation unsupported right now", - treeName)); + &def, + formatv( + "binding symbol '{0}' to native code call unsupported right now", + treeName)); } - return; - } - auto &op = getDialectOp(tree); - auto numOpArgs = op.getNumArgs(); - auto numTreeArgs = tree.getNumArgs(); - - // The pattern might have the last argument specifying the location. - bool hasLocDirective = false; - if (numTreeArgs != 0) { - if (auto lastArg = tree.getArgAsNestedDag(numTreeArgs - 1)) - hasLocDirective = lastArg.isLocationDirective(); - } + for (int i = 0; i != numTreeArgs; ++i) { + if (auto treeArg = tree.getArgAsNestedDag(i)) { + // This DAG node argument is a DAG node itself. Go inside recursively. + collectBoundSymbols(treeArg, infoMap, isSrcPattern); + continue; + } - if (numOpArgs != numTreeArgs - hasLocDirective) { - auto err = formatv("op '{0}' argument number mismatch: " - "{1} in pattern vs. {2} in definition", - op.getOperationName(), numTreeArgs, numOpArgs); - PrintFatalError(def.getLoc(), err); - } + if (!isSrcPattern) + continue; - // The name attached to the DAG node's operator is for representing the - // results generated from this op. It should be remembered as bound results. - if (!treeName.empty()) { - LLVM_DEBUG(llvm::dbgs() - << "found symbol bound to op result: " << treeName << '\n'); - if (!infoMap.bindOpResult(treeName, op)) - PrintFatalError(def.getLoc(), - formatv("symbol '{0}' bound more than once", treeName)); - } - - for (int i = 0; i != numTreeArgs; ++i) { - if (auto treeArg = tree.getArgAsNestedDag(i)) { - // This DAG node argument is a DAG node itself. Go inside recursively. - collectBoundSymbols(treeArg, infoMap, isSrcPattern); - } else if (isSrcPattern) { - // We can only bind symbols to op arguments in source pattern. Those + // We can only bind symbols to arguments in source pattern. Those // symbols are referenced in result patterns. auto treeArgName = tree.getArgName(i); + // `$_` is a special symbol meaning ignore the current argument. if (!treeArgName.empty() && treeArgName != "_") { - LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: " - << treeArgName << '\n'); - if (!infoMap.bindOpArgument(treeArgName, op, i)) { - auto err = formatv("symbol '{0}' bound more than once", treeArgName); - PrintFatalError(def.getLoc(), err); + if (tree.isNestedDagArg(i)) { + auto err = formatv("cannot bind '{0}' for nested native call arg", + treeArgName); + PrintFatalError(&def, err); } + + DagLeaf leaf = tree.getArgAsLeaf(i); + auto constraint = leaf.getAsConstraint(); + bool isAttr = leaf.isAttrMatcher() || leaf.isEnumAttrCase() || + leaf.isConstantAttr() || + constraint.getKind() == Constraint::Kind::CK_Attr; + + if (isAttr) { + verifyBind(infoMap.bindAttr(treeArgName), treeArgName); + continue; + } + + verifyBind(infoMap.bindValue(treeArgName), treeArgName); } } + + return; + } + + if (tree.isOperation()) { + auto &op = getDialectOp(tree); + auto numOpArgs = op.getNumArgs(); + + // The pattern might have the last argument specifying the location. + bool hasLocDirective = false; + if (numTreeArgs != 0) { + if (auto lastArg = tree.getArgAsNestedDag(numTreeArgs - 1)) + hasLocDirective = lastArg.isLocationDirective(); + } + + if (numOpArgs != numTreeArgs - hasLocDirective) { + auto err = formatv("op '{0}' argument number mismatch: " + "{1} in pattern vs. {2} in definition", + op.getOperationName(), numTreeArgs, numOpArgs); + PrintFatalError(&def, err); + } + + // The name attached to the DAG node's operator is for representing the + // results generated from this op. It should be remembered as bound results. + if (!treeName.empty()) { + LLVM_DEBUG(llvm::dbgs() + << "found symbol bound to op result: " << treeName << '\n'); + verifyBind(infoMap.bindOpResult(treeName, op), treeName); + } + + for (int i = 0; i != numTreeArgs; ++i) { + if (auto treeArg = tree.getArgAsNestedDag(i)) { + // This DAG node argument is a DAG node itself. Go inside recursively. + collectBoundSymbols(treeArg, infoMap, isSrcPattern); + continue; + } + + if (isSrcPattern) { + // We can only bind symbols to op arguments in source pattern. Those + // symbols are referenced in result patterns. + auto treeArgName = tree.getArgName(i); + // `$_` is a special symbol meaning ignore the current argument. + if (!treeArgName.empty() && treeArgName != "_") { + LLVM_DEBUG(llvm::dbgs() << "found symbol bound to op argument: " + << treeArgName << '\n'); + verifyBind(infoMap.bindOpArgument(treeArgName, op, i), treeArgName); + } + } + } + return; + } + + if (!treeName.empty()) { + PrintFatalError( + &def, formatv("binding symbol '{0}' to non-operation/native code call " + "unsupported right now", + treeName)); } + return; } diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp index 3bfb82495ce1..d34e997644a5 100644 --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -1,875 +1,879 @@ //===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "TestDialect.h" #include "TestTypes.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/Function.h" #include "mlir/IR/Module.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/StringSwitch.h" using namespace mlir; void mlir::registerTestDialect(DialectRegistry ®istry) { registry.insert(); } //===----------------------------------------------------------------------===// // TestDialect Interfaces //===----------------------------------------------------------------------===// namespace { // Test support for interacting with the AsmPrinter. struct TestOpAsmInterface : public OpAsmDialectInterface { using OpAsmDialectInterface::OpAsmDialectInterface; void getAsmResultNames(Operation *op, OpAsmSetValueNameFn setNameFn) const final { if (auto asmOp = dyn_cast(op)) setNameFn(asmOp, "result"); } void getAsmBlockArgumentNames(Block *block, OpAsmSetValueNameFn setNameFn) const final { auto op = block->getParentOp(); auto arrayAttr = op->getAttrOfType("arg_names"); if (!arrayAttr) return; auto args = block->getArguments(); auto e = std::min(arrayAttr.size(), args.size()); for (unsigned i = 0; i < e; ++i) { if (auto strAttr = arrayAttr[i].dyn_cast()) setNameFn(args[i], strAttr.getValue()); } } }; struct TestDialectFoldInterface : public DialectFoldInterface { using DialectFoldInterface::DialectFoldInterface; /// Registered hook to check if the given region, which is attached to an /// operation that is *not* isolated from above, should be used when /// materializing constants. bool shouldMaterializeInto(Region *region) const final { // If this is a one region operation, then insert into it. return isa(region->getParentOp()); } }; /// This class defines the interface for handling inlining with standard /// operations. struct TestInlinerInterface : public DialectInlinerInterface { using DialectInlinerInterface::DialectInlinerInterface; //===--------------------------------------------------------------------===// // Analysis Hooks //===--------------------------------------------------------------------===// bool isLegalToInline(Region *, Region *, BlockAndValueMapping &) const final { // Inlining into test dialect regions is legal. return true; } bool isLegalToInline(Operation *, Region *, BlockAndValueMapping &) const final { return true; } bool shouldAnalyzeRecursively(Operation *op) const final { // Analyze recursively if this is not a functional region operation, it // froms a separate functional scope. return !isa(op); } //===--------------------------------------------------------------------===// // Transformation Hooks //===--------------------------------------------------------------------===// /// Handle the given inlined terminator by replacing it with a new operation /// as necessary. void handleTerminator(Operation *op, ArrayRef valuesToRepl) const final { // Only handle "test.return" here. auto returnOp = dyn_cast(op); if (!returnOp) return; // Replace the values directly with the return operands. assert(returnOp.getNumOperands() == valuesToRepl.size()); for (const auto &it : llvm::enumerate(returnOp.getOperands())) valuesToRepl[it.index()].replaceAllUsesWith(it.value()); } /// Attempt to materialize a conversion for a type mismatch between a call /// from this dialect, and a callable region. This method should generate an /// operation that takes 'input' as the only operand, and produces a single /// result of 'resultType'. If a conversion can not be generated, nullptr /// should be returned. Operation *materializeCallConversion(OpBuilder &builder, Value input, Type resultType, Location conversionLoc) const final { // Only allow conversion for i16/i32 types. if (!(resultType.isSignlessInteger(16) || resultType.isSignlessInteger(32)) || !(input.getType().isSignlessInteger(16) || input.getType().isSignlessInteger(32))) return nullptr; return builder.create(conversionLoc, resultType, input); } }; } // end anonymous namespace //===----------------------------------------------------------------------===// // TestDialect //===----------------------------------------------------------------------===// void TestDialect::initialize() { addOperations< #define GET_OP_LIST #include "TestOps.cpp.inc" >(); addInterfaces(); addTypes(); allowUnknownOperations(); } static Type parseTestType(MLIRContext *ctxt, DialectAsmParser &parser, llvm::SetVector &stack) { StringRef typeTag; if (failed(parser.parseKeyword(&typeTag))) return Type(); auto genType = generatedTypeParser(ctxt, parser, typeTag); if (genType != Type()) return genType; if (typeTag == "test_type") return TestType::get(parser.getBuilder().getContext()); if (typeTag != "test_rec") return Type(); StringRef name; if (parser.parseLess() || parser.parseKeyword(&name)) return Type(); auto rec = TestRecursiveType::get(parser.getBuilder().getContext(), name); // If this type already has been parsed above in the stack, expect just the // name. if (stack.contains(rec)) { if (failed(parser.parseGreater())) return Type(); return rec; } // Otherwise, parse the body and update the type. if (failed(parser.parseComma())) return Type(); stack.insert(rec); Type subtype = parseTestType(ctxt, parser, stack); stack.pop_back(); if (!subtype || failed(parser.parseGreater()) || failed(rec.setBody(subtype))) return Type(); return rec; } Type TestDialect::parseType(DialectAsmParser &parser) const { llvm::SetVector stack; return parseTestType(getContext(), parser, stack); } static void printTestType(Type type, DialectAsmPrinter &printer, llvm::SetVector &stack) { if (succeeded(generatedTypePrinter(type, printer))) return; if (type.isa()) { printer << "test_type"; return; } auto rec = type.cast(); printer << "test_rec<" << rec.getName(); if (!stack.contains(rec)) { printer << ", "; stack.insert(rec); printTestType(rec.getBody(), printer, stack); stack.pop_back(); } printer << ">"; } void TestDialect::printType(Type type, DialectAsmPrinter &printer) const { llvm::SetVector stack; printTestType(type, printer, stack); } LogicalResult TestDialect::verifyOperationAttribute(Operation *op, NamedAttribute namedAttr) { if (namedAttr.first == "test.invalid_attr") return op->emitError() << "invalid to use 'test.invalid_attr'"; return success(); } LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op, unsigned regionIndex, unsigned argIndex, NamedAttribute namedAttr) { if (namedAttr.first == "test.invalid_attr") return op->emitError() << "invalid to use 'test.invalid_attr'"; return success(); } LogicalResult TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex, unsigned resultIndex, NamedAttribute namedAttr) { if (namedAttr.first == "test.invalid_attr") return op->emitError() << "invalid to use 'test.invalid_attr'"; return success(); } //===----------------------------------------------------------------------===// // TestBranchOp //===----------------------------------------------------------------------===// Optional TestBranchOp::getMutableSuccessorOperands(unsigned index) { assert(index == 0 && "invalid successor index"); return targetOperandsMutable(); } //===----------------------------------------------------------------------===// // TestFoldToCallOp //===----------------------------------------------------------------------===// namespace { struct FoldToCallOpPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(FoldToCallOp op, PatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, TypeRange(), op.calleeAttr(), ValueRange()); return success(); } }; } // end anonymous namespace void FoldToCallOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { results.insert(context); } //===----------------------------------------------------------------------===// // Test Format* operations //===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===// // Parsing static ParseResult parseCustomDirectiveOperands( OpAsmParser &parser, OpAsmParser::OperandType &operand, Optional &optOperand, SmallVectorImpl &varOperands) { if (parser.parseOperand(operand)) return failure(); if (succeeded(parser.parseOptionalComma())) { optOperand.emplace(); if (parser.parseOperand(*optOperand)) return failure(); } if (parser.parseArrow() || parser.parseLParen() || parser.parseOperandList(varOperands) || parser.parseRParen()) return failure(); return success(); } static ParseResult parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType, Type &optOperandType, SmallVectorImpl &varOperandTypes) { if (parser.parseColon()) return failure(); if (parser.parseType(operandType)) return failure(); if (succeeded(parser.parseOptionalComma())) { if (parser.parseType(optOperandType)) return failure(); } if (parser.parseArrow() || parser.parseLParen() || parser.parseTypeList(varOperandTypes) || parser.parseRParen()) return failure(); return success(); } static ParseResult parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType, Type optOperandType, const SmallVectorImpl &varOperandTypes) { if (parser.parseKeyword("type_refs_capture")) return failure(); Type operandType2, optOperandType2; SmallVector varOperandTypes2; if (parseCustomDirectiveResults(parser, operandType2, optOperandType2, varOperandTypes2)) return failure(); if (operandType != operandType2 || optOperandType != optOperandType2 || varOperandTypes != varOperandTypes2) return failure(); return success(); } static ParseResult parseCustomDirectiveOperandsAndTypes( OpAsmParser &parser, OpAsmParser::OperandType &operand, Optional &optOperand, SmallVectorImpl &varOperands, Type &operandType, Type &optOperandType, SmallVectorImpl &varOperandTypes) { if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) || parseCustomDirectiveResults(parser, operandType, optOperandType, varOperandTypes)) return failure(); return success(); } static ParseResult parseCustomDirectiveRegions( OpAsmParser &parser, Region ®ion, SmallVectorImpl> &varRegions) { if (parser.parseRegion(region)) return failure(); if (failed(parser.parseOptionalComma())) return success(); std::unique_ptr varRegion = std::make_unique(); if (parser.parseRegion(*varRegion)) return failure(); varRegions.emplace_back(std::move(varRegion)); return success(); } static ParseResult parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor, SmallVectorImpl &varSuccessors) { if (parser.parseSuccessor(successor)) return failure(); if (failed(parser.parseOptionalComma())) return success(); Block *varSuccessor; if (parser.parseSuccessor(varSuccessor)) return failure(); varSuccessors.append(2, varSuccessor); return success(); } static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser, IntegerAttr &attr, IntegerAttr &optAttr) { if (parser.parseAttribute(attr)) return failure(); if (succeeded(parser.parseOptionalComma())) { if (parser.parseAttribute(optAttr)) return failure(); } return success(); } //===----------------------------------------------------------------------===// // Printing static void printCustomDirectiveOperands(OpAsmPrinter &printer, Value operand, Value optOperand, OperandRange varOperands) { printer << operand; if (optOperand) printer << ", " << optOperand; printer << " -> (" << varOperands << ")"; } static void printCustomDirectiveResults(OpAsmPrinter &printer, Type operandType, Type optOperandType, TypeRange varOperandTypes) { printer << " : " << operandType; if (optOperandType) printer << ", " << optOperandType; printer << " -> (" << varOperandTypes << ")"; } static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer, Type operandType, Type optOperandType, TypeRange varOperandTypes) { printer << " type_refs_capture "; printCustomDirectiveResults(printer, operandType, optOperandType, varOperandTypes); } static void printCustomDirectiveOperandsAndTypes(OpAsmPrinter &printer, Value operand, Value optOperand, OperandRange varOperands, Type operandType, Type optOperandType, TypeRange varOperandTypes) { printCustomDirectiveOperands(printer, operand, optOperand, varOperands); printCustomDirectiveResults(printer, operandType, optOperandType, varOperandTypes); } static void printCustomDirectiveRegions(OpAsmPrinter &printer, Region ®ion, MutableArrayRef varRegions) { printer.printRegion(region); if (!varRegions.empty()) { printer << ", "; for (Region ®ion : varRegions) printer.printRegion(region); } } static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Block *successor, SuccessorRange varSuccessors) { printer << successor; if (!varSuccessors.empty()) printer << ", " << varSuccessors.front(); } static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Attribute attribute, Attribute optAttribute) { printer << attribute; if (optAttribute) printer << ", " << optAttribute; } //===----------------------------------------------------------------------===// // Test IsolatedRegionOp - parse passthrough region arguments. //===----------------------------------------------------------------------===// static ParseResult parseIsolatedRegionOp(OpAsmParser &parser, OperationState &result) { OpAsmParser::OperandType argInfo; Type argType = parser.getBuilder().getIndexType(); // Parse the input operand. if (parser.parseOperand(argInfo) || parser.resolveOperand(argInfo, argType, result.operands)) return failure(); // Parse the body region, and reuse the operand info as the argument info. Region *body = result.addRegion(); return parser.parseRegion(*body, argInfo, argType, /*enableNameShadowing=*/true); } static void print(OpAsmPrinter &p, IsolatedRegionOp op) { p << "test.isolated_region "; p.printOperand(op.getOperand()); p.shadowRegionArgs(op.region(), op.getOperand()); p.printRegion(op.region(), /*printEntryBlockArgs=*/false); } //===----------------------------------------------------------------------===// // Test SSACFGRegionOp //===----------------------------------------------------------------------===// RegionKind SSACFGRegionOp::getRegionKind(unsigned index) { return RegionKind::SSACFG; } //===----------------------------------------------------------------------===// // Test GraphRegionOp //===----------------------------------------------------------------------===// static ParseResult parseGraphRegionOp(OpAsmParser &parser, OperationState &result) { // Parse the body region, and reuse the operand info as the argument info. Region *body = result.addRegion(); return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); } static void print(OpAsmPrinter &p, GraphRegionOp op) { p << "test.graph_region "; p.printRegion(op.region(), /*printEntryBlockArgs=*/false); } RegionKind GraphRegionOp::getRegionKind(unsigned index) { return RegionKind::Graph; } //===----------------------------------------------------------------------===// // Test AffineScopeOp //===----------------------------------------------------------------------===// static ParseResult parseAffineScopeOp(OpAsmParser &parser, OperationState &result) { // Parse the body region, and reuse the operand info as the argument info. Region *body = result.addRegion(); return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); } static void print(OpAsmPrinter &p, AffineScopeOp op) { p << "test.affine_scope "; p.printRegion(op.region(), /*printEntryBlockArgs=*/false); } //===----------------------------------------------------------------------===// // Test parser. //===----------------------------------------------------------------------===// static ParseResult parseWrappedKeywordOp(OpAsmParser &parser, OperationState &result) { StringRef keyword; if (parser.parseKeyword(&keyword)) return failure(); result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword)); return success(); } static void print(OpAsmPrinter &p, WrappedKeywordOp op) { p << WrappedKeywordOp::getOperationName() << " " << op.keyword(); } //===----------------------------------------------------------------------===// // Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`. static ParseResult parseWrappingRegionOp(OpAsmParser &parser, OperationState &result) { if (parser.parseKeyword("wraps")) return failure(); // Parse the wrapped op in a region Region &body = *result.addRegion(); body.push_back(new Block); Block &block = body.back(); Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin()); if (!wrapped_op) return failure(); // Create a return terminator in the inner region, pass as operand to the // terminator the returned values from the wrapped operation. SmallVector return_operands(wrapped_op->getResults()); OpBuilder builder(parser.getBuilder().getContext()); builder.setInsertionPointToEnd(&block); builder.create(wrapped_op->getLoc(), return_operands); // Get the results type for the wrapping op from the terminator operands. Operation &return_op = body.back().back(); result.types.append(return_op.operand_type_begin(), return_op.operand_type_end()); // Use the location of the wrapped op for the "test.wrapping_region" op. result.location = wrapped_op->getLoc(); return success(); } static void print(OpAsmPrinter &p, WrappingRegionOp op) { p << op.getOperationName() << " wraps "; p.printGenericOp(&op.region().front().front()); } //===----------------------------------------------------------------------===// // Test PolyForOp - parse list of region arguments. //===----------------------------------------------------------------------===// static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) { SmallVector ivsInfo; // Parse list of region arguments without a delimiter. if (parser.parseRegionArgumentList(ivsInfo)) return failure(); // Parse the body region. Region *body = result.addRegion(); auto &builder = parser.getBuilder(); SmallVector argTypes(ivsInfo.size(), builder.getIndexType()); return parser.parseRegion(*body, ivsInfo, argTypes); } //===----------------------------------------------------------------------===// // Test removing op with inner ops. //===----------------------------------------------------------------------===// namespace { struct TestRemoveOpWithInnerOps : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(TestOpWithRegionPattern op, PatternRewriter &rewriter) const override { rewriter.eraseOp(op); return success(); } }; } // end anonymous namespace void TestOpWithRegionPattern::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { results.insert(context); } OpFoldResult TestOpWithRegionFold::fold(ArrayRef operands) { return operand(); } +OpFoldResult TestOpConstant::fold(ArrayRef operands) { + return getValue(); +} + LogicalResult TestOpWithVariadicResultsAndFolder::fold( ArrayRef operands, SmallVectorImpl &results) { for (Value input : this->operands()) { results.push_back(input); } return success(); } OpFoldResult TestOpInPlaceFold::fold(ArrayRef operands) { assert(operands.size() == 1); if (operands.front()) { setAttr("attr", operands.front()); return getResult(); } return {}; } LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes( MLIRContext *, Optional location, ValueRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { if (operands[0].getType() != operands[1].getType()) { return emitOptionalError(location, "operand type mismatch ", operands[0].getType(), " vs ", operands[1].getType()); } inferredReturnTypes.assign({operands[0].getType()}); return success(); } LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( MLIRContext *context, Optional location, ValueRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { // Create return type consisting of the last element of the first operand. auto operandType = *operands.getTypes().begin(); auto sval = operandType.dyn_cast(); if (!sval) { return emitOptionalError(location, "only shaped type operands allowed"); } int64_t dim = sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize; auto type = IntegerType::get(17, context); inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type)); return success(); } LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes( OpBuilder &builder, llvm::SmallVectorImpl &shapes) { shapes = SmallVector{ builder.createOrFold(getLoc(), getOperand(0), 0)}; return success(); } //===----------------------------------------------------------------------===// // Test SideEffect interfaces //===----------------------------------------------------------------------===// namespace { /// A test resource for side effects. struct TestResource : public SideEffects::Resource::Base { StringRef getName() final { return ""; } }; } // end anonymous namespace void SideEffectOp::getEffects( SmallVectorImpl &effects) { // Check for an effects attribute on the op instance. ArrayAttr effectsAttr = getAttrOfType("effects"); if (!effectsAttr) return; // If there is one, it is an array of dictionary attributes that hold // information on the effects of this operation. for (Attribute element : effectsAttr) { DictionaryAttr effectElement = element.cast(); // Get the specific memory effect. MemoryEffects::Effect *effect = StringSwitch( effectElement.get("effect").cast().getValue()) .Case("allocate", MemoryEffects::Allocate::get()) .Case("free", MemoryEffects::Free::get()) .Case("read", MemoryEffects::Read::get()) .Case("write", MemoryEffects::Write::get()); // Check for a result to affect. Value value; if (effectElement.get("on_result")) value = getResult(); // Check for a non-default resource to use. SideEffects::Resource *resource = SideEffects::DefaultResource::get(); if (effectElement.get("test_resource")) resource = TestResource::get(); effects.emplace_back(effect, value, resource); } } //===----------------------------------------------------------------------===// // StringAttrPrettyNameOp //===----------------------------------------------------------------------===// // This op has fancy handling of its SSA result name. static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser, OperationState &result) { // Add the result types. for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) result.addTypes(parser.getBuilder().getIntegerType(32)); if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) return failure(); // If the attribute dictionary contains no 'names' attribute, infer it from // the SSA name (if specified). bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) { return attr.first == "names"; }); // If there was no name specified, check to see if there was a useful name // specified in the asm file. if (hadNames || parser.getNumResults() == 0) return success(); SmallVector names; auto *context = result.getContext(); for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) { auto resultName = parser.getResultName(i); StringRef nameStr; if (!resultName.first.empty() && !isdigit(resultName.first[0])) nameStr = resultName.first; names.push_back(nameStr); } auto namesAttr = parser.getBuilder().getStrArrayAttr(names); result.attributes.push_back({Identifier::get("names", context), namesAttr}); return success(); } static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) { p << "test.string_attr_pretty_name"; // Note that we only need to print the "name" attribute if the asmprinter // result name disagrees with it. This can happen in strange cases, e.g. // when there are conflicts. bool namesDisagree = op.names().size() != op.getNumResults(); SmallString<32> resultNameStr; for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) { resultNameStr.clear(); llvm::raw_svector_ostream tmpStream(resultNameStr); p.printOperand(op.getResult(i), tmpStream); auto expectedName = op.names()[i].dyn_cast(); if (!expectedName || tmpStream.str().drop_front() != expectedName.getValue()) { namesDisagree = true; } } if (namesDisagree) p.printOptionalAttrDictWithKeyword(op.getAttrs()); else p.printOptionalAttrDictWithKeyword(op.getAttrs(), {"names"}); } // We set the SSA name in the asm syntax to the contents of the name // attribute. void StringAttrPrettyNameOp::getAsmResultNames( function_ref setNameFn) { auto value = names(); for (size_t i = 0, e = value.size(); i != e; ++i) if (auto str = value[i].dyn_cast()) if (!str.getValue().empty()) setNameFn(getResult(i), str.getValue()); } //===----------------------------------------------------------------------===// // RegionIfOp //===----------------------------------------------------------------------===// static void print(OpAsmPrinter &p, RegionIfOp op) { p << RegionIfOp::getOperationName() << " "; p.printOperands(op.getOperands()); p << ": " << op.getOperandTypes(); p.printArrowTypeList(op.getResultTypes()); p << " then"; p.printRegion(op.thenRegion(), /*printEntryBlockArgs=*/true, /*printBlockTerminators=*/true); p << " else"; p.printRegion(op.elseRegion(), /*printEntryBlockArgs=*/true, /*printBlockTerminators=*/true); p << " join"; p.printRegion(op.joinRegion(), /*printEntryBlockArgs=*/true, /*printBlockTerminators=*/true); } static ParseResult parseRegionIfOp(OpAsmParser &parser, OperationState &result) { SmallVector operandInfos; SmallVector operandTypes; result.regions.reserve(3); Region *thenRegion = result.addRegion(); Region *elseRegion = result.addRegion(); Region *joinRegion = result.addRegion(); // Parse operand, type and arrow type lists. if (parser.parseOperandList(operandInfos) || parser.parseColonTypeList(operandTypes) || parser.parseArrowTypeList(result.types)) return failure(); // Parse all attached regions. if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) || parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) || parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {})) return failure(); return parser.resolveOperands(operandInfos, operandTypes, parser.getCurrentLocation(), result.operands); } OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) { assert(index < 2 && "invalid region index"); return getOperands(); } void RegionIfOp::getSuccessorRegions( Optional index, ArrayRef operands, SmallVectorImpl ®ions) { // We always branch to the join region. if (index.hasValue()) { if (index.getValue() < 2) regions.push_back(RegionSuccessor(&joinRegion(), getJoinArgs())); else regions.push_back(RegionSuccessor(getResults())); return; } // The then and else regions are the entry regions of this op. regions.push_back(RegionSuccessor(&thenRegion(), getThenArgs())); regions.push_back(RegionSuccessor(&elseRegion(), getElseArgs())); } #include "TestOpEnums.cpp.inc" #include "TestOpStructs.cpp.inc" #include "TestTypeInterfaces.cpp.inc" #define GET_OP_CLASSES #include "TestOps.cpp.inc" diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index aef39a9e19fe..fcc677361dcc 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1,1794 +1,1810 @@ //===-- TestOps.td - Test dialect operation definitions ----*- tablegen -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #ifndef TEST_OPS #define TEST_OPS include "mlir/IR/OpBase.td" include "mlir/IR/OpAsmInterface.td" include "mlir/IR/RegionKindInterface.td" include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" def Test_Dialect : Dialect { let name = "test"; let cppNamespace = "::mlir"; let hasOperationAttrVerify = 1; let hasRegionArgAttrVerify = 1; let hasRegionResultAttrVerify = 1; } class TEST_Op traits = []> : Op; //===----------------------------------------------------------------------===// // Test Types //===----------------------------------------------------------------------===// def IntTypesOp : TEST_Op<"int_types"> { let results = (outs AnyI16:$any_i16, SI32:$si32, UI64:$ui64, AnyInteger:$any_int ); } def ComplexF64 : Complex; def ComplexOp : TEST_Op<"complex_f64"> { let results = (outs ComplexF64); } def ComplexTensorOp : TEST_Op<"complex_f64_tensor"> { let results = (outs TensorOf<[ComplexF64]>); } def TupleOp : TEST_Op<"tuple_32_bit"> { let results = (outs TupleOf<[I32, F32]>); } def NestedTupleOp : TEST_Op<"nested_tuple_32_bit"> { let results = (outs NestedTupleOf<[I32, F32]>); } def TakesStaticMemRefOp : TEST_Op<"takes_static_memref"> { let arguments = (ins AnyStaticShapeMemRef:$x); } def RankLessThan2I8F32MemRefOp : TEST_Op<"rank_less_than_2_I8_F32_memref"> { let results = (outs MemRefRankOf<[I8, F32], [0, 1]>); } def NDTensorOfOp : TEST_Op<"nd_tensor_of"> { let arguments = (ins 0DTensorOf<[F32]>:$arg0, 1DTensorOf<[F32]>:$arg1, 2DTensorOf<[I16]>:$arg2, 3DTensorOf<[I16]>:$arg3, 4DTensorOf<[I16]>:$arg4 ); } def RankedTensorOp : TEST_Op<"ranked_tensor_op"> { let arguments = (ins AnyRankedTensor:$input); } def MultiTensorRankOf : TEST_Op<"multi_tensor_rank_of"> { let arguments = (ins TensorRankOf<[I8, I32, F32], [0, 1]>:$arg0 ); } def TEST_TestType : DialectType()">, "test">, BuildableType<"$_builder.getType<::mlir::TestType>()">; //===----------------------------------------------------------------------===// // Test Symbols //===----------------------------------------------------------------------===// def SymbolOp : TEST_Op<"symbol", [Symbol]> { let summary = "operation which defines a new symbol"; let arguments = (ins StrAttr:$sym_name, OptionalAttr:$sym_visibility); } def SymbolScopeOp : TEST_Op<"symbol_scope", [SymbolTable, SingleBlockImplicitTerminator<"TerminatorOp">]> { let summary = "operation which defines a new symbol table"; let regions = (region SizedRegion<1>:$region); } def SymbolTableRegionOp : TEST_Op<"symbol_table_region", [SymbolTable]> { let summary = "operation which defines a new symbol table without a " "restriction on a terminator"; let regions = (region SizedRegion<1>:$region); } //===----------------------------------------------------------------------===// // Test Operands //===----------------------------------------------------------------------===// def MixedNormalVariadicOperandOp : TEST_Op< "mixed_normal_variadic_operand", [SameVariadicOperandSize]> { let arguments = (ins Variadic:$input1, AnyTensor:$input2, Variadic:$input3 ); } def VariadicWithSameOperandsResult : TEST_Op<"variadic_with_same_operand_results", [SameOperandsAndResultType]> { let arguments = (ins Variadic:$operands); let results = (outs AnySignlessInteger:$result); } //===----------------------------------------------------------------------===// // Test Results //===----------------------------------------------------------------------===// def MixedNormalVariadicResults : TEST_Op< "mixed_normal_variadic_result", [SameVariadicResultSize]> { let results = (outs Variadic:$output1, AnyTensor:$output2, Variadic:$output3 ); } //===----------------------------------------------------------------------===// // Test Attributes //===----------------------------------------------------------------------===// def NonNegIntAttrOp : TEST_Op<"non_negative_int_attr"> { let arguments = (ins Confined:$i32attr, Confined:$i64attr ); } def PositiveIntAttrOp : TEST_Op<"positive_int_attr"> { let arguments = (ins Confined:$i32attr, Confined:$i64attr ); } def TypeArrayAttrOp : TEST_Op<"type_array_attr"> { let arguments = (ins TypeArrayAttr:$attr); } def TypeArrayAttrWithDefaultOp : TEST_Op<"type_array_attr_with_default"> { let arguments = (ins DefaultValuedAttr:$attr); } def TypeStringAttrWithTypeOp : TEST_Op<"string_attr_with_type"> { let arguments = (ins TypedStrAttr:$attr); let assemblyFormat = "$attr attr-dict"; } def StrCaseA: StrEnumAttrCase<"A">; def StrCaseB: StrEnumAttrCase<"B">; def SomeStrEnum: StrEnumAttr< "SomeStrEnum", "", [StrCaseA, StrCaseB]>; def StrEnumAttrOp : TEST_Op<"str_enum_attr"> { let arguments = (ins SomeStrEnum:$attr); let results = (outs I32:$val); } def I32Case5: I32EnumAttrCase<"case5", 5>; def I32Case10: I32EnumAttrCase<"case10", 10>; def SomeI32Enum: I32EnumAttr< "SomeI32Enum", "", [I32Case5, I32Case10]>; def I32EnumAttrOp : TEST_Op<"i32_enum_attr"> { let arguments = (ins SomeI32Enum:$attr); let results = (outs I32:$val); } def I64Case5: I64EnumAttrCase<"case5", 5>; def I64Case10: I64EnumAttrCase<"case10", 10>; def SomeI64Enum: I64EnumAttr< "SomeI64Enum", "", [I64Case5, I64Case10]>; def I64EnumAttrOp : TEST_Op<"i64_enum_attr"> { let arguments = (ins SomeI64Enum:$attr); let results = (outs I32:$val); } def SomeStructAttr : StructAttr<"SomeStructAttr", Test_Dialect, [ StructFieldAttr<"some_field", I64Attr>, StructFieldAttr<"some_other_field", I64Attr> ]> {} def StructAttrOp : TEST_Op<"struct_attr"> { let arguments = (ins SomeStructAttr:$the_struct_attr); let results = (outs); } def IntAttrOp : TEST_Op<"int_attrs"> { let arguments = (ins AnyI32Attr:$any_i32_attr, IndexAttr:$index_attr, UI32Attr:$ui32_attr, SI32Attr:$si32_attr ); } def FloatElementsAttrOp : TEST_Op<"float_elements_attr"> { let arguments = (ins RankedF32ElementsAttr<[2]>:$scalar_f32_attr, RankedF64ElementsAttr<[4, 8]>:$tensor_f64_attr ); } // A pattern that updates dense<[3.0, 4.0]> to dense<[5.0, 6.0]>. // This tests both matching and generating float elements attributes. def UpdateFloatElementsAttr : Pat< (FloatElementsAttrOp ConstantAttr, "{3.0f, 4.0f}">:$f32attr, $f64attr), (FloatElementsAttrOp ConstantAttr, "{5.0f, 6.0f}">:$f32attr, $f64attr)>; def IntElementsAttrOp : TEST_Op<"int_elements_attr"> { let arguments = (ins AnyI32ElementsAttr:$any_i32_attr, I32ElementsAttr:$i32_attr ); } def RankedIntElementsAttrOp : TEST_Op<"ranked_int_elements_attr"> { let arguments = (ins RankedI32ElementsAttr<[2]>:$vector_i32_attr, RankedI64ElementsAttr<[4, 8]>:$matrix_i64_attr ); } def DerivedTypeAttrOp : TEST_Op<"derived_type_attr", []> { let results = (outs AnyTensor:$output); DerivedTypeAttr element_dtype = DerivedTypeAttr<"return getElementTypeOrSelf(output().getType());">; DerivedAttr size = DerivedAttr<"int", "return output().getType().cast().getSizeInBits();", "$_builder.getI32IntegerAttr($_self)">; } def StringElementsAttrOp : TEST_Op<"string_elements_attr"> { let arguments = (ins StringElementsAttr:$scalar_string_attr ); } //===----------------------------------------------------------------------===// // Test Attribute Constraints //===----------------------------------------------------------------------===// def SymbolRefOp : TEST_Op<"symbol_ref_attr"> { let arguments = (ins Confined]>:$symbol ); } //===----------------------------------------------------------------------===// // Test Regions //===----------------------------------------------------------------------===// def OneRegionOp : TEST_Op<"one_region_op", []> { let regions = (region AnyRegion); } def TwoRegionOp : TEST_Op<"two_region_op", []> { let regions = (region AnyRegion, AnyRegion); } def SizedRegionOp : TEST_Op<"sized_region_op", []> { let regions = (region SizedRegion<2>:$my_region, SizedRegion<1>); } //===----------------------------------------------------------------------===// // Test Call Interfaces //===----------------------------------------------------------------------===// def ConversionCallOp : TEST_Op<"conversion_call_op", [CallOpInterface]> { let arguments = (ins Variadic:$inputs, SymbolRefAttr:$callee); let results = (outs Variadic); let extraClassDeclaration = [{ /// Get the argument operands to the called function. operand_range getArgOperands() { return inputs(); } /// Return the callee of this operation. CallInterfaceCallable getCallableForCallee() { return getAttrOfType("callee"); } }]; } def FunctionalRegionOp : TEST_Op<"functional_region_op", [CallableOpInterface]> { let regions = (region AnyRegion:$body); let results = (outs FunctionType); let extraClassDeclaration = [{ Region *getCallableRegion() { return &body(); } ArrayRef getCallableResults() { return getType().cast().getResults(); } }]; } def FoldToCallOp : TEST_Op<"fold_to_call_op"> { let arguments = (ins FlatSymbolRefAttr:$callee); let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// // Test Traits //===----------------------------------------------------------------------===// def SameOperandElementTypeOp : TEST_Op<"same_operand_element_type", [SameOperandsElementType]> { let arguments = (ins AnyType, AnyType); let results = (outs AnyType); } def SameOperandAndResultElementTypeOp : TEST_Op<"same_operand_and_result_element_type", [SameOperandsAndResultElementType]> { let arguments = (ins Variadic); let results = (outs Variadic); } def SameOperandShapeOp : TEST_Op<"same_operand_shape", [SameOperandsShape]> { let arguments = (ins Variadic); } def SameOperandAndResultShapeOp : TEST_Op<"same_operand_and_result_shape", [SameOperandsAndResultShape]> { let arguments = (ins Variadic); let results = (outs Variadic); } def SameOperandAndResultTypeOp : TEST_Op<"same_operand_and_result_type", [SameOperandsAndResultType]> { let arguments = (ins Variadic); let results = (outs Variadic); } def ArgAndResHaveFixedElementTypesOp : TEST_Op<"arg_and_res_have_fixed_element_types", [PredOpTrait<"fixed type combination", And<[ElementTypeIsPred<"x", I32>, ElementTypeIsPred<"y", F32>]>>, ElementTypeIs<"res", I16>]> { let arguments = (ins AnyShaped:$x, AnyShaped:$y); let results = (outs AnyShaped:$res); } def OperandsHaveSameElementType : TEST_Op<"operands_have_same_element_type", [ AllElementTypesMatch<["x", "y"]>]> { let arguments = (ins AnyType:$x, AnyType:$y); } def OperandZeroAndResultHaveSameElementType : TEST_Op< "operand0_and_result_have_same_element_type", [AllElementTypesMatch<["x", "res"]>]> { let arguments = (ins AnyType:$x, AnyType:$y); let results = (outs AnyType:$res); } def OperandsHaveSameType : TEST_Op<"operands_have_same_type", [AllTypesMatch<["x", "y"]>]> { let arguments = (ins AnyType:$x, AnyType:$y); } def OperandZeroAndResultHaveSameType : TEST_Op<"operand0_and_result_have_same_type", [AllTypesMatch<["x", "res"]>]> { let arguments = (ins AnyType:$x, AnyType:$y); let results = (outs AnyType:$res); } def OperandsHaveSameRank : TEST_Op<"operands_have_same_rank", [AllRanksMatch<["x", "y"]>]> { let arguments = (ins AnyShaped:$x, AnyShaped:$y); } def OperandZeroAndResultHaveSameRank : TEST_Op<"operand0_and_result_have_same_rank", [AllRanksMatch<["x", "res"]>]> { let arguments = (ins AnyShaped:$x, AnyShaped:$y); let results = (outs AnyShaped:$res); } def OperandZeroAndResultHaveSameShape : TEST_Op<"operand0_and_result_have_same_shape", [AllShapesMatch<["x", "res"]>]> { let arguments = (ins AnyShaped:$x, AnyShaped:$y); let results = (outs AnyShaped:$res); } def OperandZeroAndResultHaveSameElementCount : TEST_Op<"operand0_and_result_have_same_element_count", [AllElementCountsMatch<["x", "res"]>]> { let arguments = (ins AnyShaped:$x, AnyShaped:$y); let results = (outs AnyShaped:$res); } def FourEqualsFive : TEST_Op<"four_equals_five", [AllMatch<["5", "4"], "4 equals 5">]>; def OperandRankEqualsResultSize : TEST_Op<"operand_rank_equals_result_size", [AllMatch<[Rank<"operand">.result, ElementCount<"result">.result], "operand rank equals result size">]> { let arguments = (ins AnyShaped:$operand); let results = (outs AnyShaped:$result); } def IfFirstOperandIsNoneThenSoIsSecond : TEST_Op<"if_first_operand_is_none_then_so_is_second", [PredOpTrait< "has either both none type operands or first is not none", Or<[ And<[TypeIsPred<"x", NoneType>, TypeIsPred<"y", NoneType>]>, Neg>]>>]> { let arguments = (ins AnyType:$x, AnyType:$y); } def BroadcastableOp : TEST_Op<"broadcastable", [ResultsBroadcastableShape]> { let arguments = (ins Variadic); let results = (outs AnyTensor); } // HasParent trait def ParentOp : TEST_Op<"parent"> { let regions = (region AnyRegion); } def ChildOp : TEST_Op<"child", [HasParent<"ParentOp">]>; // ParentOneOf trait def ParentOp1 : TEST_Op<"parent1"> { let regions = (region AnyRegion); } def ChildWithParentOneOf : TEST_Op<"child_with_parent_one_of", [ParentOneOf<["ParentOp", "ParentOp1"]>]>; def TerminatorOp : TEST_Op<"finish", [Terminator]>; def SingleBlockImplicitTerminatorOp : TEST_Op<"SingleBlockImplicitTerminator", [SingleBlockImplicitTerminator<"TerminatorOp">]> { let regions = (region SizedRegion<1>:$region); } def I32ElementsAttrOp : TEST_Op<"i32ElementsAttr"> { let arguments = (ins I32ElementsAttr:$attr); } def IndexElementsAttrOp : TEST_Op<"indexElementsAttr"> { let arguments = (ins IndexElementsAttr:$attr); } def OpWithInferTypeInterfaceOp : TEST_Op<"op_with_infer_type_if", [ DeclareOpInterfaceMethods]> { let arguments = (ins AnyTensor, AnyTensor); let results = (outs AnyTensor); } def InferTensorType : NativeOpTrait<"InferTensorType">; def OpWithShapedTypeInferTypeInterfaceOp : TEST_Op<"op_with_shaped_type_infer_type_if", [ // Op implements infer type op interface. InferTypeOpInterface, // The op will have methods implementing the ShapedType type infer interface. DeclareOpInterfaceMethods, // The op produces tensors and will use the ShapedType type infer interface // along with knowledge that it is producing Tensors to infer shape. InferTensorType ]> { let arguments = (ins AnyTensor, AnyTensor); let results = (outs AnyTensor); let extraClassDeclaration = [{ LogicalResult reifyReturnTypeShapes(OpBuilder &builder, SmallVectorImpl &shapes); }]; } def IsNotScalar : Constraint>; def UpdateAttr : Pat<(I32ElementsAttrOp $attr), (I32ElementsAttrOp ConstantAttr), [(IsNotScalar $attr)]>; def TestBranchOp : TEST_Op<"br", [DeclareOpInterfaceMethods, Terminator]> { let arguments = (ins Variadic:$targetOperands); let successors = (successor AnySuccessor:$target); } def AttrSizedOperandOp : TEST_Op<"attr_sized_operands", [AttrSizedOperandSegments]> { let arguments = (ins Variadic:$a, Variadic:$b, I32:$c, Variadic:$d, I32ElementsAttr:$operand_segment_sizes ); } def AttrSizedResultOp : TEST_Op<"attr_sized_results", [AttrSizedResultSegments]> { let arguments = (ins I32ElementsAttr:$result_segment_sizes ); let results = (outs Variadic:$a, Variadic:$b, I32:$c, Variadic:$d ); } // This is used to test encoding of a string attribute into an SSA name of a // pretty printed value name. def StringAttrPrettyNameOp : TEST_Op<"string_attr_pretty_name", [DeclareOpInterfaceMethods]> { let arguments = (ins StrArrayAttr:$names); let results = (outs Variadic:$r); let printer = [{ return ::print(p, *this); }]; let parser = [{ return ::parse$cppClass(parser, result); }]; } //===----------------------------------------------------------------------===// // Test Locations //===----------------------------------------------------------------------===// def TestLocationSrcOp : TEST_Op<"loc_src"> { let arguments = (ins I32:$input); let results = (outs I32:$output); } def TestLocationDstOp : TEST_Op<"loc_dst", [SameOperandsAndResultType]> { let arguments = (ins I32:$input); let results = (outs I32:$output); } //===----------------------------------------------------------------------===// // Test Patterns //===----------------------------------------------------------------------===// def OpA : TEST_Op<"op_a"> { let arguments = (ins I32, I32Attr:$attr); let results = (outs I32); } def OpB : TEST_Op<"op_b"> { let arguments = (ins I32, I32Attr:$attr); let results = (outs I32); } // Test named pattern. def TestNamedPatternRule : Pat<(OpA $input, $attr), (OpB $input, $attr)>; // Test with fused location. def : Pat<(OpA (OpA $input, $attr), $bttr), (OpB $input, $bttr)>; // Test added benefit. def OpD : TEST_Op<"op_d">, Arguments<(ins I32)>, Results<(outs I32)>; def OpE : TEST_Op<"op_e">, Arguments<(ins I32)>, Results<(outs I32)>; def OpF : TEST_Op<"op_f">, Arguments<(ins I32)>, Results<(outs I32)>; def OpG : TEST_Op<"op_g">, Arguments<(ins I32)>, Results<(outs I32)>; // Verify that bumping benefit results in selecting different op. def : Pat<(OpD $input), (OpE $input)>; def : Pat<(OpD $input), (OpF $input), [], (addBenefit 10)>; // Verify that patterns with more source nodes are selected before those with fewer. def : Pat<(OpG $input), (OpB $input, ConstantAttr:$attr)>; def : Pat<(OpG (OpG $input)), (OpB $input, ConstantAttr:$attr)>; // Test patterns for zero-result op. def OpH : TEST_Op<"op_h">, Arguments<(ins I32)>, Results<(outs)>; def OpI : TEST_Op<"op_i">, Arguments<(ins I32)>, Results<(outs)>; def : Pat<(OpH $input), (OpI $input)>; // Test patterns for zero-input op. def OpJ : TEST_Op<"op_j">, Arguments<(ins)>, Results<(outs I32)>; def OpK : TEST_Op<"op_k">, Arguments<(ins)>, Results<(outs I32)>; def : Pat<(OpJ), (OpK)>; // Test that natives calls are only called once during rewrites. def OpM : TEST_Op<"op_m"> { let arguments = (ins I32, OptionalAttr:$optional_attr); let results = (outs I32); } def OpN : TEST_Op<"op_n"> { let arguments = (ins I32, I32); let results = (outs I32); } def OpO : TEST_Op<"op_o"> { let arguments = (ins I32); let results = (outs I32); } def OpP : TEST_Op<"op_p"> { let arguments = (ins I32, I32, I32, I32, I32, I32); let results = (outs I32); } // Test same operand name enforces equality condition check. def TestEqualArgsPattern : Pat<(OpN $a, $a), (OpO $a)>; // Test when equality is enforced at different depth. def TestNestedOpEqualArgsPattern : Pat<(OpN $b, (OpP $a, $b, $c, $d, $e, $f)), (replaceWithValue $b)>; // Test multiple equal arguments check enforced. def TestMultipleEqualArgsPattern : Pat<(OpP $a, $b, $a, $a, $b, $c), (OpN $c, $b)>; // Test for memrefs normalization of an op with normalizable memrefs. def OpNorm : TEST_Op<"op_norm", [MemRefsNormalizable]> { let arguments = (ins AnyMemRef:$X, AnyMemRef:$Y); } // Test for memrefs normalization of an op without normalizable memrefs. def OpNonNorm : TEST_Op<"op_nonnorm"> { let arguments = (ins AnyMemRef:$X, AnyMemRef:$Y); } // Test for memrefs normalization of an op with a reference to a function // symbol. def OpFuncRef : TEST_Op<"op_funcref"> { let summary = "Test op with a reference to a function symbol"; let description = [{ The "test.op_funcref" is a test op with a reference to a function symbol. }]; let builders = [OpBuilder<[{FuncOp function}]>]; } // Pattern add the argument plus a increasing static number hidden in // OpMTest function. That value is set into the optional argument. // That way, we will know if operations is called once or twice. def OpMGetNullAttr : NativeCodeCall<"Attribute()">; def OpMAttributeIsNull : Constraint, "Attribute is null">; def OpMVal : NativeCodeCall<"OpMTest($_builder, $0)">; def : Pat<(OpM $attr, $optAttr), (OpM $attr, (OpMVal $attr) ), [(OpMAttributeIsNull:$optAttr)]>; // Test `$_` for ignoring op argument match. def TestIgnoreArgMatchSrcOp : TEST_Op<"ignore_arg_match_src"> { let arguments = (ins AnyType:$a, AnyType:$b, AnyType:$c, AnyAttr:$d, AnyAttr:$e, AnyAttr:$f); } def TestIgnoreArgMatchDstOp : TEST_Op<"ignore_arg_match_dst"> { let arguments = (ins AnyType:$b, AnyAttr:$f); } def : Pat<(TestIgnoreArgMatchSrcOp $_, $b, I32, I64Attr:$_, $_, $f), (TestIgnoreArgMatchDstOp $b, $f)>; def OpInterleavedOperandAttribute1 : TEST_Op<"interleaved_operand_attr1"> { let arguments = (ins I32:$input1, I64Attr:$attr1, I32:$input2, I64Attr:$attr2 ); } def OpInterleavedOperandAttribute2 : TEST_Op<"interleaved_operand_attr2"> { let arguments = (ins I32:$input1, I64Attr:$attr1, I32:$input2, I64Attr:$attr2 ); } def ManyArgsOp : TEST_Op<"many_arguments"> { let arguments = (ins I32:$input1, I32:$input2, I32:$input3, I32:$input4, I32:$input5, I32:$input6, I32:$input7, I32:$input8, I32:$input9, I64Attr:$attr1, I64Attr:$attr2, I64Attr:$attr3, I64Attr:$attr4, I64Attr:$attr5, I64Attr:$attr6, I64Attr:$attr7, I64Attr:$attr8, I64Attr:$attr9 ); } // Test that DRR does not blow up when seeing lots of arguments. def : Pat<(ManyArgsOp $input1, $input2, $input3, $input4, $input5, $input6, $input7, $input8, $input9, ConstantAttr, $attr2, $attr3, $attr4, $attr5, $attr6, $attr7, $attr8, $attr9), (ManyArgsOp $input1, $input2, $input3, $input4, $input5, $input6, $input7, $input8, $input9, ConstantAttr, $attr2, $attr3, $attr4, $attr5, $attr6, $attr7, $attr8, $attr9)>; // Test that we can capture and reference interleaved operands and attributes. def : Pat<(OpInterleavedOperandAttribute1 $input1, $attr1, $input2, $attr2), (OpInterleavedOperandAttribute2 $input1, $attr1, $input2, $attr2)>; // Test NativeCodeCall. def OpNativeCodeCall1 : TEST_Op<"native_code_call1"> { let arguments = (ins I32:$input1, I32:$input2, BoolAttr:$choice, I64Attr:$attr1, I64Attr:$attr2 ); let results = (outs I32); } def OpNativeCodeCall2 : TEST_Op<"native_code_call2"> { let arguments = (ins I32:$input, I64ArrayAttr:$attr); let results = (outs I32); } // Native code call to invoke a C++ function def CreateOperand: NativeCodeCall<"chooseOperand($0, $1, $2)">; // Native code call to invoke a C++ expression def CreateArrayAttr: NativeCodeCall<"$_builder.getArrayAttr({$0, $1})">; // Test that we can use NativeCodeCall to create operand and attribute. // This pattern chooses between $input1 and $input2 according to $choice and // it combines $attr1 and $attr2 into an array attribute. def : Pat<(OpNativeCodeCall1 $input1, $input2, ConstBoolAttrTrue:$choice, $attr1, $attr2), (OpNativeCodeCall2 (CreateOperand $input1, $input2, $choice), (CreateArrayAttr $attr1, $attr2))>; // Note: the following is just for testing purpose. // Should use the replaceWithValue directive instead. def UseOpResult: NativeCodeCall<"$0">; // Test that we can use NativeCodeCall to create result. def : Pat<(OpNativeCodeCall1 $input1, $input2, ConstBoolAttrFalse, $attr1, $attr2), (UseOpResult $input2)>; def OpNativeCodeCall3 : TEST_Op<"native_code_call3"> { let arguments = (ins I32:$input); let results = (outs I32); } // Test that NativeCodeCall is not ignored if it is not used to directly // replace the matched root op. def : Pattern<(OpNativeCodeCall3 $input), [(NativeCodeCall<"createOpI($_builder, $_loc, $0)"> $input), (OpK)]>; // Test AllAttrConstraintsOf. def OpAllAttrConstraint1 : TEST_Op<"all_attr_constraint_of1"> { let arguments = (ins I64ArrayAttr:$attr); let results = (outs I32); } def OpAllAttrConstraint2 : TEST_Op<"all_attr_constraint_of2"> { let arguments = (ins I64ArrayAttr:$attr); let results = (outs I32); } def Constraint0 : AttrConstraint< CPred<"$_self.cast()[0]." "cast().getInt() == 0">, "[0] == 0">; def Constraint1 : AttrConstraint< CPred<"$_self.cast()[1].cast().getInt() == 1">, "[1] == 1">; def : Pat<(OpAllAttrConstraint1 AllAttrConstraintsOf<[Constraint0, Constraint1]>:$attr), (OpAllAttrConstraint2 $attr)>; // Op for testing RewritePattern removing op with inner ops. def TestOpWithRegionPattern : TEST_Op<"op_with_region_pattern"> { let regions = (region SizedRegion<1>:$region); let hasCanonicalizer = 1; } +def TestOpConstant : TEST_Op<"constant", [ConstantLike, NoSideEffect]> { + let arguments = (ins AnyAttr:$value); + let results = (outs AnyType); + let extraClassDeclaration = [{ + Attribute getValue() { return getAttr("value"); } + }]; + + let hasFolder = 1; +} + +def OpR : TEST_Op<"op_r">, Arguments<(ins AnyInteger, AnyInteger)>, Results<(outs AnyInteger)>; +def OpS : TEST_Op<"op_s">, Arguments<(ins AnyInteger, AnyAttr:$value)>, Results<(outs AnyInteger)>; + +def : Pat<(OpR $input1, (ConstantLikeMatcher I32Attr:$input2)), + (OpS:$unused $input1, $input2)>; + // Op for testing trivial removal via folding of op with inner ops and no uses. def TestOpWithRegionFoldNoSideEffect : TEST_Op< "op_with_region_fold_no_side_effect", [NoSideEffect]> { let regions = (region SizedRegion<1>:$region); } // Op for testing folding of outer op with inner ops. def TestOpWithRegionFold : TEST_Op<"op_with_region_fold"> { let arguments = (ins I32:$operand); let results = (outs I32); let regions = (region SizedRegion<1>:$region); let hasFolder = 1; } def TestOpWithVariadicResultsAndFolder: TEST_Op<"op_with_variadic_results_and_folder"> { let arguments = (ins Variadic:$operands); let results = (outs Variadic); let hasFolder = 1; } def TestCommutativeOp : TEST_Op<"op_commutative", [Commutative]> { let arguments = (ins I32:$op1, I32:$op2, I32:$op3, I32:$op4); let results = (outs I32); } def TestInvolutionTraitNoOperationFolderOp : TEST_Op<"op_involution_trait_no_operation_fold", [SameOperandsAndResultType, NoSideEffect, Involution]> { let arguments = (ins I32:$op1); let results = (outs I32); } def TestInvolutionTraitFailingOperationFolderOp : TEST_Op<"op_involution_trait_failing_operation_fold", [SameOperandsAndResultType, NoSideEffect, Involution]> { let arguments = (ins I32:$op1); let results = (outs I32); let hasFolder = 1; } def TestInvolutionTraitSuccesfulOperationFolderOp : TEST_Op<"op_involution_trait_succesful_operation_fold", [SameOperandsAndResultType, NoSideEffect, Involution]> { let arguments = (ins I32:$op1); let results = (outs I32); let hasFolder = 1; } def TestOpInPlaceFoldAnchor : TEST_Op<"op_in_place_fold_anchor"> { let arguments = (ins I32); let results = (outs I32); } def TestOpInPlaceFold : TEST_Op<"op_in_place_fold"> { let arguments = (ins I32:$op, I32Attr:$attr); let results = (outs I32); let hasFolder = 1; } //===----------------------------------------------------------------------===// // Test Patterns (Symbol Binding) // Test symbol binding. def OpSymbolBindingA : TEST_Op<"symbol_binding_a", []> { let arguments = (ins I32:$operand, I64Attr:$attr); let results = (outs I32); } def OpSymbolBindingB : TEST_Op<"symbol_binding_b", []> { let arguments = (ins I32:$operand); let results = (outs I32); } def OpSymbolBindingC : TEST_Op<"symbol_binding_c", []> { let arguments = (ins I32:$operand); let results = (outs I32); let builders = OpSymbolBindingB.builders; } def OpSymbolBindingD : TEST_Op<"symbol_binding_d", []> { let arguments = (ins I32:$input1, I32:$input2, I64Attr:$attr); let results = (outs I32); } def HasOneUse: Constraint, "has one use">; def : Pattern< // Bind to source pattern op operand/attribute/result (OpSymbolBindingA:$res_a $operand, $attr), [ // Bind to auxiliary op result (OpSymbolBindingC:$res_c (OpSymbolBindingB:$res_b $operand)), // Use bound symbols in resultant ops (OpSymbolBindingD $res_b, $res_c, $attr)], // Use bound symbols in additional constraints [(HasOneUse $res_a)]>; def OpSymbolBindingNoResult : TEST_Op<"symbol_binding_no_result", []> { let arguments = (ins I32:$operand); } // Test that we can bind to an op without results and reference it later. def : Pat<(OpSymbolBindingNoResult:$op $operand), (NativeCodeCall<"handleNoResultOp($_builder, $0)"> $op)>; //===----------------------------------------------------------------------===// // Test Patterns (Attributes) // Test matching against op attributes. def OpAttrMatch1 : TEST_Op<"match_op_attribute1"> { let arguments = (ins I32Attr:$required_attr, OptionalAttr:$optional_attr, DefaultValuedAttr:$default_valued_attr, I32Attr:$more_attr ); let results = (outs I32); } def OpAttrMatch2 : TEST_Op<"match_op_attribute2"> { let arguments = OpAttrMatch1.arguments; let results = (outs I32); } def MoreConstraint : AttrConstraint< CPred<"$_self.cast().getInt() == 4">, "more constraint">; def : Pat<(OpAttrMatch1 $required, $optional, $default_valued, MoreConstraint:$more), (OpAttrMatch2 $required, $optional, $default_valued, $more)>; // Test unit attrs. def OpAttrMatch3 : TEST_Op<"match_op_attribute3"> { let arguments = (ins UnitAttr:$attr); let results = (outs I32); } def OpAttrMatch4 : TEST_Op<"match_op_attribute4"> { let arguments = (ins UnitAttr:$attr1, UnitAttr:$attr2); let results = (outs I32); } def : Pat<(OpAttrMatch3 $attr), (OpAttrMatch4 ConstUnitAttr, $attr)>; // Test with constant attr. def OpC : TEST_Op<"op_c">, Arguments<(ins I32)>, Results<(outs I32)>; def : Pat<(OpC $input), (OpB $input, ConstantAttr:$attr)>; // Test string enum attribute in rewrites. def : Pat<(StrEnumAttrOp StrCaseA), (StrEnumAttrOp StrCaseB)>; // Test integer enum attribute in rewrites. def : Pat<(I32EnumAttrOp I32Case5), (I32EnumAttrOp I32Case10)>; def : Pat<(I64EnumAttrOp I64Case5), (I64EnumAttrOp I64Case10)>; //===----------------------------------------------------------------------===// // Test Patterns (Multi-result Ops) def MultiResultOpKind1: I64EnumAttrCase<"kind1", 1>; def MultiResultOpKind2: I64EnumAttrCase<"kind2", 2>; def MultiResultOpKind3: I64EnumAttrCase<"kind3", 3>; def MultiResultOpKind4: I64EnumAttrCase<"kind4", 4>; def MultiResultOpKind5: I64EnumAttrCase<"kind5", 5>; def MultiResultOpKind6: I64EnumAttrCase<"kind6", 6>; def MultiResultOpEnum: I64EnumAttr< "MultiResultOpEnum", "Multi-result op kinds", [ MultiResultOpKind1, MultiResultOpKind2, MultiResultOpKind3, MultiResultOpKind4, MultiResultOpKind5, MultiResultOpKind6 ]>; def ThreeResultOp : TEST_Op<"three_result"> { let arguments = (ins MultiResultOpEnum:$kind); let results = (outs I32:$result1, F32:$result2, F32:$result3); } def AnotherThreeResultOp : TEST_Op<"another_three_result"> { let arguments = (ins MultiResultOpEnum:$kind); let results = (outs I32:$result1, F32:$result2, F32:$result3); } def TwoResultOp : TEST_Op<"two_result"> { let arguments = (ins MultiResultOpEnum:$kind); let results = (outs I32:$result1, F32:$result2); } def AnotherTwoResultOp : TEST_Op<"another_two_result"> { let arguments = (ins MultiResultOpEnum:$kind); let results = (outs F32:$result1, F32:$result2); } def OneResultOp1 : TEST_Op<"one_result1"> { let arguments = (ins MultiResultOpEnum:$kind); let results = (outs F32:$result1); } def OneResultOp2 : TEST_Op<"one_result2"> { let arguments = (ins MultiResultOpEnum:$kind); let results = (outs I32:$result1); } def OneResultOp3 : TEST_Op<"one_result3"> { let arguments = (ins F32); let results = (outs I32:$result1); } // Test using multi-result op as a whole def : Pat<(ThreeResultOp MultiResultOpKind1), (AnotherThreeResultOp MultiResultOpKind1)>; // Test using multi-result op as a whole for partial replacement def : Pattern<(ThreeResultOp MultiResultOpKind2), [(TwoResultOp MultiResultOpKind2), (OneResultOp1 MultiResultOpKind2)]>; def : Pattern<(ThreeResultOp MultiResultOpKind3), [(OneResultOp2 MultiResultOpKind3), (AnotherTwoResultOp MultiResultOpKind3)]>; // Test using results separately in a multi-result op def : Pattern<(ThreeResultOp MultiResultOpKind4), [(TwoResultOp:$res1__0 MultiResultOpKind4), (OneResultOp1 MultiResultOpKind4), (TwoResultOp:$res2__1 MultiResultOpKind4)]>; // Test referencing a single value in the value pack // This rule only matches TwoResultOp if its second result has no use. def : Pattern<(TwoResultOp:$res MultiResultOpKind5), [(OneResultOp2 MultiResultOpKind5), (OneResultOp1 MultiResultOpKind5)], [(HasNoUseOf:$res__1)]>; // Test using auxiliary ops for replacing multi-result op def : Pattern< (ThreeResultOp MultiResultOpKind6), [ // Auxiliary op generated to help building the final result but not // directly used to replace the source op's results. (TwoResultOp:$interm MultiResultOpKind6), (OneResultOp3 $interm__1), (AnotherTwoResultOp MultiResultOpKind6) ]>; //===----------------------------------------------------------------------===// // Test Patterns (Variadic Ops) def OneVResOneVOperandOp1 : TEST_Op<"one_variadic_out_one_variadic_in1"> { let arguments = (ins Variadic); let results = (outs Variadic); } def OneVResOneVOperandOp2 : TEST_Op<"one_variadic_out_one_variadic_in2"> { let arguments = (ins Variadic); let results = (outs Variadic); } // Rewrite an op with one variadic operand and one variadic result to // another similar op. def : Pat<(OneVResOneVOperandOp1 $inputs), (OneVResOneVOperandOp2 $inputs)>; def MixedVOperandOp1 : TEST_Op<"mixed_variadic_in1", [SameVariadicOperandSize]> { let arguments = (ins Variadic:$input1, F32:$input2, Variadic:$input3 ); } def MixedVOperandOp2 : TEST_Op<"mixed_variadic_in2", [SameVariadicOperandSize]> { let arguments = (ins Variadic:$input1, F32:$input2, Variadic:$input3 ); } // Rewrite an op with both variadic operands and normal operands. def : Pat<(MixedVOperandOp1 $input1, $input2, $input3), (MixedVOperandOp2 $input1, $input2, $input3)>; def MixedVResultOp1 : TEST_Op<"mixed_variadic_out1", [SameVariadicResultSize]> { let results = (outs Variadic:$output1, F32:$output2, Variadic:$output3 ); } def MixedVResultOp2 : TEST_Op<"mixed_variadic_out2", [SameVariadicResultSize]> { let results = (outs Variadic:$output1, F32:$output2, Variadic:$output3 ); } // Rewrite an op with both variadic results and normal results. // Note that because we are generating the op with a top-level result pattern, // we are able to deduce the correct result types for the generated op using // the information from the matched root op. def : Pat<(MixedVResultOp1), (MixedVResultOp2)>; def OneI32ResultOp : TEST_Op<"one_i32_out"> { let results = (outs I32); } def MixedVOperandOp3 : TEST_Op<"mixed_variadic_in3", [SameVariadicOperandSize]> { let arguments = (ins I32:$input1, Variadic:$input2, Variadic:$input3, I32Attr:$count ); let results = (outs I32); } def MixedVResultOp3 : TEST_Op<"mixed_variadic_out3", [SameVariadicResultSize]> { let arguments = (ins I32Attr:$count); let results = (outs I32:$output1, Variadic:$output2, Variadic:$output3 ); // We will use this op in a nested result pattern, where we cannot deduce the // result type. So need to provide a builder not requiring result types. let builders = [ OpBuilder< "IntegerAttr count", [{ auto i32Type = $_builder.getIntegerType(32); $_state.addTypes(i32Type); // $output1 SmallVector types(count.getInt(), i32Type); $_state.addTypes(types); // $output2 $_state.addTypes(types); // $output3 $_state.addAttribute("count", count); }]> ]; } // Generates an op with variadic results using nested pattern. def : Pat<(OneI32ResultOp), (MixedVOperandOp3 (MixedVResultOp3:$results__0 ConstantAttr), (replaceWithValue $results__1), (replaceWithValue $results__2), ConstantAttr)>; //===----------------------------------------------------------------------===// // Test Patterns (Location) // Test that we can specify locations for generated ops. def : Pat<(TestLocationSrcOp:$res1 (TestLocationSrcOp:$res2 (TestLocationSrcOp:$res3 $input))), (TestLocationDstOp (TestLocationDstOp (TestLocationDstOp $input, (location $res1)), (location "named")), (location "fused", $res2, $res3))>; //===----------------------------------------------------------------------===// // Test Legalization //===----------------------------------------------------------------------===// def Test_LegalizerEnum_Success : StrEnumAttrCase<"Success">; def Test_LegalizerEnum_Failure : StrEnumAttrCase<"Failure">; def Test_LegalizerEnum : StrEnumAttr<"Success", "Failure", [Test_LegalizerEnum_Success, Test_LegalizerEnum_Failure]>; def ILLegalOpA : TEST_Op<"illegal_op_a">, Results<(outs I32)>; def ILLegalOpB : TEST_Op<"illegal_op_b">, Results<(outs I32)>; def ILLegalOpC : TEST_Op<"illegal_op_c">, Results<(outs I32)>; def ILLegalOpD : TEST_Op<"illegal_op_d">, Results<(outs I32)>; def ILLegalOpE : TEST_Op<"illegal_op_e">, Results<(outs I32)>; def ILLegalOpF : TEST_Op<"illegal_op_f">, Results<(outs I32)>; def LegalOpA : TEST_Op<"legal_op_a">, Arguments<(ins Test_LegalizerEnum:$status)>, Results<(outs I32)>; def LegalOpB : TEST_Op<"legal_op_b">, Results<(outs I32)>; // Check that the conversion infrastructure can properly undo the creation of // operations where an operation was created before its parent, in this case, // in the parent's builder. def IllegalOpTerminator : TEST_Op<"illegal_op_terminator", [Terminator]>; def IllegalOpWithRegion : TEST_Op<"illegal_op_with_region"> { let skipDefaultBuilders = 1; let builders = [OpBuilder<"", [{ Region *bodyRegion = $_state.addRegion(); OpBuilder::InsertionGuard g($_builder); Block *body = $_builder.createBlock(bodyRegion); $_builder.setInsertionPointToEnd(body); $_builder.create($_state.location); }]>]; } def IllegalOpWithRegionAnchor : TEST_Op<"illegal_op_with_region_anchor">; // Check that smaller pattern depths are chosen, i.e. prioritize more direct // mappings. def : Pat<(ILLegalOpA), (LegalOpA Test_LegalizerEnum_Success)>; def : Pat<(ILLegalOpA), (ILLegalOpB)>; def : Pat<(ILLegalOpB), (LegalOpA Test_LegalizerEnum_Failure)>; // Check that the higher benefit pattern is taken for multiple legalizations // with the same depth. def : Pat<(ILLegalOpC), (ILLegalOpD)>; def : Pat<(ILLegalOpD), (LegalOpA Test_LegalizerEnum_Failure)>; def : Pat<(ILLegalOpC), (ILLegalOpE), [], (addBenefit 10)>; def : Pat<(ILLegalOpE), (LegalOpA Test_LegalizerEnum_Success)>; // Check that patterns use the most up-to-date value when being replaced. def TestRewriteOp : TEST_Op<"rewrite">, Arguments<(ins AnyType)>, Results<(outs AnyType)>; def : Pat<(TestRewriteOp $input), (replaceWithValue $input)>; // Check that patterns can specify bounded recursion when rewriting. def TestRecursiveRewriteOp : TEST_Op<"recursive_rewrite"> { let arguments = (ins I64Attr:$depth); let assemblyFormat = "$depth attr-dict"; } //===----------------------------------------------------------------------===// // Test Type Legalization //===----------------------------------------------------------------------===// def TestRegionBuilderOp : TEST_Op<"region_builder">; def TestReturnOp : TEST_Op<"return", [ReturnLike, Terminator]> { let arguments = (ins Variadic); let builders = [ OpBuilder<"", [{ build($_builder, $_state, {}); }]> ]; } def TestCastOp : TEST_Op<"cast">, Arguments<(ins Variadic)>, Results<(outs AnyType)>; def TestInvalidOp : TEST_Op<"invalid", [Terminator]>, Arguments<(ins Variadic)>; def TestTypeProducerOp : TEST_Op<"type_producer">, Results<(outs AnyType)>; def TestTypeConsumerOp : TEST_Op<"type_consumer">, Arguments<(ins AnyType)>; def TestValidOp : TEST_Op<"valid", [Terminator]>, Arguments<(ins Variadic)>; def TestMergeBlocksOp : TEST_Op<"merge_blocks"> { let summary = "merge_blocks operation"; let description = [{ Test op with multiple blocks that are merged with Dialect Conversion" }]; let regions = (region AnyRegion:$body); let results = (outs Variadic:$result); } //===----------------------------------------------------------------------===// // Test parser. //===----------------------------------------------------------------------===// def WrappedKeywordOp : TEST_Op<"wrapped_keyword"> { let arguments = (ins StrAttr:$keyword); let parser = [{ return ::parse$cppClass(parser, result); }]; let printer = [{ return ::print(p, *this); }]; } //===----------------------------------------------------------------------===// // Test region argument list parsing. def IsolatedRegionOp : TEST_Op<"isolated_region", [IsolatedFromAbove]> { let summary = "isolated region operation"; let description = [{ Test op with an isolated region, to test passthrough region arguments. Each argument is of index type. }]; let arguments = (ins Index); let regions = (region SizedRegion<1>:$region); let parser = [{ return ::parse$cppClass(parser, result); }]; let printer = [{ return ::print(p, *this); }]; } def SSACFGRegionOp : TEST_Op<"ssacfg_region", [ DeclareOpInterfaceMethods]> { let summary = "operation with an SSACFG region"; let description = [{ Test op that defines an SSACFG region. }]; let regions = (region VariadicRegion:$regions); let arguments = (ins Variadic); let results = (outs Variadic); } def GraphRegionOp : TEST_Op<"graph_region", [ DeclareOpInterfaceMethods]> { let summary = "operation with a graph region"; let description = [{ Test op that defines a graph region. }]; let regions = (region AnyRegion:$region); let parser = [{ return ::parse$cppClass(parser, result); }]; let printer = [{ return ::print(p, *this); }]; } def AffineScopeOp : TEST_Op<"affine_scope", [AffineScope]> { let summary = "affine scope operation"; let description = [{ Test op that defines a new affine scope. }]; let regions = (region SizedRegion<1>:$region); let parser = [{ return ::parse$cppClass(parser, result); }]; let printer = [{ return ::print(p, *this); }]; } def WrappingRegionOp : TEST_Op<"wrapping_region", [SingleBlockImplicitTerminator<"TestReturnOp">]> { let summary = "wrapping region operation"; let description = [{ Test op wrapping another op in a region, to test calling parseGenericOperation from the custom parser. }]; let results = (outs Variadic); let regions = (region SizedRegion<1>:$region); let parser = [{ return ::parse$cppClass(parser, result); }]; let printer = [{ return ::print(p, *this); }]; } def PolyForOp : TEST_Op<"polyfor"> { let summary = "polyfor operation"; let description = [{ Test op with multiple region arguments, each argument of index type. }]; let regions = (region SizedRegion<1>:$region); let parser = [{ return ::parse$cppClass(parser, result); }]; } //===----------------------------------------------------------------------===// // Test OpAsmInterface. def AsmInterfaceOp : TEST_Op<"asm_interface_op"> { let results = (outs AnyType:$first, Variadic:$middle_results, AnyType); } def AsmDialectInterfaceOp : TEST_Op<"asm_dialect_interface_op"> { let results = (outs AnyType); } //===----------------------------------------------------------------------===// // Test Op Asm Format //===----------------------------------------------------------------------===// def FormatLiteralOp : TEST_Op<"format_literal_op"> { let assemblyFormat = [{ `keyword_$.` `->` `:` `,` `=` `<` `>` `(` `)` `[` `]` attr-dict }]; } // Test that we elide attributes that are within the syntax. def FormatAttrOp : TEST_Op<"format_attr_op"> { let arguments = (ins I64Attr:$attr); let assemblyFormat = "$attr attr-dict"; } // Test that we elide optional attributes that are within the syntax. def FormatOptAttrAOp : TEST_Op<"format_opt_attr_op_a"> { let arguments = (ins OptionalAttr:$opt_attr); let assemblyFormat = "(`(` $opt_attr^ `)` )? attr-dict"; } def FormatOptAttrBOp : TEST_Op<"format_opt_attr_op_b"> { let arguments = (ins OptionalAttr:$opt_attr); let assemblyFormat = "($opt_attr^)? attr-dict"; } // Test that we format symbol name attributes properly. def FormatSymbolNameAttrOp : TEST_Op<"format_symbol_name_attr_op"> { let arguments = (ins SymbolNameAttr:$attr); let assemblyFormat = "$attr attr-dict"; } // Test that we format optional symbol name attributes properly. def FormatOptSymbolNameAttrOp : TEST_Op<"format_opt_symbol_name_attr_op"> { let arguments = (ins OptionalAttr:$opt_attr); let assemblyFormat = "($opt_attr^)? attr-dict"; } // Test that we elide attributes that are within the syntax. def FormatAttrDictWithKeywordOp : TEST_Op<"format_attr_dict_w_keyword"> { let arguments = (ins I64Attr:$attr, OptionalAttr:$opt_attr); let assemblyFormat = "attr-dict-with-keyword"; } // Test that we don't need to provide types in the format if they are buildable. def FormatBuildableTypeOp : TEST_Op<"format_buildable_type_op"> { let arguments = (ins I64:$buildable); let results = (outs I64:$buildable_res); let assemblyFormat = "$buildable attr-dict"; } // Test various mixings of region formatting. class FormatRegionBase : TEST_Op<"format_region_" # suffix # "_op"> { let regions = (region AnyRegion:$region); let assemblyFormat = fmt; } def FormatRegionAOp : FormatRegionBase<"a", [{ regions attr-dict }]>; def FormatRegionBOp : FormatRegionBase<"b", [{ $region attr-dict }]>; def FormatRegionCOp : FormatRegionBase<"c", [{ (`region` $region^)? attr-dict }]>; class FormatVariadicRegionBase : TEST_Op<"format_variadic_region_" # suffix # "_op"> { let regions = (region VariadicRegion:$regions); let assemblyFormat = fmt; } def FormatVariadicRegionAOp : FormatVariadicRegionBase<"a", [{ $regions attr-dict }]>; def FormatVariadicRegionBOp : FormatVariadicRegionBase<"b", [{ ($regions^ `found_regions`)? attr-dict }]>; class FormatRegionImplicitTerminatorBase : TEST_Op<"format_implicit_terminator_region_" # suffix # "_op", [SingleBlockImplicitTerminator<"TestReturnOp">]> { let regions = (region AnyRegion:$region); let assemblyFormat = fmt; } def FormatFormatRegionImplicitTerminatorAOp : FormatRegionImplicitTerminatorBase<"a", [{ $region attr-dict }]>; // Test various mixings of result type formatting. class FormatResultBase : TEST_Op<"format_result_" # suffix # "_op"> { let results = (outs I64:$buildable_res, AnyMemRef:$result); let assemblyFormat = fmt; } def FormatResultAOp : FormatResultBase<"a", [{ type($result) attr-dict }]>; def FormatResultBOp : FormatResultBase<"b", [{ type(results) attr-dict }]>; def FormatResultCOp : FormatResultBase<"c", [{ functional-type($buildable_res, $result) attr-dict }]>; // Test various mixings of operand type formatting. class FormatOperandBase : TEST_Op<"format_operand_" # suffix # "_op"> { let arguments = (ins I64:$buildable, AnyMemRef:$operand); let assemblyFormat = fmt; } def FormatOperandAOp : FormatOperandBase<"a", [{ operands `:` type(operands) attr-dict }]>; def FormatOperandBOp : FormatOperandBase<"b", [{ operands `:` type($operand) attr-dict }]>; def FormatOperandCOp : FormatOperandBase<"c", [{ $buildable `,` $operand `:` type(operands) attr-dict }]>; def FormatOperandDOp : FormatOperandBase<"d", [{ $buildable `,` $operand `:` type($operand) attr-dict }]>; def FormatOperandEOp : FormatOperandBase<"e", [{ $buildable `,` $operand `:` type($buildable) `,` type($operand) attr-dict }]>; def FormatSuccessorAOp : TEST_Op<"format_successor_a_op", [Terminator]> { let successors = (successor VariadicSuccessor:$targets); let assemblyFormat = "$targets attr-dict"; } // Test various mixings of optional operand and result type formatting. class FormatOptionalOperandResultOpBase : TEST_Op<"format_optional_operand_result_" # suffix # "_op", [AttrSizedOperandSegments]> { let arguments = (ins Optional:$optional, Variadic:$variadic); let results = (outs Optional:$optional_res); let assemblyFormat = fmt; } def FormatOptionalOperandResultAOp : FormatOptionalOperandResultOpBase<"a", [{ `(` $optional `:` type($optional) `)` `:` type($optional_res) (`[` $variadic^ `]`)? attr-dict }]>; def FormatOptionalOperandResultBOp : FormatOptionalOperandResultOpBase<"b", [{ (`(` $optional^ `:` type($optional) `)`)? `:` type($optional_res) (`[` $variadic^ `]`)? attr-dict }]>; def FormatTwoVariadicOperandsNoBuildableTypeOp : TEST_Op<"format_two_variadic_operands_no_buildable_type_op", [AttrSizedOperandSegments]> { let arguments = (ins Variadic:$a, Variadic:$b); let assemblyFormat = [{ `(` $a `:` type($a) `)` `->` `(` $b `:` type($b) `)` attr-dict }]; } def FormatInferVariadicTypeFromNonVariadic : TEST_Op<"format_infer_variadic_type_from_non_variadic", [SameOperandsAndResultType]> { let arguments = (ins Variadic:$operands); let results = (outs AnyType:$result); let assemblyFormat = "$operands attr-dict `:` type($result)"; } def FormatOptionalUnitAttr : TEST_Op<"format_optional_unit_attribute"> { let arguments = (ins UnitAttr:$is_optional); let assemblyFormat = "(`is_optional` $is_optional^)? attr-dict"; } def FormatOptionalUnitAttrNoElide : TEST_Op<"format_optional_unit_attribute_no_elide"> { let arguments = (ins UnitAttr:$is_optional); let assemblyFormat = "($is_optional^)? attr-dict"; } //===----------------------------------------------------------------------===// // Custom Directives def FormatCustomDirectiveOperands : TEST_Op<"format_custom_directive_operands", [AttrSizedOperandSegments]> { let arguments = (ins I64:$operand, Optional:$optOperand, Variadic:$varOperands); let assemblyFormat = [{ custom( $operand, $optOperand, $varOperands ) attr-dict }]; } def FormatCustomDirectiveOperandsAndTypes : TEST_Op<"format_custom_directive_operands_and_types", [AttrSizedOperandSegments]> { let arguments = (ins AnyType:$operand, Optional:$optOperand, Variadic:$varOperands); let assemblyFormat = [{ custom( $operand, $optOperand, $varOperands, type($operand), type($optOperand), type($varOperands) ) attr-dict }]; } def FormatCustomDirectiveRegions : TEST_Op<"format_custom_directive_regions"> { let regions = (region AnyRegion:$region, VariadicRegion:$regions); let assemblyFormat = [{ custom( $region, $regions ) attr-dict }]; } def FormatCustomDirectiveResults : TEST_Op<"format_custom_directive_results", [AttrSizedResultSegments]> { let results = (outs AnyType:$result, Optional:$optResult, Variadic:$varResults); let assemblyFormat = [{ custom( type($result), type($optResult), type($varResults) ) attr-dict }]; } def FormatCustomDirectiveResultsWithTypeRefs : TEST_Op<"format_custom_directive_results_with_type_refs", [AttrSizedResultSegments]> { let results = (outs AnyType:$result, Optional:$optResult, Variadic:$varResults); let assemblyFormat = [{ custom( type($result), type($optResult), type($varResults) ) custom( type_ref($result), type_ref($optResult), type_ref($varResults) ) attr-dict }]; } def FormatCustomDirectiveSuccessors : TEST_Op<"format_custom_directive_successors", [Terminator]> { let successors = (successor AnySuccessor:$successor, VariadicSuccessor:$successors); let assemblyFormat = [{ custom( $successor, $successors ) attr-dict }]; } def FormatCustomDirectiveAttributes : TEST_Op<"format_custom_directive_attributes"> { let arguments = (ins I64Attr:$attr, OptionalAttr:$optAttr); let assemblyFormat = [{ custom( $attr, $optAttr ) attr-dict }]; } //===----------------------------------------------------------------------===// // AllTypesMatch type inference def FormatAllTypesMatchVarOp : TEST_Op<"format_all_types_match_var", [ AllTypesMatch<["value1", "value2", "result"]> ]> { let arguments = (ins AnyType:$value1, AnyType:$value2); let results = (outs AnyType:$result); let assemblyFormat = "attr-dict $value1 `,` $value2 `:` type($value1)"; } def FormatAllTypesMatchAttrOp : TEST_Op<"format_all_types_match_attr", [ AllTypesMatch<["value1", "value2", "result"]> ]> { let arguments = (ins AnyAttr:$value1, AnyType:$value2); let results = (outs AnyType:$result); let assemblyFormat = "attr-dict $value1 `,` $value2"; } //===----------------------------------------------------------------------===// // TypesMatchWith type inference def FormatTypesMatchVarOp : TEST_Op<"format_types_match_var", [ TypesMatchWith<"result type matches operand", "value", "result", "$_self"> ]> { let arguments = (ins AnyType:$value); let results = (outs AnyType:$result); let assemblyFormat = "attr-dict $value `:` type($value)"; } def FormatTypesMatchAttrOp : TEST_Op<"format_types_match_attr", [ TypesMatchWith<"result type matches constant", "value", "result", "$_self"> ]> { let arguments = (ins AnyAttr:$value); let results = (outs AnyType:$result); let assemblyFormat = "attr-dict $value"; } //===----------------------------------------------------------------------===// // Test SideEffects //===----------------------------------------------------------------------===// def SideEffectOp : TEST_Op<"side_effect_op", [DeclareOpInterfaceMethods]> { let results = (outs AnyType:$result); } //===----------------------------------------------------------------------===// // Test RegionBranchOpInterface //===----------------------------------------------------------------------===// def RegionIfYieldOp : TEST_Op<"region_if_yield", [NoSideEffect, ReturnLike, Terminator]> { let arguments = (ins Variadic:$results); let assemblyFormat = [{ $results `:` type($results) attr-dict }]; } def RegionIfOp : TEST_Op<"region_if", [DeclareOpInterfaceMethods, SingleBlockImplicitTerminator<"RegionIfYieldOp">, RecursiveSideEffects]> { let description =[{ Represents an abstract if-then-else-join pattern. In this context, the then and else regions jump to the join region, which finally returns to its parent op. }]; let printer = [{ return ::print(p, *this); }]; let parser = [{ return ::parseRegionIfOp(parser, result); }]; let arguments = (ins Variadic); let results = (outs Variadic:$results); let regions = (region SizedRegion<1>:$thenRegion, AnyRegion:$elseRegion, AnyRegion:$joinRegion); let extraClassDeclaration = [{ Block::BlockArgListType getThenArgs() { return getBody(0)->getArguments(); } Block::BlockArgListType getElseArgs() { return getBody(1)->getArguments(); } Block::BlockArgListType getJoinArgs() { return getBody(2)->getArguments(); } OperandRange getSuccessorEntryOperands(unsigned index); }]; } //===----------------------------------------------------------------------===// // Test TableGen generated build() methods //===----------------------------------------------------------------------===// def TableGenConstant : TEST_Op<"tblgen_constant"> { let results = (outs AnyType); } // No variadic args or results. def TableGenBuildOp0 : TEST_Op<"tblgen_build_0"> { let arguments = (ins AnyType:$value); let results = (outs AnyType:$result); } // Sigle variadic arg and single variadic results. def TableGenBuildOp1 : TEST_Op<"tblgen_build_1"> { let arguments = (ins Variadic:$inputs); let results = (outs Variadic:$results); } // Single variadic arg and non-variadic results. def TableGenBuildOp2 : TEST_Op<"tblgen_build_2"> { let arguments = (ins Variadic:$inputs); let results = (outs AnyType:$result); } // Single variadic arg and multiple variadic results. def TableGenBuildOp3 : TEST_Op<"tblgen_build_3", [SameVariadicResultSize]> { let arguments = (ins Variadic:$inputs); let results = (outs Variadic:$resultA, Variadic:$resultB); } // Single variadic arg, non variadic results, with SameOperandsAndResultType. // Tests suppression of ambiguous build methods for operations with // SameOperandsAndResultType trait. def TableGenBuildOp4 : TEST_Op<"tblgen_build_4", [SameOperandsAndResultType]> { let arguments = (ins Variadic:$inputs); let results = (outs AnyType:$result); } // Single variadic arg with SameOperandsAndResultType and InferTypeOpInterface. // Tests suppression of ambiguous build methods for operations with // SameOperandsAndResultType and InferTypeOpInterface. def TableGenBuildOp5 : TEST_Op<"tblgen_build_5", [SameOperandsAndResultType, InferTypeOpInterface]> { let arguments = (ins Variadic:$inputs); let results = (outs AnyType:$result); let extraClassDeclaration = [{ static LogicalResult inferReturnTypes(MLIRContext *, Optional location, ValueRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { inferredReturnTypes.assign({operands[0].getType()}); return success(); } }]; } //===----------------------------------------------------------------------===// // Test BufferPlacement //===----------------------------------------------------------------------===// def GetTupleElementOp: TEST_Op<"get_tuple_element"> { let description = [{ Test op that returns a specified element of the tuple. }]; let arguments = (ins TupleOf<[AnyType]>, I32Attr:$index ); let results = (outs AnyType); } def MakeTupleOp: TEST_Op<"make_tuple"> { let description = [{ Test op that creates a tuple value from a list of values. }]; let arguments = (ins Variadic:$inputs ); let results = (outs TupleOf<[AnyType]>); } #endif // TEST_OPS diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index 32d618d9008e..282d31065549 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -1,985 +1,986 @@ //===- TestPatterns.cpp - Test dialect pattern driver ---------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "TestDialect.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" +#include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/FoldUtils.h" using namespace mlir; // Native function for testing NativeCodeCall static Value chooseOperand(Value input1, Value input2, BoolAttr choice) { return choice.getValue() ? input1 : input2; } static void createOpI(PatternRewriter &rewriter, Location loc, Value input) { rewriter.create(loc, input); } static void handleNoResultOp(PatternRewriter &rewriter, OpSymbolBindingNoResult op) { // Turn the no result op to a one-result op. rewriter.create(op.getLoc(), op.operand().getType(), op.operand()); } // Test that natives calls are only called once during rewrites. // OpM_Test will return Pi, increased by 1 for each subsequent calls. // This let us check the number of times OpM_Test was called by inspecting // the returned value in the MLIR output. static int64_t opMIncreasingValue = 314159265; static Attribute OpMTest(PatternRewriter &rewriter, Value val) { int64_t i = opMIncreasingValue++; return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i); } namespace { #include "TestPatterns.inc" } // end anonymous namespace //===----------------------------------------------------------------------===// // Canonicalizer Driver. //===----------------------------------------------------------------------===// namespace { struct FoldingPattern : public RewritePattern { public: FoldingPattern(MLIRContext *context) : RewritePattern(TestOpInPlaceFoldAnchor::getOperationName(), /*benefit=*/1, context) {} LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { // Exercice OperationFolder API for a single-result operation that is folded // upon construction. The operation being created through the folder has an // in-place folder, and it should be still present in the output. // Furthermore, the folder should not crash when attempting to recover the // (unchanged) operation result. OperationFolder folder(op->getContext()); Value result = folder.create( rewriter, op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0), rewriter.getI32IntegerAttr(0)); assert(result); rewriter.replaceOp(op, result); return success(); } }; struct TestPatternDriver : public PassWrapper { void runOnFunction() override { mlir::OwningRewritePatternList patterns; populateWithGenerated(&getContext(), patterns); // Verify named pattern is generated with expected name. patterns.insert(&getContext()); applyPatternsAndFoldGreedily(getFunction(), patterns); } }; } // end anonymous namespace //===----------------------------------------------------------------------===// // ReturnType Driver. //===----------------------------------------------------------------------===// namespace { // Generate ops for each instance where the type can be successfully inferred. template static void invokeCreateWithInferredReturnType(Operation *op) { auto *context = op->getContext(); auto fop = op->getParentOfType(); auto location = UnknownLoc::get(context); OpBuilder b(op); b.setInsertionPointAfter(op); // Use permutations of 2 args as operands. assert(fop.getNumArguments() >= 2); for (int i = 0, e = fop.getNumArguments(); i < e; ++i) { for (int j = 0; j < e; ++j) { std::array values = {{fop.getArgument(i), fop.getArgument(j)}}; SmallVector inferredReturnTypes; if (succeeded(OpTy::inferReturnTypes( context, llvm::None, values, op->getAttrDictionary(), op->getRegions(), inferredReturnTypes))) { OperationState state(location, OpTy::getOperationName()); // TODO: Expand to regions. OpTy::build(b, state, values, op->getAttrs()); (void)b.createOperation(state); } } } } static void reifyReturnShape(Operation *op) { OpBuilder b(op); // Use permutations of 2 args as operands. auto shapedOp = cast(op); SmallVector shapes; if (failed(shapedOp.reifyReturnTypeShapes(b, shapes))) return; for (auto it : llvm::enumerate(shapes)) op->emitRemark() << "value " << it.index() << ": " << it.value().getDefiningOp(); } struct TestReturnTypeDriver : public PassWrapper { void runOnFunction() override { if (getFunction().getName() == "testCreateFunctions") { std::vector ops; // Collect ops to avoid triggering on inserted ops. for (auto &op : getFunction().getBody().front()) ops.push_back(&op); // Generate test patterns for each, but skip terminator. for (auto *op : llvm::makeArrayRef(ops).drop_back()) { // Test create method of each of the Op classes below. The resultant // output would be in reverse order underneath `op` from which // the attributes and regions are used. invokeCreateWithInferredReturnType(op); invokeCreateWithInferredReturnType< OpWithShapedTypeInferTypeInterfaceOp>(op); }; return; } if (getFunction().getName() == "testReifyFunctions") { std::vector ops; // Collect ops to avoid triggering on inserted ops. for (auto &op : getFunction().getBody().front()) if (isa(op)) ops.push_back(&op); // Generate test patterns for each, but skip terminator. for (auto *op : ops) reifyReturnShape(op); } } }; } // end anonymous namespace namespace { struct TestDerivedAttributeDriver : public PassWrapper { void runOnFunction() override; }; } // end anonymous namespace void TestDerivedAttributeDriver::runOnFunction() { getFunction().walk([](DerivedAttributeOpInterface dOp) { auto dAttr = dOp.materializeDerivedAttributes(); if (!dAttr) return; for (auto d : dAttr) dOp.emitRemark() << d.first << " = " << d.second; }); } //===----------------------------------------------------------------------===// // Legalization Driver. //===----------------------------------------------------------------------===// namespace { //===----------------------------------------------------------------------===// // Region-Block Rewrite Testing /// This pattern is a simple pattern that inlines the first region of a given /// operation into the parent region. struct TestRegionRewriteBlockMovement : public ConversionPattern { TestRegionRewriteBlockMovement(MLIRContext *ctx) : ConversionPattern("test.region", 1, ctx) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { // Inline this region into the parent region. auto &parentRegion = *op->getParentRegion(); if (op->getAttr("legalizer.should_clone")) rewriter.cloneRegionBefore(op->getRegion(0), parentRegion, parentRegion.end()); else rewriter.inlineRegionBefore(op->getRegion(0), parentRegion, parentRegion.end()); // Drop this operation. rewriter.eraseOp(op); return success(); } }; /// This pattern is a simple pattern that generates a region containing an /// illegal operation. struct TestRegionRewriteUndo : public RewritePattern { TestRegionRewriteUndo(MLIRContext *ctx) : RewritePattern("test.region_builder", 1, ctx) {} LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final { // Create the region operation with an entry block containing arguments. OperationState newRegion(op->getLoc(), "test.region"); newRegion.addRegion(); auto *regionOp = rewriter.createOperation(newRegion); auto *entryBlock = rewriter.createBlock(®ionOp->getRegion(0)); entryBlock->addArgument(rewriter.getIntegerType(64)); // Add an explicitly illegal operation to ensure the conversion fails. rewriter.create(op->getLoc(), rewriter.getIntegerType(32)); rewriter.create(op->getLoc(), ArrayRef()); // Drop this operation. rewriter.eraseOp(op); return success(); } }; /// A simple pattern that creates a block at the end of the parent region of the /// matched operation. struct TestCreateBlock : public RewritePattern { TestCreateBlock(MLIRContext *ctx) : RewritePattern("test.create_block", /*benefit=*/1, ctx) {} LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final { Region ®ion = *op->getParentRegion(); Type i32Type = rewriter.getIntegerType(32); rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}); rewriter.create(op->getLoc()); rewriter.replaceOp(op, {}); return success(); } }; /// A simple pattern that creates a block containing an invalid operation in /// order to trigger the block creation undo mechanism. struct TestCreateIllegalBlock : public RewritePattern { TestCreateIllegalBlock(MLIRContext *ctx) : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {} LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final { Region ®ion = *op->getParentRegion(); Type i32Type = rewriter.getIntegerType(32); rewriter.createBlock(®ion, region.end(), {i32Type, i32Type}); // Create an illegal op to ensure the conversion fails. rewriter.create(op->getLoc(), i32Type); rewriter.create(op->getLoc()); rewriter.replaceOp(op, {}); return success(); } }; /// A simple pattern that tests the undo mechanism when replacing the uses of a /// block argument. struct TestUndoBlockArgReplace : public ConversionPattern { TestUndoBlockArgReplace(MLIRContext *ctx) : ConversionPattern("test.undo_block_arg_replace", /*benefit=*/1, ctx) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { auto illegalOp = rewriter.create(op->getLoc(), rewriter.getF32Type()); rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0), illegalOp); rewriter.updateRootInPlace(op, [] {}); return success(); } }; /// A rewrite pattern that tests the undo mechanism when erasing a block. struct TestUndoBlockErase : public ConversionPattern { TestUndoBlockErase(MLIRContext *ctx) : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { Block *secondBlock = &*std::next(op->getRegion(0).begin()); rewriter.setInsertionPointToStart(secondBlock); rewriter.create(op->getLoc(), rewriter.getF32Type()); rewriter.eraseBlock(secondBlock); rewriter.updateRootInPlace(op, [] {}); return success(); } }; //===----------------------------------------------------------------------===// // Type-Conversion Rewrite Testing /// This patterns erases a region operation that has had a type conversion. struct TestDropOpSignatureConversion : public ConversionPattern { TestDropOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter) : ConversionPattern("test.drop_region_op", 1, converter, ctx) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { Region ®ion = op->getRegion(0); Block *entry = ®ion.front(); // Convert the original entry arguments. TypeConverter &converter = *getTypeConverter(); TypeConverter::SignatureConversion result(entry->getNumArguments()); if (failed(converter.convertSignatureArgs(entry->getArgumentTypes(), result)) || failed(rewriter.convertRegionTypes(®ion, converter, &result))) return failure(); // Convert the region signature and just drop the operation. rewriter.eraseOp(op); return success(); } }; /// This pattern simply updates the operands of the given operation. struct TestPassthroughInvalidOp : public ConversionPattern { TestPassthroughInvalidOp(MLIRContext *ctx) : ConversionPattern("test.invalid", 1, ctx) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { rewriter.replaceOpWithNewOp(op, llvm::None, operands, llvm::None); return success(); } }; /// This pattern handles the case of a split return value. struct TestSplitReturnType : public ConversionPattern { TestSplitReturnType(MLIRContext *ctx) : ConversionPattern("test.return", 1, ctx) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { // Check for a return of F32. if (op->getNumOperands() != 1 || !op->getOperand(0).getType().isF32()) return failure(); // Check if the first operation is a cast operation, if it is we use the // results directly. auto *defOp = operands[0].getDefiningOp(); if (auto packerOp = llvm::dyn_cast_or_null(defOp)) { rewriter.replaceOpWithNewOp(op, packerOp.getOperands()); return success(); } // Otherwise, fail to match. return failure(); } }; //===----------------------------------------------------------------------===// // Multi-Level Type-Conversion Rewrite Testing struct TestChangeProducerTypeI32ToF32 : public ConversionPattern { TestChangeProducerTypeI32ToF32(MLIRContext *ctx) : ConversionPattern("test.type_producer", 1, ctx) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { // If the type is I32, change the type to F32. if (!Type(*op->result_type_begin()).isSignlessInteger(32)) return failure(); rewriter.replaceOpWithNewOp(op, rewriter.getF32Type()); return success(); } }; struct TestChangeProducerTypeF32ToF64 : public ConversionPattern { TestChangeProducerTypeF32ToF64(MLIRContext *ctx) : ConversionPattern("test.type_producer", 1, ctx) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { // If the type is F32, change the type to F64. if (!Type(*op->result_type_begin()).isF32()) return rewriter.notifyMatchFailure(op, "expected single f32 operand"); rewriter.replaceOpWithNewOp(op, rewriter.getF64Type()); return success(); } }; struct TestChangeProducerTypeF32ToInvalid : public ConversionPattern { TestChangeProducerTypeF32ToInvalid(MLIRContext *ctx) : ConversionPattern("test.type_producer", 10, ctx) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { // Always convert to B16, even though it is not a legal type. This tests // that values are unmapped correctly. rewriter.replaceOpWithNewOp(op, rewriter.getBF16Type()); return success(); } }; struct TestUpdateConsumerType : public ConversionPattern { TestUpdateConsumerType(MLIRContext *ctx) : ConversionPattern("test.type_consumer", 1, ctx) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { // Verify that the incoming operand has been successfully remapped to F64. if (!operands[0].getType().isF64()) return failure(); rewriter.replaceOpWithNewOp(op, operands[0]); return success(); } }; //===----------------------------------------------------------------------===// // Non-Root Replacement Rewrite Testing /// This pattern generates an invalid operation, but replaces it before the /// pattern is finished. This checks that we don't need to legalize the /// temporary op. struct TestNonRootReplacement : public RewritePattern { TestNonRootReplacement(MLIRContext *ctx) : RewritePattern("test.replace_non_root", 1, ctx) {} LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final { auto resultType = *op->result_type_begin(); auto illegalOp = rewriter.create(op->getLoc(), resultType); auto legalOp = rewriter.create(op->getLoc(), resultType); rewriter.replaceOp(illegalOp, {legalOp}); rewriter.replaceOp(op, {illegalOp}); return success(); } }; //===----------------------------------------------------------------------===// // Recursive Rewrite Testing /// This pattern is applied to the same operation multiple times, but has a /// bounded recursion. struct TestBoundedRecursiveRewrite : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(TestRecursiveRewriteOp op, PatternRewriter &rewriter) const final { // Decrement the depth of the op in-place. rewriter.updateRootInPlace(op, [&] { op.setAttr("depth", rewriter.getI64IntegerAttr(op.depth() - 1)); }); return success(); } /// The conversion target handles bounding the recursion of this pattern. bool hasBoundedRewriteRecursion() const final { return true; } }; struct TestNestedOpCreationUndoRewrite : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(IllegalOpWithRegionAnchor op, PatternRewriter &rewriter) const final { // rewriter.replaceOpWithNewOp(op); rewriter.replaceOpWithNewOp(op); return success(); }; }; } // namespace namespace { struct TestTypeConverter : public TypeConverter { using TypeConverter::TypeConverter; TestTypeConverter() { addConversion(convertType); addArgumentMaterialization(materializeCast); addArgumentMaterialization(materializeOneToOneCast); addSourceMaterialization(materializeCast); } static LogicalResult convertType(Type t, SmallVectorImpl &results) { // Drop I16 types. if (t.isSignlessInteger(16)) return success(); // Convert I64 to F64. if (t.isSignlessInteger(64)) { results.push_back(FloatType::getF64(t.getContext())); return success(); } // Convert I42 to I43. if (t.isInteger(42)) { results.push_back(IntegerType::get(43, t.getContext())); return success(); } // Split F32 into F16,F16. if (t.isF32()) { results.assign(2, FloatType::getF16(t.getContext())); return success(); } // Otherwise, convert the type directly. results.push_back(t); return success(); } /// Hook for materializing a conversion. This is necessary because we generate /// 1->N type mappings. static Optional materializeCast(OpBuilder &builder, Type resultType, ValueRange inputs, Location loc) { if (inputs.size() == 1) return inputs[0]; return builder.create(loc, resultType, inputs).getResult(); } /// Materialize the cast for one-to-one conversion from i64 to f64. static Optional materializeOneToOneCast(OpBuilder &builder, IntegerType resultType, ValueRange inputs, Location loc) { if (resultType.getWidth() == 42 && inputs.size() == 1) return builder.create(loc, resultType, inputs).getResult(); return llvm::None; } }; struct TestLegalizePatternDriver : public PassWrapper> { /// The mode of conversion to use with the driver. enum class ConversionMode { Analysis, Full, Partial }; TestLegalizePatternDriver(ConversionMode mode) : mode(mode) {} void runOnOperation() override { TestTypeConverter converter; mlir::OwningRewritePatternList patterns; populateWithGenerated(&getContext(), patterns); patterns.insert< TestRegionRewriteBlockMovement, TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock, TestUndoBlockArgReplace, TestUndoBlockErase, TestPassthroughInvalidOp, TestSplitReturnType, TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64, TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType, TestNonRootReplacement, TestBoundedRecursiveRewrite, TestNestedOpCreationUndoRewrite>(&getContext()); patterns.insert(&getContext(), converter); mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(), converter); mlir::populateCallOpTypeConversionPattern(patterns, &getContext(), converter); // Define the conversion target used for the test. ConversionTarget target(getContext()); target.addLegalOp(); target.addLegalOp(); target .addIllegalOp(); target.addDynamicallyLegalOp([](TestReturnOp op) { // Don't allow F32 operands. return llvm::none_of(op.getOperandTypes(), [](Type type) { return type.isF32(); }); }); target.addDynamicallyLegalOp([&](FuncOp op) { return converter.isSignatureLegal(op.getType()) && converter.isLegal(&op.getBody()); }); // Expect the type_producer/type_consumer operations to only operate on f64. target.addDynamicallyLegalOp( [](TestTypeProducerOp op) { return op.getType().isF64(); }); target.addDynamicallyLegalOp([](TestTypeConsumerOp op) { return op.getOperand().getType().isF64(); }); // Check support for marking certain operations as recursively legal. target.markOpRecursivelyLegal([](Operation *op) { return static_cast( op->getAttrOfType("test.recursively_legal")); }); // Mark the bound recursion operation as dynamically legal. target.addDynamicallyLegalOp( [](TestRecursiveRewriteOp op) { return op.depth() == 0; }); // Handle a partial conversion. if (mode == ConversionMode::Partial) { DenseSet unlegalizedOps; (void)applyPartialConversion(getOperation(), target, patterns, &unlegalizedOps); // Emit remarks for each legalizable operation. for (auto *op : unlegalizedOps) op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; return; } // Handle a full conversion. if (mode == ConversionMode::Full) { // Check support for marking unknown operations as dynamically legal. target.markUnknownOpDynamicallyLegal([](Operation *op) { return (bool)op->getAttrOfType("test.dynamically_legal"); }); (void)applyFullConversion(getOperation(), target, patterns); return; } // Otherwise, handle an analysis conversion. assert(mode == ConversionMode::Analysis); // Analyze the convertible operations. DenseSet legalizedOps; if (failed(applyAnalysisConversion(getOperation(), target, patterns, legalizedOps))) return signalPassFailure(); // Emit remarks for each legalizable operation. for (auto *op : legalizedOps) op->emitRemark() << "op '" << op->getName() << "' is legalizable"; } /// The mode of conversion to use. ConversionMode mode; }; } // end anonymous namespace static llvm::cl::opt legalizerConversionMode( "test-legalize-mode", llvm::cl::desc("The legalization mode to use with the test driver"), llvm::cl::init(TestLegalizePatternDriver::ConversionMode::Partial), llvm::cl::values( clEnumValN(TestLegalizePatternDriver::ConversionMode::Analysis, "analysis", "Perform an analysis conversion"), clEnumValN(TestLegalizePatternDriver::ConversionMode::Full, "full", "Perform a full conversion"), clEnumValN(TestLegalizePatternDriver::ConversionMode::Partial, "partial", "Perform a partial conversion"))); //===----------------------------------------------------------------------===// // ConversionPatternRewriter::getRemappedValue testing. This method is used // to get the remapped value of an original value that was replaced using // ConversionPatternRewriter. namespace { /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original /// operand twice. /// /// Example: /// %1 = test.one_variadic_out_one_variadic_in1"(%0) /// is replaced with: /// %1 = test.one_variadic_out_one_variadic_in1"(%0, %0) struct OneVResOneVOperandOp1Converter : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(OneVResOneVOperandOp1 op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto origOps = op.getOperands(); assert(std::distance(origOps.begin(), origOps.end()) == 1 && "One operand expected"); Value origOp = *origOps.begin(); SmallVector remappedOperands; // Replicate the remapped original operand twice. Note that we don't used // the remapped 'operand' since the goal is testing 'getRemappedValue'. remappedOperands.push_back(rewriter.getRemappedValue(origOp)); remappedOperands.push_back(rewriter.getRemappedValue(origOp)); rewriter.replaceOpWithNewOp(op, op.getResultTypes(), remappedOperands); return success(); } }; struct TestRemappedValue : public mlir::PassWrapper { void runOnFunction() override { mlir::OwningRewritePatternList patterns; patterns.insert(&getContext()); mlir::ConversionTarget target(getContext()); target.addLegalOp(); // We make OneVResOneVOperandOp1 legal only when it has more that one // operand. This will trigger the conversion that will replace one-operand // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1. target.addDynamicallyLegalOp( [](Operation *op) -> bool { return std::distance(op->operand_begin(), op->operand_end()) > 1; }); if (failed(mlir::applyFullConversion(getFunction(), target, patterns))) { signalPassFailure(); } } }; } // end anonymous namespace //===----------------------------------------------------------------------===// // Test patterns without a specific root operation kind //===----------------------------------------------------------------------===// namespace { /// This pattern matches and removes any operation in the test dialect. struct RemoveTestDialectOps : public RewritePattern { RemoveTestDialectOps() : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()) {} LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { if (!isa(op->getDialect())) return failure(); rewriter.eraseOp(op); return success(); } }; struct TestUnknownRootOpDriver : public mlir::PassWrapper { void runOnFunction() override { mlir::OwningRewritePatternList patterns; patterns.insert(); mlir::ConversionTarget target(getContext()); target.addIllegalDialect(); if (failed(applyPartialConversion(getFunction(), target, patterns))) signalPassFailure(); } }; } // end anonymous namespace //===----------------------------------------------------------------------===// // Test type conversions //===----------------------------------------------------------------------===// namespace { struct TestTypeConversionProducer : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(TestTypeProducerOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { Type resultType = op.getType(); if (resultType.isa()) resultType = rewriter.getF64Type(); else if (resultType.isInteger(16)) resultType = rewriter.getIntegerType(64); else return failure(); rewriter.replaceOpWithNewOp(op, resultType); return success(); } }; struct TestTypeConversionDriver : public PassWrapper> { void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } void runOnOperation() override { // Initialize the type converter. TypeConverter converter; /// Add the legal set of type conversions. converter.addConversion([](Type type) -> Type { // Treat F64 as legal. if (type.isF64()) return type; // Allow converting BF16/F16/F32 to F64. if (type.isBF16() || type.isF16() || type.isF32()) return FloatType::getF64(type.getContext()); // Otherwise, the type is illegal. return nullptr; }); converter.addConversion([](IntegerType type, SmallVectorImpl &) { // Drop all integer types. return success(); }); /// Add the legal set of type materializations. converter.addSourceMaterialization([](OpBuilder &builder, Type resultType, ValueRange inputs, Location loc) -> Value { // Allow casting from F64 back to F32. if (!resultType.isF16() && inputs.size() == 1 && inputs[0].getType().isF64()) return builder.create(loc, resultType, inputs).getResult(); // Allow producing an i32 or i64 from nothing. if ((resultType.isInteger(32) || resultType.isInteger(64)) && inputs.empty()) return builder.create(loc, resultType); // Allow producing an i64 from an integer. if (resultType.isa() && inputs.size() == 1 && inputs[0].getType().isa()) return builder.create(loc, resultType, inputs).getResult(); // Otherwise, fail. return nullptr; }); // Initialize the conversion target. mlir::ConversionTarget target(getContext()); target.addDynamicallyLegalOp([](TestTypeProducerOp op) { return op.getType().isF64() || op.getType().isInteger(64); }); target.addDynamicallyLegalOp([&](FuncOp op) { return converter.isSignatureLegal(op.getType()) && converter.isLegal(&op.getBody()); }); target.addDynamicallyLegalOp([&](TestCastOp op) { // Allow casts from F64 to F32. return (*op.operand_type_begin()).isF64() && op.getType().isF32(); }); // Initialize the set of rewrite patterns. OwningRewritePatternList patterns; patterns.insert(converter, &getContext()); mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(), converter); if (failed(applyPartialConversion(getOperation(), target, patterns))) signalPassFailure(); } }; } // end anonymous namespace namespace { /// A rewriter pattern that tests that blocks can be merged. struct TestMergeBlock : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(TestMergeBlocksOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { Block &firstBlock = op.body().front(); Operation *branchOp = firstBlock.getTerminator(); Block *secondBlock = &*(std::next(op.body().begin())); auto succOperands = branchOp->getOperands(); SmallVector replacements(succOperands); rewriter.eraseOp(branchOp); rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); rewriter.updateRootInPlace(op, [] {}); return success(); } }; /// A rewrite pattern to tests the undo mechanism of blocks being merged. struct TestUndoBlocksMerge : public ConversionPattern { TestUndoBlocksMerge(MLIRContext *ctx) : ConversionPattern("test.undo_blocks_merge", /*benefit=*/1, ctx) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { Block &firstBlock = op->getRegion(0).front(); Operation *branchOp = firstBlock.getTerminator(); Block *secondBlock = &*(std::next(op->getRegion(0).begin())); rewriter.setInsertionPointToStart(secondBlock); rewriter.create(op->getLoc(), rewriter.getF32Type()); auto succOperands = branchOp->getOperands(); SmallVector replacements(succOperands); rewriter.eraseOp(branchOp); rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); rewriter.updateRootInPlace(op, [] {}); return success(); } }; /// A rewrite mechanism to inline the body of the op into its parent, when both /// ops can have a single block. struct TestMergeSingleBlockOps : public OpConversionPattern { using OpConversionPattern< SingleBlockImplicitTerminatorOp>::OpConversionPattern; LogicalResult matchAndRewrite(SingleBlockImplicitTerminatorOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { SingleBlockImplicitTerminatorOp parentOp = op.getParentOfType(); if (!parentOp) return failure(); Block &innerBlock = op.region().front(); TerminatorOp innerTerminator = cast(innerBlock.getTerminator()); rewriter.mergeBlockBefore(&innerBlock, op); rewriter.eraseOp(innerTerminator); rewriter.eraseOp(op); rewriter.updateRootInPlace(op, [] {}); return success(); } }; struct TestMergeBlocksPatternDriver : public PassWrapper> { void runOnOperation() override { mlir::OwningRewritePatternList patterns; MLIRContext *context = &getContext(); patterns .insert( context); ConversionTarget target(*context); target.addLegalOp(); target.addIllegalOp(); /// Expect the op to have a single block after legalization. target.addDynamicallyLegalOp( [&](TestMergeBlocksOp op) -> bool { return llvm::hasSingleElement(op.body()); }); /// Only allow `test.br` within test.merge_blocks op. target.addDynamicallyLegalOp([&](TestBranchOp op) -> bool { return op.getParentOfType(); }); /// Expect that all nested test.SingleBlockImplicitTerminator ops are /// inlined. target.addDynamicallyLegalOp( [&](SingleBlockImplicitTerminatorOp op) -> bool { return !op.getParentOfType(); }); DenseSet unlegalizedOps; (void)applyPartialConversion(getOperation(), target, patterns, &unlegalizedOps); for (auto *op : unlegalizedOps) op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; } }; } // namespace //===----------------------------------------------------------------------===// // PassRegistration //===----------------------------------------------------------------------===// namespace mlir { void registerPatternsTestPass() { PassRegistration("test-return-type", "Run return type functions"); PassRegistration("test-derived-attr", "Run test derived attributes"); PassRegistration("test-patterns", "Run test dialect patterns"); PassRegistration( "test-legalize-patterns", "Run test dialect legalization patterns", [] { return std::make_unique( legalizerConversionMode); }); PassRegistration( "test-remapped-value", "Test public remapped value mechanism in ConversionPatternRewriter"); PassRegistration( "test-legalize-unknown-root-patterns", "Test public remapped value mechanism in ConversionPatternRewriter"); PassRegistration( "test-legalize-type-conversion", "Test various type conversion functionalities in DialectConversion"); PassRegistration{ "test-merge-blocks", "Test Merging operation in ConversionPatternRewriter"}; } } // namespace mlir diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir index 5986be6240f9..616e116cb170 100644 --- a/mlir/test/mlir-tblgen/pattern.mlir +++ b/mlir/test/mlir-tblgen/pattern.mlir @@ -1,430 +1,482 @@ // RUN: mlir-opt -test-patterns -mlir-print-debuginfo %s | FileCheck %s // CHECK-LABEL: verifyFusedLocs func @verifyFusedLocs(%arg0 : i32) -> i32 { %0 = "test.op_a"(%arg0) {attr = 10 : i32} : (i32) -> i32 loc("a") %result = "test.op_a"(%0) {attr = 20 : i32} : (i32) -> i32 loc("b") // CHECK: "test.op_b"(%arg0) {attr = 10 : i32} : (i32) -> i32 loc("a") // CHECK: "test.op_b"(%arg0) {attr = 20 : i32} : (i32) -> i32 loc(fused["b", "a"]) return %result : i32 } // CHECK-LABEL: verifyDesignatedLoc func @verifyDesignatedLoc(%arg0 : i32) -> i32 { %0 = "test.loc_src"(%arg0) : (i32) -> i32 loc("loc3") %1 = "test.loc_src"(%0) : (i32) -> i32 loc("loc2") %2 = "test.loc_src"(%1) : (i32) -> i32 loc("loc1") // CHECK: "test.loc_dst"({{.*}}) : (i32) -> i32 loc("loc1") // CHECK: "test.loc_dst"({{.*}}) : (i32) -> i32 loc("named") // CHECK: "test.loc_dst"({{.*}}) : (i32) -> i32 loc(fused<"fused">["loc2", "loc3"]) return %1 : i32 } // CHECK-LABEL: verifyZeroResult func @verifyZeroResult(%arg0 : i32) { // CHECK: "test.op_i"(%arg0) : (i32) -> () "test.op_h"(%arg0) : (i32) -> () return } // CHECK-LABEL: verifyZeroArg func @verifyZeroArg() -> i32 { // CHECK: "test.op_k"() : () -> i32 %0 = "test.op_j"() : () -> i32 return %0 : i32 } // CHECK-LABEL: testIgnoreArgMatch // CHECK-SAME: (%{{[a-z0-9]*}}: i32, %[[ARG1:[a-z0-9]*]]: i32 func @testIgnoreArgMatch(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: f32) { // CHECK: "test.ignore_arg_match_dst"(%[[ARG1]]) {f = 15 : i64} "test.ignore_arg_match_src"(%arg0, %arg1, %arg2) {d = 42, e = 24, f = 15} : (i32, i32, i32) -> () // CHECK: test.ignore_arg_match_src // Not match because wrong type for $c. "test.ignore_arg_match_src"(%arg0, %arg1, %arg3) {d = 42, e = 24, f = 15} : (i32, i32, f32) -> () // CHECK: test.ignore_arg_match_src // Not match because wrong type for $f. "test.ignore_arg_match_src"(%arg0, %arg1, %arg2) {d = 42 : i32, e = 24, f = 15} : (i32, i32, i32) -> () return } // CHECK-LABEL: verifyInterleavedOperandAttribute // CHECK-SAME: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32 func @verifyInterleavedOperandAttribute(%arg0: i32, %arg1: i32) { // CHECK: "test.interleaved_operand_attr2"(%[[ARG0]], %[[ARG1]]) {attr1 = 15 : i64, attr2 = 42 : i64} "test.interleaved_operand_attr1"(%arg0, %arg1) {attr1 = 15, attr2 = 42} : (i32, i32) -> () return } // CHECK-LABEL: verifyBenefit func @verifyBenefit(%arg0 : i32) -> i32 { %0 = "test.op_d"(%arg0) : (i32) -> i32 %1 = "test.op_g"(%arg0) : (i32) -> i32 %2 = "test.op_g"(%1) : (i32) -> i32 // CHECK: "test.op_f"(%arg0) // CHECK: "test.op_b"(%arg0) {attr = 34 : i32} return %0 : i32 } // CHECK-LABEL: verifyNativeCodeCall func @verifyNativeCodeCall(%arg0: i32, %arg1: i32) -> (i32, i32) { // CHECK: %0 = "test.native_code_call2"(%arg0) {attr = [42, 24]} : (i32) -> i32 // CHECK: return %0, %arg1 %0 = "test.native_code_call1"(%arg0, %arg1) {choice = true, attr1 = 42, attr2 = 24} : (i32, i32) -> (i32) %1 = "test.native_code_call1"(%arg0, %arg1) {choice = false, attr1 = 42, attr2 = 24} : (i32, i32) -> (i32) return %0, %1: i32, i32 } // CHECK-LABEL: verifyAuxiliaryNativeCodeCall func @verifyAuxiliaryNativeCodeCall(%arg0: i32) -> (i32) { // CHECK: test.op_i // CHECK: test.op_k %0 = "test.native_code_call3"(%arg0) : (i32) -> (i32) return %0 : i32 } // CHECK-LABEL: verifyAllAttrConstraintOf func @verifyAllAttrConstraintOf() -> (i32, i32, i32) { // CHECK: "test.all_attr_constraint_of2" %0 = "test.all_attr_constraint_of1"() {attr = [0, 1]} : () -> (i32) // CHECK: "test.all_attr_constraint_of1" %1 = "test.all_attr_constraint_of1"() {attr = [0, 2]} : () -> (i32) // CHECK: "test.all_attr_constraint_of1" %2 = "test.all_attr_constraint_of1"() {attr = [-1, 1]} : () -> (i32) return %0, %1, %2: i32, i32, i32 } // CHECK-LABEL: verifyManyArgs // CHECK-SAME: (%[[ARG:.*]]: i32) func @verifyManyArgs(%arg: i32) { // CHECK: "test.many_arguments"(%[[ARG]], %[[ARG]], %[[ARG]], %[[ARG]], %[[ARG]], %[[ARG]], %[[ARG]], %[[ARG]], %[[ARG]]) // CHECK-SAME: {attr1 = 24 : i64, attr2 = 42 : i64, attr3 = 42 : i64, attr4 = 42 : i64, attr5 = 42 : i64, attr6 = 42 : i64, attr7 = 42 : i64, attr8 = 42 : i64, attr9 = 42 : i64} "test.many_arguments"(%arg, %arg, %arg, %arg, %arg, %arg, %arg, %arg, %arg) { attr1 = 42, attr2 = 42, attr3 = 42, attr4 = 42, attr5 = 42, attr6 = 42, attr7 = 42, attr8 = 42, attr9 = 42 } : (i32, i32, i32, i32, i32, i32, i32, i32, i32) -> () return } // CHECK-LABEL: verifyEqualArgs func @verifyEqualArgs(%arg0: i32, %arg1: i32) { // def TestEqualArgsPattern : Pat<(OpN $a, $a), (OpO $a)>; // CHECK: "test.op_o"(%arg0) : (i32) -> i32 "test.op_n"(%arg0, %arg0) : (i32, i32) -> (i32) // CHECK: "test.op_n"(%arg0, %arg1) : (i32, i32) -> i32 "test.op_n"(%arg0, %arg1) : (i32, i32) -> (i32) return } // CHECK-LABEL: verifyNestedOpEqualArgs func @verifyNestedOpEqualArgs( %arg0: i32, %arg1: i32, %arg2 : i32, %arg3 : i32, %arg4 : i32, %arg5 : i32) { // def TestNestedOpEqualArgsPattern : // Pat<(OpN $b, (OpP $a, $b, $c, $d, $e, $f)), (replaceWithValue $b)>; // CHECK: %arg1 %0 = "test.op_p"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (i32, i32, i32, i32, i32, i32) -> (i32) %1 = "test.op_n"(%arg1, %0) : (i32, i32) -> (i32) // CHECK: test.op_p // CHECK: test.op_n %2 = "test.op_p"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (i32, i32, i32, i32, i32, i32) -> (i32) %3 = "test.op_n"(%arg0, %2) : (i32, i32) -> (i32) return } // CHECK-LABEL: verifyMultipleEqualArgs func @verifyMultipleEqualArgs( %arg0: i32, %arg1 : i32, %arg2 : i32, %arg3 : i32, %arg4 : i32) { // def TestMultipleEqualArgsPattern : // Pat<(OpP $a, $b, $a, $a, $b, $c), (OpN $c, $b)>; // CHECK: "test.op_n"(%arg2, %arg1) : (i32, i32) -> i32 "test.op_p"(%arg0, %arg1, %arg0, %arg0, %arg1, %arg2) : (i32, i32, i32, i32 , i32, i32) -> i32 // CHECK: test.op_p "test.op_p"(%arg0, %arg1, %arg0, %arg0, %arg0, %arg2) : (i32, i32, i32, i32 , i32, i32) -> i32 // CHECK: test.op_p "test.op_p"(%arg0, %arg1, %arg1, %arg0, %arg1, %arg2) : (i32, i32, i32, i32 , i32, i32) -> i32 // CHECK: test.op_p "test.op_p"(%arg0, %arg1, %arg2, %arg2, %arg3, %arg4) : (i32, i32, i32, i32 , i32, i32) -> i32 return } //===----------------------------------------------------------------------===// // Test Symbol Binding //===----------------------------------------------------------------------===// // CHECK-LABEL: symbolBinding func @symbolBinding(%arg0: i32) -> i32 { // An op with one use is matched. // CHECK: %0 = "test.symbol_binding_b"(%arg0) // CHECK: %1 = "test.symbol_binding_c"(%0) // CHECK: %2 = "test.symbol_binding_d"(%0, %1) {attr = 42 : i64} %0 = "test.symbol_binding_a"(%arg0) {attr = 42} : (i32) -> (i32) // An op without any use is not matched. // CHECK: "test.symbol_binding_a"(%arg0) %1 = "test.symbol_binding_a"(%arg0) {attr = 42} : (i32) -> (i32) // CHECK: return %2 return %0: i32 } // CHECK-LABEL: symbolBindingNoResult func @symbolBindingNoResult(%arg0: i32) { // CHECK: test.symbol_binding_b "test.symbol_binding_no_result"(%arg0) : (i32) -> () return } //===----------------------------------------------------------------------===// // Test Attributes //===----------------------------------------------------------------------===// // CHECK-LABEL: succeedMatchOpAttr func @succeedMatchOpAttr() -> i32 { // CHECK: "test.match_op_attribute2"() {default_valued_attr = 3 : i32, more_attr = 4 : i32, optional_attr = 2 : i32, required_attr = 1 : i32} %0 = "test.match_op_attribute1"() {required_attr = 1: i32, optional_attr = 2: i32, default_valued_attr = 3: i32, more_attr = 4: i32} : () -> (i32) return %0: i32 } // CHECK-LABEL: succeedMatchMissingOptionalAttr func @succeedMatchMissingOptionalAttr() -> i32 { // CHECK: "test.match_op_attribute2"() {default_valued_attr = 3 : i32, more_attr = 4 : i32, required_attr = 1 : i32} %0 = "test.match_op_attribute1"() {required_attr = 1: i32, default_valued_attr = 3: i32, more_attr = 4: i32} : () -> (i32) return %0: i32 } // CHECK-LABEL: succeedMatchMissingDefaultValuedAttr func @succeedMatchMissingDefaultValuedAttr() -> i32 { // CHECK: "test.match_op_attribute2"() {default_valued_attr = 42 : i32, more_attr = 4 : i32, optional_attr = 2 : i32, required_attr = 1 : i32} %0 = "test.match_op_attribute1"() {required_attr = 1: i32, optional_attr = 2: i32, more_attr = 4: i32} : () -> (i32) return %0: i32 } // CHECK-LABEL: failedMatchAdditionalConstraintNotSatisfied func @failedMatchAdditionalConstraintNotSatisfied() -> i32 { // CHECK: "test.match_op_attribute1"() %0 = "test.match_op_attribute1"() {required_attr = 1: i32, optional_attr = 2: i32, more_attr = 5: i32} : () -> (i32) return %0: i32 } // CHECK-LABEL: verifyConstantAttr func @verifyConstantAttr(%arg0 : i32) -> i32 { // CHECK: "test.op_b"(%arg0) {attr = 17 : i32} : (i32) -> i32 loc("a") %0 = "test.op_c"(%arg0) : (i32) -> i32 loc("a") return %0 : i32 } // CHECK-LABEL: verifyUnitAttr func @verifyUnitAttr() -> (i32, i32) { // Unit attribute present in the matched op is propagated as attr2. // CHECK: "test.match_op_attribute4"() {attr1, attr2} : () -> i32 %0 = "test.match_op_attribute3"() {attr} : () -> i32 // Since the original op doesn't have the unit attribute, the new op // only has the constant-constructed unit attribute attr1. // CHECK: "test.match_op_attribute4"() {attr1} : () -> i32 %1 = "test.match_op_attribute3"() : () -> i32 return %0, %1 : i32, i32 } +//===----------------------------------------------------------------------===// +// Test Constant Matching +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: testConstOp +func @testConstOp() -> (i32) { + // CHECK-NEXT: [[C0:%.+]] = constant 1 + %0 = "test.constant"() {value = 1 : i32} : () -> i32 + + // CHECK-NEXT: return [[C0]] + return %0 : i32 +} + +// CHECK-LABEL: testConstOpUsed +func @testConstOpUsed() -> (i32) { + // CHECK-NEXT: [[C0:%.+]] = constant 1 + %0 = "test.constant"() {value = 1 : i32} : () -> i32 + + // CHECK-NEXT: [[V0:%.+]] = "test.op_s"([[C0]]) + %1 = "test.op_s"(%0) {value = 1 : i32} : (i32) -> i32 + + // CHECK-NEXT: return [[V0]] + return %1 : i32 +} + +// CHECK-LABEL: testConstOpReplaced +func @testConstOpReplaced() -> (i32) { + // CHECK-NEXT: [[C0:%.+]] = constant 1 + %0 = "test.constant"() {value = 1 : i32} : () -> i32 + %1 = "test.constant"() {value = 2 : i32} : () -> i32 + + // CHECK: [[V0:%.+]] = "test.op_s"([[C0]]) {value = 2 : i32} + %2 = "test.op_r"(%0, %1) : (i32, i32) -> i32 + + // CHECK: [[V0]] + return %2 : i32 +} +// CHECK-LABEL: testConstOpMatchFailure +func @testConstOpMatchFailure() -> (i64) { + // CHECK-DAG: [[C0:%.+]] = constant 1 + %0 = "test.constant"() {value = 1 : i64} : () -> i64 + + // CHECK-DAG: [[C1:%.+]] = constant 2 + %1 = "test.constant"() {value = 2 : i64} : () -> i64 + + // CHECK: [[V0:%.+]] = "test.op_r"([[C0]], [[C1]]) + %2 = "test.op_r"(%0, %1) : (i64, i64) -> i64 + + // CHECK: [[V0]] + return %2 : i64 +} + //===----------------------------------------------------------------------===// // Test Enum Attributes //===----------------------------------------------------------------------===// // CHECK-LABEL: verifyStrEnumAttr func @verifyStrEnumAttr() -> i32 { // CHECK: "test.str_enum_attr"() {attr = "B"} %0 = "test.str_enum_attr"() {attr = "A"} : () -> i32 return %0 : i32 } // CHECK-LABEL: verifyI32EnumAttr func @verifyI32EnumAttr() -> i32 { // CHECK: "test.i32_enum_attr"() {attr = 10 : i32} %0 = "test.i32_enum_attr"() {attr = 5: i32} : () -> i32 return %0 : i32 } // CHECK-LABEL: verifyI64EnumAttr func @verifyI64EnumAttr() -> i32 { // CHECK: "test.i64_enum_attr"() {attr = 10 : i64} %0 = "test.i64_enum_attr"() {attr = 5: i64} : () -> i32 return %0 : i32 } //===----------------------------------------------------------------------===// // Test ElementsAttr //===----------------------------------------------------------------------===// // CHECK-LABEL: rewrite_i32elementsattr func @rewrite_i32elementsattr() -> () { // CHECK: attr = dense<0> : tensor "test.i32ElementsAttr"() {attr = dense<[3, 5]>:tensor<2xi32>} : () -> () return } // CHECK-LABEL: rewrite_f64elementsattr func @rewrite_f64elementsattr() -> () { "test.float_elements_attr"() { // Should match // CHECK: scalar_f32_attr = dense<[5.000000e+00, 6.000000e+00]> : tensor<2xf32> scalar_f32_attr = dense<[3.0, 4.0]> : tensor<2xf32>, tensor_f64_attr = dense<6.0> : tensor<4x8xf64> } : () -> () "test.float_elements_attr"() { // Should not match // CHECK: scalar_f32_attr = dense<7.000000e+00> : tensor<2xf32> scalar_f32_attr = dense<7.0> : tensor<2xf32>, tensor_f64_attr = dense<3.0> : tensor<4x8xf64> } : () -> () return } //===----------------------------------------------------------------------===// // Test Multi-result Ops //===----------------------------------------------------------------------===// // CHECK-LABEL: @useMultiResultOpToReplaceWhole func @useMultiResultOpToReplaceWhole() -> (i32, f32, f32) { // CHECK: %[[A:.*]], %[[B:.*]], %[[C:.*]] = "test.another_three_result"() // CHECK: return %[[A]], %[[B]], %[[C]] %0:3 = "test.three_result"() {kind = 1} : () -> (i32, f32, f32) return %0#0, %0#1, %0#2 : i32, f32, f32 } // CHECK-LABEL: @useMultiResultOpToReplacePartial1 func @useMultiResultOpToReplacePartial1() -> (i32, f32, f32) { // CHECK: %[[A:.*]], %[[B:.*]] = "test.two_result"() // CHECK: %[[C:.*]] = "test.one_result1"() // CHECK: return %[[A]], %[[B]], %[[C]] %0:3 = "test.three_result"() {kind = 2} : () -> (i32, f32, f32) return %0#0, %0#1, %0#2 : i32, f32, f32 } // CHECK-LABEL: @useMultiResultOpToReplacePartial2 func @useMultiResultOpToReplacePartial2() -> (i32, f32, f32) { // CHECK: %[[A:.*]] = "test.one_result2"() // CHECK: %[[B:.*]], %[[C:.*]] = "test.another_two_result"() // CHECK: return %[[A]], %[[B]], %[[C]] %0:3 = "test.three_result"() {kind = 3} : () -> (i32, f32, f32) return %0#0, %0#1, %0#2 : i32, f32, f32 } // CHECK-LABEL: @useMultiResultOpResultsSeparately func @useMultiResultOpResultsSeparately() -> (i32, f32, f32) { // CHECK: %[[A:.*]], %[[B:.*]] = "test.two_result"() // CHECK: %[[C:.*]] = "test.one_result1"() // CHECK: %[[D:.*]], %[[E:.*]] = "test.two_result"() // CHECK: return %[[A]], %[[C]], %[[E]] %0:3 = "test.three_result"() {kind = 4} : () -> (i32, f32, f32) return %0#0, %0#1, %0#2 : i32, f32, f32 } // CHECK-LABEL: @constraintOnSourceOpResult func @constraintOnSourceOpResult() -> (i32, f32, i32) { // CHECK: %[[A:.*]], %[[B:.*]] = "test.two_result"() // CHECK: %[[C:.*]] = "test.one_result2"() // CHECK: %[[D:.*]] = "test.one_result1"() // CHECK: return %[[A]], %[[B]], %[[C]] %0:2 = "test.two_result"() {kind = 5} : () -> (i32, f32) %1:2 = "test.two_result"() {kind = 5} : () -> (i32, f32) return %0#0, %0#1, %1#0 : i32, f32, i32 } // CHECK-LABEL: @useAuxiliaryOpToReplaceMultiResultOp func @useAuxiliaryOpToReplaceMultiResultOp() -> (i32, f32, f32) { // An auxiliary op is generated to help building the op for replacing the // matched op. // CHECK: %[[A:.*]], %[[B:.*]] = "test.two_result"() // CHECK: %[[C:.*]] = "test.one_result3"(%[[B]]) // CHECK: %[[D:.*]], %[[E:.*]] = "test.another_two_result"() // CHECK: return %[[C]], %[[D]], %[[E]] %0:3 = "test.three_result"() {kind = 6} : () -> (i32, f32, f32) return %0#0, %0#1, %0#2 : i32, f32, f32 } //===----------------------------------------------------------------------===// // Test Multi-result Ops //===----------------------------------------------------------------------===// // CHECK-LABEL: @replaceOneVariadicOutOneVariadicInOp func @replaceOneVariadicOutOneVariadicInOp(%arg0: i32, %arg1: i32, %arg2: i32) -> (i32, i32, i32, i32, i32, i32) { // CHECK: %[[cnt1:.*]] = "test.one_variadic_out_one_variadic_in2"(%arg0) // CHECK: %[[cnt2:.*]]:2 = "test.one_variadic_out_one_variadic_in2"(%arg0, %arg1) // CHECK: %[[cnt3:.*]]:3 = "test.one_variadic_out_one_variadic_in2"(%arg0, %arg1, %arg2) // CHECK: return %[[cnt1]], %[[cnt2]]#0, %[[cnt2]]#1, %[[cnt3]]#0, %[[cnt3]]#1, %[[cnt3]]#2 %0 = "test.one_variadic_out_one_variadic_in1"(%arg0) : (i32) -> (i32) %1:2 = "test.one_variadic_out_one_variadic_in1"(%arg0, %arg1) : (i32, i32) -> (i32, i32) %2:3 = "test.one_variadic_out_one_variadic_in1"(%arg0, %arg1, %arg2) : (i32, i32, i32) -> (i32, i32, i32) return %0, %1#0, %1#1, %2#0, %2#1, %2#2 : i32, i32, i32, i32, i32, i32 } // CHECK-LABEL: @replaceMixedVariadicInputOp func @replaceMixedVariadicInputOp(%arg0: i32, %arg1: f32, %arg2: i32) -> () { // CHECK: "test.mixed_variadic_in2"(%arg1) // CHECK: "test.mixed_variadic_in2"(%arg0, %arg1, %arg2) // CHECK: "test.mixed_variadic_in2"(%arg0, %arg0, %arg1, %arg2, %arg2) "test.mixed_variadic_in1"(%arg1) : (f32) -> () "test.mixed_variadic_in1"(%arg0, %arg1, %arg2) : (i32, f32, i32) -> () "test.mixed_variadic_in1"(%arg0, %arg0, %arg1, %arg2, %arg2) : (i32, i32, f32, i32, i32) -> () return } // CHECK-LABEL: @replaceMixedVariadicOutputOp func @replaceMixedVariadicOutputOp() -> (f32, i32, f32, i32, i32, i32, f32, i32, i32) { // CHECK: %[[cnt1:.*]] = "test.mixed_variadic_out2"() // CHECK: %[[cnt3_a:.*]], %[[cnt3_b:.*]], %[[cnt3_c:.*]] = "test.mixed_variadic_out2"() // CHECK: %[[cnt5_a:.*]]:2, %[[cnt5_b:.*]], %[[cnt5_c:.*]]:2 = "test.mixed_variadic_out2"() // CHECK: return %[[cnt1]], %[[cnt3_a]], %[[cnt3_b]], %[[cnt3_c]], %[[cnt5_a]]#0, %[[cnt5_a]]#1, %[[cnt5_b]], %[[cnt5_c]]#0, %[[cnt5_c]]#1 %0 = "test.mixed_variadic_out1"() : () -> (f32) %1:3 = "test.mixed_variadic_out1"() : () -> (i32, f32, i32) %2:5 = "test.mixed_variadic_out1"() : () -> (i32, i32, f32, i32, i32) return %0, %1#0, %1#1, %1#2, %2#0, %2#1, %2#2, %2#3, %2#4 : f32, i32, f32, i32, i32, i32, f32, i32, i32 } // CHECK-LABEL: @generateVariadicOutputOpInNestedPattern func @generateVariadicOutputOpInNestedPattern() -> (i32) { // CHECK: %[[cnt5_a:.*]], %[[cnt5_b:.*]]:2, %[[cnt5_c:.*]]:2 = "test.mixed_variadic_out3"() // CHECK: %[[res:.*]] = "test.mixed_variadic_in3"(%[[cnt5_a]], %[[cnt5_b]]#0, %[[cnt5_b]]#1, %[[cnt5_c]]#0, %[[cnt5_c]]#1) // CHECK: return %[[res]] %0 = "test.one_i32_out"() : () -> (i32) return %0 : i32 } //===----------------------------------------------------------------------===// // Test that natives calls are only called once during rewrites. //===----------------------------------------------------------------------===// // CHECK-LABEL: redundantTest func @redundantTest(%arg0: i32) -> i32 { %0 = "test.op_m"(%arg0) : (i32) -> i32 // CHECK: "test.op_m"(%arg0) {optional_attr = 314159265 : i32} : (i32) -> i32 return %0 : i32 } diff --git a/mlir/test/mlir-tblgen/rewriter-errors.td b/mlir/test/mlir-tblgen/rewriter-errors.td new file mode 100644 index 000000000000..eeb049482b88 --- /dev/null +++ b/mlir/test/mlir-tblgen/rewriter-errors.td @@ -0,0 +1,29 @@ +// RUN: not mlir-tblgen -gen-rewriters -I %S/../../include -DERROR1 %s 2>&1 | FileCheck --check-prefix=ERROR1 %s +// RUN: not mlir-tblgen -gen-rewriters -I %S/../../include -DERROR2 %s 2>&1 | FileCheck --check-prefix=ERROR2 %s + +include "mlir/IR/OpBase.td" + +// Check using the dialect name as the namespace +def A_Dialect : Dialect { + let name = "a"; +} + +class A_Op traits = []> : + Op; + +def OpA : A_Op<"op_a">, Arguments<(ins AnyInteger, AnyInteger)>, Results<(outs AnyInteger)>; +def OpB : A_Op<"op_b">, Arguments<(ins AnyInteger, AnyAttr:$value)>, Results<(outs AnyInteger)>; + +#ifdef ERROR1 +def NativeMatcher : NativeCodeCall<"success(nativeCall($0, $1))">; +// ERROR1: [[@LINE+1]]:1: error: binding symbol 'error' to native code call unsupported right now +def : Pat<(OpA (NativeMatcher:$error $val), AnyI32Attr:$arg), + (OpB $val, $arg)>; +#endif + +#ifdef ERROR2 +def NativeMatcher : NativeCodeCall<"success(nativeCall($0, $1))">; +// ERROR2: [[@LINE+1]]:1: error: Matching nested tree in NativeCodecall not support for +def : Pat<(OpA (NativeMatcher (OpB $val, $unused)), AnyI32Attr:$arg), + (OpB $val, $arg)>; +#endif diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp index 7bff3e3b40b6..5521eea38252 100644 --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -1,1198 +1,1309 @@ //===- RewriterGen.cpp - MLIR pattern rewriter generator ------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // RewriterGen uses pattern rewrite definitions to generate rewriter matchers. // //===----------------------------------------------------------------------===// #include "mlir/Support/IndentedOstream.h" #include "mlir/TableGen/Attribute.h" #include "mlir/TableGen/Format.h" #include "mlir/TableGen/GenInfo.h" #include "mlir/TableGen/Operator.h" #include "mlir/TableGen/Pattern.h" #include "mlir/TableGen/Predicate.h" #include "mlir/TableGen/Type.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringSet.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/FormatAdapters.h" #include "llvm/Support/PrettyStackTrace.h" #include "llvm/Support/Signals.h" #include "llvm/TableGen/Error.h" #include "llvm/TableGen/Main.h" #include "llvm/TableGen/Record.h" #include "llvm/TableGen/TableGenBackend.h" using namespace mlir; using namespace mlir::tblgen; using llvm::formatv; using llvm::Record; using llvm::RecordKeeper; #define DEBUG_TYPE "mlir-tblgen-rewritergen" namespace llvm { template <> struct format_provider { static void format(const mlir::tblgen::Pattern::IdentifierLine &v, raw_ostream &os, StringRef style) { os << v.first << ":" << v.second; } }; } // end namespace llvm //===----------------------------------------------------------------------===// // PatternEmitter //===----------------------------------------------------------------------===// namespace { class PatternEmitter { public: PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os); // Emits the mlir::RewritePattern struct named `rewriteName`. void emit(StringRef rewriteName); private: // Emits the code for matching ops. - void emitMatchLogic(DagNode tree); + void emitMatchLogic(DagNode tree, StringRef opName); // Emits the code for rewriting ops. void emitRewriteLogic(); //===--------------------------------------------------------------------===// // Match utilities //===--------------------------------------------------------------------===// + // Emits C++ statements for matching the DAG structure. + void emitMatch(DagNode tree, StringRef name, int depth); + + // Emits C++ statements for matching using a native code call. + void emitNativeCodeMatch(DagNode tree, StringRef name, int depth); + // Emits C++ statements for matching the op constrained by the given DAG - // `tree`. - void emitOpMatch(DagNode tree, int depth); + // `tree` returning the op's variable name. + void emitOpMatch(DagNode tree, StringRef opName, int depth); // Emits C++ statements for matching the `argIndex`-th argument of the given // DAG `tree` as an operand. - void emitOperandMatch(DagNode tree, int argIndex, int depth); + void emitOperandMatch(DagNode tree, StringRef opName, int argIndex, + int depth); // Emits C++ statements for matching the `argIndex`-th argument of the given // DAG `tree` as an attribute. - void emitAttributeMatch(DagNode tree, int argIndex, int depth); + void emitAttributeMatch(DagNode tree, StringRef opName, int argIndex, + int depth); // Emits C++ for checking a match with a corresponding match failure // diagnostic. - void emitMatchCheck(int depth, const FmtObjectBase &matchFmt, + void emitMatchCheck(StringRef opName, const FmtObjectBase &matchFmt, const llvm::formatv_object_base &failureFmt); // Emits C++ for checking a match with a corresponding match failure // diagnostics. - void emitMatchCheck(int depth, const std::string &matchStr, + void emitMatchCheck(StringRef opName, const std::string &matchStr, const std::string &failureStr); //===--------------------------------------------------------------------===// // Rewrite utilities //===--------------------------------------------------------------------===// // The entry point for handling a result pattern rooted at `resultTree`. This // method dispatches to concrete handlers according to `resultTree`'s kind and // returns a symbol representing the whole value pack. Callers are expected to // further resolve the symbol according to the specific use case. // // `depth` is the nesting level of `resultTree`; 0 means top-level result // pattern. For top-level result pattern, `resultIndex` indicates which result // of the matched root op this pattern is intended to replace, which can be // used to deduce the result type of the op generated from this result // pattern. std::string handleResultPattern(DagNode resultTree, int resultIndex, int depth); // Emits the C++ statement to replace the matched DAG with a value built via // calling native C++ code. - std::string handleReplaceWithNativeCodeCall(DagNode resultTree); + std::string handleReplaceWithNativeCodeCall(DagNode resultTree, int depth); // Returns the symbol of the old value serving as the replacement. StringRef handleReplaceWithValue(DagNode tree); // Returns the location value to use. std::pair getLocation(DagNode tree); // Returns the location value to use. std::string handleLocationDirective(DagNode tree); // Emits the C++ statement to build a new op out of the given DAG `tree` and // returns the variable name that this op is assigned to. If the root op in // DAG `tree` has a specified name, the created op will be assigned to a // variable of the given name. Otherwise, a unique name will be used as the // result value name. std::string handleOpCreation(DagNode tree, int resultIndex, int depth); using ChildNodeIndexNameMap = DenseMap; // Emits a local variable for each value and attribute to be used for creating // an op. void createSeparateLocalVarsForOpArgs(DagNode node, ChildNodeIndexNameMap &childNodeNames); // Emits the concrete arguments used to call an op's builder. void supplyValuesForOpArgs(DagNode node, - const ChildNodeIndexNameMap &childNodeNames); + const ChildNodeIndexNameMap &childNodeNames, + int depth); // Emits the local variables for holding all values as a whole and all named // attributes as a whole to be used for creating an op. void createAggregateLocalVarsForOpArgs( - DagNode node, const ChildNodeIndexNameMap &childNodeNames); + DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth); // Returns the C++ expression to construct a constant attribute of the given // `value` for the given attribute kind `attr`. std::string handleConstantAttr(Attribute attr, StringRef value); // Returns the C++ expression to build an argument from the given DAG `leaf`. // `patArgName` is used to bound the argument to the source pattern. std::string handleOpArgument(DagLeaf leaf, StringRef patArgName); //===--------------------------------------------------------------------===// // General utilities //===--------------------------------------------------------------------===// // Collects all of the operations within the given dag tree. void collectOps(DagNode tree, llvm::SmallPtrSetImpl &ops); // Returns a unique symbol for a local variable of the given `op`. std::string getUniqueSymbol(const Operator *op); //===--------------------------------------------------------------------===// // Symbol utilities //===--------------------------------------------------------------------===// // Returns how many static values the given DAG `node` correspond to. int getNodeValueCount(DagNode node); private: // Pattern instantiation location followed by the location of multiclass // prototypes used. This is intended to be used as a whole to // PrintFatalError() on errors. ArrayRef loc; // Op's TableGen Record to wrapper object. RecordOperatorMap *opMap; // Handy wrapper for pattern being emitted. Pattern pattern; // Map for all bound symbols' info. SymbolInfoMap symbolInfoMap; // The next unused ID for newly created values. unsigned nextValueId; raw_indented_ostream os; // Format contexts containing placeholder substitutions. FmtContext fmtCtx; // Number of op processed. int opCounter = 0; }; } // end anonymous namespace PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os) : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper), symbolInfoMap(pat->getLoc()), nextValueId(0), os(os) { fmtCtx.withBuilder("rewriter"); } std::string PatternEmitter::handleConstantAttr(Attribute attr, StringRef value) { if (!attr.isConstBuildable()) PrintFatalError(loc, "Attribute " + attr.getAttrDefName() + " does not have the 'constBuilderCall' field"); // TODO: Verify the constants here return std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, value)); } // Helper function to match patterns. -void PatternEmitter::emitOpMatch(DagNode tree, int depth) { +void PatternEmitter::emitMatch(DagNode tree, StringRef name, int depth) { + if (tree.isNativeCodeCall()) { + emitNativeCodeMatch(tree, name, depth); + return; + } + + if (tree.isOperation()) { + emitOpMatch(tree, name, depth); + return; + } + + PrintFatalError(loc, "encountered non-op, non-NativeCodeCall match."); +} + +// Helper function to match patterns. +void PatternEmitter::emitNativeCodeMatch(DagNode tree, StringRef opName, + int depth) { + LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall matcher pattern: "); + LLVM_DEBUG(tree.print(llvm::dbgs())); + LLVM_DEBUG(llvm::dbgs() << '\n'); + + // TODO(suderman): iterate through arguments, determine their types, output + // names. + SmallVector capture(8); + if (tree.getNumArgs() > 8) { + PrintFatalError(loc, + "unsupported NativeCodeCall matcher argument numbers: " + + Twine(tree.getNumArgs())); + } + + raw_indented_ostream::DelimitedScope scope(os); + + for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { + std::string argName = formatv("arg{0}_{1}", depth, i); + if (DagNode argTree = tree.getArgAsNestedDag(i)) { + os << "Value " << argName << ";\n"; + } else { + auto leaf = tree.getArgAsLeaf(i); + if (leaf.isAttrMatcher() || leaf.isConstantAttr()) { + os << "Attribute " << argName << ";\n"; + } else if (leaf.isOperandMatcher()) { + os << "Operation " << argName << ";\n"; + } + } + + capture[i] = std::move(argName); + } + + bool hasLocationDirective; + std::string locToUse; + std::tie(hasLocationDirective, locToUse) = getLocation(tree); + + auto fmt = tree.getNativeCodeTemplate(); + auto nativeCodeCall = std::string(tgfmt( + fmt, &fmtCtx.addSubst("_loc", locToUse), opName, capture[0], capture[1], + capture[2], capture[3], capture[4], capture[5], capture[6], capture[7])); + + os << "if (failed(" << nativeCodeCall << ")) return failure();\n"; + + for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { + auto name = tree.getArgName(i); + if (!name.empty() && name != "_") { + os << formatv("{0} = {1};\n", name, capture[i]); + } + } + + for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { + std::string argName = capture[i]; + + // Handle nested DAG construct first + if (DagNode argTree = tree.getArgAsNestedDag(i)) { + PrintFatalError( + loc, formatv("Matching nested tree in NativeCodecall not support for " + "{0} as arg {1}", + argName, i)); + } + + DagLeaf leaf = tree.getArgAsLeaf(i); + auto constraint = leaf.getAsConstraint(); + + auto self = formatv("{0}", argName); + emitMatchCheck( + opName, + tgfmt(constraint.getConditionTemplate(), &fmtCtx.withSelf(self)), + formatv("\"operand {0} of native code call '{1}' failed to satisfy " + "constraint: " + "'{2}'\"", + i, tree.getNativeCodeTemplate(), constraint.getDescription())); + } + + LLVM_DEBUG(llvm::dbgs() << "done emitting match for native code call\n"); +} + +// Helper function to match patterns. +void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) { Operator &op = tree.getDialectOp(opMap); LLVM_DEBUG(llvm::dbgs() << "start emitting match for op '" << op.getOperationName() << "' at depth " << depth << '\n'); - int indent = 4 + 2 * depth; - os.indent(indent) << formatv( - "auto castedOp{0} = ::llvm::dyn_cast_or_null<{1}>(op{0}); " - "(void)castedOp{0};\n", - depth, op.getQualCppClassName()); + std::string castedName = formatv("castedOp{0}", depth); + os << formatv("auto {0} = ::llvm::dyn_cast_or_null<{2}>({1}); " + "(void){0};\n", + castedName, opName, op.getQualCppClassName()); // Skip the operand matching at depth 0 as the pattern rewriter already does. if (depth != 0) { // Skip if there is no defining operation (e.g., arguments to function). - os << formatv("if (!castedOp{0})\n return failure();\n", depth); + os << formatv("if (!{0}) return failure();\n", castedName); } if (tree.getNumArgs() != op.getNumArgs()) { PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in " "pattern vs. {2} in definition", op.getOperationName(), tree.getNumArgs(), op.getNumArgs())); } // If the operand's name is set, set to that variable. auto name = tree.getSymbol(); if (!name.empty()) - os << formatv("{0} = castedOp{1};\n", name, depth); + os << formatv("{0} = {1};\n", name, castedName); for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { auto opArg = op.getArg(i); + std::string argName = formatv("op{0}", depth + 1); // Handle nested DAG construct first if (DagNode argTree = tree.getArgAsNestedDag(i)) { if (auto *operand = opArg.dyn_cast()) { if (operand->isVariableLength()) { auto error = formatv("use nested DAG construct to match op {0}'s " "variadic operand #{1} unsupported now", op.getOperationName(), i); PrintFatalError(loc, error); } } os << "{\n"; os.indent() << formatv( - "auto *op{0} = " - "(*castedOp{1}.getODSOperands({2}).begin()).getDefiningOp();\n", - depth + 1, depth, i); - emitOpMatch(argTree, depth + 1); - os << formatv("tblgen_ops[{0}] = op{1};\n", ++opCounter, depth + 1); + "auto *{0} = " + "(*{1}.getODSOperands({2}).begin()).getDefiningOp();\n", + argName, castedName, i); + emitMatch(argTree, argName, depth + 1); + os << formatv("tblgen_ops[{0}] = {1};\n", ++opCounter, argName); os.unindent() << "}\n"; continue; } // Next handle DAG leaf: operand or attribute if (opArg.is()) { - emitOperandMatch(tree, i, depth); + emitOperandMatch(tree, castedName, i, depth); } else if (opArg.is()) { - emitAttributeMatch(tree, i, depth); + emitAttributeMatch(tree, opName, i, depth); } else { PrintFatalError(loc, "unhandled case when matching op"); } } LLVM_DEBUG(llvm::dbgs() << "done emitting match for op '" << op.getOperationName() << "' at depth " << depth << '\n'); } -void PatternEmitter::emitOperandMatch(DagNode tree, int argIndex, int depth) { +void PatternEmitter::emitOperandMatch(DagNode tree, StringRef opName, + int argIndex, int depth) { Operator &op = tree.getDialectOp(opMap); auto *operand = op.getArg(argIndex).get(); auto matcher = tree.getArgAsLeaf(argIndex); // If a constraint is specified, we need to generate C++ statements to // check the constraint. if (!matcher.isUnspecified()) { if (!matcher.isOperandMatcher()) { PrintFatalError( loc, formatv("the {1}-th argument of op '{0}' should be an operand", op.getOperationName(), argIndex + 1)); } // Only need to verify if the matcher's type is different from the one // of op definition. Constraint constraint = matcher.getAsConstraint(); if (operand->constraint != constraint) { if (operand->isVariableLength()) { auto error = formatv( "further constrain op {0}'s variadic operand #{1} unsupported now", op.getOperationName(), argIndex); PrintFatalError(loc, error); } - auto self = - formatv("(*castedOp{0}.getODSOperands({1}).begin()).getType()", depth, - argIndex); + auto self = formatv("(*{0}.getODSOperands({1}).begin()).getType()", + opName, argIndex); emitMatchCheck( - depth, + opName, tgfmt(constraint.getConditionTemplate(), &fmtCtx.withSelf(self)), formatv("\"operand {0} of op '{1}' failed to satisfy constraint: " "'{2}'\"", operand - op.operand_begin(), op.getOperationName(), constraint.getDescription())); } } // Capture the value auto name = tree.getArgName(argIndex); // `$_` is a special symbol to ignore op argument matching. if (!name.empty() && name != "_") { // We need to subtract the number of attributes before this operand to get // the index in the operand list. auto numPrevAttrs = std::count_if( op.arg_begin(), op.arg_begin() + argIndex, [](const Argument &arg) { return arg.is(); }); auto res = symbolInfoMap.findBoundSymbol(name, op, argIndex); - os << formatv("{0} = castedOp{1}.getODSOperands({2});\n", - res->second.getVarName(name), depth, argIndex - numPrevAttrs); + os << formatv("{0} = {1}.getODSOperands({2});\n", + res->second.getVarName(name), opName, + argIndex - numPrevAttrs); } } -void PatternEmitter::emitAttributeMatch(DagNode tree, int argIndex, int depth) { +void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName, + int argIndex, int depth) { Operator &op = tree.getDialectOp(opMap); auto *namedAttr = op.getArg(argIndex).get(); const auto &attr = namedAttr->attr; os << "{\n"; - os.indent() << formatv( - "auto tblgen_attr = op{0}->getAttrOfType<{1}>(\"{2}\"); " - "(void)tblgen_attr;\n", - depth, attr.getStorageType(), namedAttr->name); + os.indent() << formatv("auto tblgen_attr = {0}->getAttrOfType<{1}>(\"{2}\");" + "(void)tblgen_attr;\n", + opName, attr.getStorageType(), namedAttr->name); // TODO: This should use getter method to avoid duplication. if (attr.hasDefaultValue()) { os << "if (!tblgen_attr) tblgen_attr = " << std::string(tgfmt(attr.getConstBuilderTemplate(), &fmtCtx, attr.getDefaultValue())) << ";\n"; } else if (attr.isOptional()) { // For a missing attribute that is optional according to definition, we // should just capture a mlir::Attribute() to signal the missing state. // That is precisely what getAttr() returns on missing attributes. } else { - emitMatchCheck(depth, tgfmt("tblgen_attr", &fmtCtx), + emitMatchCheck(opName, tgfmt("tblgen_attr", &fmtCtx), formatv("\"expected op '{0}' to have attribute '{1}' " "of type '{2}'\"", op.getOperationName(), namedAttr->name, attr.getStorageType())); } auto matcher = tree.getArgAsLeaf(argIndex); if (!matcher.isUnspecified()) { if (!matcher.isAttrMatcher()) { PrintFatalError( loc, formatv("the {1}-th argument of op '{0}' should be an attribute", op.getOperationName(), argIndex + 1)); } // If a constraint is specified, we need to generate C++ statements to // check the constraint. emitMatchCheck( - depth, + opName, tgfmt(matcher.getConditionTemplate(), &fmtCtx.withSelf("tblgen_attr")), formatv("\"op '{0}' attribute '{1}' failed to satisfy constraint: " "{2}\"", op.getOperationName(), namedAttr->name, matcher.getAsConstraint().getDescription())); } // Capture the value auto name = tree.getArgName(argIndex); // `$_` is a special symbol to ignore op argument matching. if (!name.empty() && name != "_") { os << formatv("{0} = tblgen_attr;\n", name); } os.unindent() << "}\n"; } void PatternEmitter::emitMatchCheck( - int depth, const FmtObjectBase &matchFmt, + StringRef opName, const FmtObjectBase &matchFmt, const llvm::formatv_object_base &failureFmt) { - emitMatchCheck(depth, matchFmt.str(), failureFmt.str()); + emitMatchCheck(opName, matchFmt.str(), failureFmt.str()); } -void PatternEmitter::emitMatchCheck(int depth, const std::string &matchStr, +void PatternEmitter::emitMatchCheck(StringRef opName, + const std::string &matchStr, const std::string &failureStr) { + os << "if (!(" << matchStr << "))"; - os.scope("{\n", "\n}\n").os - << "return rewriter.notifyMatchFailure(op" << depth - << ", [&](::mlir::Diagnostic &diag) {\n diag << " << failureStr - << ";\n});"; + os.scope("{\n", "\n}\n").os << "return rewriter.notifyMatchFailure(" << opName + << ", [&](::mlir::Diagnostic &diag) {\n diag << " + << failureStr << ";\n});"; } -void PatternEmitter::emitMatchLogic(DagNode tree) { +void PatternEmitter::emitMatchLogic(DagNode tree, StringRef opName) { LLVM_DEBUG(llvm::dbgs() << "--- start emitting match logic ---\n"); int depth = 0; - emitOpMatch(tree, depth); + emitMatch(tree, opName, depth); for (auto &appliedConstraint : pattern.getConstraints()) { auto &constraint = appliedConstraint.constraint; auto &entities = appliedConstraint.entities; auto condition = constraint.getConditionTemplate(); if (isa(constraint)) { auto self = formatv("({0}.getType())", symbolInfoMap.getValueAndRangeUse(entities.front())); emitMatchCheck( - depth, tgfmt(condition, &fmtCtx.withSelf(self.str())), + opName, tgfmt(condition, &fmtCtx.withSelf(self.str())), formatv("\"value entity '{0}' failed to satisfy constraint: {1}\"", entities.front(), constraint.getDescription())); } else if (isa(constraint)) { PrintFatalError( loc, "cannot use AttrConstraint in Pattern multi-entity constraints"); } else { // TODO: replace formatv arguments with the exact specified // args. if (entities.size() > 4) { PrintFatalError(loc, "only support up to 4-entity constraints now"); } SmallVector names; int i = 0; for (int e = entities.size(); i < e; ++i) names.push_back(symbolInfoMap.getValueAndRangeUse(entities[i])); std::string self = appliedConstraint.self; if (!self.empty()) self = symbolInfoMap.getValueAndRangeUse(self); for (; i < 4; ++i) names.push_back(""); - emitMatchCheck(depth, + emitMatchCheck(opName, tgfmt(condition, &fmtCtx.withSelf(self), names[0], names[1], names[2], names[3]), formatv("\"entities '{0}' failed to satisfy constraint: " "{1}\"", llvm::join(entities, ", "), constraint.getDescription())); } } // Some of the operands could be bound to the same symbol name, we need // to enforce equality constraint on those. // TODO: we should be able to emit equality checks early // and short circuit unnecessary work if vars are not equal. for (auto symbolInfoIt = symbolInfoMap.begin(); symbolInfoIt != symbolInfoMap.end();) { auto range = symbolInfoMap.getRangeOfEqualElements(symbolInfoIt->first); auto startRange = range.first; auto endRange = range.second; auto firstOperand = symbolInfoIt->second.getVarName(symbolInfoIt->first); for (++startRange; startRange != endRange; ++startRange) { auto secondOperand = startRange->second.getVarName(symbolInfoIt->first); emitMatchCheck( - depth, + opName, formatv("*{0}.begin() == *{1}.begin()", firstOperand, secondOperand), formatv("\"Operands '{0}' and '{1}' must be equal\"", firstOperand, secondOperand)); } symbolInfoIt = endRange; } LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n"); } void PatternEmitter::collectOps(DagNode tree, llvm::SmallPtrSetImpl &ops) { // Check if this tree is an operation. if (tree.isOperation()) { const Operator &op = tree.getDialectOp(opMap); LLVM_DEBUG(llvm::dbgs() << "found operation " << op.getOperationName() << '\n'); ops.insert(&op); } // Recurse the arguments of the tree. for (unsigned i = 0, e = tree.getNumArgs(); i != e; ++i) if (auto child = tree.getArgAsNestedDag(i)) collectOps(child, ops); } void PatternEmitter::emit(StringRef rewriteName) { // Get the DAG tree for the source pattern. DagNode sourceTree = pattern.getSourcePattern(); const Operator &rootOp = pattern.getSourceRootOp(); auto rootName = rootOp.getOperationName(); // Collect the set of result operations. llvm::SmallPtrSet resultOps; LLVM_DEBUG(llvm::dbgs() << "start collecting ops used in result patterns\n"); for (unsigned i = 0, e = pattern.getNumResultPatterns(); i != e; ++i) { collectOps(pattern.getResultPattern(i), resultOps); } LLVM_DEBUG(llvm::dbgs() << "done collecting ops used in result patterns\n"); // Emit RewritePattern for Pattern. auto locs = pattern.getLocation(); os << formatv("/* Generated from:\n {0:$[ instantiating\n ]}\n*/\n", make_range(locs.rbegin(), locs.rend())); os << formatv(R"(struct {0} : public ::mlir::RewritePattern { {0}(::mlir::MLIRContext *context) : ::mlir::RewritePattern("{1}", {{)", rewriteName, rootName); // Sort result operators by name. llvm::SmallVector sortedResultOps(resultOps.begin(), resultOps.end()); llvm::sort(sortedResultOps, [&](const Operator *lhs, const Operator *rhs) { return lhs->getOperationName() < rhs->getOperationName(); }); llvm::interleaveComma(sortedResultOps, os, [&](const Operator *op) { os << '"' << op->getOperationName() << '"'; }); os << formatv(R"(}, {0}, context) {{})", pattern.getBenefit()) << "\n"; // Emit matchAndRewrite() function. { auto classScope = os.scope(); os.reindent(R"( ::mlir::LogicalResult matchAndRewrite(::mlir::Operation *op0, ::mlir::PatternRewriter &rewriter) const override {)") << '\n'; { auto functionScope = os.scope(); // Register all symbols bound in the source pattern. pattern.collectSourcePatternBoundSymbols(symbolInfoMap); LLVM_DEBUG(llvm::dbgs() << "start creating local variables for capturing matches\n"); os << "// Variables for capturing values and attributes used while " "creating ops\n"; // Create local variables for storing the arguments and results bound // to symbols. for (const auto &symbolInfoPair : symbolInfoMap) { const auto &symbol = symbolInfoPair.first; const auto &info = symbolInfoPair.second; os << info.getVarDecl(symbol); } // TODO: capture ops with consistent numbering so that it can be // reused for fused loc. os << formatv("::mlir::Operation *tblgen_ops[{0}];\n\n", pattern.getSourcePattern().getNumOps()); LLVM_DEBUG(llvm::dbgs() << "done creating local variables for capturing matches\n"); os << "// Match\n"; os << "tblgen_ops[0] = op0;\n"; - emitMatchLogic(sourceTree); + emitMatchLogic(sourceTree, "op0"); os << "\n// Rewrite\n"; emitRewriteLogic(); os << "return ::mlir::success();\n"; } os << "};\n"; } os << "};\n\n"; } void PatternEmitter::emitRewriteLogic() { LLVM_DEBUG(llvm::dbgs() << "--- start emitting rewrite logic ---\n"); const Operator &rootOp = pattern.getSourceRootOp(); int numExpectedResults = rootOp.getNumResults(); int numResultPatterns = pattern.getNumResultPatterns(); // First register all symbols bound to ops generated in result patterns. pattern.collectResultPatternBoundSymbols(symbolInfoMap); // Only the last N static values generated are used to replace the matched // root N-result op. We need to calculate the starting index (of the results // of the matched op) each result pattern is to replace. SmallVector offsets(numResultPatterns + 1, numExpectedResults); // If we don't need to replace any value at all, set the replacement starting // index as the number of result patterns so we skip all of them when trying // to replace the matched op's results. int replStartIndex = numExpectedResults == 0 ? numResultPatterns : -1; for (int i = numResultPatterns - 1; i >= 0; --i) { auto numValues = getNodeValueCount(pattern.getResultPattern(i)); offsets[i] = offsets[i + 1] - numValues; if (offsets[i] == 0) { if (replStartIndex == -1) replStartIndex = i; } else if (offsets[i] < 0 && offsets[i + 1] > 0) { auto error = formatv( "cannot use the same multi-result op '{0}' to generate both " "auxiliary values and values to be used for replacing the matched op", pattern.getResultPattern(i).getSymbol()); PrintFatalError(loc, error); } } if (offsets.front() > 0) { const char error[] = "no enough values generated to replace the matched op"; PrintFatalError(loc, error); } os << "auto odsLoc = rewriter.getFusedLoc({"; for (int i = 0, e = pattern.getSourcePattern().getNumOps(); i != e; ++i) { os << (i ? ", " : "") << "tblgen_ops[" << i << "]->getLoc()"; } os << "}); (void)odsLoc;\n"; // Process auxiliary result patterns. for (int i = 0; i < replStartIndex; ++i) { DagNode resultTree = pattern.getResultPattern(i); auto val = handleResultPattern(resultTree, offsets[i], 0); // Normal op creation will be streamed to `os` by the above call; but // NativeCodeCall will only be materialized to `os` if it is used. Here // we are handling auxiliary patterns so we want the side effect even if // NativeCodeCall is not replacing matched root op's results. if (resultTree.isNativeCodeCall()) os << val << ";\n"; } if (numExpectedResults == 0) { assert(replStartIndex >= numResultPatterns && "invalid auxiliary vs. replacement pattern division!"); // No result to replace. Just erase the op. os << "rewriter.eraseOp(op0);\n"; } else { // Process replacement result patterns. os << "::llvm::SmallVector<::mlir::Value, 4> tblgen_repl_values;\n"; for (int i = replStartIndex; i < numResultPatterns; ++i) { DagNode resultTree = pattern.getResultPattern(i); auto val = handleResultPattern(resultTree, offsets[i], 0); os << "\n"; // Resolve each symbol for all range use so that we can loop over them. // We need an explicit cast to `SmallVector` to capture the cases where // `{0}` resolves to an `Operation::result_range` as well as cases that // are not iterable (e.g. vector that gets wrapped in additional braces by // RewriterGen). // TODO: Revisit the need for materializing a vector. os << symbolInfoMap.getAllRangeUse( val, "for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ {0} }) {{\n" " tblgen_repl_values.push_back(v);\n}\n", "\n"); } os << "\nrewriter.replaceOp(op0, tblgen_repl_values);\n"; } LLVM_DEBUG(llvm::dbgs() << "--- done emitting rewrite logic ---\n"); } std::string PatternEmitter::getUniqueSymbol(const Operator *op) { return std::string( formatv("tblgen_{0}_{1}", op->getCppClassName(), nextValueId++)); } std::string PatternEmitter::handleResultPattern(DagNode resultTree, int resultIndex, int depth) { LLVM_DEBUG(llvm::dbgs() << "handle result pattern: "); LLVM_DEBUG(resultTree.print(llvm::dbgs())); LLVM_DEBUG(llvm::dbgs() << '\n'); if (resultTree.isLocationDirective()) { PrintFatalError(loc, "location directive can only be used with op creation"); } if (resultTree.isNativeCodeCall()) { - auto symbol = handleReplaceWithNativeCodeCall(resultTree); + auto symbol = handleReplaceWithNativeCodeCall(resultTree, depth); symbolInfoMap.bindValue(symbol); return symbol; } if (resultTree.isReplaceWithValue()) return handleReplaceWithValue(resultTree).str(); // Normal op creation. auto symbol = handleOpCreation(resultTree, resultIndex, depth); if (resultTree.getSymbol().empty()) { // This is an op not explicitly bound to a symbol in the rewrite rule. // Register the auto-generated symbol for it. symbolInfoMap.bindOpResult(symbol, pattern.getDialectOp(resultTree)); } return symbol; } StringRef PatternEmitter::handleReplaceWithValue(DagNode tree) { assert(tree.isReplaceWithValue()); if (tree.getNumArgs() != 1) { PrintFatalError( loc, "replaceWithValue directive must take exactly one argument"); } if (!tree.getSymbol().empty()) { PrintFatalError(loc, "cannot bind symbol to replaceWithValue"); } return tree.getArgName(0); } std::string PatternEmitter::handleLocationDirective(DagNode tree) { assert(tree.isLocationDirective()); auto lookUpArgLoc = [this, &tree](int idx) { const auto *const lookupFmt = "(*{0}.begin()).getLoc()"; return symbolInfoMap.getAllRangeUse(tree.getArgName(idx), lookupFmt); }; if (tree.getNumArgs() == 0) llvm::PrintFatalError( "At least one argument to location directive required"); if (!tree.getSymbol().empty()) PrintFatalError(loc, "cannot bind symbol to location"); if (tree.getNumArgs() == 1) { DagLeaf leaf = tree.getArgAsLeaf(0); if (leaf.isStringAttr()) return formatv("::mlir::NameLoc::get(rewriter.getIdentifier(\"{0}\"), " "rewriter.getContext())", leaf.getStringAttr()) .str(); return lookUpArgLoc(0); } std::string ret; llvm::raw_string_ostream os(ret); std::string strAttr; os << "rewriter.getFusedLoc({"; bool first = true; for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { DagLeaf leaf = tree.getArgAsLeaf(i); // Handle the optional string value. if (leaf.isStringAttr()) { if (!strAttr.empty()) llvm::PrintFatalError("Only one string attribute may be specified"); strAttr = leaf.getStringAttr(); continue; } os << (first ? "" : ", ") << lookUpArgLoc(i); first = false; } os << "}"; if (!strAttr.empty()) { os << ", rewriter.getStringAttr(\"" << strAttr << "\")"; } os << ")"; return os.str(); } std::string PatternEmitter::handleOpArgument(DagLeaf leaf, StringRef patArgName) { if (leaf.isStringAttr()) PrintFatalError(loc, "raw string not supported as argument"); if (leaf.isConstantAttr()) { auto constAttr = leaf.getAsConstantAttr(); return handleConstantAttr(constAttr.getAttribute(), constAttr.getConstantValue()); } if (leaf.isEnumAttrCase()) { auto enumCase = leaf.getAsEnumAttrCase(); if (enumCase.isStrCase()) return handleConstantAttr(enumCase, enumCase.getSymbol()); // This is an enum case backed by an IntegerAttr. We need to get its value // to build the constant. std::string val = std::to_string(enumCase.getValue()); return handleConstantAttr(enumCase, val); } LLVM_DEBUG(llvm::dbgs() << "handle argument '" << patArgName << "'\n"); auto argName = symbolInfoMap.getValueAndRangeUse(patArgName); if (leaf.isUnspecified() || leaf.isOperandMatcher()) { LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << argName << "' (via symbol ref)\n"); return argName; } if (leaf.isNativeCodeCall()) { auto repl = tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(argName)); LLVM_DEBUG(llvm::dbgs() << "replace " << patArgName << " with '" << repl << "' (via NativeCodeCall)\n"); return std::string(repl); } PrintFatalError(loc, "unhandled case when rewriting op"); } -std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree) { +std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree, + int depth) { LLVM_DEBUG(llvm::dbgs() << "handle NativeCodeCall pattern: "); LLVM_DEBUG(tree.print(llvm::dbgs())); LLVM_DEBUG(llvm::dbgs() << '\n'); auto fmt = tree.getNativeCodeTemplate(); // TODO: replace formatv arguments with the exact specified args. SmallVector attrs(8); if (tree.getNumArgs() > 8) { - PrintFatalError(loc, "unsupported NativeCodeCall argument numbers: " + - Twine(tree.getNumArgs())); + PrintFatalError(loc, + "unsupported NativeCodeCall replace argument numbers: " + + Twine(tree.getNumArgs())); } bool hasLocationDirective; std::string locToUse; std::tie(hasLocationDirective, locToUse) = getLocation(tree); for (int i = 0, e = tree.getNumArgs() - hasLocationDirective; i != e; ++i) { - attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i)); + if (tree.isNestedDagArg(i)) { + attrs[i] = handleResultPattern(tree.getArgAsNestedDag(i), i, depth + 1); + } else { + attrs[i] = handleOpArgument(tree.getArgAsLeaf(i), tree.getArgName(i)); + } LLVM_DEBUG(llvm::dbgs() << "NativeCodeCall argument #" << i << " replacement: " << attrs[i] << "\n"); } return std::string(tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse), attrs[0], attrs[1], attrs[2], attrs[3], attrs[4], attrs[5], attrs[6], attrs[7])); } int PatternEmitter::getNodeValueCount(DagNode node) { if (node.isOperation()) { // If the op is bound to a symbol in the rewrite rule, query its result // count from the symbol info map. auto symbol = node.getSymbol(); if (!symbol.empty()) { return symbolInfoMap.getStaticValueCount(symbol); } // Otherwise this is an unbound op; we will use all its results. return pattern.getDialectOp(node).getNumResults(); } // TODO: This considers all NativeCodeCall as returning one // value. Enhance if multi-value ones are needed. return 1; } std::pair PatternEmitter::getLocation(DagNode tree) { auto numPatArgs = tree.getNumArgs(); if (numPatArgs != 0) { if (auto lastArg = tree.getArgAsNestedDag(numPatArgs - 1)) if (lastArg.isLocationDirective()) { return std::make_pair(true, handleLocationDirective(lastArg)); } } // If no explicit location is given, use the default, all fused, location. return std::make_pair(false, "odsLoc"); } std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex, int depth) { LLVM_DEBUG(llvm::dbgs() << "create op for pattern: "); LLVM_DEBUG(tree.print(llvm::dbgs())); LLVM_DEBUG(llvm::dbgs() << '\n'); Operator &resultOp = tree.getDialectOp(opMap); auto numOpArgs = resultOp.getNumArgs(); auto numPatArgs = tree.getNumArgs(); bool hasLocationDirective; std::string locToUse; std::tie(hasLocationDirective, locToUse) = getLocation(tree); auto inPattern = numPatArgs - hasLocationDirective; if (numOpArgs != inPattern) { PrintFatalError(loc, formatv("resultant op '{0}' argument number mismatch: " "{1} in pattern vs. {2} in definition", resultOp.getOperationName(), inPattern, numOpArgs)); } // A map to collect all nested DAG child nodes' names, with operand index as // the key. This includes both bound and unbound child nodes. ChildNodeIndexNameMap childNodeNames; // First go through all the child nodes who are nested DAG constructs to // create ops for them and remember the symbol names for them, so that we can // use the results in the current node. This happens in a recursive manner. for (int i = 0, e = resultOp.getNumOperands(); i != e; ++i) { if (auto child = tree.getArgAsNestedDag(i)) childNodeNames[i] = handleResultPattern(child, i, depth + 1); } // The name of the local variable holding this op. std::string valuePackName; // The symbol for holding the result of this pattern. Note that the result of // this pattern is not necessarily the same as the variable created by this // pattern because we can use `__N` suffix to refer only a specific result if // the generated op is a multi-result op. std::string resultValue; if (tree.getSymbol().empty()) { // No symbol is explicitly bound to this op in the pattern. Generate a // unique name. valuePackName = resultValue = getUniqueSymbol(&resultOp); } else { resultValue = std::string(tree.getSymbol()); // Strip the index to get the name for the value pack and use it to name the // local variable for the op. valuePackName = std::string(SymbolInfoMap::getValuePackName(resultValue)); } // Create the local variable for this op. os << formatv("{0} {1};\n{{\n", resultOp.getQualCppClassName(), valuePackName); // Right now ODS don't have general type inference support. Except a few // special cases listed below, DRR needs to supply types for all results // when building an op. bool isSameOperandsAndResultType = resultOp.getTrait("::mlir::OpTrait::SameOperandsAndResultType"); bool useFirstAttr = resultOp.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType"); if (isSameOperandsAndResultType || useFirstAttr) { // We know how to deduce the result type for ops with these traits and we've // generated builders taking aggregate parameters. Use those builders to // create the ops. // First prepare local variables for op arguments used in builder call. - createAggregateLocalVarsForOpArgs(tree, childNodeNames); + createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth); // Then create the op. os.scope("", "\n}\n").os << formatv( "{0} = rewriter.create<{1}>({2}, tblgen_values, tblgen_attrs);", valuePackName, resultOp.getQualCppClassName(), locToUse); return resultValue; } bool usePartialResults = valuePackName != resultValue; if (usePartialResults || depth > 0 || resultIndex < 0) { // For these cases (broadcastable ops, op results used both as auxiliary // values and replacement values, ops in nested patterns, auxiliary ops), we // still need to supply the result types when building the op. But because // we don't generate a builder automatically with ODS for them, it's the // developer's responsibility to make sure such a builder (with result type // deduction ability) exists. We go through the separate-parameter builder // here given that it's easier for developers to write compared to // aggregate-parameter builders. createSeparateLocalVarsForOpArgs(tree, childNodeNames); os.scope().os << formatv("{0} = rewriter.create<{1}>({2}", valuePackName, resultOp.getQualCppClassName(), locToUse); - supplyValuesForOpArgs(tree, childNodeNames); + supplyValuesForOpArgs(tree, childNodeNames, depth); os << "\n );\n}\n"; return resultValue; } // If depth == 0 and resultIndex >= 0, it means we are replacing the values // generated from the source pattern root op. Then we can use the source // pattern's value types to determine the value type of the generated op // here. // First prepare local variables for op arguments used in builder call. - createAggregateLocalVarsForOpArgs(tree, childNodeNames); + createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth); // Then prepare the result types. We need to specify the types for all // results. os.indent() << formatv("::mlir::SmallVector<::mlir::Type, 4> tblgen_types; " "(void)tblgen_types;\n"); int numResults = resultOp.getNumResults(); if (numResults != 0) { for (int i = 0; i < numResults; ++i) os << formatv("for (auto v: castedOp0.getODSResults({0})) {{\n" " tblgen_types.push_back(v.getType());\n}\n", resultIndex + i); } os << formatv("{0} = rewriter.create<{1}>({2}, tblgen_types, " "tblgen_values, tblgen_attrs);\n", valuePackName, resultOp.getQualCppClassName(), locToUse); os.unindent() << "}\n"; return resultValue; } void PatternEmitter::createSeparateLocalVarsForOpArgs( DagNode node, ChildNodeIndexNameMap &childNodeNames) { Operator &resultOp = node.getDialectOp(opMap); // Now prepare operands used for building this op: // * If the operand is non-variadic, we create a `Value` local variable. // * If the operand is variadic, we create a `SmallVector` local // variable. int valueIndex = 0; // An index for uniquing local variable names. for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) { const auto *operand = resultOp.getArg(argIndex).dyn_cast(); // We do not need special handling for attributes. if (!operand) continue; raw_indented_ostream::DelimitedScope scope(os); std::string varName; if (operand->isVariadic()) { varName = std::string(formatv("tblgen_values_{0}", valueIndex++)); os << formatv("::mlir::SmallVector<::mlir::Value, 4> {0};\n", varName); std::string range; if (node.isNestedDagArg(argIndex)) { range = childNodeNames[argIndex]; } else { range = std::string(node.getArgName(argIndex)); } // Resolve the symbol for all range use so that we have a uniform way of // capturing the values. range = symbolInfoMap.getValueAndRangeUse(range); os << formatv("for (auto v: {0}) {{\n {1}.push_back(v);\n}\n", range, varName); } else { varName = std::string(formatv("tblgen_value_{0}", valueIndex++)); os << formatv("::mlir::Value {0} = ", varName); if (node.isNestedDagArg(argIndex)) { os << symbolInfoMap.getValueAndRangeUse(childNodeNames[argIndex]); } else { DagLeaf leaf = node.getArgAsLeaf(argIndex); auto symbol = symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex)); if (leaf.isNativeCodeCall()) { os << std::string( tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol))); } else { os << symbol; } } os << ";\n"; } // Update to use the newly created local variable for building the op later. childNodeNames[argIndex] = varName; } } void PatternEmitter::supplyValuesForOpArgs( - DagNode node, const ChildNodeIndexNameMap &childNodeNames) { + DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) { Operator &resultOp = node.getDialectOp(opMap); for (int argIndex = 0, numOpArgs = resultOp.getNumArgs(); argIndex != numOpArgs; ++argIndex) { // Start each argument on its own line. os << ",\n "; Argument opArg = resultOp.getArg(argIndex); // Handle the case of operand first. if (auto *operand = opArg.dyn_cast()) { if (!operand->name.empty()) os << "/*" << operand->name << "=*/"; os << childNodeNames.lookup(argIndex); continue; } // The argument in the op definition. auto opArgName = resultOp.getArgName(argIndex); if (auto subTree = node.getArgAsNestedDag(argIndex)) { if (!subTree.isNativeCodeCall()) PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node " "for creating attribute"); os << formatv("/*{0}=*/{1}", opArgName, - handleReplaceWithNativeCodeCall(subTree)); + handleReplaceWithNativeCodeCall(subTree, depth)); } else { auto leaf = node.getArgAsLeaf(argIndex); // The argument in the result DAG pattern. auto patArgName = node.getArgName(argIndex); if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) { // TODO: Refactor out into map to avoid recomputing these. if (!opArg.is()) PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex)); if (!patArgName.empty()) os << "/*" << patArgName << "=*/"; } else { os << "/*" << opArgName << "=*/"; } os << handleOpArgument(leaf, patArgName); } } } void PatternEmitter::createAggregateLocalVarsForOpArgs( - DagNode node, const ChildNodeIndexNameMap &childNodeNames) { + DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) { Operator &resultOp = node.getDialectOp(opMap); auto scope = os.scope(); os << formatv("::mlir::SmallVector<::mlir::Value, 4> " "tblgen_values; (void)tblgen_values;\n"); os << formatv("::mlir::SmallVector<::mlir::NamedAttribute, 4> " "tblgen_attrs; (void)tblgen_attrs;\n"); const char *addAttrCmd = "if (auto tmpAttr = {1}) {\n" " tblgen_attrs.emplace_back(rewriter.getIdentifier(\"{0}\"), " "tmpAttr);\n}\n"; for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) { if (resultOp.getArg(argIndex).is()) { // The argument in the op definition. auto opArgName = resultOp.getArgName(argIndex); if (auto subTree = node.getArgAsNestedDag(argIndex)) { if (!subTree.isNativeCodeCall()) PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node " "for creating attribute"); os << formatv(addAttrCmd, opArgName, - handleReplaceWithNativeCodeCall(subTree)); + handleReplaceWithNativeCodeCall(subTree, depth + 1)); } else { auto leaf = node.getArgAsLeaf(argIndex); // The argument in the result DAG pattern. auto patArgName = node.getArgName(argIndex); os << formatv(addAttrCmd, opArgName, handleOpArgument(leaf, patArgName)); } continue; } const auto *operand = resultOp.getArg(argIndex).get(); std::string varName; if (operand->isVariadic()) { std::string range; if (node.isNestedDagArg(argIndex)) { range = childNodeNames.lookup(argIndex); } else { range = std::string(node.getArgName(argIndex)); } // Resolve the symbol for all range use so that we have a uniform way of // capturing the values. range = symbolInfoMap.getValueAndRangeUse(range); os << formatv("for (auto v: {0}) {{\n tblgen_values.push_back(v);\n}\n", range); } else { os << formatv("tblgen_values.push_back("); if (node.isNestedDagArg(argIndex)) { os << symbolInfoMap.getValueAndRangeUse( childNodeNames.lookup(argIndex)); } else { DagLeaf leaf = node.getArgAsLeaf(argIndex); auto symbol = symbolInfoMap.getValueAndRangeUse(node.getArgName(argIndex)); if (leaf.isNativeCodeCall()) { os << std::string( tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol))); } else { os << symbol; } } os << ");\n"; } } } static void emitRewriters(const RecordKeeper &recordKeeper, raw_ostream &os) { emitSourceFileHeader("Rewriters", os); const auto &patterns = recordKeeper.getAllDerivedDefinitions("Pattern"); auto numPatterns = patterns.size(); // We put the map here because it can be shared among multiple patterns. RecordOperatorMap recordOpMap; std::vector rewriterNames; rewriterNames.reserve(numPatterns); std::string baseRewriterName = "GeneratedConvert"; int rewriterIndex = 0; for (Record *p : patterns) { std::string name; if (p->isAnonymous()) { // If no name is provided, ensure unique rewriter names simply by // appending unique suffix. name = baseRewriterName + llvm::utostr(rewriterIndex++); } else { name = std::string(p->getName()); } LLVM_DEBUG(llvm::dbgs() << "=== start generating pattern '" << name << "' ===\n"); PatternEmitter(p, &recordOpMap, os).emit(name); LLVM_DEBUG(llvm::dbgs() << "=== done generating pattern '" << name << "' ===\n"); rewriterNames.push_back(std::move(name)); } // Emit function to add the generated matchers to the pattern list. os << "void LLVM_ATTRIBUTE_UNUSED populateWithGenerated(::mlir::MLIRContext " "*context, ::mlir::OwningRewritePatternList &patterns) {\n"; for (const auto &name : rewriterNames) { os << " patterns.insert<" << name << ">(context);\n"; } os << "}\n"; } static mlir::GenRegistration genRewriters("gen-rewriters", "Generate pattern rewriters", [](const RecordKeeper &records, raw_ostream &os) { emitRewriters(records, os); return false; });