diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -211,10 +211,14 @@ // Whether a type is a MemRefType. def IsMemRefTypePred : CPred<"$_self.isa<::mlir::MemRefType>()">; -// Whether a type is an IsUnrankedMemRefType +// Whether a type is an UnrankedMemRefType def IsUnrankedMemRefTypePred : CPred<"$_self.isa<::mlir::UnrankedMemRefType>()">; +// Whether a type is a BaseMemRefType +def IsBaseMemRefTypePred + : CPred<"$_self.isa<::mlir::BaseMemRefType>()">; + // Whether a type is a ShapedType. def IsShapedTypePred : CPred<"$_self.isa<::mlir::ShapedType>()">; @@ -651,10 +655,13 @@ class 4DTensorOf allowedTypes> : TensorRankOf; // Unranked Memref type -def AnyUnrankedMemRef : - ShapedContainerType<[AnyType], +class UnrankedMemRefOf allowedTypes> : + ShapedContainerType; + +def AnyUnrankedMemRef : UnrankedMemRefOf<[AnyType]>; + // Memref type. // Memrefs are blocks of data with fixed type and rank. @@ -664,6 +671,9 @@ def AnyMemRef : MemRefOf<[AnyType]>; +class RankedOrUnrankedMemRefOf allowedTypes>: + AnyTypeOf<[UnrankedMemRefOf, MemRefOf]>; + def AnyRankedOrUnrankedMemRef: AnyTypeOf<[AnyUnrankedMemRef, AnyMemRef]>; // Memref declarations handle any memref, independent of rank, size, (static or