diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -604,11 +604,13 @@ ShapedContainerType; +class RankedTensorOf allowedTypes> : + ShapedContainerType, + "ranked tensor", "::mlir::TensorType">; + def AnyTensor : TensorOf<[AnyType]>; -def AnyRankedTensor : - ShapedContainerType<[AnyType], And<[IsTensorTypePred, HasRankPred]>, - "ranked tensor", "::mlir::TensorType">; +def AnyRankedTensor : RankedTensorOf<[AnyType]>; // TODO: Have an easy way to add another constraint to a type. class StaticShapeTensorOf allowedTypes> @@ -675,11 +677,13 @@ class MemRefRankOf allowedTypes, list ranks> : Type.predicate, HasAnyRankOfPred]>, !interleave(!foreach(rank, ranks, rank # "D"), "/") # " " # - MemRefOf.summary>; + MemRefOf.summary, + "::mlir::MemRefType">; class StaticShapeMemRefOf allowedTypes> : Type.predicate, HasStaticShapePred]>, - "statically shaped " # MemRefOf.summary>; + "statically shaped " # MemRefOf.summary, + "::mlir::MemRefType">; def AnyStaticShapeMemRef : StaticShapeMemRefOf<[AnyType]>;