diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -340,13 +340,29 @@ // Any type from the given list class AnyTypeOf allowedTypes, string summary = "", string cppClassName = "::mlir::Type"> : Type< - // Satisfy any of the allowed type's condition + // Satisfy any of the allowed types' conditions. Or, !if(!eq(summary, ""), !interleave(!foreach(t, allowedTypes, t.summary), " or "), summary), cppClassName>; +// A type that satisfies the constraints of all given types. +class AllOfType allowedTypes, string summary = "", + string cppClassName = "::mlir::Type"> : Type< + // Satisfy all of the allowedf types' conditions. + And, + !if(!eq(summary, ""), + !interleave(!foreach(t, allowedTypes, t.summary), " and "), + summary), + cppClassName>; + +// A type that satisfies additional predicates. +class ConfinedType predicates, string summary = "", + string cppClassName = "::mlir::Type"> : Type< + And, + summary, cppClassName>; + // Integer types. // Any integer type irrespective of its width and signedness semantics. @@ -475,12 +491,14 @@ def BF16 : Type, "bfloat16 type">, BuildableType<"$_builder.getBF16Type()">; +def AnyComplex : Type()">, + "complex-type", "::mlir::ComplexType">; + class Complex - : Type()">, + : ConfinedType().getElementType()", - type.predicate>]>, + type.predicate>], "complex type with " # type.summary # " elements", "::mlir::ComplexType">, SameBuildabilityAs()">, - "complex-type", "::mlir::ComplexType">; - class OpaqueType : Type, summary, "::mlir::OpaqueType">, @@ -572,9 +587,8 @@ // Any vector where the rank is from the given `allowedRanks` list and the type // is from the given `allowedTypes` list class VectorOfRankAndType allowedRanks, - list allowedTypes> : Type< - And<[VectorOf.predicate, - VectorOfRank.predicate]>, + list allowedTypes> : AllOfType< + [VectorOf, VectorOfRank], VectorOf.summary # VectorOfRank.summary, "::mlir::VectorType">; @@ -630,18 +644,16 @@ // `allowedLengths` list and the type is from the given `allowedTypes` // list class VectorOfLengthAndType allowedLengths, - list allowedTypes> : Type< - And<[VectorOf.predicate, - VectorOfLength.predicate]>, + list allowedTypes> : AllOfType< + [VectorOf, VectorOfLength], VectorOf.summary # VectorOfLength.summary, "::mlir::VectorType">; // Any fixed-length vector where the number of elements is from the given // `allowedLengths` list and the type is from the given `allowedTypes` list class FixedVectorOfLengthAndType allowedLengths, - list allowedTypes> : Type< - And<[FixedVectorOf.predicate, - FixedVectorOfLength.predicate]>, + list allowedTypes> : AllOfType< + [FixedVectorOf, FixedVectorOfLength], FixedVectorOf.summary # FixedVectorOfLength.summary, "::mlir::VectorType">; @@ -649,9 +661,8 @@ // Any scalable vector where the number of elements is from the given // `allowedLengths` list and the type is from the given `allowedTypes` list class ScalableVectorOfLengthAndType allowedLengths, - list allowedTypes> : Type< - And<[ScalableVectorOf.predicate, - ScalableVectorOfLength.predicate]>, + list allowedTypes> : AllOfType< + [ScalableVectorOf, ScalableVectorOfLength], ScalableVectorOf.summary # ScalableVectorOfLength.summary, "::mlir::VectorType">; @@ -768,34 +779,33 @@ // TODO: Have an easy way to add another constraint to a type. class MemRefRankOf allowedTypes, list ranks> : - Type.predicate, HasAnyRankOfPred]>, + ConfinedType, [HasAnyRankOfPred], !interleave(!foreach(rank, ranks, rank # "D"), "/") # " " # MemRefOf.summary, "::mlir::MemRefType">; -class StaticShapeMemRefOf allowedTypes> - : Type.predicate, HasStaticShapePred]>, - "statically shaped " # MemRefOf.summary, - "::mlir::MemRefType">; +class StaticShapeMemRefOf allowedTypes> : + ConfinedType, [HasStaticShapePred], + "statically shaped " # MemRefOf.summary, + "::mlir::MemRefType">; def AnyStaticShapeMemRef : StaticShapeMemRefOf<[AnyType]>; // For a MemRefType, verify that it has strides. def HasStridesPred : CPred<[{ isStrided($_self.cast<::mlir::MemRefType>()) }]>; -class StridedMemRefOf allowedTypes> - : Type.predicate, HasStridesPred]>, - "strided " # MemRefOf.summary>; +class StridedMemRefOf allowedTypes> : + ConfinedType, [HasStridesPred], + "strided " # MemRefOf.summary>; def AnyStridedMemRef : StridedMemRefOf<[AnyType]>; class AnyStridedMemRefOfRank : - Type.predicate]>, + AllOfType<[AnyStridedMemRef, MemRefRankOf<[AnyType], [rank]>], AnyStridedMemRef.summary # " of rank " # rank>; class StridedMemRefRankOf allowedTypes, list ranks> : - Type.predicate, HasAnyRankOfPred]>, + ConfinedType, [HasAnyRankOfPred], !interleave(!foreach(rank, ranks, rank # "D"), "/") # " " # MemRefOf.summary>;