diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td --- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td +++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td @@ -59,8 +59,13 @@ }]; let methods = [ InterfaceMethod<[{ - Returns a clone of this type with the given shape and element - type. If a shape is not provided, the current shape of the type is used. + Returns a clone of this type with the given shape and element type. + + If no shape is provided, the shape of this type is used. In that case, if + this type is unranked, so is the resulting type. + + If a shape is provided, the resulting type is always ranked, even if this + type is unranked. }], "::mlir::ShapedType", "cloneWith", (ins "::std::optional<::llvm::ArrayRef>":$shape, @@ -89,7 +94,7 @@ /// Whether the given dimension size indicates a dynamic dimension. static constexpr bool isDynamic(int64_t dValue) { - return dValue == kDynamic; + return dValue == kDynamic; } /// Whether the given shape has any size that indicates a dynamic dimension. @@ -99,18 +104,24 @@ /// Return the number of elements present in the given shape. static int64_t getNumElements(ArrayRef shape); - }]; - let extraSharedClassDeclaration = [{ /// Return a clone of this type with the given new shape and element type. + /// The returned type is ranked, even if this type is unranked. auto clone(::llvm::ArrayRef shape, Type elementType) { - return $_type.cloneWith(shape, elementType); + return cloneWith(shape, elementType); } - /// Return a clone of this type with the given new shape. + + /// Return a clone of this type with the given new shape. The returned type + /// is ranked, even if this type is unranked. auto clone(::llvm::ArrayRef shape) { - return $_type.cloneWith(shape, $_type.getElementType()); + return cloneWith(shape, getElementType()); } - /// Return a clone of this type with the given new element type. + }]; + + let extraSharedClassDeclaration = [{ + /// Return a clone of this type with the given new element type. The + /// returned type is ranked if and only if this type is ranked. In that + /// case, the returned type has the same shape as this type. auto clone(::mlir::Type elementType) { return $_type.cloneWith(/*shape=*/std::nullopt, elementType); } diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -27,6 +27,8 @@ class FloatType; class IndexType; class IntegerType; +class MemRefType; +class RankedTensorType; class StringAttr; class TypeRange; @@ -95,6 +97,17 @@ TensorType cloneWith(std::optional> shape, Type elementType) const; + // Make sure that base class overloads are visible. + using ShapedType::Trait::clone; + + /// Return a clone of this type with the given new shape and element type. + /// The returned type is ranked, even if this type is unranked. + RankedTensorType clone(ArrayRef shape, Type elementType) const; + + /// Return a clone of this type with the given new shape. The returned type + /// is ranked, even if this type is unranked. + RankedTensorType clone(ArrayRef shape) const; + /// Return true if the specified element type is ok in a tensor. static bool isValidElementType(Type type); @@ -131,6 +144,17 @@ BaseMemRefType cloneWith(std::optional> shape, Type elementType) const; + // Make sure that base class overloads are visible. + using ShapedType::Trait::clone; + + /// Return a clone of this type with the given new shape and element type. + /// The returned type is ranked, even if this type is unranked. + MemRefType clone(ArrayRef shape, Type elementType) const; + + /// Return a clone of this type with the given new shape. The returned type + /// is ranked, even if this type is unranked. + MemRefType clone(ArrayRef shape) const; + /// Return true if the specified element type is ok in a memref. static bool isValidElementType(Type type); diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -629,7 +629,7 @@ "unsigned":$memorySpaceInd)> ]; let extraClassDeclaration = [{ - using ShapedType::Trait::clone; + using BaseMemRefType::clone; using ShapedType::Trait::getElementTypeBitWidth; using ShapedType::Trait::getRank; using ShapedType::Trait::getNumElements; @@ -794,7 +794,7 @@ }]> ]; let extraClassDeclaration = [{ - using ShapedType::Trait::clone; + using TensorType::clone; using ShapedType::Trait::getElementTypeBitWidth; using ShapedType::Trait::getRank; using ShapedType::Trait::getNumElements; @@ -807,6 +807,12 @@ /// This is a builder type that keeps local references to arguments. /// Arguments that are passed into the builder must outlive the builder. class Builder; + + /// Return a clone of this type with the given new element type and the same + /// shape as this type. + RankedTensorType clone(::mlir::Type elementType) { + return ::llvm::cast(cloneWith(getShape(), elementType)); + } }]; let skipDefaultBuilders = 1; let genVerifyDecl = 1; @@ -931,7 +937,7 @@ }]> ]; let extraClassDeclaration = [{ - using ShapedType::Trait::clone; + using BaseMemRefType::clone; using ShapedType::Trait::getElementTypeBitWidth; using ShapedType::Trait::getRank; using ShapedType::Trait::getNumElements; @@ -946,6 +952,12 @@ /// [deprecated] Returns the memory space in old raw integer representation. /// New `Attribute getMemorySpace()` method should be used instead. unsigned getMemorySpaceAsInt() const; + + /// Return a clone of this type with the given new element type and the same + /// shape as this type. + MemRefType clone(::mlir::Type elementType) { + return ::llvm::cast(cloneWith(getShape(), elementType)); + } }]; let skipDefaultBuilders = 1; let genVerifyDecl = 1; @@ -984,7 +996,7 @@ }]> ]; let extraClassDeclaration = [{ - using ShapedType::Trait::clone; + using TensorType::clone; using ShapedType::Trait::getElementTypeBitWidth; using ShapedType::Trait::getRank; using ShapedType::Trait::getNumElements; diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -291,6 +291,15 @@ rankedTy.getEncoding()); } +RankedTensorType TensorType::clone(::llvm::ArrayRef shape, + Type elementType) const { + return ::llvm::cast(cloneWith(shape, elementType)); +} + +RankedTensorType TensorType::clone(::llvm::ArrayRef shape) const { + return ::llvm::cast(cloneWith(shape, getElementType())); +} + // Check if "elementType" can be an element type of a tensor. static LogicalResult checkTensorElementType(function_ref emitError, @@ -370,6 +379,15 @@ return builder; } +MemRefType BaseMemRefType::clone(::llvm::ArrayRef shape, + Type elementType) const { + return ::llvm::cast(cloneWith(shape, elementType)); +} + +MemRefType BaseMemRefType::clone(::llvm::ArrayRef shape) const { + return ::llvm::cast(cloneWith(shape, getElementType())); +} + Attribute BaseMemRefType::getMemorySpace() const { if (auto rankedMemRefTy = dyn_cast()) return rankedMemRefTy.getMemorySpace();