diff --git a/mlir/include/mlir/IR/AttrTypeBase.td b/mlir/include/mlir/IR/AttrTypeBase.td --- a/mlir/include/mlir/IR/AttrTypeBase.td +++ b/mlir/include/mlir/IR/AttrTypeBase.td @@ -14,7 +14,1605 @@ #ifndef ATTRTYPEBASE_TD #define ATTRTYPEBASE_TD -include "mlir/IR/OpBase.td" +include "mlir/IR/Constraints.td" +include "mlir/IR/DialectBase.td" +include "mlir/IR/Properties.td" +include "mlir/IR/Traits.td" + +//===----------------------------------------------------------------------===// +// Type definitions +//===----------------------------------------------------------------------===// + +// A type, carries type constraints. +class Type : + TypeConstraint { + string description = ""; + string builderCall = ""; +} + +// Allows providing an alternative name and summary to an existing type def. +class TypeAlias : + Type { + let description = t.description; + let builderCall = t.builderCall; +} + +// A type of a specific dialect. +class DialectType : + Type { + Dialect dialect = d; +} + +// A variadic type constraint. It expands to zero or more of the base type. This +// class is used for supporting variadic operands/results. +class Variadic : TypeConstraint { + Type baseType = type; + int minSize = 0; +} + +// A nested variadic type constraint. It expands to zero or more variadic ranges +// of the base type. This class is used for supporting variadic operands and +// results. `variadicSegmentAttrName` should correspond to the name of an +// DenseI32ArrayAttr argument that provides the sizes of the inner variadic +// operand groups. +class VariadicOfVariadic + : Variadic { + string segmentAttrName = variadicSegmentAttrName; +} + +// An optional type constraint. It expands to either zero or one of the base +// type. This class is used for supporting optional operands/results. +class Optional : TypeConstraint { + Type baseType = type; +} + +// A type that can be constructed using MLIR::Builder. +// Note that this does not "inherit" from Type because it would require +// duplicating Type subclasses for buildable and non-buildable cases to avoid +// diamond "inheritance". +// TODO: we may extend this to a more general 'Buildable' trait, making some +// Types and some Attrs buildable. +class BuildableType { + // The builder call to invoke (if specified) to construct the BuildableType. + code builderCall = builder; +} + +// A type that's buildable iff the type passed as an argument is buildable. +// This is intended for use by types like container types, which are only +// buildable if the type of their elements is buildable. +class SameBuildabilityAs { + code builderCall = !if(!empty(type.builderCall), "", builder); +} + +// Any type at all. +def AnyType : Type, "any type">; + +// None type +def NoneType : Type($_self)">, "none type", + "::mlir::NoneType">, + BuildableType<"$_builder.getType<::mlir::NoneType>()">; + +// Any type from the given list +class AnyTypeOf allowedTypes, string summary = "", + string cppClassName = "::mlir::Type"> : Type< + // Satisfy any of the allowed types' conditions. + Or, + !if(!eq(summary, ""), + !interleave(!foreach(t, allowedTypes, t.summary), " or "), + summary), + cppClassName>; + +// A type that satisfies the constraints of all given types. +class AllOfType allowedTypes, string summary = "", + string cppClassName = "::mlir::Type"> : Type< + // Satisfy all of the allowedf types' conditions. + And, + !if(!eq(summary, ""), + !interleave(!foreach(t, allowedTypes, t.summary), " and "), + summary), + cppClassName>; + +// A type that satisfies additional predicates. +class ConfinedType predicates, string summary = "", + string cppClassName = type.cppClassName> : Type< + And, + summary, cppClassName>; + +// Integer types. + +// Any integer type irrespective of its width and signedness semantics. +def AnyInteger : Type($_self)">, "integer", + "::mlir::IntegerType">; + +// Any integer type (regardless of signedness semantics) of a specific width. +class AnyI + : Type, width # "-bit integer"> { + int bitwidth = width; +} + +class AnyIntOfWidths widths> : + AnyTypeOf), + !interleave(widths, "/") # "-bit integer", + "::mlir::IntegerType">; + +def AnyI1 : AnyI<1>; +def AnyI8 : AnyI<8>; +def AnyI16 : AnyI<16>; +def AnyI32 : AnyI<32>; +def AnyI64 : AnyI<64>; + +// Any signless integer type irrespective of its width. +def AnySignlessInteger : Type< + CPred<"$_self.isSignlessInteger()">, "signless integer", + "::mlir::IntegerType">; + +// Signless integer type of a specific width. +class I + : Type, + width # "-bit signless integer", "::mlir::IntegerType">, + BuildableType<"$_builder.getIntegerType(" # width # ")"> { + int bitwidth = width; +} + +class SignlessIntOfWidths widths> : + AnyTypeOf), + !interleave(widths, "/") # "-bit signless integer">; + +def I1 : I<1>; +def I8 : I<8>; +def I16 : I<16>; +def I32 : I<32>; +def I64 : I<64>; +def I128 : I<128>; + +// Any signed integer type irrespective of its width. +def AnySignedInteger : Type< + CPred<"$_self.isSignedInteger()">, "signed integer">; + +// Signed integer type of a specific width. +class SI + : Type, + width # "-bit signed integer", "::mlir::IntegerType">, + BuildableType< + "$_builder.getIntegerType(" # width # ", /*isSigned=*/true)"> { + int bitwidth = width; +} + +class SignedIntOfWidths widths> : + AnyTypeOf), + !interleave(widths, "/") # "-bit signed integer">; + +def SI1 : SI<1>; +def SI8 : SI<8>; +def SI16 : SI<16>; +def SI32 : SI<32>; +def SI64 : SI<64>; + +// Any unsigned integer type irrespective of its width. +def AnyUnsignedInteger : Type< + CPred<"$_self.isUnsignedInteger()">, "unsigned integer">; + +// Unsigned integer type of a specific width. +class UI + : Type, + width # "-bit unsigned integer", "::mlir::IntegerType">, + BuildableType< + "$_builder.getIntegerType(" # width # ", /*isSigned=*/false)"> { + int bitwidth = width; +} + +class UnsignedIntOfWidths widths> : + AnyTypeOf), + !interleave(widths, "/") # "-bit unsigned integer">; + +def UI1 : UI<1>; +def UI8 : UI<8>; +def UI16 : UI<16>; +def UI32 : UI<32>; +def UI64 : UI<64>; + +// Index type. +def Index : Type($_self)">, "index", + "::mlir::IndexType">, + BuildableType<"$_builder.getIndexType()">; + +// Any signless integer type or index type. +def AnySignlessIntegerOrIndex : Type, + "signless integer or index">; + +// Floating point types. + +// Any float type irrespective of its width. +def AnyFloat : Type($_self)">, "floating-point", + "::mlir::FloatType">; + +// Float type of a specific width. +class F + : Type, + width # "-bit float", "::mlir::FloatType">, + BuildableType<"$_builder.getF" # width # "Type()"> { + int bitwidth = width; +} + +class FloatOfWidths widths> : + AnyTypeOf), + !interleave(widths, "/") # "-bit float">; + +def F16 : F<16>; +def F32 : F<32>; +def F64 : F<64>; +def F80 : F<80>; +def F128 : F<128>; + +def BF16 : Type, "bfloat16 type">, + BuildableType<"$_builder.getBF16Type()">; +def TF32 : Type, "tf32 type">, + BuildableType<"$_builder.getTF32Type()">; +def F8E4M3FN : Type, "f8E4M3FN type">, + BuildableType<"$_builder.getFloat8E4M3FNType()">; +def F8E5M2 : Type, "f8E5M2 type">, + BuildableType<"$_builder.getFloat8E5M2Type()">; +def F8E4M3FNUZ : Type, "f8E4M3FNUZ type">, + BuildableType<"$_builder.getFloat8E4M3FNUZType()">; +def F8E4M3B11FNUZ : Type, "f8E4M3B11FNUZ type">, + BuildableType<"$_builder.getFloat8E4M3B11FNUZType()">; +def F8E5M2FNUZ : Type, "f8E5M2FNUZ type">, + BuildableType<"$_builder.getFloat8E5M2FNUZType()">; + +def AnyComplex : Type($_self)">, + "complex-type", "::mlir::ComplexType">; + +class Complex + : ConfinedType($_self).getElementType()", + type.predicate>], + "complex type with " # type.summary # " elements", + "::mlir::ComplexType">, + SameBuildabilityAs { + Type elementType = type; +} + +class OpaqueType + : Type, + summary, "::mlir::OpaqueType">, + BuildableType<"::mlir::OpaqueType::get(" + "$_builder.getStringAttr(\"" # dialect # "\"), \"" + # name # "\")">; + +// Function Type + +// Any function type. +def FunctionType : Type($_self)">, + "function type", "::mlir::FunctionType">; + +// A container type is a type that has another type embedded within it. +class ContainerType : + // First, check the container predicate. Then, substitute the extracted + // element into the element type checker. + Type(elementTypeCall), + etype.predicate>]>, + descr # " of " # etype.summary # " values", cppClassName>; + +class ShapedContainerType allowedTypes, + Pred containerPred, string descr, + string cppClassName = "::mlir::Type"> : + Type.predicate>, + "; }(::llvm::cast<::mlir::ShapedType>($_self).getElementType())">]>, + descr # " of " # AnyTypeOf.summary # " values", cppClassName>; + +// Whether a shaped type is ranked. +def HasRankPred : CPred<"::llvm::cast<::mlir::ShapedType>($_self).hasRank()">; + +// Whether a shaped type has one of the specified ranks. +class HasAnyRankOfPred ranks> : And<[ + HasRankPred, + Or($_self).getRank() + == }] + # rank>)>]>; + +// Whether a shaped type has a rank greater than or equal of the specified rank. +class HasRankGreaterOrEqualPred : And<[ + HasRankPred, + CPred<[{::llvm::cast<::mlir::ShapedType>($_self).getRank() >= }] # rank> +]>; + +// Vector types. + +class VectorOf allowedTypes> : + ShapedContainerType; + +// Temporary vector type clone that allows gradual transition to 0-D vectors. +// TODO: Remove this when all ops support 0-D vectors. +class VectorOfAnyRankOf allowedTypes> : + ShapedContainerType; + +class FixedVectorOf allowedTypes> : + ShapedContainerType; + +class ScalableVectorOf allowedTypes> : + ShapedContainerType; + +// Whether the number of elements of a vector is from the given +// `allowedRanks` list +class IsVectorOfRankPred allowedRanks> : + And<[IsVectorTypePred, + Or($_self).getRank() + == }] + # allowedlength>)>]>; + +// Whether the number of elements of a fixed-length vector is from the given +// `allowedRanks` list +class IsFixedVectorOfRankPred allowedRanks> : + And<[IsFixedVectorTypePred, + Or($_self).getRank() + == }] + # allowedlength>)>]>; + +// Whether the number of elements of a scalable vector is from the given +// `allowedRanks` list +class IsScalableVectorOfRankPred allowedRanks> : + And<[IsScalableVectorTypePred, + Or($_self).getRank() + == }] + # allowedlength>)>]>; + +// Any vector where the rank is from the given `allowedRanks` list +class VectorOfRank allowedRanks> : Type< + IsVectorOfRankPred, + " of ranks " # !interleave(allowedRanks, "/"), "::mlir::VectorType">; + +// Any fixed-length vector where the rank is from the given `allowedRanks` list +class FixedVectorOfRank allowedRanks> : Type< + IsFixedVectorOfRankPred, + " of ranks " # !interleave(allowedRanks, "/"), "::mlir::VectorType">; + +// Any scalable vector where the rank is from the given `allowedRanks` list +class ScalableVectorOfRank allowedRanks> : Type< + IsScalableVectorOfRankPred, + " of ranks " # !interleave(allowedRanks, "/"), "::mlir::VectorType">; + +// Any vector where the rank is from the given `allowedRanks` list and the type +// is from the given `allowedTypes` list +class VectorOfRankAndType allowedRanks, + list allowedTypes> : AllOfType< + [VectorOf, VectorOfRank], + VectorOf.summary # VectorOfRank.summary, + "::mlir::VectorType">; + +// Whether the number of elements of a vector is from the given +// `allowedLengths` list +class IsVectorOfLengthPred allowedLengths> : + And<[IsVectorTypePred, + Or($_self).getNumElements() + == }] + # allowedlength>)>]>; + +// Whether the number of elements of a fixed-length vector is from the given +// `allowedLengths` list +class IsFixedVectorOfLengthPred allowedLengths> : + And<[IsFixedVectorTypePred, + Or($_self).getNumElements() + == }] + # allowedlength>)>]>; + +// Whether the number of elements of a scalable vector is from the given +// `allowedLengths` list +class IsScalableVectorOfLengthPred allowedLengths> : + And<[IsScalableVectorTypePred, + Or($_self).getNumElements() + == }] + # allowedlength>)>]>; + +// Whether the shape of a vector matches the given `shape` list. +class IsVectorOfShape shape> + : CPred<"::llvm::cast<::mlir::VectorType>($_self).getShape() == ArrayRef({" # !interleave(shape, ", ") # "})">; + +// Any vector where the number of elements is from the given +// `allowedLengths` list +class VectorOfLength allowedLengths> : Type< + IsVectorOfLengthPred, + " of length " # !interleave(allowedLengths, "/"), + "::mlir::VectorType">; + +// Any fixed-length vector where the number of elements is from the given +// `allowedLengths` list +class FixedVectorOfLength allowedLengths> : Type< + IsFixedVectorOfLengthPred, + " of length " # !interleave(allowedLengths, "/"), + "::mlir::VectorType">; + +// Any scalable vector where the number of elements is from the given +// `allowedLengths` list +class ScalableVectorOfLength allowedLengths> : Type< + IsScalableVectorOfLengthPred, + " of length " # !interleave(allowedLengths, "/"), + "::mlir::VectorType">; + +// Any vector where the number of elements is from the given +// `allowedLengths` list and the type is from the given `allowedTypes` +// list +class VectorOfLengthAndType allowedLengths, + list allowedTypes> : AllOfType< + [VectorOf, VectorOfLength], + VectorOf.summary # VectorOfLength.summary, + "::mlir::VectorType">; + +// Any fixed-length vector where the number of elements is from the given +// `allowedLengths` list and the type is from the given `allowedTypes` list +class FixedVectorOfLengthAndType allowedLengths, + list allowedTypes> : AllOfType< + [FixedVectorOf, FixedVectorOfLength], + FixedVectorOf.summary # + FixedVectorOfLength.summary, + "::mlir::VectorType">; + +// Any scalable vector where the number of elements is from the given +// `allowedLengths` list and the type is from the given `allowedTypes` list +class ScalableVectorOfLengthAndType allowedLengths, + list allowedTypes> : AllOfType< + [ScalableVectorOf, ScalableVectorOfLength], + ScalableVectorOf.summary # + ScalableVectorOfLength.summary, + "::mlir::VectorType">; + +def AnyVector : VectorOf<[AnyType]>; +// Temporary vector type clone that allows gradual transition to 0-D vectors. +def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>; + +def AnyFixedVector : FixedVectorOf<[AnyType]>; + +def AnyScalableVector : ScalableVectorOf<[AnyType]>; + +// Shaped types. + +def AnyShaped: ShapedContainerType<[AnyType], IsShapedTypePred, "shaped", + "::mlir::ShapedType">; + +//===----------------------------------------------------------------------===// +// Tensor types. + +// Unranked tensor type whose element type is from the given `allowedTypes` +// list, and which additionally satisfies an optional list of predicates. +class UnrankedTensorOf allowedTypes, list preds = [], + string summary = "unranked tensor"> + : ShapedContainerType< + allowedTypes, And, + summary, "::mlir::UnrankedTensorType">; + +// Ranked tensor type whose element type is from the given `allowedTypes` list, +// and which additionally satisfies an optional list of predicates. +class RankedTensorOf allowedTypes, list preds = [], + string summary = "ranked tensor"> + : ShapedContainerType< + allowedTypes, And, + summary, "::mlir::RankedTensorType">; + +// Any tensor type whose element type is from the given `allowedTypes` +// list, and which additionally satisfies an optional list of predicates. +// +// TODO: use `Constraint` instead of `Pred`, so we can generate a better +// default summary (a la `ConfinedAttr`). +class TensorOf< + list allowedTypes, + list preds = [], + string summary = "tensor"> + : ShapedContainerType, + summary, "::mlir::TensorType">; + +def AnyTensor : TensorOf<[AnyType]>; + +def I1Tensor : TensorOf<[I1]>; +def I8Tensor : TensorOf<[I8]>; +def I16Tensor : TensorOf<[I16]>; +def I32Tensor : TensorOf<[I32]>; +def I64Tensor : TensorOf<[I64]>; +def IndexTensor: TensorOf<[Index]>; + +def BF16Tensor : TensorOf<[BF16]>; +def F16Tensor : TensorOf<[F16]>; +def F32Tensor : TensorOf<[F32]>; +def F64Tensor : TensorOf<[F64]>; + +class Non0RankedTensorOf allowedTypes> + : TensorOf], + "non-0-ranked.tensor">; + +def AnyRankedTensor : RankedTensorOf<[AnyType]>; +def AnyNon0RankedTensor : Non0RankedTensorOf<[AnyType]>; +def AnyUnrankedTensor : UnrankedTensorOf<[AnyType]>; + +def AnyNon0RankedOrUnrankedTensor + : AnyTypeOf<[AnyUnrankedTensor, AnyNon0RankedTensor], + "non-0-ranked or unranked tensor", "::mlir::TensorType">; + +// Ranked tensor type with one of the specified types and ranks. +class TensorRankOf allowedTypes, list ranks> + : RankedTensorOf], + !interleave(!foreach(rank, ranks, rank # "D"), "/") # " tensor">; + +class 0DTensorOf allowedTypes> : TensorRankOf; +class 1DTensorOf allowedTypes> : TensorRankOf; +class 2DTensorOf allowedTypes> : TensorRankOf; +class 3DTensorOf allowedTypes> : TensorRankOf; +class 4DTensorOf allowedTypes> : TensorRankOf; + +class StaticShapeTensorOf allowedTypes> + : RankedTensorOf; + +def AnyStaticShapeTensor : StaticShapeTensorOf<[AnyType]>; + +//===----------------------------------------------------------------------===// +// Memref type. + +// Any unranked memref whose element type is from the given `allowedTypes` list. +class UnrankedMemRefOf allowedTypes> : + ShapedContainerType; + +def AnyUnrankedMemRef : UnrankedMemRefOf<[AnyType]>; + +// Any ranked memref whose element type is from the given `allowedTypes` list. +class MemRefOf allowedTypes> : + ShapedContainerType; + +class Non0RankedMemRefOf allowedTypes> : + ConfinedType, [HasRankGreaterOrEqualPred<1>], + "non-0-ranked." # MemRefOf.summary, + "::mlir::MemRefType">; + +def AnyMemRef : MemRefOf<[AnyType]>; +def AnyNon0RankedMemRef : Non0RankedMemRefOf<[AnyType]>; + +// Any memref (ranked or unranked) whose element type is from the given +// `allowedTypes` list, and which additionally satisfies an optional list of +// predicates. +class RankedOrUnrankedMemRefOf< + list allowedTypes, + list preds = [], + string summary = "ranked or unranked memref"> + : ShapedContainerType, + summary, "::mlir::BaseMemRefType">; + +def AnyRankedOrUnrankedMemRef : RankedOrUnrankedMemRefOf<[AnyType]>; +def AnyNon0RankedOrUnrankedMemRef: + AnyTypeOf<[AnyUnrankedMemRef, AnyNon0RankedMemRef]>; + +// Memref declarations handle any memref, independent of rank, size, (static or +// dynamic), layout, or memory space. +def I1MemRef : MemRefOf<[I1]>; +def I8MemRef : MemRefOf<[I8]>; +def I16MemRef : MemRefOf<[I16]>; +def I32MemRef : MemRefOf<[I32]>; +def I64MemRef : MemRefOf<[I64]>; + +def BF16MemRef : MemRefOf<[BF16]>; +def F16MemRef : MemRefOf<[F16]>; +def F32MemRef : MemRefOf<[F32]>; +def F64MemRef : MemRefOf<[F64]>; + +// TODO: Have an easy way to add another constraint to a type. +class MemRefRankOf allowedTypes, list ranks> : + ConfinedType, [HasAnyRankOfPred], + !interleave(!foreach(rank, ranks, rank # "D"), "/") # " " # + MemRefOf.summary, + "::mlir::MemRefType">; + +class StaticShapeMemRefOf allowedTypes> : + ConfinedType, [HasStaticShapePred], + "statically shaped " # MemRefOf.summary, + "::mlir::MemRefType">; + +def AnyStaticShapeMemRef : StaticShapeMemRefOf<[AnyType]>; + +// For a MemRefType, verify that it has strides. +def HasStridesPred : CPred<[{ isStrided(::llvm::cast<::mlir::MemRefType>($_self)) }]>; + +class StridedMemRefOf allowedTypes> : + ConfinedType, [HasStridesPred], + "strided " # MemRefOf.summary>; + +def AnyStridedMemRef : StridedMemRefOf<[AnyType]>; + +class AnyStridedMemRefOfRank : + AllOfType<[AnyStridedMemRef, MemRefRankOf<[AnyType], [rank]>], + AnyStridedMemRef.summary # " of rank " # rank>; + +class StridedMemRefRankOf allowedTypes, list ranks> : + ConfinedType, [HasAnyRankOfPred], + !interleave(!foreach(rank, ranks, rank # "D"), "/") # " " # + MemRefOf.summary>; + +// This represents a generic tuple without any constraints on element type. +def AnyTuple : Type; + +// A container type that has other types embedded in it, but (unlike +// ContainerType) can hold elements with a mix of types. Requires a call that +// produces a list of all elements' types. +class MixedContainerType : + Type< + And<[ + containerPred, + Concat< + "::llvm::all_of(" # elementTypesCall # ", [](::mlir::Type t) { " + "return t && (", + SubstLeaves<"$_self", "t", etype.predicate>, + "); })" + > + ]>, + descr # " with any combination of " # etype.summary # " values"> { + // The type of elements in the container. + Type elementType = etype; + + // Call to retrieve. + code getElementTypesCall = elementTypesCall; +} + +// A Tuple that holds a mix of elements of the allowed types. +class TupleOf allowedTypes> + : MixedContainerType, IsTupleTypePred, + "::llvm::cast<::mlir::TupleType>($_self).getTypes()", + "tuple">; + +// A Tuple with arbitrary nesting, where all elements are a mix of the allowed +// types. +class NestedTupleOf allowedTypes> : + MixedContainerType, IsTupleTypePred, + "getFlattenedTypes(::llvm::cast<::mlir::TupleType>($_self))", + "nested tuple">; + +//===----------------------------------------------------------------------===// +// Common type constraints +//===----------------------------------------------------------------------===// +// Type constraint for types that are "like" some type or set of types T, that is +// they're either a T, a vector of Ts, or a tensor of Ts +class TypeOrContainer : TypeConstraint.predicate, + TensorOf<[allowedType]>.predicate]>, + name>; + +// Temporary constraint to allow gradual transition to supporting 0-D vectors. +// TODO: Remove this when all ops support 0-D vectors. +class TypeOrContainerOfAnyRank : TypeConstraint.predicate, + TensorOf<[allowedType]>.predicate]>, + name>; + + +// Type constraint for bool-like types: bools, vectors of bools, tensors of +// bools. +def BoolLike : TypeOrContainer; + +def BoolLikeOfAnyRank : TypeOrContainerOfAnyRank; + +// Type constraint for signless-integer-like types: signless integers, indices, +// vectors of signless integers or indices, tensors of signless integers. +def SignlessIntegerLike : TypeOrContainer; + +def SignlessIntegerLikeOfAnyRank : TypeOrContainerOfAnyRank< + AnySignlessIntegerOrIndex, + "signless-integer-like">; + +// Type constraint for float-like types: floats, vectors or tensors thereof. +def FloatLike : TypeOrContainer; + +// Type constraint for signless-integer-like or float-like types. +def SignlessIntegerOrFloatLike : TypeConstraint, + "signless-integer-like or floating-point-like">; + +//===----------------------------------------------------------------------===// +// Attribute definitions +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Base attribute definition + +// Base class for all attributes. +class Attr : + AttrConstraint { + code storageType = ?; // The backing mlir::Attribute type + code returnType = ?; // The underlying C++ value type + + // The call expression to convert from the storage type to the return + // type. For example, an enum can be stored as an int but returned as an + // enum class. + // + // Format: $_self will be expanded to the attribute. + // + // For example, `$_self.getValue().getSExtValue()` for `IntegerAttr val` will + // expand to `getAttrOfType("val").getValue().getSExtValue()`. + code convertFromStorage = "$_self.getValue()"; + + // The call expression to build an attribute from a constant value. + // + // Format: $0 will be expanded to the constant value of the attribute. + // + // For example, `$_builder.getStringAttr("$0")` for `StringAttr:"foo"` will + // expand to `builder.getStringAttr("foo")`. + string constBuilderCall = ?; + + // Default value for attribute. + // Requires a constBuilderCall defined. + string defaultValue = ?; + + // The value type of this attribute. This corresponds to the mlir::Type that + // this attribute returns via `getType()`. + Type valueType = ?; + + // Whether the attribute is optional. Typically requires a custom + // convertFromStorage method to handle the case where the attribute is + // not present. + bit isOptional = 0; + + // What is the base-level Attr instantiation that this Attr is built upon. + // Unset means this is a base-level Attr. + // + // This field is used by attribute wrapper classes (DefaultValuedAttr, + // OptionalAttr, etc.) to retrieve the base-level attribute definition. + // This can be used for getting its name; otherwise, we will see + // "anonymous_" as the attribute def name because of template + // instantiation. + // TOOD(b/132458159): deduplicate the fields in attribute wrapper classes. + Attr baseAttr = ?; + + // The fully-qualified C++ namespace where the generated class lives. + string cppNamespace = ""; + + // The full description of this attribute. + string description = ""; +} + +// An attribute of a specific dialect. +class DialectAttr : + Attr { + Dialect dialect = d; + let cppNamespace = d.cppNamespace; +} + +//===----------------------------------------------------------------------===// +// Attribute modifier definition + +// Decorates an attribute to have an (unvalidated) default value if not present. +class DefaultValuedAttr : + Attr { + // Construct this attribute with the input attribute and change only + // the default value. + // Note: this has to be kept up to date with Attr above. + let storageType = attr.storageType; + let returnType = attr.returnType; + let convertFromStorage = attr.convertFromStorage; + let constBuilderCall = attr.constBuilderCall; + let defaultValue = val; + let valueType = attr.valueType; + + let baseAttr = attr; +} + +// Decorates an optional attribute to have an (unvalidated) default value +// return by ODS generated accessors if not present. +class DefaultValuedOptionalAttr : + Attr { + // Construct this attribute with the input attribute and change only + // the default value. + // Note: this has to be kept up to date with Attr above. + let storageType = attr.storageType; + let returnType = attr.returnType; + let convertFromStorage = attr.convertFromStorage; + let constBuilderCall = attr.constBuilderCall; + let defaultValue = val; + let valueType = attr.valueType; + let isOptional = 1; + + let baseAttr = attr; +} + +// Decorates an attribute as optional. The return type of the generated +// attribute accessor method will be Optional<>. +class OptionalAttr : Attr { + // Rewrite the attribute to be optional. + // Note: this has to be kept up to date with Attr above. + let storageType = attr.storageType; + let returnType = "::std::optional<" # attr.returnType #">"; + let convertFromStorage = "$_self ? " # returnType # "(" # + attr.convertFromStorage # ") : (::std::nullopt)"; + let valueType = attr.valueType; + let isOptional = 1; + + let baseAttr = attr; +} + +// Default-valued string-based attribute. Wraps the default value in escaped +// quotes. +class DefaultValuedStrAttr + : DefaultValuedAttr; +class DefaultValuedOptionalStrAttr + : DefaultValuedOptionalAttr; + +//===----------------------------------------------------------------------===// +// Primitive property kinds + +// Any kind of integer stored as properties. +class IntProperty : + Property { + code writeToMlirBytecode = [{ + $_writer.writeVarInt($_storage); + }]; + code readFromMlirBytecode = [{ + uint64_t val; + if (failed($_reader.readVarInt(val))) + return ::mlir::failure(); + $_storage = val; + }]; +} + +class ArrayProperty : + Property { + let interfaceType = "::llvm::ArrayRef<" # storageTypeParam # ">"; + let convertFromStorage = "$_storage"; + let assignToStorage = "::llvm::copy($_value, $_storage)"; +} + +//===----------------------------------------------------------------------===// +// Primitive attribute kinds + +// A generic attribute that must be constructed around a specific buildable type +// `attrValType`. Backed by MLIR attribute kind `attrKind`. +class TypedAttrBase : + Attr { + let constBuilderCall = "$_builder.get" # attrKind # "(" # + attrValType.builderCall # ", $0)"; + let storageType = "::mlir::" # attrKind; + let valueType = attrValType; +} + +// Any attribute. +def AnyAttr : Attr, "any attribute"> { + let storageType = "::mlir::Attribute"; + let returnType = "::mlir::Attribute"; + let convertFromStorage = "$_self"; + let constBuilderCall = "$0"; +} + +// Any attribute from the given list +class AnyAttrOf allowedAttrs, string summary = "", + string cppClassName = "::mlir::Attribute", + string fromStorage = "$_self"> : Attr< + // Satisfy any of the allowed attribute's condition + Or, + !if(!eq(summary, ""), + !interleave(!foreach(t, allowedAttrs, t.summary), " or "), + summary)> { + let returnType = cppClassName; + let convertFromStorage = fromStorage; +} + +def LocationAttr : Attr($_self)">, + "location attribute">; + +def BoolAttr : Attr($_self)">, "bool attribute"> { + let storageType = [{ ::mlir::BoolAttr }]; + let returnType = [{ bool }]; + let valueType = I1; + let constBuilderCall = "$_builder.getBoolAttr($0)"; +} + +// Index attribute. +def IndexAttr : + TypedAttrBase< + Index, "IntegerAttr", + And<[CPred<"::llvm::isa<::mlir::IntegerAttr>($_self)">, + CPred<"::llvm::isa<::mlir::IndexType>(::llvm::cast<::mlir::IntegerAttr>($_self).getType())">]>, + "index attribute"> { + let returnType = [{ ::llvm::APInt }]; +} + +// Base class for any integer (regardless of signedness semantics) attributes +// of fixed width. +class AnyIntegerAttrBase : + TypedAttrBase< + attrValType, "IntegerAttr", + And<[CPred<"::llvm::isa<::mlir::IntegerAttr>($_self)">, + CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getType()." + "isInteger(" # attrValType.bitwidth # ")">]>, + descr> { + let returnType = [{ ::llvm::APInt }]; + let constBuilderCall = ?; +} + +def AnyI1Attr : AnyIntegerAttrBase; +def AnyI8Attr : AnyIntegerAttrBase; +def AnyI16Attr : AnyIntegerAttrBase; +def AnyI32Attr : AnyIntegerAttrBase; +def AnyI64Attr : AnyIntegerAttrBase; + +def APIntAttr : Attr($_self)">, + "arbitrary integer attribute"> { + let storageType = [{ ::mlir::IntegerAttr }]; + let returnType = [{ ::mlir::APInt }]; +} + +// Base class for signless integer attributes of fixed width. +class SignlessIntegerAttrBase : + TypedAttrBase< + attrValType, "IntegerAttr", + And<[CPred<"::llvm::isa<::mlir::IntegerAttr>($_self)">, + CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getType()." + "isSignlessInteger(" # attrValType.bitwidth # ")">]>, + descr> { + let returnType = [{ ::llvm::APInt }]; +} +// Base class for signless integer attributes of fixed width that have a +// corresponding C++ type. +class TypedSignlessIntegerAttrBase + : SignlessIntegerAttrBase { + let returnType = retType; + let convertFromStorage = "$_self.getValue().getZExtValue()"; +} + +def I1Attr : TypedSignlessIntegerAttrBase< + I1, "bool", "1-bit signless integer attribute">; +def I8Attr : TypedSignlessIntegerAttrBase< + I8, "uint8_t", "8-bit signless integer attribute">; +def I16Attr : TypedSignlessIntegerAttrBase< + I16, "uint16_t", "16-bit signless integer attribute">; +def I32Attr : TypedSignlessIntegerAttrBase< + I32, "uint32_t", "32-bit signless integer attribute">; +def I64Attr : TypedSignlessIntegerAttrBase< + I64, "uint64_t", "64-bit signless integer attribute">; + +// Base class for signed integer attributes of fixed width. +class SignedIntegerAttrBase : + TypedAttrBase< + attrValType, "IntegerAttr", + And<[CPred<"::llvm::isa<::mlir::IntegerAttr>($_self)">, + CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getType()." + "isSignedInteger(" # attrValType.bitwidth # ")">]>, + descr> { + let returnType = [{ ::llvm::APInt }]; +} +// Base class for signed integer attributes of fixed width that have a +// corresponding C++ type. +class TypedSignedIntegerAttrBase + : SignedIntegerAttrBase { + let returnType = retType; + let convertFromStorage = "$_self.getValue().getSExtValue()"; +} + +def SI1Attr : TypedSignedIntegerAttrBase< + SI1, "bool", "1-bit signed integer attribute">; +def SI8Attr : TypedSignedIntegerAttrBase< + SI8, "int8_t", "8-bit signed integer attribute">; +def SI16Attr : TypedSignedIntegerAttrBase< + SI16, "int16_t", "16-bit signed integer attribute">; +def SI32Attr : TypedSignedIntegerAttrBase< + SI32, "int32_t", "32-bit signed integer attribute">; +def SI64Attr : TypedSignedIntegerAttrBase< + SI64, "int64_t", "64-bit signed integer attribute">; + +// Base class for unsigned integer attributes of fixed width. +class UnsignedIntegerAttrBase : + TypedAttrBase< + attrValType, "IntegerAttr", + And<[CPred<"::llvm::isa<::mlir::IntegerAttr>($_self)">, + CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getType()." + "isUnsignedInteger(" # attrValType.bitwidth # ")">]>, + descr> { + let returnType = [{ ::llvm::APInt }]; +} +// Base class for unsigned integer attributes of fixed width that have a +// corresponding C++ type. +class TypedUnsignedIntegerAttrBase + : UnsignedIntegerAttrBase { + let returnType = retType; + let convertFromStorage = "$_self.getValue().getZExtValue()"; +} + +def UI1Attr : TypedUnsignedIntegerAttrBase< + UI1, "bool", "1-bit unsigned integer attribute">; +def UI8Attr : TypedUnsignedIntegerAttrBase< + UI8, "uint8_t", "8-bit unsigned integer attribute">; +def UI16Attr : TypedUnsignedIntegerAttrBase< + UI16, "uint16_t", "16-bit unsigned integer attribute">; +def UI32Attr : TypedUnsignedIntegerAttrBase< + UI32, "uint32_t", "32-bit unsigned integer attribute">; +def UI64Attr : TypedUnsignedIntegerAttrBase< + UI64, "uint64_t", "64-bit unsigned integer attribute">; + +// Base class for float attributes of fixed width. +class FloatAttrBase : + TypedAttrBase($_self)">, + CPred<"::llvm::cast<::mlir::FloatAttr>($_self).getType().isF" # + attrValType.bitwidth # "()">]>, + descr> { + let returnType = [{ ::llvm::APFloat }]; +} + +def F32Attr : FloatAttrBase; +def F64Attr : FloatAttrBase; + +// An attribute backed by a string type. +class StringBasedAttr : Attr { + let constBuilderCall = "$_builder.getStringAttr($0)"; + let storageType = [{ ::mlir::StringAttr }]; + let returnType = [{ ::llvm::StringRef }]; + let valueType = NoneType; +} + +def StrAttr : StringBasedAttr($_self)">, + "string attribute">; + +// A string attribute that represents the name of a symbol. +def SymbolNameAttr : StringBasedAttr($_self)">, + "string attribute">; + +// String attribute that has a specific value type. +class TypedStrAttr + : StringBasedAttr($_self)">, + "string attribute"> { + let valueType = ty; +} + +// Base class for attributes containing types. Example: +// def IntTypeAttr : TypeAttrBase<"IntegerType", "integer type attribute"> +// defines a type attribute containing an integer type. +class TypeAttrBase> : + Attr($_self)">, + CPred<"::llvm::isa<" # retType # ">(::llvm::cast<::mlir::TypeAttr>($_self).getValue())">, + SubstLeaves<"$_self", + "::llvm::cast<::mlir::TypeAttr>($_self).getValue()", typePred>]>, + summary> { + let storageType = [{ ::mlir::TypeAttr }]; + let returnType = retType; + let valueType = NoneType; + let convertFromStorage = "::llvm::cast<" # retType # ">($_self.getValue())"; +} + +def TypeAttr : TypeAttrBase<"::mlir::Type", "any type attribute"> { + let constBuilderCall = "::mlir::TypeAttr::get($0)"; +} + +class TypeAttrOf + : TypeAttrBase { + let constBuilderCall = "::mlir::TypeAttr::get($0)"; +} + +// The mere presence of unit attributes has a meaning. Therefore, unit +// attributes are always treated as optional and accessors to them return +// "true" if the attribute is present and "false" otherwise. +def UnitAttr : Attr($_self)">, "unit attribute"> { + let storageType = [{ ::mlir::UnitAttr }]; + let constBuilderCall = "(($0) ? $_builder.getUnitAttr() : nullptr)"; + let convertFromStorage = "$_self != nullptr"; + let returnType = "bool"; + let defaultValue = "false"; + let valueType = NoneType; + let isOptional = 1; +} + +//===----------------------------------------------------------------------===// +// Composite attribute kinds + +class DictionaryAttrBase : + Attr { + let storageType = [{ ::mlir::DictionaryAttr }]; + let constBuilderCall = "$_builder.getDictionaryAttr($0)"; + let returnType = [{ ::mlir::DictionaryAttr }]; + let valueType = NoneType; + let convertFromStorage = "$_self"; +} + +def DictionaryAttr + : DictionaryAttrBase($_self)">, + "dictionary of named attribute values">; + +class ElementsAttrBase : + Attr { + let storageType = [{ ::mlir::ElementsAttr }]; + let returnType = [{ ::mlir::ElementsAttr }]; + let convertFromStorage = "$_self"; +} + +def ElementsAttr : ElementsAttrBase($_self)">, + "constant vector/tensor attribute">; + +class IntElementsAttrBase : + ElementsAttrBase($_self)">, + condition]>, + summary> { + let storageType = [{ ::mlir::DenseIntElementsAttr }]; + let returnType = [{ ::mlir::DenseIntElementsAttr }]; + + let convertFromStorage = "$_self"; +} + +class DenseArrayAttrBase : + ElementsAttrBase($_self)">, + summaryName # " dense array attribute"> { + let storageType = "::mlir::" # denseAttrName; + let returnType = "::llvm::ArrayRef<" # cppType # ">"; + let constBuilderCall = "$_builder.get" # denseAttrName # "($0)"; +} +def DenseBoolArrayAttr : DenseArrayAttrBase<"DenseBoolArrayAttr", "bool", "i1">; +def DenseI8ArrayAttr : DenseArrayAttrBase<"DenseI8ArrayAttr", "int8_t", "i8">; +def DenseI16ArrayAttr : DenseArrayAttrBase<"DenseI16ArrayAttr", "int16_t", "i16">; +def DenseI32ArrayAttr : DenseArrayAttrBase<"DenseI32ArrayAttr", "int32_t", "i32">; +def DenseI64ArrayAttr : DenseArrayAttrBase<"DenseI64ArrayAttr", "int64_t", "i64">; +def DenseF32ArrayAttr : DenseArrayAttrBase<"DenseF32ArrayAttr", "float", "f32">; +def DenseF64ArrayAttr : DenseArrayAttrBase<"DenseF64ArrayAttr", "double", "f64">; + +def IndexElementsAttr + : IntElementsAttrBase($_self) + .getType() + .getElementType() + .isIndex()}]>, + "index elements attribute">; + +def AnyIntElementsAttr : IntElementsAttrBase, "integer elements attribute">; + +class IntElementsAttrOf : IntElementsAttrBase< + CPred<"::llvm::cast<::mlir::DenseIntElementsAttr>($_self).getType()." + "getElementType().isInteger(" # width # ")">, + width # "-bit integer elements attribute">; + +def AnyI32ElementsAttr : IntElementsAttrOf<32>; +def AnyI64ElementsAttr : IntElementsAttrOf<64>; + +class SignlessIntElementsAttr : IntElementsAttrBase< + CPred<"::llvm::cast<::mlir::DenseIntElementsAttr>($_self).getType()." + "getElementType().isSignlessInteger(" # width # ")">, + width # "-bit signless integer elements attribute"> { + + // Note that this is only constructing scalar elements attribute. + let constBuilderCall = "::llvm::cast<::mlir::DenseIntElementsAttr>(" + "::mlir::DenseElementsAttr::get(" + "::mlir::RankedTensorType::get({}, $_builder.getIntegerType(" # width # ")), " + "::llvm::ArrayRef($0)))"; +} + +def I32ElementsAttr : SignlessIntElementsAttr<32>; +def I64ElementsAttr : SignlessIntElementsAttr<64>; + +// A `width`-bit signless integer elements attribute. The attribute should be +// ranked and has a shape as specified in `dims`. +class RankedSignlessIntElementsAttr dims> : + SignlessIntElementsAttr { + // Check that this has the specified shape. + let predicate = And<[ + SignlessIntElementsAttr.predicate, + CPred<"::llvm::cast<::mlir::DenseIntElementsAttr>($_self).getType().getShape() == " + "::mlir::ArrayRef({" # !interleave(dims, ", ") # "})">]>; + + let summary = width # "-bit signless int elements attribute of shape [" # + !interleave(dims, ", ") # "]"; + + let constBuilderCall = "::mlir::DenseIntElementsAttr::get(" + "::mlir::RankedTensorType::get({" # !interleave(dims, ", ") # + "}, $_builder.getIntegerType(" # width # ")), ::llvm::ArrayRef($0))"; +} + +class RankedI32ElementsAttr dims> : + RankedSignlessIntElementsAttr<32, dims>; +class RankedI64ElementsAttr dims> : + RankedSignlessIntElementsAttr<64, dims>; + +class FloatElementsAttr : ElementsAttrBase< + CPred<"::llvm::isa<::mlir::DenseFPElementsAttr>($_self) &&" + "::llvm::cast<::mlir::DenseElementsAttr>($_self).getType()." + "getElementType().isF" # width # "()">, + width # "-bit float elements attribute"> { + + let storageType = [{ ::mlir::DenseElementsAttr }]; + let returnType = [{ ::mlir::DenseElementsAttr }]; + + // Note that this is only constructing scalar elements attribute. + let constBuilderCall = "::mlir::DenseElementsAttr::get(" + "::mlir::RankedTensorType::get({}, $_builder.getF" # width # "Type())," + "::llvm::ArrayRef($0))"; + let convertFromStorage = "$_self"; +} + +def F64ElementsAttr : FloatElementsAttr<64>; + +// A `width`-bit floating point elements attribute. The attribute should be +// ranked and has a shape as specified in `dims`. +class RankedFloatElementsAttr dims> : ElementsAttrBase< + CPred<"::llvm::isa<::mlir::DenseFPElementsAttr>($_self) &&" + "::llvm::cast<::mlir::DenseFPElementsAttr>($_self).getType()." + "getElementType().isF" # width # "() && " + // Check that this is ranked and has the specified shape. + "::llvm::cast<::mlir::DenseFPElementsAttr>($_self).getType().hasRank() && " + "::llvm::cast<::mlir::DenseFPElementsAttr>($_self).getType().getShape() == " + "::mlir::ArrayRef({" # !interleave(dims, ", ") # "})">, + width # "-bit float elements attribute of shape [" # + !interleave(dims, ", ") # "]"> { + + let storageType = [{ ::mlir::DenseFPElementsAttr }]; + let returnType = [{ ::mlir::DenseFPElementsAttr }]; + + let constBuilderCall = "::llvm::cast<::mlir::DenseFPElementsAttr>(" + "::mlir::DenseElementsAttr::get(" + "::mlir::RankedTensorType::get({" # !interleave(dims, ", ") # + "}, $_builder.getF" # width # "Type()), " + "::llvm::ArrayRef($0)))"; + let convertFromStorage = "$_self"; +} + +class RankedF32ElementsAttr dims> : RankedFloatElementsAttr<32, dims>; +class RankedF64ElementsAttr dims> : RankedFloatElementsAttr<64, dims>; + +def StringElementsAttr : ElementsAttrBase< + CPred<"::llvm::isa<::mlir::DenseStringElementsAttr>($_self)" >, + "string elements attribute"> { + + let storageType = [{ ::mlir::DenseElementsAttr }]; + let returnType = [{ ::mlir::DenseElementsAttr }]; + + let convertFromStorage = "$_self"; +} + +// Attributes containing affine maps. +def AffineMapAttr : Attr< +CPred<"::llvm::isa<::mlir::AffineMapAttr>($_self)">, "AffineMap attribute"> { + let storageType = [{::mlir::AffineMapAttr }]; + let returnType = [{ ::mlir::AffineMap }]; + let valueType = Index; + let constBuilderCall = "::mlir::AffineMapAttr::get($0)"; +} + +// Base class for array attributes. +class ArrayAttrBase : Attr { + let storageType = [{ ::mlir::ArrayAttr }]; + let returnType = [{ ::mlir::ArrayAttr }]; + let valueType = NoneType; + let convertFromStorage = "$_self"; + let constBuilderCall = "$_builder.getArrayAttr($0)"; +} + +def ArrayAttr : ArrayAttrBase($_self)">, + "array attribute">; + +// Base class for array attributes whose elements are of the same kind. +// `element` specifies the element attribute kind stored in this array. +class TypedArrayAttrBase: ArrayAttrBase< + And<[ + // Guarantee this is an ArrayAttr first + CPred<"::llvm::isa<::mlir::ArrayAttr>($_self)">, + // Guarantee all elements satisfy the constraints from `element` + Concat<"::llvm::all_of(::llvm::cast<::mlir::ArrayAttr>($_self), " + "[&](::mlir::Attribute attr) { return attr && (", + SubstLeaves<"$_self", "attr", element.predicate>, + "); })">]>, + summary> { + + Attr elementAttr = element; +} + +def LocationArrayAttr : TypedArrayAttrBase; + +def AffineMapArrayAttr : TypedArrayAttrBase { + let constBuilderCall = "$_builder.getAffineMapArrayAttr($0)"; +} + +def BoolArrayAttr : TypedArrayAttrBase { + let constBuilderCall = "$_builder.getBoolArrayAttr($0)"; +} +def I32ArrayAttr : TypedArrayAttrBase { + let constBuilderCall = "$_builder.getI32ArrayAttr($0)"; +} +def I64ArrayAttr : TypedArrayAttrBase { + let constBuilderCall = "$_builder.getI64ArrayAttr($0)"; +} +// Variant of I64ArrayAttr whose user accessor is SmallVector. +def I64SmallVectorArrayAttr : + TypedArrayAttrBase { + let returnType = [{ ::llvm::SmallVector }]; + let convertFromStorage = [{ + llvm::to_vector<4>( + llvm::map_range($_self.getAsRange(), + [](IntegerAttr attr) { return attr.getInt(); })); + }]; + let constBuilderCall = "$_builder.getI64ArrayAttr($0)"; +} +def F32ArrayAttr : TypedArrayAttrBase { + let constBuilderCall = "$_builder.getF32ArrayAttr($0)"; +} +def F64ArrayAttr : TypedArrayAttrBase { + let constBuilderCall = "$_builder.getF64ArrayAttr($0)"; +} +def StrArrayAttr : TypedArrayAttrBase { + let constBuilderCall = "$_builder.getStrArrayAttr($0)"; +} +def TypeArrayAttr : TypedArrayAttrBase { + let constBuilderCall = "$_builder.getTypeArrayAttr($0)"; +} +def IndexListArrayAttr : + TypedArrayAttrBase; +def DictArrayAttr : + TypedArrayAttrBase; + +// Attributes containing symbol references. +def SymbolRefAttr : Attr($_self)">, + "symbol reference attribute"> { + let storageType = [{ ::mlir::SymbolRefAttr }]; + let returnType = [{ ::mlir::SymbolRefAttr }]; + let valueType = NoneType; + let constBuilderCall = + "::mlir::SymbolRefAttr::get($_builder.getContext(), $0)"; + let convertFromStorage = "$_self"; +} + +def FlatSymbolRefAttr : Attr($_self)">, + "flat symbol reference attribute"> { + let storageType = [{ ::mlir::FlatSymbolRefAttr }]; + let returnType = [{ ::llvm::StringRef }]; + let valueType = NoneType; + let constBuilderCall = + "::mlir::SymbolRefAttr::get($_builder.getContext(), $0)"; + let convertFromStorage = "$_self.getValue()"; +} + +def SymbolRefArrayAttr : + TypedArrayAttrBase { + let constBuilderCall = ?; +} + +def FlatSymbolRefArrayAttr : + TypedArrayAttrBase { + let constBuilderCall = ?; +} + +//===----------------------------------------------------------------------===// +// Derive attribute kinds + +// DerivedAttr are attributes whose value is computed from properties +// of the operation. They do not require additional storage and are +// materialized as needed. +// Note: All derived attributes should be materializable as an Attribute. E.g., +// do not use DerivedAttr for things that could not have been stored as +// Attribute. +// +class DerivedAttr : + Attr, "derived attribute"> { + let returnType = ret; + code body = b; + + // Specify how to convert from the derived attribute to an attribute. + // + // ## Special placeholders + // + // Special placeholders can be used to refer to entities during conversion: + // + // * `$_builder` will be replaced by a mlir::Builder instance. + // * `$_ctxt` will be replaced by the MLIRContext* instance. + // * `$_self` will be replaced with the derived attribute (value produces + // `returnType`). + let convertFromStorage = convert; +} + +// Derived attribute that returns a mlir::Type. +class DerivedTypeAttr : DerivedAttr<"::mlir::Type", body> { + let convertFromStorage = "::mlir::TypeAttr::get($_self)"; +} + +//===----------------------------------------------------------------------===// +// Constant attribute kinds + +// Represents a constant attribute of specific Attr type. A constant +// attribute can be specified only of attributes that have a constant +// builder call defined. The constant value is specified as a string. +// +// If used as a constraint, it generates a matcher on a constant attribute by +// using the constant value builder of the attribute and the value. +class ConstantAttr : AttrConstraint< + CPred<"$_self == " # !subst("$0", val, attribute.constBuilderCall)>, + "constant attribute " # val> { + Attr attr = attribute; + string value = val; +} + +class ConstF32Attr : ConstantAttr; +def ConstBoolAttrFalse : ConstantAttr; +def ConstBoolAttrTrue : ConstantAttr; +def ConstUnitAttr : ConstantAttr; + +// Constant string-based attribute. Wraps the desired string in escaped quotes. +class ConstantStrAttr + : ConstantAttr; + +//===----------------------------------------------------------------------===// +// Common attribute constraints +//===----------------------------------------------------------------------===// + +// A general mechanism to further confine the given `attr` with all the +// `constraints`. This allows to compose complex constraints out of a series +// of more primitive ones. +class ConfinedAttr constraints> : Attr< + And, + !foldl(/*init*/attr.summary, /*list*/constraints, + prev, cur, prev # " " # cur.summary)> { + let storageType = attr.storageType; + let returnType = attr.returnType; + let convertFromStorage = attr.convertFromStorage; + let constBuilderCall = attr.constBuilderCall; + let defaultValue = attr.defaultValue; + let valueType = attr.valueType; + let isOptional = attr.isOptional; + + let baseAttr = attr; +} + +// An AttrConstraint that holds if all attr constraints specified in +// 'constraints' hold. +class AllAttrOf constraints> : AttrConstraint< + And, + !interleave(!foreach(con, constraints, con.summary), " and ")> { +} + +class IntNEQValue : AttrConstraint< + CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getInt() != " # n>, + "whose minimum value is " # n>; + +class IntMinValue : AttrConstraint< + CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getInt() >= " # n>, + "whose minimum value is " # n>; + +class IntMaxValue : AttrConstraint< + CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getInt() <= " # n>, + "whose maximum value is " # n>; + +def IntNonNegative : AttrConstraint< + CPred<"!::llvm::cast<::mlir::IntegerAttr>($_self).getValue().isNegative()">, + "whose value is non-negative">; + +def IntPositive : AttrConstraint< + CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getValue().isStrictlyPositive()">, + "whose value is positive">; + +class ArrayMinCount : AttrConstraint< + CPred<"::llvm::cast<::mlir::ArrayAttr>($_self).size() >= " # n>, + "with at least " # n # " elements">; + +class ArrayCount : AttrConstraint< + CPred<"::llvm::cast<::mlir::ArrayAttr>($_self).size() == " #n>, + "with exactly " # n # " elements">; + +class DenseArrayCount : AttrConstraint< + CPred<"::llvm::cast<::mlir::DenseArrayAttr>($_self).size() == " #n>, + "with exactly " # n # " elements">; + +class DenseArrayStrictlyPositive : AttrConstraint< + CPred<"::llvm::all_of(::llvm::cast<" # arrayType #">($_self).asArrayRef(), " + "[&](auto v) { return v > 0; })">, + "whose value is positive">; + +class DenseArrayNonNegative : AttrConstraint< + CPred<"::llvm::all_of(::llvm::cast<" # arrayType #">($_self).asArrayRef(), " + "[&](auto v) { return v >= 0; })">, + "whose value is non-negative">; + +class DenseArraySorted : AttrConstraint< + CPred<"llvm::is_sorted(::llvm::cast<" # arrayType # ">($_self).asArrayRef())">, + "should be in non-decreasing order">; + +class DenseArrayStrictlySorted : AttrConstraint< + And<[ + CPred<"llvm::is_sorted(::llvm::cast<" # arrayType # ">($_self).asArrayRef())">, + // Check that no two adjacent elements are the same. + CPred<"[](" # arrayType.returnType # " a) {\n" + "return std::adjacent_find(std::begin(a), std::end(a)) == " + "std::end(a);\n" + "}(::llvm::cast<" # arrayType # ">($_self).asArrayRef())" + >]>, + "should be in increasing order">; + +class IntArrayNthElemEq : AttrConstraint< + And<[ + CPred<"::llvm::cast<::mlir::ArrayAttr>($_self).size() > " # index>, + CPred<"::llvm::cast<::mlir::IntegerAttr>(::llvm::cast<::mlir::ArrayAttr>($_self)[" + # index # "]).getInt() == " # value> + ]>, + "whose " # index # "-th element must be " # value>; + +class IntArrayNthElemMinValue : AttrConstraint< + And<[ + CPred<"::llvm::cast<::mlir::ArrayAttr>($_self).size() > " # index>, + CPred<"::llvm::cast<::mlir::IntegerAttr>(::llvm::cast<::mlir::ArrayAttr>($_self)[" + # index # "]).getInt() >= " # min> + ]>, + "whose " # index # "-th element must be at least " # min>; + +class IntArrayNthElemMaxValue : AttrConstraint< + And<[ + CPred<"::llvm::cast<::mlir::ArrayAttr>($_self).size() > " # index>, + CPred<"::llvm::cast<::mlir::IntegerAttr>(::llvm::cast<::mlir::ArrayAttr>($_self)[" + # index # "]).getInt() <= " # max> + ]>, + "whose " # index # "-th element must be at most " # max>; + +class IntArrayNthElemInRange : AttrConstraint< + And<[ + CPred<"::llvm::cast<::mlir::ArrayAttr>($_self).size() > " # index>, + CPred<"::llvm::cast<::mlir::IntegerAttr>(::llvm::cast<::mlir::ArrayAttr>($_self)[" + # index # "]).getInt() >= " # min>, + CPred<"::llvm::cast<::mlir::IntegerAttr>(::llvm::cast<::mlir::ArrayAttr>($_self)[" + # index # "]).getInt() <= " # max> + ]>, + "whose " # index # "-th element must be at least " # min # " and at most " # max>; + +def IsNullAttr : AttrConstraint< + CPred<"!$_self">, "empty attribute (for optional attributes)">; + +//===----------------------------------------------------------------------===// +// Region definitions +//===----------------------------------------------------------------------===// + +class Region : + RegionConstraint; + +// Any region. +def AnyRegion : Region, "any region">; + +// A region with the given number of blocks. +class SizedRegion : Region< + CPred<"::llvm::hasNItems($_self, " # numBlocks # ")">, + "region with " # numBlocks # " blocks">; + +// A region with at least the given number of blocks. +class MinSizedRegion : Region< + CPred<"::llvm::hasNItemsOrMore($_self, " # numBlocks # ")">, + "region with at least " # numBlocks # " blocks">; + +// A region with at most the given number of blocks. +class MaxSizedRegion : Region< + CPred<"::llvm::hasNItemsOrLess($_self, " # numBlocks # ")">, + "region with at most " # numBlocks # " blocks">; + +// A variadic region constraint. It expands to zero or more of the base region. +class VariadicRegion + : Region; + //-------------------------------------------------------------------------===// // AttrTrait definitions diff --git a/mlir/include/mlir/IR/Constraints.td b/mlir/include/mlir/IR/Constraints.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/IR/Constraints.td @@ -0,0 +1,243 @@ +//===-- Constraints.td - Constraints definition file ----------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines constraints/predicates for verifiers. +// +//===----------------------------------------------------------------------===// + +#ifndef CONSTRAINTS +#define CONSTRAINTS + +include "mlir/IR/Utils.td" + + +//===----------------------------------------------------------------------===// +// Predicate definitions +//===----------------------------------------------------------------------===// + +// Base class for logical predicates. +// +// Predicates are used to compose constraints (see next section for details). +// There are two categories of predicates: +// +// 1. CPred: the primitive leaf predicate. +// 2. Compound predicate: a predicate composed from child predicates using +// predicate combiners ("conjunction", "disjunction", "negation" or +// "substitution"). +class Pred; + +// A logical predicate wrapping any C expression. +// +// This is the basis for composing more complex predicates. It is the "atom" +// predicate from the perspective of TableGen and the "interface" between +// TableGen and C++. What is inside is already C++ code, which will be treated +// as opaque strings with special placeholders to be substituted. +// +// ## Special placeholders +// +// Special placeholders can be used to refer to entities in the context where +// this predicate is used. They serve as "hooks" to the enclosing environment. +// The following special placeholders are supported in constraints for an op: +// +// * `$_builder` will be replaced by a mlir::Builder instance. +// * `$_op` will be replaced by the current operation. +// * `$_self` will be replaced with the entity this predicate is attached to. +// E.g., `BoolAttr` is an attribute constraint that wraps a +// `CPred<"::llvm::isa($_self)">` (see the following sections for details). +// Then for `F32:$attr`,`$_self` will be replaced by `$attr`. +// For type constraints, it's a little bit special since we want the +// constraints on each type definition reads naturally and we want to attach +// type constraints directly to an operand/result, $_self will be replaced +// by the operand/result's type. E.g., for `F32` in `F32:$operand`, its +// `$_self` will be expanded as `getOperand(...).getType()`. +// +// One thing to be noticed, while using these placeholders in the C expression, +// the type of placeholder is only guaranteed to be the base type. For example, +// if you have a predicate in the form `CPred<"CheckType($_self)">, the argument +// type of the function `CheckType` should be `mlir::Type`. +class CPred : Pred { + code predExpr = "(" # pred # ")"; +} + +// Kinds of predicate combiners. These must closely match the predicates +// implemented by the C++ backend (tblgen::PredCombinerKind). +class PredCombinerKind; +def PredCombinerAnd : PredCombinerKind; +def PredCombinerOr : PredCombinerKind; +def PredCombinerNot : PredCombinerKind; +def PredCombinerSubstLeaves : PredCombinerKind; +def PredCombinerConcat : PredCombinerKind; + +// A predicate that combines other predicates as defined by PredCombinerKind. +// Instantiated below. +class CombinedPred c> : Pred { + PredCombinerKind kind = k; + list children = c; +} + +// Predicate combiners + +// A predicate that holds if all of its children hold. Always holds for zero +// children. +class And children> : CombinedPred; + +// A predicate that holds if any of its children hold. Never holds for zero +// children. +class Or children> : CombinedPred; + +// A predicate that holds if its child does not. +class Neg : CombinedPred; + +// A predicate that substitutes "pat" with "repl" in predicate calls of the +// leaves of the predicate tree (i.e., not CombinedPred). +// +// This is plain string substitution without regular expressions or captures. +// New predicates with more complex logical can be introduced should the need +// arise. +class SubstLeaves + : CombinedPred { + string pattern = pat; + string replacement = repl; +} + +// A predicate that prepends `pre` and appends `suf` to the final predicate +// string composed from `child`. This is plain string concatenation and there +// will be no substitution happening for `pre` and `suf`. +class Concat : + CombinedPred { + string prefix = pre; + string suffix = suf; +} + +//===----------------------------------------------------------------------===// +// Constraint definitions +//===----------------------------------------------------------------------===// + +// TODO: Merge Constraints into Pred. + +// Base class for named constraints. +// +// An op's operands/attributes/results can have various requirements, e.g., +// having certain types, having values inside a certain range, and so on. +// Besides, for a graph rewrite rule, the source pattern used to match against +// the existing graph has conditions, like the op's operand must be of a more +// constrained subtype, the attribute must have a certain value, and so on. +// +// These requirements and conditions are modeled using this class. Records of +// this class are used to generate verification code in op verifier, and +// matching code in pattern matcher. +// +// Constraints are predicates with descriptive names, to facilitate inspection, +// provide nice error messages, etc. +class Constraint { + // The predicates that this constraint requires. + Pred predicate = pred; + // User-readable one line summary used in error reporting messages. If empty, + // a generic message will be used. + string summary = desc; +} + +// Subclasses used to differentiate different constraint kinds. These are used +// as markers for the TableGen backend to handle different constraint kinds +// differently if needed. Constraints not deriving from the following subclasses +// are considered as uncategorized constraints. + +// Subclass for constraints on a type. +class TypeConstraint : + Constraint { + // The name of the C++ Type class if known, or Type if not. + string cppClassName = cppClassNameParam; +} + +// Subclass for constraints on an attribute. +class AttrConstraint : + Constraint; + +// Subclass for constraints on a region. +class RegionConstraint : + Constraint; + +// Subclass for constraints on a successor. +class SuccessorConstraint : + Constraint; + +// How to use these constraint categories: +// +// * Use TypeConstraint to specify +// * Constraints on an op's operand/result definition +// * Further constraints to match an op's operand/result in source pattern +// +// * Use Attr (a subclass for AttrConstraint) for +// * Constraints on an op's attribute definition +// * Use AttrConstraint to specify +// * Further constraints to match an op's attribute in source pattern +// +// * Use uncategorized constraint to specify +// * Multi-entity constraints in rewrite rules + +//===----------------------------------------------------------------------===// +// Common predicates +//===----------------------------------------------------------------------===// + +// Whether a type is a VectorType. +// Explicitly disallow 0-D vectors for now until we have good enough coverage. +def IsVectorTypePred : And<[CPred<"::llvm::isa<::mlir::VectorType>($_self)">, + CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">]>; + +// Temporary vector type clone that allows gradual transition to 0-D vectors. +// TODO: Remove this when all ops support 0-D vectors. +def IsVectorOfAnyRankTypePred : CPred<"::llvm::isa<::mlir::VectorType>($_self)">; + +// Whether a type is a fixed-length VectorType. +def IsFixedVectorTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) && + !::llvm::cast($_self).isScalable()}]>; + +// Whether a type is a scalable VectorType. +def IsScalableVectorTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) && + ::llvm::cast($_self).isScalable()}]>; + +// Whether a type is a VectorType and all dimensions are scalable. +def allDimsScalableVectorTypePred : And<[ + IsVectorTypePred, + CPred<[{::llvm::cast<::mlir::VectorType>($_self).allDimsScalable()}]> +]>; + +// Whether a type is a TensorType. +def IsTensorTypePred : CPred<"::llvm::isa<::mlir::TensorType>($_self)">; + +// Whether a type is a MemRefType. +def IsMemRefTypePred : CPred<"::llvm::isa<::mlir::MemRefType>($_self)">; + +// Whether a type is an UnrankedMemRefType +def IsUnrankedMemRefTypePred + : CPred<"::llvm::isa<::mlir::UnrankedMemRefType>($_self)">; + +// Whether a type is an UnrankedTensorType +def IsUnrankedTensorTypePred + : CPred<"::llvm::isa<::mlir::UnrankedTensorType>($_self)">; + +// Whether a type is a RankedTensorType +def IsRankedTensorTypePred + : CPred<"::llvm::isa<::mlir::RankedTensorType>($_self)">; + +// Whether a type is a BaseMemRefType +def IsBaseMemRefTypePred + : CPred<"::llvm::isa<::mlir::BaseMemRefType>($_self)">; + +// Whether a type is a ShapedType. +def IsShapedTypePred : CPred<"::llvm::isa<::mlir::ShapedType>($_self)">; + +// For a ShapedType, verify that it has a static shape. +def HasStaticShapePred : + CPred<"::llvm::cast<::mlir::ShapedType>($_self).hasStaticShape()">; + +// Whether a type is a TupleType. +def IsTupleTypePred : CPred<"::llvm::isa<::mlir::TupleType>($_self)">; + +#endif // CONSTRAINTS diff --git a/mlir/include/mlir/IR/DialectBase.td b/mlir/include/mlir/IR/DialectBase.td --- a/mlir/include/mlir/IR/DialectBase.td +++ b/mlir/include/mlir/IR/DialectBase.td @@ -13,23 +13,7 @@ #ifndef DIALECTBASE_TD #define DIALECTBASE_TD -// Helper for marking deprecated classes or defs in TableGen. To mark a def as -// deprecated, mix in the `Deprecate` class with a reason. -// Usage of a deprecated def within TableGen will cause a warning with the -// given message. -class Deprecated { - string odsDeprecated = reason; -} - -// Helper for marking entities in ODS generated C++ as deprecated. -// Usage of such an entity from C++ code will cause a warning being emitted by -// the C++ compiler with the given message. -// -// Note: Support has to be implemented by the code generator of a given -// entity. -class CppDeprecated { - string odsCppDeprecated = reason; -} +include "mlir/IR/Utils.td" //===----------------------------------------------------------------------===// // Dialect definitions diff --git a/mlir/include/mlir/IR/Interfaces.td b/mlir/include/mlir/IR/Interfaces.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/IR/Interfaces.td @@ -0,0 +1,193 @@ +//===-- Interfaces.td - Interfaces defination file ------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains definations for Interfaces. +// +//===----------------------------------------------------------------------===// + +#ifndef INTERFACES_TD +#define INTERFACES_TD + +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/Constraints.td" +include "mlir/IR/Traits.td" + +//===----------------------------------------------------------------------===// +// Interface definitions +//===----------------------------------------------------------------------===// + +// InterfaceTrait corresponds to a specific 'Interface' class defined in C++. +// The purpose to wrap around C++ symbol string with this class is to make +// interfaces specified for ops in TableGen less alien and more integrated. +class InterfaceTrait : NativeTrait<"", ""> { + let trait = name # "::Trait"; + let cppNamespace = ""; + + // An optional code block containing extra declarations to place in the + // interface trait declaration. + code extraTraitClassDeclaration = ""; +} + +// OpInterfaceTrait corresponds to a specific 'OpInterface' class defined in +// C++. The purpose to wrap around C++ symbol string with this class is to make +// interfaces specified for ops in TableGen less alien and more integrated. +class OpInterfaceTrait traits = []> + : InterfaceTrait { + // Specify the body of the verification function. `$_op` will be replaced with + // the operation being verified. + code verify = verifyBody; + + // A bit indicating if the verifier needs to access the ops in the regions. If + // it set to `1`, the region ops will be verified before invoking this + // verifier. + bit verifyWithRegions = 0; + + // Specify the list of traits that need to be verified before the verification + // of this OpInterfaceTrait. + list dependentTraits = traits; +} + +// This class represents a single, optionally static, interface method. +// Note: non-static interface methods have an implicit parameter, either +// $_op/$_attr/$_type corresponding to an instance of the derived value. +class InterfaceMethod { + // A human-readable description of what this method does. + string description = desc; + + // The name of the interface method. + string name = methodName; + + // The c++ type-name of the return type. + string returnType = retTy; + + // A dag of string that correspond to the arguments of the method. + dag arguments = args; + + // An optional body to the method. + code body = methodBody; + + // An optional default implementation of the method. + code defaultBody = defaultImplementation; +} + +// This class represents a single static interface method. +class StaticInterfaceMethod + : InterfaceMethod; + +// Interface represents a base interface. +class Interface baseInterfacesArg = []> { + // A human-readable description of what this interface does. + string description = ""; + + // The name given to the c++ interface class. + string cppInterfaceName = name; + + // The C++ namespace that this interface should be placed into. + // + // To specify nested namespaces, use "::" as the delimiter, e.g., given + // "A::B", ops will be placed in `namespace A { namespace B { } }`. + string cppNamespace = ""; + + // The list of methods defined by this interface. + list methods = []; + + // An optional code block containing extra declarations to place in the + // interface declaration. + code extraClassDeclaration = ""; + + // An optional code block containing extra declarations to place in both + // the interface and trait declaration. + code extraSharedClassDeclaration = ""; + + // An optional code block for adding additional "classof" logic. This can + // be used to better enable "optional" interfaces, where an entity only + // implements the interface if some dynamic characteristic holds. + // `$_attr`/`$_op`/`$_type` may be used to refer to an instance of the + // entity being checked. + code extraClassOf = ""; + + // An optional set of base interfaces that this interface + // "derives" from. + list baseInterfaces = baseInterfacesArg; +} + +// AttrInterface represents an interface registered to an attribute. +class AttrInterface baseInterfaces = []> + : Interface, InterfaceTrait, + Attr($_self)">, + name # " instance" + > { + let storageType = !if(!empty(cppNamespace), "", cppNamespace # "::") # name; + let returnType = storageType; + let convertFromStorage = "$_self"; +} + +// OpInterface represents an interface registered to an operation. +class OpInterface baseInterfaces = []> + : Interface, OpInterfaceTrait; + +// TypeInterface represents an interface registered to a type. +class TypeInterface baseInterfaces = []> + : Interface, InterfaceTrait, + Type($_self)">, + name # " instance", + !if(!empty(cppNamespace),"", cppNamespace # "::") # name + >; + +// Whether to declare the interface methods in the user entity's header. This +// class simply wraps an Interface but is used to indicate that the method +// declarations should be generated. This class takes an optional set of methods +// that should have declarations generated even if the method has a default +// implementation. +class DeclareInterfaceMethods overridenMethods = []> { + // This field contains a set of method names that should always have their + // declarations generated. This allows for generating declarations for + // methods with default implementations that need to be overridden. + list alwaysOverriddenMethods = overridenMethods; +} +class DeclareAttrInterfaceMethods overridenMethods = []> + : DeclareInterfaceMethods, + AttrInterface { + let description = interface.description; + let cppInterfaceName = interface.cppInterfaceName; + let cppNamespace = interface.cppNamespace; + let methods = interface.methods; + let baseInterfaces = interface.baseInterfaces; +} +class DeclareOpInterfaceMethods overridenMethods = []> + : DeclareInterfaceMethods, + OpInterface { + let description = interface.description; + let cppInterfaceName = interface.cppInterfaceName; + let cppNamespace = interface.cppNamespace; + let methods = interface.methods; + let baseInterfaces = interface.baseInterfaces; +} +class DeclareTypeInterfaceMethods overridenMethods = []> + : DeclareInterfaceMethods, + TypeInterface { + let description = interface.description; + let cppInterfaceName = interface.cppInterfaceName; + let cppNamespace = interface.cppNamespace; + let methods = interface.methods; + let baseInterfaces = interface.baseInterfaces; +} + + +#endif // INTERFACES_TD diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -13,1966 +13,13 @@ #ifndef OP_BASE #define OP_BASE +include "mlir/IR/Constraints.td" include "mlir/IR/DialectBase.td" - -//===----------------------------------------------------------------------===// -// Common utilities for defining TableGen mechanisms -//===----------------------------------------------------------------------===// - -// A workaround for the inability to define functions in Tablegen. -// -// The template parameter defines a string that can be extracted from an -// instance of this class by accessing the "result" member. Subclasses can take -// their own template parameters as function "arguments" and use them to -// populate result. -// For example, if it didn't already exist, a concat function could be defined -// like: -// -// class StrConcat strings> : -// StrFunc -// -// and then called like -// -// StrConcat<["a", "b", "c"]>.result -// -// to get the string "abc" -class StrFunc { - string result = r; -} - -//===----------------------------------------------------------------------===// -// Predicate definitions -//===----------------------------------------------------------------------===// - -// Base class for logical predicates. -// -// Predicates are used to compose constraints (see next section for details). -// There are two categories of predicates: -// -// 1. CPred: the primitive leaf predicate. -// 2. Compound predicate: a predicate composed from child predicates using -// predicate combiners ("conjunction", "disjunction", "negation" or -// "substitution"). -class Pred; - -// A logical predicate wrapping any C expression. -// -// This is the basis for composing more complex predicates. It is the "atom" -// predicate from the perspective of TableGen and the "interface" between -// TableGen and C++. What is inside is already C++ code, which will be treated -// as opaque strings with special placeholders to be substituted. -// -// ## Special placeholders -// -// Special placeholders can be used to refer to entities in the context where -// this predicate is used. They serve as "hooks" to the enclosing environment. -// The following special placeholders are supported in constraints for an op: -// -// * `$_builder` will be replaced by a mlir::Builder instance. -// * `$_op` will be replaced by the current operation. -// * `$_self` will be replaced with the entity this predicate is attached to. -// E.g., `BoolAttr` is an attribute constraint that wraps a -// `CPred<"::llvm::isa($_self)">` (see the following sections for details). -// Then for `F32:$attr`,`$_self` will be replaced by `$attr`. -// For type constraints, it's a little bit special since we want the -// constraints on each type definition reads naturally and we want to attach -// type constraints directly to an operand/result, $_self will be replaced -// by the operand/result's type. E.g., for `F32` in `F32:$operand`, its -// `$_self` will be expanded as `getOperand(...).getType()`. -// -// One thing to be noticed, while using these placeholders in the C expression, -// the type of placeholder is only guaranteed to be the base type. For example, -// if you have a predicate in the form `CPred<"CheckType($_self)">, the argument -// type of the function `CheckType` should be `mlir::Type`. -class CPred : Pred { - code predExpr = "(" # pred # ")"; -} - -// Kinds of predicate combiners. These must closely match the predicates -// implemented by the C++ backend (tblgen::PredCombinerKind). -class PredCombinerKind; -def PredCombinerAnd : PredCombinerKind; -def PredCombinerOr : PredCombinerKind; -def PredCombinerNot : PredCombinerKind; -def PredCombinerSubstLeaves : PredCombinerKind; -def PredCombinerConcat : PredCombinerKind; - -// A predicate that combines other predicates as defined by PredCombinerKind. -// Instantiated below. -class CombinedPred c> : Pred { - PredCombinerKind kind = k; - list children = c; -} - -// Predicate combiners - -// A predicate that holds if all of its children hold. Always holds for zero -// children. -class And children> : CombinedPred; - -// A predicate that holds if any of its children hold. Never holds for zero -// children. -class Or children> : CombinedPred; - -// A predicate that holds if its child does not. -class Neg : CombinedPred; - -// A predicate that substitutes "pat" with "repl" in predicate calls of the -// leaves of the predicate tree (i.e., not CombinedPred). -// -// This is plain string substitution without regular expressions or captures. -// New predicates with more complex logical can be introduced should the need -// arise. -class SubstLeaves - : CombinedPred { - string pattern = pat; - string replacement = repl; -} - -// A predicate that prepends `pre` and appends `suf` to the final predicate -// string composed from `child`. This is plain string concatenation and there -// will be no substitution happening for `pre` and `suf`. -class Concat : - CombinedPred { - string prefix = pre; - string suffix = suf; -} - -//===----------------------------------------------------------------------===// -// Constraint definitions -//===----------------------------------------------------------------------===// - -// TODO: Merge Constraints into Pred. - -// Base class for named constraints. -// -// An op's operands/attributes/results can have various requirements, e.g., -// having certain types, having values inside a certain range, and so on. -// Besides, for a graph rewrite rule, the source pattern used to match against -// the existing graph has conditions, like the op's operand must be of a more -// constrained subtype, the attribute must have a certain value, and so on. -// -// These requirements and conditions are modeled using this class. Records of -// this class are used to generate verification code in op verifier, and -// matching code in pattern matcher. -// -// Constraints are predicates with descriptive names, to facilitate inspection, -// provide nice error messages, etc. -class Constraint { - // The predicates that this constraint requires. - Pred predicate = pred; - // User-readable one line summary used in error reporting messages. If empty, - // a generic message will be used. - string summary = desc; -} - -// Subclasses used to differentiate different constraint kinds. These are used -// as markers for the TableGen backend to handle different constraint kinds -// differently if needed. Constraints not deriving from the following subclasses -// are considered as uncategorized constraints. - -// Subclass for constraints on a type. -class TypeConstraint : - Constraint { - // The name of the C++ Type class if known, or Type if not. - string cppClassName = cppClassNameParam; -} - -// Base class for defining properties. -class Property { - // User-readable one line summary used in error reporting messages. If empty, - // a generic message will be used. - string summary = desc; - // The full description of this property. - string description = ""; - code storageType = storageTypeParam; - code interfaceType = storageTypeParam; - - // The expression to convert from the storage type to the Interface - // type. For example, an enum can be stored as an int but returned as an - // enum class. - // - // Format: - // - `$_storage` will contain the property in the storage type. - // - `$_ctxt` will contain an `MLIRContext *`. - code convertFromStorage = "$_storage"; - - // The call expression to build a property storage from the interface type. - // - // Format: - // - `$_storage` will contain the property in the storage type. - // - `$_value` will contain the property in the user interface type. - code assignToStorage = "$_storage = $_value"; - - // The call expression to convert from the storage type to an attribute. - // - // Format: - // - `$_storage` is the storage type value. - // - `$_ctxt` is a `MLIRContext *`. - // - // The expression must result in an Attribute. - code convertToAttribute = [{ - convertToAttribute($_ctxt, $_storage) - }]; - - // The call expression to convert from an Attribute to the storage type. - // - // Format: - // - `$_storage` is the storage type value. - // - `$_attr` is the attribute. - // - `$_diag` is an optional Diagnostic pointer to emit error. - // - // The expression must return a LogicalResult - code convertFromAttribute = [{ - return convertFromAttribute($_storage, $_attr, $_diag); - }]; - - // The call expression to hash the property. - // - // Format: - // - `$_storage` is the variable to hash. - // - // The expression should define a llvm::hash_code. - code hashProperty = [{ - llvm::hash_value($_storage); - }]; - - // The call expression to emit the storage type to bytecode. - // - // Format: - // - `$_storage` is the storage type value. - // - `$_writer` is a `DialectBytecodeWriter`. - // - `$_ctxt` is a `MLIRContext *`. - code writeToMlirBytecode = [{ - writeToMlirBytecode($_writer, $_storage) - }]; - - // The call expression to read the storage type from bytecode. - // - // Format: - // - `$_storage` is the storage type value. - // - `$_reader` is a `DialectBytecodeReader`. - // - `$_ctxt` is a `MLIRContext *`. - code readFromMlirBytecode = [{ - if (::mlir::failed(readFromMlirBytecode($_reader, $_storage))) - return ::mlir::failure(); - }]; - - // Default value for the property. - string defaultValue = ?; -} - -/// Implementation of the Property class's `readFromMlirBytecode` field using -/// the default `convertFromAttribute` implementation. -/// Users not wanting to implement their own `readFromMlirBytecode` and -/// `writeToMlirBytecode` implementations can opt into using this implementation -/// by writing: -/// -/// let writeToMlirBytecode = writeMlirBytecodeWithConvertToAttribute; -/// let readFromMlirBytecode = readMlirBytecodeUsingConvertFromAttribute; -/// -/// in their property definition. -/// Serialization and deserialization is performed using the attributes -/// returned by `convertFromAttribute` and `convertToAttribute`. -/// -/// WARNING: This implementation creates a less than optimal encoding. -/// Users caring about optimal encoding should not use this implementation and -/// implement `readFromMlirBytecode` and `writeToMlirBytecode` themselves. -defvar readMlirBytecodeUsingConvertFromAttribute = [{ - ::mlir::Attribute attr; - if (::mlir::failed($_reader.readAttribute(attr))) - return ::mlir::failure(); - if (::mlir::failed(convertFromAttribute($_storage, attr, nullptr))) - return ::mlir::failure(); -}]; - -/// Implementation of the Property class's `writeToMlirBytecode` field using -/// the default `convertToAttribute` implementation. -/// See description of `readMlirBytecodeUsingConvertFromAttribute` above for -/// details. -defvar writeMlirBytecodeWithConvertToAttribute = [{ - $_writer.writeAttribute(convertToAttribute($_ctxt, $_storage)) -}]; - -// Subclass for constraints on an attribute. -class AttrConstraint : - Constraint; - -// Subclass for constraints on a region. -class RegionConstraint : - Constraint; - -// Subclass for constraints on a successor. -class SuccessorConstraint : - Constraint; - -// How to use these constraint categories: -// -// * Use TypeConstraint to specify -// * Constraints on an op's operand/result definition -// * Further constraints to match an op's operand/result in source pattern -// -// * Use Attr (a subclass for AttrConstraint) for -// * Constraints on an op's attribute definition -// * Use AttrConstraint to specify -// * Further constraints to match an op's attribute in source pattern -// -// * Use uncategorized constraint to specify -// * Multi-entity constraints in rewrite rules - -//===----------------------------------------------------------------------===// -// Common predicates -//===----------------------------------------------------------------------===// - -// Whether a type is a VectorType. -// Explicitly disallow 0-D vectors for now until we have good enough coverage. -def IsVectorTypePred : And<[CPred<"::llvm::isa<::mlir::VectorType>($_self)">, - CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">]>; - -// Temporary vector type clone that allows gradual transition to 0-D vectors. -// TODO: Remove this when all ops support 0-D vectors. -def IsVectorOfAnyRankTypePred : CPred<"::llvm::isa<::mlir::VectorType>($_self)">; - -// Whether a type is a fixed-length VectorType. -def IsFixedVectorTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) && - !::llvm::cast($_self).isScalable()}]>; - -// Whether a type is a scalable VectorType. -def IsScalableVectorTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) && - ::llvm::cast($_self).isScalable()}]>; - -// Whether a type is a VectorType and all dimensions are scalable. -def allDimsScalableVectorTypePred : And<[ - IsVectorTypePred, - CPred<[{::llvm::cast<::mlir::VectorType>($_self).allDimsScalable()}]> -]>; - -// Whether a type is a TensorType. -def IsTensorTypePred : CPred<"::llvm::isa<::mlir::TensorType>($_self)">; - -// Whether a type is a MemRefType. -def IsMemRefTypePred : CPred<"::llvm::isa<::mlir::MemRefType>($_self)">; - -// Whether a type is an UnrankedMemRefType -def IsUnrankedMemRefTypePred - : CPred<"::llvm::isa<::mlir::UnrankedMemRefType>($_self)">; - -// Whether a type is an UnrankedTensorType -def IsUnrankedTensorTypePred - : CPred<"::llvm::isa<::mlir::UnrankedTensorType>($_self)">; - -// Whether a type is a RankedTensorType -def IsRankedTensorTypePred - : CPred<"::llvm::isa<::mlir::RankedTensorType>($_self)">; - -// Whether a type is a BaseMemRefType -def IsBaseMemRefTypePred - : CPred<"::llvm::isa<::mlir::BaseMemRefType>($_self)">; - -// Whether a type is a ShapedType. -def IsShapedTypePred : CPred<"::llvm::isa<::mlir::ShapedType>($_self)">; - -// For a ShapedType, verify that it has a static shape. -def HasStaticShapePred : - CPred<"::llvm::cast<::mlir::ShapedType>($_self).hasStaticShape()">; - -// Whether a type is a TupleType. -def IsTupleTypePred : CPred<"::llvm::isa<::mlir::TupleType>($_self)">; - -//===----------------------------------------------------------------------===// -// Type definitions -//===----------------------------------------------------------------------===// - -// A type, carries type constraints. -class Type : - TypeConstraint { - string description = ""; - string builderCall = ""; -} - -// Allows providing an alternative name and summary to an existing type def. -class TypeAlias : - Type { - let description = t.description; - let builderCall = t.builderCall; -} - -// A type of a specific dialect. -class DialectType : - Type { - Dialect dialect = d; -} - -// A variadic type constraint. It expands to zero or more of the base type. This -// class is used for supporting variadic operands/results. -class Variadic : TypeConstraint { - Type baseType = type; - int minSize = 0; -} - -// A nested variadic type constraint. It expands to zero or more variadic ranges -// of the base type. This class is used for supporting variadic operands and -// results. `variadicSegmentAttrName` should correspond to the name of an -// DenseI32ArrayAttr argument that provides the sizes of the inner variadic -// operand groups. -class VariadicOfVariadic - : Variadic { - string segmentAttrName = variadicSegmentAttrName; -} - -// An optional type constraint. It expands to either zero or one of the base -// type. This class is used for supporting optional operands/results. -class Optional : TypeConstraint { - Type baseType = type; -} - -// A type that can be constructed using MLIR::Builder. -// Note that this does not "inherit" from Type because it would require -// duplicating Type subclasses for buildable and non-buildable cases to avoid -// diamond "inheritance". -// TODO: we may extend this to a more general 'Buildable' trait, making some -// Types and some Attrs buildable. -class BuildableType { - // The builder call to invoke (if specified) to construct the BuildableType. - code builderCall = builder; -} - -// A type that's buildable iff the type passed as an argument is buildable. -// This is intended for use by types like container types, which are only -// buildable if the type of their elements is buildable. -class SameBuildabilityAs { - code builderCall = !if(!empty(type.builderCall), "", builder); -} - -// Any type at all. -def AnyType : Type, "any type">; - -// None type -def NoneType : Type($_self)">, "none type", - "::mlir::NoneType">, - BuildableType<"$_builder.getType<::mlir::NoneType>()">; - -// Any type from the given list -class AnyTypeOf allowedTypes, string summary = "", - string cppClassName = "::mlir::Type"> : Type< - // Satisfy any of the allowed types' conditions. - Or, - !if(!eq(summary, ""), - !interleave(!foreach(t, allowedTypes, t.summary), " or "), - summary), - cppClassName>; - -// A type that satisfies the constraints of all given types. -class AllOfType allowedTypes, string summary = "", - string cppClassName = "::mlir::Type"> : Type< - // Satisfy all of the allowedf types' conditions. - And, - !if(!eq(summary, ""), - !interleave(!foreach(t, allowedTypes, t.summary), " and "), - summary), - cppClassName>; - -// A type that satisfies additional predicates. -class ConfinedType predicates, string summary = "", - string cppClassName = type.cppClassName> : Type< - And, - summary, cppClassName>; - -// Integer types. - -// Any integer type irrespective of its width and signedness semantics. -def AnyInteger : Type($_self)">, "integer", - "::mlir::IntegerType">; - -// Any integer type (regardless of signedness semantics) of a specific width. -class AnyI - : Type, width # "-bit integer"> { - int bitwidth = width; -} - -class AnyIntOfWidths widths> : - AnyTypeOf), - !interleave(widths, "/") # "-bit integer", - "::mlir::IntegerType">; - -def AnyI1 : AnyI<1>; -def AnyI8 : AnyI<8>; -def AnyI16 : AnyI<16>; -def AnyI32 : AnyI<32>; -def AnyI64 : AnyI<64>; - -// Any signless integer type irrespective of its width. -def AnySignlessInteger : Type< - CPred<"$_self.isSignlessInteger()">, "signless integer", - "::mlir::IntegerType">; - -// Signless integer type of a specific width. -class I - : Type, - width # "-bit signless integer", "::mlir::IntegerType">, - BuildableType<"$_builder.getIntegerType(" # width # ")"> { - int bitwidth = width; -} - -class SignlessIntOfWidths widths> : - AnyTypeOf), - !interleave(widths, "/") # "-bit signless integer">; - -def I1 : I<1>; -def I8 : I<8>; -def I16 : I<16>; -def I32 : I<32>; -def I64 : I<64>; -def I128 : I<128>; - -// Any signed integer type irrespective of its width. -def AnySignedInteger : Type< - CPred<"$_self.isSignedInteger()">, "signed integer">; - -// Signed integer type of a specific width. -class SI - : Type, - width # "-bit signed integer", "::mlir::IntegerType">, - BuildableType< - "$_builder.getIntegerType(" # width # ", /*isSigned=*/true)"> { - int bitwidth = width; -} - -class SignedIntOfWidths widths> : - AnyTypeOf), - !interleave(widths, "/") # "-bit signed integer">; - -def SI1 : SI<1>; -def SI8 : SI<8>; -def SI16 : SI<16>; -def SI32 : SI<32>; -def SI64 : SI<64>; - -// Any unsigned integer type irrespective of its width. -def AnyUnsignedInteger : Type< - CPred<"$_self.isUnsignedInteger()">, "unsigned integer">; - -// Unsigned integer type of a specific width. -class UI - : Type, - width # "-bit unsigned integer", "::mlir::IntegerType">, - BuildableType< - "$_builder.getIntegerType(" # width # ", /*isSigned=*/false)"> { - int bitwidth = width; -} - -class UnsignedIntOfWidths widths> : - AnyTypeOf), - !interleave(widths, "/") # "-bit unsigned integer">; - -def UI1 : UI<1>; -def UI8 : UI<8>; -def UI16 : UI<16>; -def UI32 : UI<32>; -def UI64 : UI<64>; - -// Index type. -def Index : Type($_self)">, "index", - "::mlir::IndexType">, - BuildableType<"$_builder.getIndexType()">; - -// Any signless integer type or index type. -def AnySignlessIntegerOrIndex : Type, - "signless integer or index">; - -// Floating point types. - -// Any float type irrespective of its width. -def AnyFloat : Type($_self)">, "floating-point", - "::mlir::FloatType">; - -// Float type of a specific width. -class F - : Type, - width # "-bit float", "::mlir::FloatType">, - BuildableType<"$_builder.getF" # width # "Type()"> { - int bitwidth = width; -} - -class FloatOfWidths widths> : - AnyTypeOf), - !interleave(widths, "/") # "-bit float">; - -def F16 : F<16>; -def F32 : F<32>; -def F64 : F<64>; -def F80 : F<80>; -def F128 : F<128>; - -def BF16 : Type, "bfloat16 type">, - BuildableType<"$_builder.getBF16Type()">; -def TF32 : Type, "tf32 type">, - BuildableType<"$_builder.getTF32Type()">; -def F8E4M3FN : Type, "f8E4M3FN type">, - BuildableType<"$_builder.getFloat8E4M3FNType()">; -def F8E5M2 : Type, "f8E5M2 type">, - BuildableType<"$_builder.getFloat8E5M2Type()">; -def F8E4M3FNUZ : Type, "f8E4M3FNUZ type">, - BuildableType<"$_builder.getFloat8E4M3FNUZType()">; -def F8E4M3B11FNUZ : Type, "f8E4M3B11FNUZ type">, - BuildableType<"$_builder.getFloat8E4M3B11FNUZType()">; -def F8E5M2FNUZ : Type, "f8E5M2FNUZ type">, - BuildableType<"$_builder.getFloat8E5M2FNUZType()">; - -def AnyComplex : Type($_self)">, - "complex-type", "::mlir::ComplexType">; - -class Complex - : ConfinedType($_self).getElementType()", - type.predicate>], - "complex type with " # type.summary # " elements", - "::mlir::ComplexType">, - SameBuildabilityAs { - Type elementType = type; -} - -class OpaqueType - : Type, - summary, "::mlir::OpaqueType">, - BuildableType<"::mlir::OpaqueType::get(" - "$_builder.getStringAttr(\"" # dialect # "\"), \"" - # name # "\")">; - -// Function Type - -// Any function type. -def FunctionType : Type($_self)">, - "function type", "::mlir::FunctionType">; - -// A container type is a type that has another type embedded within it. -class ContainerType : - // First, check the container predicate. Then, substitute the extracted - // element into the element type checker. - Type(elementTypeCall), - etype.predicate>]>, - descr # " of " # etype.summary # " values", cppClassName>; - -class ShapedContainerType allowedTypes, - Pred containerPred, string descr, - string cppClassName = "::mlir::Type"> : - Type.predicate>, - "; }(::llvm::cast<::mlir::ShapedType>($_self).getElementType())">]>, - descr # " of " # AnyTypeOf.summary # " values", cppClassName>; - -// Whether a shaped type is ranked. -def HasRankPred : CPred<"::llvm::cast<::mlir::ShapedType>($_self).hasRank()">; - -// Whether a shaped type has one of the specified ranks. -class HasAnyRankOfPred ranks> : And<[ - HasRankPred, - Or($_self).getRank() - == }] - # rank>)>]>; - -// Whether a shaped type has a rank greater than or equal of the specified rank. -class HasRankGreaterOrEqualPred : And<[ - HasRankPred, - CPred<[{::llvm::cast<::mlir::ShapedType>($_self).getRank() >= }] # rank> -]>; - -// Vector types. - -class VectorOf allowedTypes> : - ShapedContainerType; - -// Temporary vector type clone that allows gradual transition to 0-D vectors. -// TODO: Remove this when all ops support 0-D vectors. -class VectorOfAnyRankOf allowedTypes> : - ShapedContainerType; - -class FixedVectorOf allowedTypes> : - ShapedContainerType; - -class ScalableVectorOf allowedTypes> : - ShapedContainerType; - -// Whether the number of elements of a vector is from the given -// `allowedRanks` list -class IsVectorOfRankPred allowedRanks> : - And<[IsVectorTypePred, - Or($_self).getRank() - == }] - # allowedlength>)>]>; - -// Whether the number of elements of a fixed-length vector is from the given -// `allowedRanks` list -class IsFixedVectorOfRankPred allowedRanks> : - And<[IsFixedVectorTypePred, - Or($_self).getRank() - == }] - # allowedlength>)>]>; - -// Whether the number of elements of a scalable vector is from the given -// `allowedRanks` list -class IsScalableVectorOfRankPred allowedRanks> : - And<[IsScalableVectorTypePred, - Or($_self).getRank() - == }] - # allowedlength>)>]>; - -// Any vector where the rank is from the given `allowedRanks` list -class VectorOfRank allowedRanks> : Type< - IsVectorOfRankPred, - " of ranks " # !interleave(allowedRanks, "/"), "::mlir::VectorType">; - -// Any fixed-length vector where the rank is from the given `allowedRanks` list -class FixedVectorOfRank allowedRanks> : Type< - IsFixedVectorOfRankPred, - " of ranks " # !interleave(allowedRanks, "/"), "::mlir::VectorType">; - -// Any scalable vector where the rank is from the given `allowedRanks` list -class ScalableVectorOfRank allowedRanks> : Type< - IsScalableVectorOfRankPred, - " of ranks " # !interleave(allowedRanks, "/"), "::mlir::VectorType">; - -// Any vector where the rank is from the given `allowedRanks` list and the type -// is from the given `allowedTypes` list -class VectorOfRankAndType allowedRanks, - list allowedTypes> : AllOfType< - [VectorOf, VectorOfRank], - VectorOf.summary # VectorOfRank.summary, - "::mlir::VectorType">; - -// Whether the number of elements of a vector is from the given -// `allowedLengths` list -class IsVectorOfLengthPred allowedLengths> : - And<[IsVectorTypePred, - Or($_self).getNumElements() - == }] - # allowedlength>)>]>; - -// Whether the number of elements of a fixed-length vector is from the given -// `allowedLengths` list -class IsFixedVectorOfLengthPred allowedLengths> : - And<[IsFixedVectorTypePred, - Or($_self).getNumElements() - == }] - # allowedlength>)>]>; - -// Whether the number of elements of a scalable vector is from the given -// `allowedLengths` list -class IsScalableVectorOfLengthPred allowedLengths> : - And<[IsScalableVectorTypePred, - Or($_self).getNumElements() - == }] - # allowedlength>)>]>; - -// Whether the shape of a vector matches the given `shape` list. -class IsVectorOfShape shape> - : CPred<"::llvm::cast<::mlir::VectorType>($_self).getShape() == ArrayRef({" # !interleave(shape, ", ") # "})">; - -// Any vector where the number of elements is from the given -// `allowedLengths` list -class VectorOfLength allowedLengths> : Type< - IsVectorOfLengthPred, - " of length " # !interleave(allowedLengths, "/"), - "::mlir::VectorType">; - -// Any fixed-length vector where the number of elements is from the given -// `allowedLengths` list -class FixedVectorOfLength allowedLengths> : Type< - IsFixedVectorOfLengthPred, - " of length " # !interleave(allowedLengths, "/"), - "::mlir::VectorType">; - -// Any scalable vector where the number of elements is from the given -// `allowedLengths` list -class ScalableVectorOfLength allowedLengths> : Type< - IsScalableVectorOfLengthPred, - " of length " # !interleave(allowedLengths, "/"), - "::mlir::VectorType">; - -// Any vector where the number of elements is from the given -// `allowedLengths` list and the type is from the given `allowedTypes` -// list -class VectorOfLengthAndType allowedLengths, - list allowedTypes> : AllOfType< - [VectorOf, VectorOfLength], - VectorOf.summary # VectorOfLength.summary, - "::mlir::VectorType">; - -// Any fixed-length vector where the number of elements is from the given -// `allowedLengths` list and the type is from the given `allowedTypes` list -class FixedVectorOfLengthAndType allowedLengths, - list allowedTypes> : AllOfType< - [FixedVectorOf, FixedVectorOfLength], - FixedVectorOf.summary # - FixedVectorOfLength.summary, - "::mlir::VectorType">; - -// Any scalable vector where the number of elements is from the given -// `allowedLengths` list and the type is from the given `allowedTypes` list -class ScalableVectorOfLengthAndType allowedLengths, - list allowedTypes> : AllOfType< - [ScalableVectorOf, ScalableVectorOfLength], - ScalableVectorOf.summary # - ScalableVectorOfLength.summary, - "::mlir::VectorType">; - -def AnyVector : VectorOf<[AnyType]>; -// Temporary vector type clone that allows gradual transition to 0-D vectors. -def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>; - -def AnyFixedVector : FixedVectorOf<[AnyType]>; - -def AnyScalableVector : ScalableVectorOf<[AnyType]>; - -// Shaped types. - -def AnyShaped: ShapedContainerType<[AnyType], IsShapedTypePred, "shaped", - "::mlir::ShapedType">; - -//===----------------------------------------------------------------------===// -// Tensor types. - -// Unranked tensor type whose element type is from the given `allowedTypes` -// list, and which additionally satisfies an optional list of predicates. -class UnrankedTensorOf allowedTypes, list preds = [], - string summary = "unranked tensor"> - : ShapedContainerType< - allowedTypes, And, - summary, "::mlir::UnrankedTensorType">; - -// Ranked tensor type whose element type is from the given `allowedTypes` list, -// and which additionally satisfies an optional list of predicates. -class RankedTensorOf allowedTypes, list preds = [], - string summary = "ranked tensor"> - : ShapedContainerType< - allowedTypes, And, - summary, "::mlir::RankedTensorType">; - -// Any tensor type whose element type is from the given `allowedTypes` -// list, and which additionally satisfies an optional list of predicates. -// -// TODO: use `Constraint` instead of `Pred`, so we can generate a better -// default summary (a la `ConfinedAttr`). -class TensorOf< - list allowedTypes, - list preds = [], - string summary = "tensor"> - : ShapedContainerType, - summary, "::mlir::TensorType">; - -def AnyTensor : TensorOf<[AnyType]>; - -def I1Tensor : TensorOf<[I1]>; -def I8Tensor : TensorOf<[I8]>; -def I16Tensor : TensorOf<[I16]>; -def I32Tensor : TensorOf<[I32]>; -def I64Tensor : TensorOf<[I64]>; -def IndexTensor: TensorOf<[Index]>; - -def BF16Tensor : TensorOf<[BF16]>; -def F16Tensor : TensorOf<[F16]>; -def F32Tensor : TensorOf<[F32]>; -def F64Tensor : TensorOf<[F64]>; - -class Non0RankedTensorOf allowedTypes> - : TensorOf], - "non-0-ranked.tensor">; - -def AnyRankedTensor : RankedTensorOf<[AnyType]>; -def AnyNon0RankedTensor : Non0RankedTensorOf<[AnyType]>; -def AnyUnrankedTensor : UnrankedTensorOf<[AnyType]>; - -def AnyNon0RankedOrUnrankedTensor - : AnyTypeOf<[AnyUnrankedTensor, AnyNon0RankedTensor], - "non-0-ranked or unranked tensor", "::mlir::TensorType">; - -// Ranked tensor type with one of the specified types and ranks. -class TensorRankOf allowedTypes, list ranks> - : RankedTensorOf], - !interleave(!foreach(rank, ranks, rank # "D"), "/") # " tensor">; - -class 0DTensorOf allowedTypes> : TensorRankOf; -class 1DTensorOf allowedTypes> : TensorRankOf; -class 2DTensorOf allowedTypes> : TensorRankOf; -class 3DTensorOf allowedTypes> : TensorRankOf; -class 4DTensorOf allowedTypes> : TensorRankOf; - -class StaticShapeTensorOf allowedTypes> - : RankedTensorOf; - -def AnyStaticShapeTensor : StaticShapeTensorOf<[AnyType]>; - -//===----------------------------------------------------------------------===// -// Memref type. - -// Any unranked memref whose element type is from the given `allowedTypes` list. -class UnrankedMemRefOf allowedTypes> : - ShapedContainerType; - -def AnyUnrankedMemRef : UnrankedMemRefOf<[AnyType]>; - -// Any ranked memref whose element type is from the given `allowedTypes` list. -class MemRefOf allowedTypes> : - ShapedContainerType; - -class Non0RankedMemRefOf allowedTypes> : - ConfinedType, [HasRankGreaterOrEqualPred<1>], - "non-0-ranked." # MemRefOf.summary, - "::mlir::MemRefType">; - -def AnyMemRef : MemRefOf<[AnyType]>; -def AnyNon0RankedMemRef : Non0RankedMemRefOf<[AnyType]>; - -// Any memref (ranked or unranked) whose element type is from the given -// `allowedTypes` list, and which additionally satisfies an optional list of -// predicates. -class RankedOrUnrankedMemRefOf< - list allowedTypes, - list preds = [], - string summary = "ranked or unranked memref"> - : ShapedContainerType, - summary, "::mlir::BaseMemRefType">; - -def AnyRankedOrUnrankedMemRef : RankedOrUnrankedMemRefOf<[AnyType]>; -def AnyNon0RankedOrUnrankedMemRef: - AnyTypeOf<[AnyUnrankedMemRef, AnyNon0RankedMemRef]>; - -// Memref declarations handle any memref, independent of rank, size, (static or -// dynamic), layout, or memory space. -def I1MemRef : MemRefOf<[I1]>; -def I8MemRef : MemRefOf<[I8]>; -def I16MemRef : MemRefOf<[I16]>; -def I32MemRef : MemRefOf<[I32]>; -def I64MemRef : MemRefOf<[I64]>; - -def BF16MemRef : MemRefOf<[BF16]>; -def F16MemRef : MemRefOf<[F16]>; -def F32MemRef : MemRefOf<[F32]>; -def F64MemRef : MemRefOf<[F64]>; - -// TODO: Have an easy way to add another constraint to a type. -class MemRefRankOf allowedTypes, list ranks> : - ConfinedType, [HasAnyRankOfPred], - !interleave(!foreach(rank, ranks, rank # "D"), "/") # " " # - MemRefOf.summary, - "::mlir::MemRefType">; - -class StaticShapeMemRefOf allowedTypes> : - ConfinedType, [HasStaticShapePred], - "statically shaped " # MemRefOf.summary, - "::mlir::MemRefType">; - -def AnyStaticShapeMemRef : StaticShapeMemRefOf<[AnyType]>; - -// For a MemRefType, verify that it has strides. -def HasStridesPred : CPred<[{ isStrided(::llvm::cast<::mlir::MemRefType>($_self)) }]>; - -class StridedMemRefOf allowedTypes> : - ConfinedType, [HasStridesPred], - "strided " # MemRefOf.summary>; - -def AnyStridedMemRef : StridedMemRefOf<[AnyType]>; - -class AnyStridedMemRefOfRank : - AllOfType<[AnyStridedMemRef, MemRefRankOf<[AnyType], [rank]>], - AnyStridedMemRef.summary # " of rank " # rank>; - -class StridedMemRefRankOf allowedTypes, list ranks> : - ConfinedType, [HasAnyRankOfPred], - !interleave(!foreach(rank, ranks, rank # "D"), "/") # " " # - MemRefOf.summary>; - -// This represents a generic tuple without any constraints on element type. -def AnyTuple : Type; - -// A container type that has other types embedded in it, but (unlike -// ContainerType) can hold elements with a mix of types. Requires a call that -// produces a list of all elements' types. -class MixedContainerType : - Type< - And<[ - containerPred, - Concat< - "::llvm::all_of(" # elementTypesCall # ", [](::mlir::Type t) { " - "return t && (", - SubstLeaves<"$_self", "t", etype.predicate>, - "); })" - > - ]>, - descr # " with any combination of " # etype.summary # " values"> { - // The type of elements in the container. - Type elementType = etype; - - // Call to retrieve. - code getElementTypesCall = elementTypesCall; -} - -// A Tuple that holds a mix of elements of the allowed types. -class TupleOf allowedTypes> - : MixedContainerType, IsTupleTypePred, - "::llvm::cast<::mlir::TupleType>($_self).getTypes()", - "tuple">; - -// A Tuple with arbitrary nesting, where all elements are a mix of the allowed -// types. -class NestedTupleOf allowedTypes> : - MixedContainerType, IsTupleTypePred, - "getFlattenedTypes(::llvm::cast<::mlir::TupleType>($_self))", - "nested tuple">; - -//===----------------------------------------------------------------------===// -// Common type constraints -//===----------------------------------------------------------------------===// -// Type constraint for types that are "like" some type or set of types T, that is -// they're either a T, a vector of Ts, or a tensor of Ts -class TypeOrContainer : TypeConstraint.predicate, - TensorOf<[allowedType]>.predicate]>, - name>; - -// Temporary constraint to allow gradual transition to supporting 0-D vectors. -// TODO: Remove this when all ops support 0-D vectors. -class TypeOrContainerOfAnyRank : TypeConstraint.predicate, - TensorOf<[allowedType]>.predicate]>, - name>; - - -// Type constraint for bool-like types: bools, vectors of bools, tensors of -// bools. -def BoolLike : TypeOrContainer; - -def BoolLikeOfAnyRank : TypeOrContainerOfAnyRank; - -// Type constraint for signless-integer-like types: signless integers, indices, -// vectors of signless integers or indices, tensors of signless integers. -def SignlessIntegerLike : TypeOrContainer; - -def SignlessIntegerLikeOfAnyRank : TypeOrContainerOfAnyRank< - AnySignlessIntegerOrIndex, - "signless-integer-like">; - -// Type constraint for float-like types: floats, vectors or tensors thereof. -def FloatLike : TypeOrContainer; - -// Type constraint for signless-integer-like or float-like types. -def SignlessIntegerOrFloatLike : TypeConstraint, - "signless-integer-like or floating-point-like">; - -//===----------------------------------------------------------------------===// -// Attribute definitions -//===----------------------------------------------------------------------===// - -//===----------------------------------------------------------------------===// -// Base attribute definition - -// Base class for all attributes. -class Attr : - AttrConstraint { - code storageType = ?; // The backing mlir::Attribute type - code returnType = ?; // The underlying C++ value type - - // The call expression to convert from the storage type to the return - // type. For example, an enum can be stored as an int but returned as an - // enum class. - // - // Format: $_self will be expanded to the attribute. - // - // For example, `$_self.getValue().getSExtValue()` for `IntegerAttr val` will - // expand to `getAttrOfType("val").getValue().getSExtValue()`. - code convertFromStorage = "$_self.getValue()"; - - // The call expression to build an attribute from a constant value. - // - // Format: $0 will be expanded to the constant value of the attribute. - // - // For example, `$_builder.getStringAttr("$0")` for `StringAttr:"foo"` will - // expand to `builder.getStringAttr("foo")`. - string constBuilderCall = ?; - - // Default value for attribute. - // Requires a constBuilderCall defined. - string defaultValue = ?; - - // The value type of this attribute. This corresponds to the mlir::Type that - // this attribute returns via `getType()`. - Type valueType = ?; - - // Whether the attribute is optional. Typically requires a custom - // convertFromStorage method to handle the case where the attribute is - // not present. - bit isOptional = 0; - - // What is the base-level Attr instantiation that this Attr is built upon. - // Unset means this is a base-level Attr. - // - // This field is used by attribute wrapper classes (DefaultValuedAttr, - // OptionalAttr, etc.) to retrieve the base-level attribute definition. - // This can be used for getting its name; otherwise, we will see - // "anonymous_" as the attribute def name because of template - // instantiation. - // TOOD(b/132458159): deduplicate the fields in attribute wrapper classes. - Attr baseAttr = ?; - - // The fully-qualified C++ namespace where the generated class lives. - string cppNamespace = ""; - - // The full description of this attribute. - string description = ""; -} - -// An attribute of a specific dialect. -class DialectAttr : - Attr { - Dialect dialect = d; - let cppNamespace = d.cppNamespace; -} - -//===----------------------------------------------------------------------===// -// Attribute modifier definition - -// Decorates an attribute to have an (unvalidated) default value if not present. -class DefaultValuedAttr : - Attr { - // Construct this attribute with the input attribute and change only - // the default value. - // Note: this has to be kept up to date with Attr above. - let storageType = attr.storageType; - let returnType = attr.returnType; - let convertFromStorage = attr.convertFromStorage; - let constBuilderCall = attr.constBuilderCall; - let defaultValue = val; - let valueType = attr.valueType; - - let baseAttr = attr; -} - -// Decorates an optional attribute to have an (unvalidated) default value -// return by ODS generated accessors if not present. -class DefaultValuedOptionalAttr : - Attr { - // Construct this attribute with the input attribute and change only - // the default value. - // Note: this has to be kept up to date with Attr above. - let storageType = attr.storageType; - let returnType = attr.returnType; - let convertFromStorage = attr.convertFromStorage; - let constBuilderCall = attr.constBuilderCall; - let defaultValue = val; - let valueType = attr.valueType; - let isOptional = 1; - - let baseAttr = attr; -} - -// Decorates an attribute as optional. The return type of the generated -// attribute accessor method will be Optional<>. -class OptionalAttr : Attr { - // Rewrite the attribute to be optional. - // Note: this has to be kept up to date with Attr above. - let storageType = attr.storageType; - let returnType = "::std::optional<" # attr.returnType #">"; - let convertFromStorage = "$_self ? " # returnType # "(" # - attr.convertFromStorage # ") : (::std::nullopt)"; - let valueType = attr.valueType; - let isOptional = 1; - - let baseAttr = attr; -} - -// Default-valued string-based attribute. Wraps the default value in escaped -// quotes. -class DefaultValuedStrAttr - : DefaultValuedAttr; -class DefaultValuedOptionalStrAttr - : DefaultValuedOptionalAttr; - -//===----------------------------------------------------------------------===// -// Primitive property kinds - -// Any kind of integer stored as properties. -class IntProperty : - Property { - code writeToMlirBytecode = [{ - $_writer.writeVarInt($_storage); - }]; - code readFromMlirBytecode = [{ - uint64_t val; - if (failed($_reader.readVarInt(val))) - return ::mlir::failure(); - $_storage = val; - }]; -} - -class ArrayProperty : - Property { - let interfaceType = "::llvm::ArrayRef<" # storageTypeParam # ">"; - let convertFromStorage = "$_storage"; - let assignToStorage = "::llvm::copy($_value, $_storage)"; -} - -//===----------------------------------------------------------------------===// -// Primitive attribute kinds - -// A generic attribute that must be constructed around a specific buildable type -// `attrValType`. Backed by MLIR attribute kind `attrKind`. -class TypedAttrBase : - Attr { - let constBuilderCall = "$_builder.get" # attrKind # "(" # - attrValType.builderCall # ", $0)"; - let storageType = "::mlir::" # attrKind; - let valueType = attrValType; -} - -// Any attribute. -def AnyAttr : Attr, "any attribute"> { - let storageType = "::mlir::Attribute"; - let returnType = "::mlir::Attribute"; - let convertFromStorage = "$_self"; - let constBuilderCall = "$0"; -} - -// Any attribute from the given list -class AnyAttrOf allowedAttrs, string summary = "", - string cppClassName = "::mlir::Attribute", - string fromStorage = "$_self"> : Attr< - // Satisfy any of the allowed attribute's condition - Or, - !if(!eq(summary, ""), - !interleave(!foreach(t, allowedAttrs, t.summary), " or "), - summary)> { - let returnType = cppClassName; - let convertFromStorage = fromStorage; -} - -def LocationAttr : Attr($_self)">, - "location attribute">; - -def BoolAttr : Attr($_self)">, "bool attribute"> { - let storageType = [{ ::mlir::BoolAttr }]; - let returnType = [{ bool }]; - let valueType = I1; - let constBuilderCall = "$_builder.getBoolAttr($0)"; -} - -// Index attribute. -def IndexAttr : - TypedAttrBase< - Index, "IntegerAttr", - And<[CPred<"::llvm::isa<::mlir::IntegerAttr>($_self)">, - CPred<"::llvm::isa<::mlir::IndexType>(::llvm::cast<::mlir::IntegerAttr>($_self).getType())">]>, - "index attribute"> { - let returnType = [{ ::llvm::APInt }]; -} - -// Base class for any integer (regardless of signedness semantics) attributes -// of fixed width. -class AnyIntegerAttrBase : - TypedAttrBase< - attrValType, "IntegerAttr", - And<[CPred<"::llvm::isa<::mlir::IntegerAttr>($_self)">, - CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getType()." - "isInteger(" # attrValType.bitwidth # ")">]>, - descr> { - let returnType = [{ ::llvm::APInt }]; - let constBuilderCall = ?; -} - -def AnyI1Attr : AnyIntegerAttrBase; -def AnyI8Attr : AnyIntegerAttrBase; -def AnyI16Attr : AnyIntegerAttrBase; -def AnyI32Attr : AnyIntegerAttrBase; -def AnyI64Attr : AnyIntegerAttrBase; - -def APIntAttr : Attr($_self)">, - "arbitrary integer attribute"> { - let storageType = [{ ::mlir::IntegerAttr }]; - let returnType = [{ ::mlir::APInt }]; -} - -// Base class for signless integer attributes of fixed width. -class SignlessIntegerAttrBase : - TypedAttrBase< - attrValType, "IntegerAttr", - And<[CPred<"::llvm::isa<::mlir::IntegerAttr>($_self)">, - CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getType()." - "isSignlessInteger(" # attrValType.bitwidth # ")">]>, - descr> { - let returnType = [{ ::llvm::APInt }]; -} -// Base class for signless integer attributes of fixed width that have a -// corresponding C++ type. -class TypedSignlessIntegerAttrBase - : SignlessIntegerAttrBase { - let returnType = retType; - let convertFromStorage = "$_self.getValue().getZExtValue()"; -} - -def I1Attr : TypedSignlessIntegerAttrBase< - I1, "bool", "1-bit signless integer attribute">; -def I8Attr : TypedSignlessIntegerAttrBase< - I8, "uint8_t", "8-bit signless integer attribute">; -def I16Attr : TypedSignlessIntegerAttrBase< - I16, "uint16_t", "16-bit signless integer attribute">; -def I32Attr : TypedSignlessIntegerAttrBase< - I32, "uint32_t", "32-bit signless integer attribute">; -def I64Attr : TypedSignlessIntegerAttrBase< - I64, "uint64_t", "64-bit signless integer attribute">; - -// Base class for signed integer attributes of fixed width. -class SignedIntegerAttrBase : - TypedAttrBase< - attrValType, "IntegerAttr", - And<[CPred<"::llvm::isa<::mlir::IntegerAttr>($_self)">, - CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getType()." - "isSignedInteger(" # attrValType.bitwidth # ")">]>, - descr> { - let returnType = [{ ::llvm::APInt }]; -} -// Base class for signed integer attributes of fixed width that have a -// corresponding C++ type. -class TypedSignedIntegerAttrBase - : SignedIntegerAttrBase { - let returnType = retType; - let convertFromStorage = "$_self.getValue().getSExtValue()"; -} - -def SI1Attr : TypedSignedIntegerAttrBase< - SI1, "bool", "1-bit signed integer attribute">; -def SI8Attr : TypedSignedIntegerAttrBase< - SI8, "int8_t", "8-bit signed integer attribute">; -def SI16Attr : TypedSignedIntegerAttrBase< - SI16, "int16_t", "16-bit signed integer attribute">; -def SI32Attr : TypedSignedIntegerAttrBase< - SI32, "int32_t", "32-bit signed integer attribute">; -def SI64Attr : TypedSignedIntegerAttrBase< - SI64, "int64_t", "64-bit signed integer attribute">; - -// Base class for unsigned integer attributes of fixed width. -class UnsignedIntegerAttrBase : - TypedAttrBase< - attrValType, "IntegerAttr", - And<[CPred<"::llvm::isa<::mlir::IntegerAttr>($_self)">, - CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getType()." - "isUnsignedInteger(" # attrValType.bitwidth # ")">]>, - descr> { - let returnType = [{ ::llvm::APInt }]; -} -// Base class for unsigned integer attributes of fixed width that have a -// corresponding C++ type. -class TypedUnsignedIntegerAttrBase - : UnsignedIntegerAttrBase { - let returnType = retType; - let convertFromStorage = "$_self.getValue().getZExtValue()"; -} - -def UI1Attr : TypedUnsignedIntegerAttrBase< - UI1, "bool", "1-bit unsigned integer attribute">; -def UI8Attr : TypedUnsignedIntegerAttrBase< - UI8, "uint8_t", "8-bit unsigned integer attribute">; -def UI16Attr : TypedUnsignedIntegerAttrBase< - UI16, "uint16_t", "16-bit unsigned integer attribute">; -def UI32Attr : TypedUnsignedIntegerAttrBase< - UI32, "uint32_t", "32-bit unsigned integer attribute">; -def UI64Attr : TypedUnsignedIntegerAttrBase< - UI64, "uint64_t", "64-bit unsigned integer attribute">; - -// Base class for float attributes of fixed width. -class FloatAttrBase : - TypedAttrBase($_self)">, - CPred<"::llvm::cast<::mlir::FloatAttr>($_self).getType().isF" # - attrValType.bitwidth # "()">]>, - descr> { - let returnType = [{ ::llvm::APFloat }]; -} - -def F32Attr : FloatAttrBase; -def F64Attr : FloatAttrBase; - -// An attribute backed by a string type. -class StringBasedAttr : Attr { - let constBuilderCall = "$_builder.getStringAttr($0)"; - let storageType = [{ ::mlir::StringAttr }]; - let returnType = [{ ::llvm::StringRef }]; - let valueType = NoneType; -} - -def StrAttr : StringBasedAttr($_self)">, - "string attribute">; - -// A string attribute that represents the name of a symbol. -def SymbolNameAttr : StringBasedAttr($_self)">, - "string attribute">; - -// String attribute that has a specific value type. -class TypedStrAttr - : StringBasedAttr($_self)">, - "string attribute"> { - let valueType = ty; -} - -// Base class for attributes containing types. Example: -// def IntTypeAttr : TypeAttrBase<"IntegerType", "integer type attribute"> -// defines a type attribute containing an integer type. -class TypeAttrBase> : - Attr($_self)">, - CPred<"::llvm::isa<" # retType # ">(::llvm::cast<::mlir::TypeAttr>($_self).getValue())">, - SubstLeaves<"$_self", - "::llvm::cast<::mlir::TypeAttr>($_self).getValue()", typePred>]>, - summary> { - let storageType = [{ ::mlir::TypeAttr }]; - let returnType = retType; - let valueType = NoneType; - let convertFromStorage = "::llvm::cast<" # retType # ">($_self.getValue())"; -} - -def TypeAttr : TypeAttrBase<"::mlir::Type", "any type attribute"> { - let constBuilderCall = "::mlir::TypeAttr::get($0)"; -} - -class TypeAttrOf - : TypeAttrBase { - let constBuilderCall = "::mlir::TypeAttr::get($0)"; -} - -// The mere presence of unit attributes has a meaning. Therefore, unit -// attributes are always treated as optional and accessors to them return -// "true" if the attribute is present and "false" otherwise. -def UnitAttr : Attr($_self)">, "unit attribute"> { - let storageType = [{ ::mlir::UnitAttr }]; - let constBuilderCall = "(($0) ? $_builder.getUnitAttr() : nullptr)"; - let convertFromStorage = "$_self != nullptr"; - let returnType = "bool"; - let defaultValue = "false"; - let valueType = NoneType; - let isOptional = 1; -} - -//===----------------------------------------------------------------------===// -// Composite attribute kinds - -class DictionaryAttrBase : - Attr { - let storageType = [{ ::mlir::DictionaryAttr }]; - let constBuilderCall = "$_builder.getDictionaryAttr($0)"; - let returnType = [{ ::mlir::DictionaryAttr }]; - let valueType = NoneType; - let convertFromStorage = "$_self"; -} - -def DictionaryAttr - : DictionaryAttrBase($_self)">, - "dictionary of named attribute values">; - -class ElementsAttrBase : - Attr { - let storageType = [{ ::mlir::ElementsAttr }]; - let returnType = [{ ::mlir::ElementsAttr }]; - let convertFromStorage = "$_self"; -} - -def ElementsAttr : ElementsAttrBase($_self)">, - "constant vector/tensor attribute">; - -class IntElementsAttrBase : - ElementsAttrBase($_self)">, - condition]>, - summary> { - let storageType = [{ ::mlir::DenseIntElementsAttr }]; - let returnType = [{ ::mlir::DenseIntElementsAttr }]; - - let convertFromStorage = "$_self"; -} - -class DenseArrayAttrBase : - ElementsAttrBase($_self)">, - summaryName # " dense array attribute"> { - let storageType = "::mlir::" # denseAttrName; - let returnType = "::llvm::ArrayRef<" # cppType # ">"; - let constBuilderCall = "$_builder.get" # denseAttrName # "($0)"; -} -def DenseBoolArrayAttr : DenseArrayAttrBase<"DenseBoolArrayAttr", "bool", "i1">; -def DenseI8ArrayAttr : DenseArrayAttrBase<"DenseI8ArrayAttr", "int8_t", "i8">; -def DenseI16ArrayAttr : DenseArrayAttrBase<"DenseI16ArrayAttr", "int16_t", "i16">; -def DenseI32ArrayAttr : DenseArrayAttrBase<"DenseI32ArrayAttr", "int32_t", "i32">; -def DenseI64ArrayAttr : DenseArrayAttrBase<"DenseI64ArrayAttr", "int64_t", "i64">; -def DenseF32ArrayAttr : DenseArrayAttrBase<"DenseF32ArrayAttr", "float", "f32">; -def DenseF64ArrayAttr : DenseArrayAttrBase<"DenseF64ArrayAttr", "double", "f64">; - -def IndexElementsAttr - : IntElementsAttrBase($_self) - .getType() - .getElementType() - .isIndex()}]>, - "index elements attribute">; - -def AnyIntElementsAttr : IntElementsAttrBase, "integer elements attribute">; - -class IntElementsAttrOf : IntElementsAttrBase< - CPred<"::llvm::cast<::mlir::DenseIntElementsAttr>($_self).getType()." - "getElementType().isInteger(" # width # ")">, - width # "-bit integer elements attribute">; - -def AnyI32ElementsAttr : IntElementsAttrOf<32>; -def AnyI64ElementsAttr : IntElementsAttrOf<64>; - -class SignlessIntElementsAttr : IntElementsAttrBase< - CPred<"::llvm::cast<::mlir::DenseIntElementsAttr>($_self).getType()." - "getElementType().isSignlessInteger(" # width # ")">, - width # "-bit signless integer elements attribute"> { - - // Note that this is only constructing scalar elements attribute. - let constBuilderCall = "::llvm::cast<::mlir::DenseIntElementsAttr>(" - "::mlir::DenseElementsAttr::get(" - "::mlir::RankedTensorType::get({}, $_builder.getIntegerType(" # width # ")), " - "::llvm::ArrayRef($0)))"; -} - -def I32ElementsAttr : SignlessIntElementsAttr<32>; -def I64ElementsAttr : SignlessIntElementsAttr<64>; - -// A `width`-bit signless integer elements attribute. The attribute should be -// ranked and has a shape as specified in `dims`. -class RankedSignlessIntElementsAttr dims> : - SignlessIntElementsAttr { - // Check that this has the specified shape. - let predicate = And<[ - SignlessIntElementsAttr.predicate, - CPred<"::llvm::cast<::mlir::DenseIntElementsAttr>($_self).getType().getShape() == " - "::mlir::ArrayRef({" # !interleave(dims, ", ") # "})">]>; - - let summary = width # "-bit signless int elements attribute of shape [" # - !interleave(dims, ", ") # "]"; - - let constBuilderCall = "::mlir::DenseIntElementsAttr::get(" - "::mlir::RankedTensorType::get({" # !interleave(dims, ", ") # - "}, $_builder.getIntegerType(" # width # ")), ::llvm::ArrayRef($0))"; -} - -class RankedI32ElementsAttr dims> : - RankedSignlessIntElementsAttr<32, dims>; -class RankedI64ElementsAttr dims> : - RankedSignlessIntElementsAttr<64, dims>; - -class FloatElementsAttr : ElementsAttrBase< - CPred<"::llvm::isa<::mlir::DenseFPElementsAttr>($_self) &&" - "::llvm::cast<::mlir::DenseElementsAttr>($_self).getType()." - "getElementType().isF" # width # "()">, - width # "-bit float elements attribute"> { - - let storageType = [{ ::mlir::DenseElementsAttr }]; - let returnType = [{ ::mlir::DenseElementsAttr }]; - - // Note that this is only constructing scalar elements attribute. - let constBuilderCall = "::mlir::DenseElementsAttr::get(" - "::mlir::RankedTensorType::get({}, $_builder.getF" # width # "Type())," - "::llvm::ArrayRef($0))"; - let convertFromStorage = "$_self"; -} - -def F64ElementsAttr : FloatElementsAttr<64>; - -// A `width`-bit floating point elements attribute. The attribute should be -// ranked and has a shape as specified in `dims`. -class RankedFloatElementsAttr dims> : ElementsAttrBase< - CPred<"::llvm::isa<::mlir::DenseFPElementsAttr>($_self) &&" - "::llvm::cast<::mlir::DenseFPElementsAttr>($_self).getType()." - "getElementType().isF" # width # "() && " - // Check that this is ranked and has the specified shape. - "::llvm::cast<::mlir::DenseFPElementsAttr>($_self).getType().hasRank() && " - "::llvm::cast<::mlir::DenseFPElementsAttr>($_self).getType().getShape() == " - "::mlir::ArrayRef({" # !interleave(dims, ", ") # "})">, - width # "-bit float elements attribute of shape [" # - !interleave(dims, ", ") # "]"> { - - let storageType = [{ ::mlir::DenseFPElementsAttr }]; - let returnType = [{ ::mlir::DenseFPElementsAttr }]; - - let constBuilderCall = "::llvm::cast<::mlir::DenseFPElementsAttr>(" - "::mlir::DenseElementsAttr::get(" - "::mlir::RankedTensorType::get({" # !interleave(dims, ", ") # - "}, $_builder.getF" # width # "Type()), " - "::llvm::ArrayRef($0)))"; - let convertFromStorage = "$_self"; -} - -class RankedF32ElementsAttr dims> : RankedFloatElementsAttr<32, dims>; -class RankedF64ElementsAttr dims> : RankedFloatElementsAttr<64, dims>; - -def StringElementsAttr : ElementsAttrBase< - CPred<"::llvm::isa<::mlir::DenseStringElementsAttr>($_self)" >, - "string elements attribute"> { - - let storageType = [{ ::mlir::DenseElementsAttr }]; - let returnType = [{ ::mlir::DenseElementsAttr }]; - - let convertFromStorage = "$_self"; -} - -// Attributes containing affine maps. -def AffineMapAttr : Attr< -CPred<"::llvm::isa<::mlir::AffineMapAttr>($_self)">, "AffineMap attribute"> { - let storageType = [{::mlir::AffineMapAttr }]; - let returnType = [{ ::mlir::AffineMap }]; - let valueType = Index; - let constBuilderCall = "::mlir::AffineMapAttr::get($0)"; -} - -// Base class for array attributes. -class ArrayAttrBase : Attr { - let storageType = [{ ::mlir::ArrayAttr }]; - let returnType = [{ ::mlir::ArrayAttr }]; - let valueType = NoneType; - let convertFromStorage = "$_self"; - let constBuilderCall = "$_builder.getArrayAttr($0)"; -} - -def ArrayAttr : ArrayAttrBase($_self)">, - "array attribute">; - -// Base class for array attributes whose elements are of the same kind. -// `element` specifies the element attribute kind stored in this array. -class TypedArrayAttrBase: ArrayAttrBase< - And<[ - // Guarantee this is an ArrayAttr first - CPred<"::llvm::isa<::mlir::ArrayAttr>($_self)">, - // Guarantee all elements satisfy the constraints from `element` - Concat<"::llvm::all_of(::llvm::cast<::mlir::ArrayAttr>($_self), " - "[&](::mlir::Attribute attr) { return attr && (", - SubstLeaves<"$_self", "attr", element.predicate>, - "); })">]>, - summary> { - - Attr elementAttr = element; -} - -def LocationArrayAttr : TypedArrayAttrBase; - -def AffineMapArrayAttr : TypedArrayAttrBase { - let constBuilderCall = "$_builder.getAffineMapArrayAttr($0)"; -} - -def BoolArrayAttr : TypedArrayAttrBase { - let constBuilderCall = "$_builder.getBoolArrayAttr($0)"; -} -def I32ArrayAttr : TypedArrayAttrBase { - let constBuilderCall = "$_builder.getI32ArrayAttr($0)"; -} -def I64ArrayAttr : TypedArrayAttrBase { - let constBuilderCall = "$_builder.getI64ArrayAttr($0)"; -} -// Variant of I64ArrayAttr whose user accessor is SmallVector. -def I64SmallVectorArrayAttr : - TypedArrayAttrBase { - let returnType = [{ ::llvm::SmallVector }]; - let convertFromStorage = [{ - llvm::to_vector<4>( - llvm::map_range($_self.getAsRange(), - [](IntegerAttr attr) { return attr.getInt(); })); - }]; - let constBuilderCall = "$_builder.getI64ArrayAttr($0)"; -} -def F32ArrayAttr : TypedArrayAttrBase { - let constBuilderCall = "$_builder.getF32ArrayAttr($0)"; -} -def F64ArrayAttr : TypedArrayAttrBase { - let constBuilderCall = "$_builder.getF64ArrayAttr($0)"; -} -def StrArrayAttr : TypedArrayAttrBase { - let constBuilderCall = "$_builder.getStrArrayAttr($0)"; -} -def TypeArrayAttr : TypedArrayAttrBase { - let constBuilderCall = "$_builder.getTypeArrayAttr($0)"; -} -def IndexListArrayAttr : - TypedArrayAttrBase; -def DictArrayAttr : - TypedArrayAttrBase; - -// Attributes containing symbol references. -def SymbolRefAttr : Attr($_self)">, - "symbol reference attribute"> { - let storageType = [{ ::mlir::SymbolRefAttr }]; - let returnType = [{ ::mlir::SymbolRefAttr }]; - let valueType = NoneType; - let constBuilderCall = - "::mlir::SymbolRefAttr::get($_builder.getContext(), $0)"; - let convertFromStorage = "$_self"; -} - -def FlatSymbolRefAttr : Attr($_self)">, - "flat symbol reference attribute"> { - let storageType = [{ ::mlir::FlatSymbolRefAttr }]; - let returnType = [{ ::llvm::StringRef }]; - let valueType = NoneType; - let constBuilderCall = - "::mlir::SymbolRefAttr::get($_builder.getContext(), $0)"; - let convertFromStorage = "$_self.getValue()"; -} - -def SymbolRefArrayAttr : - TypedArrayAttrBase { - let constBuilderCall = ?; -} - -def FlatSymbolRefArrayAttr : - TypedArrayAttrBase { - let constBuilderCall = ?; -} - -//===----------------------------------------------------------------------===// -// Derive attribute kinds - -// DerivedAttr are attributes whose value is computed from properties -// of the operation. They do not require additional storage and are -// materialized as needed. -// Note: All derived attributes should be materializable as an Attribute. E.g., -// do not use DerivedAttr for things that could not have been stored as -// Attribute. -// -class DerivedAttr : - Attr, "derived attribute"> { - let returnType = ret; - code body = b; - - // Specify how to convert from the derived attribute to an attribute. - // - // ## Special placeholders - // - // Special placeholders can be used to refer to entities during conversion: - // - // * `$_builder` will be replaced by a mlir::Builder instance. - // * `$_ctxt` will be replaced by the MLIRContext* instance. - // * `$_self` will be replaced with the derived attribute (value produces - // `returnType`). - let convertFromStorage = convert; -} - -// Derived attribute that returns a mlir::Type. -class DerivedTypeAttr : DerivedAttr<"::mlir::Type", body> { - let convertFromStorage = "::mlir::TypeAttr::get($_self)"; -} - -//===----------------------------------------------------------------------===// -// Constant attribute kinds - -// Represents a constant attribute of specific Attr type. A constant -// attribute can be specified only of attributes that have a constant -// builder call defined. The constant value is specified as a string. -// -// If used as a constraint, it generates a matcher on a constant attribute by -// using the constant value builder of the attribute and the value. -class ConstantAttr : AttrConstraint< - CPred<"$_self == " # !subst("$0", val, attribute.constBuilderCall)>, - "constant attribute " # val> { - Attr attr = attribute; - string value = val; -} - -class ConstF32Attr : ConstantAttr; -def ConstBoolAttrFalse : ConstantAttr; -def ConstBoolAttrTrue : ConstantAttr; -def ConstUnitAttr : ConstantAttr; - -// Constant string-based attribute. Wraps the desired string in escaped quotes. -class ConstantStrAttr - : ConstantAttr; - -//===----------------------------------------------------------------------===// -// Common attribute constraints -//===----------------------------------------------------------------------===// - -// A general mechanism to further confine the given `attr` with all the -// `constraints`. This allows to compose complex constraints out of a series -// of more primitive ones. -class ConfinedAttr constraints> : Attr< - And, - !foldl(/*init*/attr.summary, /*list*/constraints, - prev, cur, prev # " " # cur.summary)> { - let storageType = attr.storageType; - let returnType = attr.returnType; - let convertFromStorage = attr.convertFromStorage; - let constBuilderCall = attr.constBuilderCall; - let defaultValue = attr.defaultValue; - let valueType = attr.valueType; - let isOptional = attr.isOptional; - - let baseAttr = attr; -} - -// An AttrConstraint that holds if all attr constraints specified in -// 'constraints' hold. -class AllAttrOf constraints> : AttrConstraint< - And, - !interleave(!foreach(con, constraints, con.summary), " and ")> { -} - -class IntNEQValue : AttrConstraint< - CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getInt() != " # n>, - "whose minimum value is " # n>; - -class IntMinValue : AttrConstraint< - CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getInt() >= " # n>, - "whose minimum value is " # n>; - -class IntMaxValue : AttrConstraint< - CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getInt() <= " # n>, - "whose maximum value is " # n>; - -def IntNonNegative : AttrConstraint< - CPred<"!::llvm::cast<::mlir::IntegerAttr>($_self).getValue().isNegative()">, - "whose value is non-negative">; - -def IntPositive : AttrConstraint< - CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getValue().isStrictlyPositive()">, - "whose value is positive">; - -class ArrayMinCount : AttrConstraint< - CPred<"::llvm::cast<::mlir::ArrayAttr>($_self).size() >= " # n>, - "with at least " # n # " elements">; - -class ArrayCount : AttrConstraint< - CPred<"::llvm::cast<::mlir::ArrayAttr>($_self).size() == " #n>, - "with exactly " # n # " elements">; - -class DenseArrayCount : AttrConstraint< - CPred<"::llvm::cast<::mlir::DenseArrayAttr>($_self).size() == " #n>, - "with exactly " # n # " elements">; - -class DenseArrayStrictlyPositive : AttrConstraint< - CPred<"::llvm::all_of(::llvm::cast<" # arrayType #">($_self).asArrayRef(), " - "[&](auto v) { return v > 0; })">, - "whose value is positive">; - -class DenseArrayNonNegative : AttrConstraint< - CPred<"::llvm::all_of(::llvm::cast<" # arrayType #">($_self).asArrayRef(), " - "[&](auto v) { return v >= 0; })">, - "whose value is non-negative">; - -class DenseArraySorted : AttrConstraint< - CPred<"llvm::is_sorted(::llvm::cast<" # arrayType # ">($_self).asArrayRef())">, - "should be in non-decreasing order">; - -class DenseArrayStrictlySorted : AttrConstraint< - And<[ - CPred<"llvm::is_sorted(::llvm::cast<" # arrayType # ">($_self).asArrayRef())">, - // Check that no two adjacent elements are the same. - CPred<"[](" # arrayType.returnType # " a) {\n" - "return std::adjacent_find(std::begin(a), std::end(a)) == " - "std::end(a);\n" - "}(::llvm::cast<" # arrayType # ">($_self).asArrayRef())" - >]>, - "should be in increasing order">; - -class IntArrayNthElemEq : AttrConstraint< - And<[ - CPred<"::llvm::cast<::mlir::ArrayAttr>($_self).size() > " # index>, - CPred<"::llvm::cast<::mlir::IntegerAttr>(::llvm::cast<::mlir::ArrayAttr>($_self)[" - # index # "]).getInt() == " # value> - ]>, - "whose " # index # "-th element must be " # value>; - -class IntArrayNthElemMinValue : AttrConstraint< - And<[ - CPred<"::llvm::cast<::mlir::ArrayAttr>($_self).size() > " # index>, - CPred<"::llvm::cast<::mlir::IntegerAttr>(::llvm::cast<::mlir::ArrayAttr>($_self)[" - # index # "]).getInt() >= " # min> - ]>, - "whose " # index # "-th element must be at least " # min>; - -class IntArrayNthElemMaxValue : AttrConstraint< - And<[ - CPred<"::llvm::cast<::mlir::ArrayAttr>($_self).size() > " # index>, - CPred<"::llvm::cast<::mlir::IntegerAttr>(::llvm::cast<::mlir::ArrayAttr>($_self)[" - # index # "]).getInt() <= " # max> - ]>, - "whose " # index # "-th element must be at most " # max>; - -class IntArrayNthElemInRange : AttrConstraint< - And<[ - CPred<"::llvm::cast<::mlir::ArrayAttr>($_self).size() > " # index>, - CPred<"::llvm::cast<::mlir::IntegerAttr>(::llvm::cast<::mlir::ArrayAttr>($_self)[" - # index # "]).getInt() >= " # min>, - CPred<"::llvm::cast<::mlir::IntegerAttr>(::llvm::cast<::mlir::ArrayAttr>($_self)[" - # index # "]).getInt() <= " # max> - ]>, - "whose " # index # "-th element must be at least " # min # " and at most " # max>; - -def IsNullAttr : AttrConstraint< - CPred<"!$_self">, "empty attribute (for optional attributes)">; - -//===----------------------------------------------------------------------===// -// Region definitions -//===----------------------------------------------------------------------===// - -class Region : - RegionConstraint; - -// Any region. -def AnyRegion : Region, "any region">; - -// A region with the given number of blocks. -class SizedRegion : Region< - CPred<"::llvm::hasNItems($_self, " # numBlocks # ")">, - "region with " # numBlocks # " blocks">; - -// A region with at least the given number of blocks. -class MinSizedRegion : Region< - CPred<"::llvm::hasNItemsOrMore($_self, " # numBlocks # ")">, - "region with at least " # numBlocks # " blocks">; - -// A region with at most the given number of blocks. -class MaxSizedRegion : Region< - CPred<"::llvm::hasNItemsOrLess($_self, " # numBlocks # ")">, - "region with at most " # numBlocks # " blocks">; - -// A variadic region constraint. It expands to zero or more of the base region. -class VariadicRegion - : Region; +include "mlir/IR/Interfaces.td" +include "mlir/IR/Properties.td" +include "mlir/IR/Traits.td" +include "mlir/IR/Utils.td" +include "mlir/IR/AttrTypeBase.td" //===----------------------------------------------------------------------===// // Successor definitions @@ -1989,406 +36,10 @@ class VariadicSuccessor : Successor; - -//===----------------------------------------------------------------------===// -// Trait definitions -//===----------------------------------------------------------------------===// - -// Trait represents a trait regarding an attribute, operation, or type. -class Trait; - -// Define a Trait corresponding to a list of Traits, this allows for specifying -// a list of traits as trait. Avoids needing to do `[Traits, ...] # ListOfTraits -// # [Others, ...]` while still allowing providing convenient groupings. -class TraitList props> : Trait { - list traits = props; -} - -// NativeTrait corresponds to the MLIR C++ trait mechanism. The purpose to wrap -// around C++ symbol string with this class is to make traits specified for -// entities in TableGen less alien and more integrated. -// `extraConcreteClassDeclaration` and `extraConcreteClassDefinition` code -// get injected into the entities in which the NativeTrait is specified for. -class NativeTrait : Trait { - string trait = name; - string cppNamespace = "::mlir::" # entityType # "Trait"; - - code extraConcreteClassDeclaration = extraClassDeclaration; - code extraConcreteClassDefinition = extraClassDefinition; -} - -// ParamNativeTrait corresponds to the template-parameterized traits in the C++ -// implementation. MLIR uses nested class templates to implement such traits -// leading to constructs of the form "TraitName::Impl". Use the -// value in `prop` as the trait name and the value in `params` as parameters to -// construct the native trait class name. -class ParamNativeTrait - : NativeTrait::Impl", entityType>; - -// GenInternalTrait is a trait that does not have direct C++ mapping but affects -// an entities definition generator internals, like how operation builders and -// operand/attribute/result getters are generated. -class GenInternalTrait : Trait { - string trait = "::mlir::" # entityType # "Trait::" # prop; -} - -// PredTrait is a trait implemented by way of a predicate on an entity. -class PredTrait : Trait { - string summary = descr; - Pred predicate = pred; -} - -//===----------------------------------------------------------------------===// -// OpTrait definitions -//===----------------------------------------------------------------------===// - -// A trait that describes the structure of operation will be marked with -// `StructuralOpTrait` and they will be verified first. -class StructuralOpTrait; - -// These classes are used to define operation specific traits. - -// Specify op specific declarations and definitions in `extraOpDeclaration` -// and `extraOpDefinition` template arguments. -class NativeOpTrait traits = [], - code extraOpDeclaration = [{}], - code extraOpDefinition = [{}]> - : NativeTrait { - // Specify the list of traits that need to be verified before the verification - // of this NativeOpTrait. - list dependentTraits = traits; -} -class ParamNativeOpTrait traits = []> - : ParamNativeTrait { - // Specify the list of traits that need to be verified before the verification - // of this ParamNativeOpTrait. - list dependentTraits = traits; -} -class GenInternalOpTrait traits = []> - : GenInternalTrait { - // Specify the list of traits that need to be verified before the verification - // of this GenInternalOpTrait. - list dependentTraits = traits; -} -class PredOpTrait traits = []> - : PredTrait { - // Specify the list of traits that need to be verified before the verification - // of this PredOpTrait. - list dependentTraits = traits; -} - -// Op defines an affine scope. -def AffineScope : NativeOpTrait<"AffineScope">; -// Op defines an automatic allocation scope. -def AutomaticAllocationScope : - NativeOpTrait<"AutomaticAllocationScope">; -// Op supports operand broadcast behavior. -def ResultsBroadcastableShape : - NativeOpTrait<"ResultsBroadcastableShape">; -// X op Y == Y op X -def Commutative : NativeOpTrait<"IsCommutative">; -// op op X == op X (unary) / X op X == X (binary) -// FIXME: Idempotent should depend on SameOperandsAndResultType -def Idempotent : NativeOpTrait<"IsIdempotent">; -// op op X == X -// FIXME: Involution should depend on SameOperandsAndResultType -def Involution : NativeOpTrait<"IsInvolution">; -// Op behaves like a constant. -def ConstantLike : NativeOpTrait<"ConstantLike">; -// Op is isolated from above. -def IsolatedFromAbove : NativeOpTrait<"IsIsolatedFromAbove">; -// Op results are float or vectors/tensors thereof. -def ResultsAreFloatLike : NativeOpTrait<"ResultsAreFloatLike">; -// Op has the same operand type. -def SameTypeOperands : NativeOpTrait<"SameTypeOperands">; -// Op has same shape for all operands. -def SameOperandsShape : NativeOpTrait<"SameOperandsShape">; -// Op has same operand and result shape. -def SameOperandsAndResultShape : - NativeOpTrait<"SameOperandsAndResultShape">; -// Op has the same element type (or type itself, if scalar) for all operands. -def SameOperandsElementType : - NativeOpTrait<"SameOperandsElementType">; -// Op has the same operand and result element type (or type itself, if scalar). -def SameOperandsAndResultElementType : - NativeOpTrait<"SameOperandsAndResultElementType">; -// Op is a terminator. -def Terminator : NativeOpTrait<"IsTerminator">; -// Op can be safely normalized in the presence of MemRefs with -// non-identity maps. -def MemRefsNormalizable : NativeOpTrait<"MemRefsNormalizable">; -// Op is elementwise on tensor/vector operands and results. -def Elementwise : NativeOpTrait<"Elementwise">; -// Elementwise op can be applied to scalars instead tensor/vector operands. -def Scalarizable : NativeOpTrait<"Scalarizable", [Elementwise]>; -// Elementwise op can be applied to all-vector operands. -def Vectorizable : NativeOpTrait<"Vectorizable", [Elementwise]>; -// Elementwise op can be applied to all-tensor operands. -def Tensorizable : NativeOpTrait<"Tensorizable", [Elementwise]>; - -// Group together `Elementwise`, `Scalarizable`, `Vectorizable`, and -// `Tensorizable` for convenience. -def ElementwiseMappable : TraitList<[ - Elementwise, - Scalarizable, - Vectorizable, - Tensorizable, -]>; - -// Op's regions have a single block. -def SingleBlock : NativeOpTrait<"SingleBlock">, StructuralOpTrait; - -// Op's regions have a single block with the specified terminator. -class SingleBlockImplicitTerminator - : ParamNativeOpTrait<"SingleBlockImplicitTerminator", op>, - StructuralOpTrait; - -// Op's regions don't have terminator. -def NoTerminator : NativeOpTrait<"NoTerminator">, StructuralOpTrait; - -// Op's parent operation is the provided one. -class HasParent - : ParamNativeOpTrait<"HasParent", op>, StructuralOpTrait; - -class ParentOneOf ops> - : ParamNativeOpTrait<"HasParent", !interleave(ops, ", ")>, - StructuralOpTrait; - -// Op result type is derived from the first attribute. If the attribute is an -// subclass of `TypeAttrBase`, its value is used, otherwise, the type of the -// attribute content is used. -def FirstAttrDerivedResultType : - GenInternalOpTrait<"FirstAttrDerivedResultType">; - -// TODO: Turn the following into normal traits and generate verification for -// them. - -// All variadic operands of the op have the same number of values. -// A variadic operand contains an array of values whose array size is only -// known at runtime. This trait requires all variadic operands of an op -// to have the same array size. -def SameVariadicOperandSize : GenInternalOpTrait<"SameVariadicOperandSize">; -// All variadic results of the op have the same number of values. -// A variadic result contains an array of values whose array size is only -// known at runtime. This trait requires all variadic results of an op -// to have the same array size. -def SameVariadicResultSize : GenInternalOpTrait<"SameVariadicResultSize">; - -// Uses an attribute named `operand_segment_sizes` to specify how many actual -// operand each ODS-declared operand (variadic or not) corresponds to. -// This trait is used for ops that have multiple variadic operands but do -// not know statically their size relationship. The attribute must be a 1D -// vector that has the same number of elements as the number of ODS declared -// operands. That means even if some operands are non-variadic, the attribute -// still need to have an element for its size, which is always 1. -def AttrSizedOperandSegments : - NativeOpTrait<"AttrSizedOperandSegments">, StructuralOpTrait; -// Similar to AttrSizedOperandSegments, but used for results. The attribute -// should be named as `result_segment_sizes`. -def AttrSizedResultSegments : - NativeOpTrait<"AttrSizedResultSegments">, StructuralOpTrait; - -// Op attached regions have no arguments -def NoRegionArguments : NativeOpTrait<"NoRegionArguments">, StructuralOpTrait; - -//===----------------------------------------------------------------------===// -// Interface definitions -//===----------------------------------------------------------------------===// - -// Marker used to identify the argument list for an op or interface method. -def ins; - -// This class represents a typed argument with optional default value for C -// function signatures, e.g. builders or methods. -class CArg { - string type = ty; - string defaultValue = value; -} - -// InterfaceTrait corresponds to a specific 'Interface' class defined in C++. -// The purpose to wrap around C++ symbol string with this class is to make -// interfaces specified for ops in TableGen less alien and more integrated. -class InterfaceTrait : NativeTrait<"", ""> { - let trait = name # "::Trait"; - let cppNamespace = ""; - - // An optional code block containing extra declarations to place in the - // interface trait declaration. - code extraTraitClassDeclaration = ""; -} - -// OpInterfaceTrait corresponds to a specific 'OpInterface' class defined in -// C++. The purpose to wrap around C++ symbol string with this class is to make -// interfaces specified for ops in TableGen less alien and more integrated. -class OpInterfaceTrait traits = []> - : InterfaceTrait { - // Specify the body of the verification function. `$_op` will be replaced with - // the operation being verified. - code verify = verifyBody; - - // A bit indicating if the verifier needs to access the ops in the regions. If - // it set to `1`, the region ops will be verified before invoking this - // verifier. - bit verifyWithRegions = 0; - - // Specify the list of traits that need to be verified before the verification - // of this OpInterfaceTrait. - list dependentTraits = traits; -} - -// This class represents a single, optionally static, interface method. -// Note: non-static interface methods have an implicit parameter, either -// $_op/$_attr/$_type corresponding to an instance of the derived value. -class InterfaceMethod { - // A human-readable description of what this method does. - string description = desc; - - // The name of the interface method. - string name = methodName; - - // The c++ type-name of the return type. - string returnType = retTy; - - // A dag of string that correspond to the arguments of the method. - dag arguments = args; - - // An optional body to the method. - code body = methodBody; - - // An optional default implementation of the method. - code defaultBody = defaultImplementation; -} - -// This class represents a single static interface method. -class StaticInterfaceMethod - : InterfaceMethod; - -// Interface represents a base interface. -class Interface baseInterfacesArg = []> { - // A human-readable description of what this interface does. - string description = ""; - - // The name given to the c++ interface class. - string cppInterfaceName = name; - - // The C++ namespace that this interface should be placed into. - // - // To specify nested namespaces, use "::" as the delimiter, e.g., given - // "A::B", ops will be placed in `namespace A { namespace B { } }`. - string cppNamespace = ""; - - // The list of methods defined by this interface. - list methods = []; - - // An optional code block containing extra declarations to place in the - // interface declaration. - code extraClassDeclaration = ""; - - // An optional code block containing extra declarations to place in both - // the interface and trait declaration. - code extraSharedClassDeclaration = ""; - - // An optional code block for adding additional "classof" logic. This can - // be used to better enable "optional" interfaces, where an entity only - // implements the interface if some dynamic characteristic holds. - // `$_attr`/`$_op`/`$_type` may be used to refer to an instance of the - // entity being checked. - code extraClassOf = ""; - - // An optional set of base interfaces that this interface - // "derives" from. - list baseInterfaces = baseInterfacesArg; -} - -// AttrInterface represents an interface registered to an attribute. -class AttrInterface baseInterfaces = []> - : Interface, InterfaceTrait, - Attr($_self)">, - name # " instance" - > { - let storageType = !if(!empty(cppNamespace), "", cppNamespace # "::") # name; - let returnType = storageType; - let convertFromStorage = "$_self"; -} - -// OpInterface represents an interface registered to an operation. -class OpInterface baseInterfaces = []> - : Interface, OpInterfaceTrait; - -// TypeInterface represents an interface registered to a type. -class TypeInterface baseInterfaces = []> - : Interface, InterfaceTrait, - Type($_self)">, - name # " instance", - !if(!empty(cppNamespace),"", cppNamespace # "::") # name - >; - -// Whether to declare the interface methods in the user entity's header. This -// class simply wraps an Interface but is used to indicate that the method -// declarations should be generated. This class takes an optional set of methods -// that should have declarations generated even if the method has a default -// implementation. -class DeclareInterfaceMethods overridenMethods = []> { - // This field contains a set of method names that should always have their - // declarations generated. This allows for generating declarations for - // methods with default implementations that need to be overridden. - list alwaysOverriddenMethods = overridenMethods; -} -class DeclareAttrInterfaceMethods overridenMethods = []> - : DeclareInterfaceMethods, - AttrInterface { - let description = interface.description; - let cppInterfaceName = interface.cppInterfaceName; - let cppNamespace = interface.cppNamespace; - let methods = interface.methods; - let baseInterfaces = interface.baseInterfaces; -} -class DeclareOpInterfaceMethods overridenMethods = []> - : DeclareInterfaceMethods, - OpInterface { - let description = interface.description; - let cppInterfaceName = interface.cppInterfaceName; - let cppNamespace = interface.cppNamespace; - let methods = interface.methods; - let baseInterfaces = interface.baseInterfaces; -} -class DeclareTypeInterfaceMethods overridenMethods = []> - : DeclareInterfaceMethods, - TypeInterface { - let description = interface.description; - let cppInterfaceName = interface.cppInterfaceName; - let cppNamespace = interface.cppNamespace; - let methods = interface.methods; - let baseInterfaces = interface.baseInterfaces; -} - //===----------------------------------------------------------------------===// // Op definitions //===----------------------------------------------------------------------===// -// Marker used to identify the result list for an op. -def outs; - -// Marker used to identify the region list for an op. -def region; - -// Marker used to identify the successor list for an op. -def successor; - // Class for defining a custom builder. // // TableGen generates several generic builders for each op by default (see diff --git a/mlir/include/mlir/IR/Properties.td b/mlir/include/mlir/IR/Properties.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/IR/Properties.td @@ -0,0 +1,133 @@ +//===-- Properties.td - Properties definition file ----------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the base properties defination file. +// +//===----------------------------------------------------------------------===// + +#ifndef PROPERTIES +#define PROPERTIES + +// Base class for defining properties. +class Property { + // User-readable one line summary used in error reporting messages. If empty, + // a generic message will be used. + string summary = desc; + // The full description of this property. + string description = ""; + code storageType = storageTypeParam; + code interfaceType = storageTypeParam; + + // The expression to convert from the storage type to the Interface + // type. For example, an enum can be stored as an int but returned as an + // enum class. + // + // Format: + // - `$_storage` will contain the property in the storage type. + // - `$_ctxt` will contain an `MLIRContext *`. + code convertFromStorage = "$_storage"; + + // The call expression to build a property storage from the interface type. + // + // Format: + // - `$_storage` will contain the property in the storage type. + // - `$_value` will contain the property in the user interface type. + code assignToStorage = "$_storage = $_value"; + + // The call expression to convert from the storage type to an attribute. + // + // Format: + // - `$_storage` is the storage type value. + // - `$_ctxt` is a `MLIRContext *`. + // + // The expression must result in an Attribute. + code convertToAttribute = [{ + convertToAttribute($_ctxt, $_storage) + }]; + + // The call expression to convert from an Attribute to the storage type. + // + // Format: + // - `$_storage` is the storage type value. + // - `$_attr` is the attribute. + // - `$_diag` is an optional Diagnostic pointer to emit error. + // + // The expression must return a LogicalResult + code convertFromAttribute = [{ + return convertFromAttribute($_storage, $_attr, $_diag); + }]; + + // The call expression to hash the property. + // + // Format: + // - `$_storage` is the variable to hash. + // + // The expression should define a llvm::hash_code. + code hashProperty = [{ + llvm::hash_value($_storage); + }]; + + // The call expression to emit the storage type to bytecode. + // + // Format: + // - `$_storage` is the storage type value. + // - `$_writer` is a `DialectBytecodeWriter`. + // - `$_ctxt` is a `MLIRContext *`. + code writeToMlirBytecode = [{ + writeToMlirBytecode($_writer, $_storage) + }]; + + // The call expression to read the storage type from bytecode. + // + // Format: + // - `$_storage` is the storage type value. + // - `$_reader` is a `DialectBytecodeReader`. + // - `$_ctxt` is a `MLIRContext *`. + code readFromMlirBytecode = [{ + if (::mlir::failed(readFromMlirBytecode($_reader, $_storage))) + return ::mlir::failure(); + }]; + + // Default value for the property. + string defaultValue = ?; +} + +/// Implementation of the Property class's `readFromMlirBytecode` field using +/// the default `convertFromAttribute` implementation. +/// Users not wanting to implement their own `readFromMlirBytecode` and +/// `writeToMlirBytecode` implementations can opt into using this implementation +/// by writing: +/// +/// let writeToMlirBytecode = writeMlirBytecodeWithConvertToAttribute; +/// let readFromMlirBytecode = readMlirBytecodeUsingConvertFromAttribute; +/// +/// in their property definition. +/// Serialization and deserialization is performed using the attributes +/// returned by `convertFromAttribute` and `convertToAttribute`. +/// +/// WARNING: This implementation creates a less than optimal encoding. +/// Users caring about optimal encoding should not use this implementation and +/// implement `readFromMlirBytecode` and `writeToMlirBytecode` themselves. +defvar readMlirBytecodeUsingConvertFromAttribute = [{ + ::mlir::Attribute attr; + if (::mlir::failed($_reader.readAttribute(attr))) + return ::mlir::failure(); + if (::mlir::failed(convertFromAttribute($_storage, attr, nullptr))) + return ::mlir::failure(); +}]; + +/// Implementation of the Property class's `writeToMlirBytecode` field using +/// the default `convertToAttribute` implementation. +/// See description of `readMlirBytecodeUsingConvertFromAttribute` above for +/// details. +defvar writeMlirBytecodeWithConvertToAttribute = [{ + $_writer.writeAttribute(convertToAttribute($_ctxt, $_storage)) +}]; + + +#endif // PROPERTIES diff --git a/mlir/include/mlir/IR/Traits.td b/mlir/include/mlir/IR/Traits.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/IR/Traits.td @@ -0,0 +1,222 @@ +//===-- Traits.td - Trait definations file ------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains definations for traits. +// +//===----------------------------------------------------------------------===// + +#ifndef TRAITS_TD +#define TRAITS_TD + +include "mlir/IR/Constraints.td" + +//===----------------------------------------------------------------------===// +// Trait definitions +//===----------------------------------------------------------------------===// + +// Trait represents a trait regarding an attribute, operation, or type. +class Trait; + +// Define a Trait corresponding to a list of Traits, this allows for specifying +// a list of traits as trait. Avoids needing to do `[Traits, ...] # ListOfTraits +// # [Others, ...]` while still allowing providing convenient groupings. +class TraitList props> : Trait { + list traits = props; +} + +// NativeTrait corresponds to the MLIR C++ trait mechanism. The purpose to wrap +// around C++ symbol string with this class is to make traits specified for +// entities in TableGen less alien and more integrated. +// `extraConcreteClassDeclaration` and `extraConcreteClassDefinition` code +// get injected into the entities in which the NativeTrait is specified for. +class NativeTrait : Trait { + string trait = name; + string cppNamespace = "::mlir::" # entityType # "Trait"; + + code extraConcreteClassDeclaration = extraClassDeclaration; + code extraConcreteClassDefinition = extraClassDefinition; +} + +// ParamNativeTrait corresponds to the template-parameterized traits in the C++ +// implementation. MLIR uses nested class templates to implement such traits +// leading to constructs of the form "TraitName::Impl". Use the +// value in `prop` as the trait name and the value in `params` as parameters to +// construct the native trait class name. +class ParamNativeTrait + : NativeTrait::Impl", entityType>; + +// GenInternalTrait is a trait that does not have direct C++ mapping but affects +// an entities definition generator internals, like how operation builders and +// operand/attribute/result getters are generated. +class GenInternalTrait : Trait { + string trait = "::mlir::" # entityType # "Trait::" # prop; +} + +// PredTrait is a trait implemented by way of a predicate on an entity. +class PredTrait : Trait { + string summary = descr; + Pred predicate = pred; +} + +//===----------------------------------------------------------------------===// +// OpTrait definitions +//===----------------------------------------------------------------------===// + +// A trait that describes the structure of operation will be marked with +// `StructuralOpTrait` and they will be verified first. +class StructuralOpTrait; + +// These classes are used to define operation specific traits. + +// Specify op specific declarations and definitions in `extraOpDeclaration` +// and `extraOpDefinition` template arguments. +class NativeOpTrait traits = [], + code extraOpDeclaration = [{}], + code extraOpDefinition = [{}]> + : NativeTrait { + // Specify the list of traits that need to be verified before the verification + // of this NativeOpTrait. + list dependentTraits = traits; +} +class ParamNativeOpTrait traits = []> + : ParamNativeTrait { + // Specify the list of traits that need to be verified before the verification + // of this ParamNativeOpTrait. + list dependentTraits = traits; +} +class GenInternalOpTrait traits = []> + : GenInternalTrait { + // Specify the list of traits that need to be verified before the verification + // of this GenInternalOpTrait. + list dependentTraits = traits; +} +class PredOpTrait traits = []> + : PredTrait { + // Specify the list of traits that need to be verified before the verification + // of this PredOpTrait. + list dependentTraits = traits; +} + +// Op defines an affine scope. +def AffineScope : NativeOpTrait<"AffineScope">; +// Op defines an automatic allocation scope. +def AutomaticAllocationScope : + NativeOpTrait<"AutomaticAllocationScope">; +// Op supports operand broadcast behavior. +def ResultsBroadcastableShape : + NativeOpTrait<"ResultsBroadcastableShape">; +// X op Y == Y op X +def Commutative : NativeOpTrait<"IsCommutative">; +// op op X == op X (unary) / X op X == X (binary) +// FIXME: Idempotent should depend on SameOperandsAndResultType +def Idempotent : NativeOpTrait<"IsIdempotent">; +// op op X == X +// FIXME: Involution should depend on SameOperandsAndResultType +def Involution : NativeOpTrait<"IsInvolution">; +// Op behaves like a constant. +def ConstantLike : NativeOpTrait<"ConstantLike">; +// Op is isolated from above. +def IsolatedFromAbove : NativeOpTrait<"IsIsolatedFromAbove">; +// Op results are float or vectors/tensors thereof. +def ResultsAreFloatLike : NativeOpTrait<"ResultsAreFloatLike">; +// Op has the same operand type. +def SameTypeOperands : NativeOpTrait<"SameTypeOperands">; +// Op has same shape for all operands. +def SameOperandsShape : NativeOpTrait<"SameOperandsShape">; +// Op has same operand and result shape. +def SameOperandsAndResultShape : + NativeOpTrait<"SameOperandsAndResultShape">; +// Op has the same element type (or type itself, if scalar) for all operands. +def SameOperandsElementType : + NativeOpTrait<"SameOperandsElementType">; +// Op has the same operand and result element type (or type itself, if scalar). +def SameOperandsAndResultElementType : + NativeOpTrait<"SameOperandsAndResultElementType">; +// Op is a terminator. +def Terminator : NativeOpTrait<"IsTerminator">; +// Op can be safely normalized in the presence of MemRefs with +// non-identity maps. +def MemRefsNormalizable : NativeOpTrait<"MemRefsNormalizable">; +// Op is elementwise on tensor/vector operands and results. +def Elementwise : NativeOpTrait<"Elementwise">; +// Elementwise op can be applied to scalars instead tensor/vector operands. +def Scalarizable : NativeOpTrait<"Scalarizable", [Elementwise]>; +// Elementwise op can be applied to all-vector operands. +def Vectorizable : NativeOpTrait<"Vectorizable", [Elementwise]>; +// Elementwise op can be applied to all-tensor operands. +def Tensorizable : NativeOpTrait<"Tensorizable", [Elementwise]>; + +// Group together `Elementwise`, `Scalarizable`, `Vectorizable`, and +// `Tensorizable` for convenience. +def ElementwiseMappable : TraitList<[ + Elementwise, + Scalarizable, + Vectorizable, + Tensorizable, +]>; + +// Op's regions have a single block. +def SingleBlock : NativeOpTrait<"SingleBlock">, StructuralOpTrait; + +// Op's regions have a single block with the specified terminator. +class SingleBlockImplicitTerminator + : ParamNativeOpTrait<"SingleBlockImplicitTerminator", op>, + StructuralOpTrait; + +// Op's regions don't have terminator. +def NoTerminator : NativeOpTrait<"NoTerminator">, StructuralOpTrait; + +// Op's parent operation is the provided one. +class HasParent + : ParamNativeOpTrait<"HasParent", op>, StructuralOpTrait; + +class ParentOneOf ops> + : ParamNativeOpTrait<"HasParent", !interleave(ops, ", ")>, + StructuralOpTrait; + +// Op result type is derived from the first attribute. If the attribute is an +// subclass of `TypeAttrBase`, its value is used, otherwise, the type of the +// attribute content is used. +def FirstAttrDerivedResultType : + GenInternalOpTrait<"FirstAttrDerivedResultType">; + +// TODO: Turn the following into normal traits and generate verification for +// them. + +// All variadic operands of the op have the same number of values. +// A variadic operand contains an array of values whose array size is only +// known at runtime. This trait requires all variadic operands of an op +// to have the same array size. +def SameVariadicOperandSize : GenInternalOpTrait<"SameVariadicOperandSize">; +// All variadic results of the op have the same number of values. +// A variadic result contains an array of values whose array size is only +// known at runtime. This trait requires all variadic results of an op +// to have the same array size. +def SameVariadicResultSize : GenInternalOpTrait<"SameVariadicResultSize">; + +// Uses an attribute named `operand_segment_sizes` to specify how many actual +// operand each ODS-declared operand (variadic or not) corresponds to. +// This trait is used for ops that have multiple variadic operands but do +// not know statically their size relationship. The attribute must be a 1D +// vector that has the same number of elements as the number of ODS declared +// operands. That means even if some operands are non-variadic, the attribute +// still need to have an element for its size, which is always 1. +def AttrSizedOperandSegments : + NativeOpTrait<"AttrSizedOperandSegments">, StructuralOpTrait; +// Similar to AttrSizedOperandSegments, but used for results. The attribute +// should be named as `result_segment_sizes`. +def AttrSizedResultSegments : + NativeOpTrait<"AttrSizedResultSegments">, StructuralOpTrait; + +// Op attached regions have no arguments +def NoRegionArguments : NativeOpTrait<"NoRegionArguments">, StructuralOpTrait; + +#endif // TRAITS_TD diff --git a/mlir/include/mlir/IR/Utils.td b/mlir/include/mlir/IR/Utils.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/IR/Utils.td @@ -0,0 +1,75 @@ +//===-- Utils.td - General utilities file ------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains a number of utilities which can be used across tablegen +// files. +// +//===----------------------------------------------------------------------===// + +#ifndef UTILS_TD +#define UTILS_TD + +// Helper for marking deprecated classes or defs in TableGen. To mark a def as +// deprecated, mix in the `Deprecate` class with a reason. +// Usage of a deprecated def within TableGen will cause a warning with the +// given message. +class Deprecated { + string odsDeprecated = reason; +} + +// Helper for marking entities in ODS generated C++ as deprecated. +// Usage of such an entity from C++ code will cause a warning being emitted by +// the C++ compiler with the given message. +// +// Note: Support has to be implemented by the code generator of a given +// entity. +class CppDeprecated { + string odsCppDeprecated = reason; +} + +// A workaround for the inability to define functions in Tablegen. +// +// The template parameter defines a string that can be extracted from an +// instance of this class by accessing the "result" member. Subclasses can take +// their own template parameters as function "arguments" and use them to +// populate result. +// For example, if it didn't already exist, a concat function could be defined +// like: +// +// class StrConcat strings> : +// StrFunc +// +// and then called like +// +// StrConcat<["a", "b", "c"]>.result +// +// to get the string "abc" +class StrFunc { + string result = r; +} + +// Marker used to identify the argument list. +def ins; + +// Marker used to identify the result list. +def outs; + +// Marker used to identify the region list. +def region; + +// Marker used to identify the successor list. +def successor; + +// This class represents a typed argument with optional default value for C +// function signatures, e.g. builders or methods. +class CArg { + string type = ty; + string defaultValue = value; +} + +#endif // UTILS_TD