diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td --- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td +++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td @@ -41,4 +41,151 @@ }]; } +//===----------------------------------------------------------------------===// +// ShapedType +//===----------------------------------------------------------------------===// + +def ShapedTypeInterface : TypeInterface<"ShapedType"> { + let cppNamespace = "::mlir"; + let description = [{ + This interface provides a common API for interacting with multi-dimensional + container types. These types contain a shape and an element type. + + A shape is a list of sizes corresponding to the dimensions of the container. + If the number of dimensions in the shape is unknown, the shape is "unranked". + If the number of dimensions is known, the shape "ranked". The sizes of the + dimensions of the shape must be positive, or kDynamicSize (in which case the + size of the dimension is dynamic, or not statically known). + }]; + let methods = [ + InterfaceMethod<[{ + Returns a clone of this type with the given shape and element + type. If a shape is not provided, the current shape of the type is used. + }], + "::mlir::ShapedType", "cloneWith", (ins + "::llvm::Optional<::llvm::ArrayRef>":$shape, + "::mlir::Type":$elementType + )>, + + InterfaceMethod<[{ + Returns the element type of this shaped type. + }], + "::mlir::Type", "getElementType">, + + InterfaceMethod<[{ + Returns if this type is ranked, i.e. it has a known number of dimensions. + }], + "bool", "hasRank">, + + InterfaceMethod<[{ + Returns the shape of this type if it is ranked, otherwise asserts. + }], + "::llvm::ArrayRef", "getShape">, + ]; + + let extraClassDeclaration = [{ + // TODO: merge these two special values in a single one used everywhere. + // Unfortunately, uses of `-1` have crept deep into the codebase now and are + // hard to track. + static constexpr int64_t kDynamicSize = -1; + static constexpr int64_t kDynamicStrideOrOffset = + std::numeric_limits::min(); + + /// Whether the given dimension size indicates a dynamic dimension. + static constexpr bool isDynamic(int64_t dSize) { + return dSize == kDynamicSize; + } + static constexpr bool isDynamicStrideOrOffset(int64_t dStrideOrOffset) { + return dStrideOrOffset == kDynamicStrideOrOffset; + } + + /// Return the number of elements present in the given shape. + static int64_t getNumElements(ArrayRef shape); + + /// Returns the total amount of bits occupied by a value of this type. This + /// does not take into account any memory layout or widening constraints, + /// e.g. a vector<3xi57> may report to occupy 3x57=171 bit, even though in + /// practice it will likely be stored as in a 4xi64 vector register. Fails + /// with an assertion if the size cannot be computed statically, e.g. if the + /// type has a dynamic shape or if its elemental type does not have a known + /// bit width. + int64_t getSizeInBits() const; + }]; + + let extraSharedClassDeclaration = [{ + /// Return a clone of this type with the given new shape and element type. + auto clone(::llvm::ArrayRef shape, Type elementType) { + return $_type.cloneWith(shape, elementType); + } + /// Return a clone of this type with the given new shape. + auto clone(::llvm::ArrayRef shape) { + return $_type.cloneWith(shape, $_type.getElementType()); + } + /// Return a clone of this type with the given new element type. + auto clone(::mlir::Type elementType) { + return $_type.cloneWith(/*shape=*/llvm::None, elementType); + } + + /// If an element type is an integer or a float, return its width. Otherwise, + /// abort. + unsigned getElementTypeBitWidth() const { + return $_type.getElementType().getIntOrFloatBitWidth(); + } + + /// If this is a ranked type, return the rank. Otherwise, abort. + int64_t getRank() const { + assert($_type.hasRank() && "cannot query rank of unranked shaped type"); + return $_type.getShape().size(); + } + + /// If it has static shape, return the number of elements. Otherwise, abort. + int64_t getNumElements() const { + assert(hasStaticShape() && "cannot get element count of dynamic shaped type"); + return ::mlir::ShapedType::getNumElements($_type.getShape()); + } + + /// Returns true if this dimension has a dynamic size (for ranked types); + /// aborts for unranked types. + bool isDynamicDim(unsigned idx) const { + assert(idx < getRank() && "invalid index for shaped type"); + return ::mlir::ShapedType::isDynamic($_type.getShape()[idx]); + } + + /// Returns if this type has a static shape, i.e. if the type is ranked and + /// all dimensions have known size (>= 0). + bool hasStaticShape() const { + return $_type.hasRank() && + llvm::none_of($_type.getShape(), ::mlir::ShapedType::isDynamic); + } + + /// Returns if this type has a static shape and the shape is equal to + /// `shape` return true. + bool hasStaticShape(::llvm::ArrayRef shape) const { + return hasStaticShape() && $_type.getShape() == shape; + } + + /// If this is a ranked type, return the number of dimensions with dynamic + /// size. Otherwise, abort. + int64_t getNumDynamicDims() const { + return llvm::count_if($_type.getShape(), ::mlir::ShapedType::isDynamic); + } + + /// If this is ranked type, return the size of the specified dimension. + /// Otherwise, abort. + int64_t getDimSize(unsigned idx) const { + assert(idx < getRank() && "invalid index for shaped type"); + return $_type.getShape()[idx]; + } + + /// Returns the position of the dynamic dimension relative to just the dynamic + /// dimensions, given its `index` within the shape. + unsigned getDynamicDimIndex(unsigned index) const { + assert(index < getRank() && "invalid index"); + assert(::mlir::ShapedType::isDynamic(getDimSize(index)) && "invalid index"); + return llvm::count_if($_type.getShape().take_front(index), + ::mlir::ShapedType::isDynamic); + } + }]; +} + #endif // MLIR_IR_BUILTINTYPEINTERFACES_TD_ diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -16,6 +16,12 @@ struct fltSemantics; } // namespace llvm +//===----------------------------------------------------------------------===// +// Tablegen Interface Declarations +//===----------------------------------------------------------------------===// + +#include "mlir/IR/BuiltinTypeInterfaces.h.inc" + namespace mlir { class AffineExpr; class AffineMap; @@ -56,118 +62,67 @@ }; //===----------------------------------------------------------------------===// -// ShapedType +// TensorType //===----------------------------------------------------------------------===// -/// This is a common base class between Vector, UnrankedTensor, RankedTensor, -/// and MemRef types because they share behavior and semantics around shape, -/// rank, and fixed element type. Any type with these semantics should inherit -/// from ShapedType. -class ShapedType : public Type { +/// Tensor types represent multi-dimensional arrays, and have two variants: +/// RankedTensorType and UnrankedTensorType. +/// Note: This class attaches the ShapedType trait to act as a mixin to +/// provide many useful utility functions. This inheritance has no effect +/// on derived tensor types. +class TensorType : public Type, public ShapedType::Trait { public: using Type::Type; - // TODO: merge these two special values in a single one used everywhere. - // Unfortunately, uses of `-1` have crept deep into the codebase now and are - // hard to track. - static constexpr int64_t kDynamicSize = -1; - static constexpr int64_t kDynamicStrideOrOffset = - std::numeric_limits::min(); - - /// Return clone of this type with new shape and element type. - ShapedType clone(ArrayRef shape, Type elementType); - ShapedType clone(ArrayRef shape); - ShapedType clone(Type elementType); - - /// Return the element type. + /// Returns the element type of this tensor type. Type getElementType() const; - /// If an element type is an integer or a float, return its width. Otherwise, - /// abort. - unsigned getElementTypeBitWidth() const; - - /// If it has static shape, return the number of elements. Otherwise, abort. - int64_t getNumElements() const; - - /// If this is a ranked type, return the rank. Otherwise, abort. - int64_t getRank() const; - - /// Whether or not this is a ranked type. Memrefs, vectors and ranked tensors - /// have a rank, while unranked tensors do not. + /// Returns if this type is ranked, i.e. it has a known number of dimensions. bool hasRank() const; - /// If this is a ranked type, return the shape. Otherwise, abort. + /// Returns the shape of this tensor type. ArrayRef getShape() const; - /// If this is unranked type or any dimension has unknown size (<0), it - /// doesn't have static shape. If all dimensions have known size (>= 0), it - /// has static shape. - bool hasStaticShape() const; - - /// If this has a static shape and the shape is equal to `shape` return true. - bool hasStaticShape(ArrayRef shape) const; - - /// If this is a ranked type, return the number of dimensions with dynamic - /// size. Otherwise, abort. - int64_t getNumDynamicDims() const; - - /// If this is ranked type, return the size of the specified dimension. - /// Otherwise, abort. - int64_t getDimSize(unsigned idx) const; - - /// Returns true if this dimension has a dynamic size (for ranked types); - /// aborts for unranked types. - bool isDynamicDim(unsigned idx) const; - - /// Returns the position of the dynamic dimension relative to just the dynamic - /// dimensions, given its `index` within the shape. - unsigned getDynamicDimIndex(unsigned index) const; + /// Clone this type with the given shape and element type. If the + /// provided shape is `None`, the current shape of the type is used. + TensorType cloneWith(Optional> shape, + Type elementType) const; - /// Get the total amount of bits occupied by a value of this type. This does - /// not take into account any memory layout or widening constraints, e.g. a - /// vector<3xi57> is reported to occupy 3x57=171 bit, even though in practice - /// it will likely be stored as in a 4xi64 vector register. Fail an assertion - /// if the size cannot be computed statically, i.e. if the type has a dynamic - /// shape or if its elemental type does not have a known bit width. - int64_t getSizeInBits() const; + /// Return true if the specified element type is ok in a tensor. + static bool isValidElementType(Type type); /// Methods for support type inquiry through isa, cast, and dyn_cast. static bool classof(Type type); - /// Whether the given dimension size indicates a dynamic dimension. - static constexpr bool isDynamic(int64_t dSize) { - return dSize == kDynamicSize; - } - static constexpr bool isDynamicStrideOrOffset(int64_t dStrideOrOffset) { - return dStrideOrOffset == kDynamicStrideOrOffset; - } + /// Allow implicit conversion to ShapedType. + operator ShapedType() const { return cast(); } }; //===----------------------------------------------------------------------===// -// TensorType +// BaseMemRefType //===----------------------------------------------------------------------===// -/// Tensor types represent multi-dimensional arrays, and have two variants: -/// RankedTensorType and UnrankedTensorType. -class TensorType : public ShapedType { +/// This class provides a shared interface for ranked and unranked memref types. +/// Note: This class attaches the ShapedType trait to act as a mixin to +/// provide many useful utility functions. This inheritance has no effect +/// on derived memref types. +class BaseMemRefType : public Type, public ShapedType::Trait { public: - using ShapedType::ShapedType; + using Type::Type; - /// Return true if the specified element type is ok in a tensor. - static bool isValidElementType(Type type); + /// Returns the element type of this memref type. + Type getElementType() const; - /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool classof(Type type); -}; + /// Returns if this type is ranked, i.e. it has a known number of dimensions. + bool hasRank() const; -//===----------------------------------------------------------------------===// -// BaseMemRefType -//===----------------------------------------------------------------------===// + /// Returns the shape of this memref type. + ArrayRef getShape() const; -/// Base MemRef for Ranked and Unranked variants -class BaseMemRefType : public ShapedType { -public: - using ShapedType::ShapedType; + /// Clone this type with the given shape and element type. If the + /// provided shape is `None`, the current shape of the type is used. + BaseMemRefType cloneWith(Optional> shape, + Type elementType) const; /// Return true if the specified element type is ok in a memref. static bool isValidElementType(Type type); @@ -181,6 +136,9 @@ /// [deprecated] Returns the memory space in old raw integer representation. /// New `Attribute getMemorySpace()` method should be used instead. unsigned getMemorySpaceAsInt() const; + + /// Allow implicit conversion to ShapedType. + operator ShapedType() const { return cast(); } }; } // namespace mlir @@ -192,12 +150,6 @@ #define GET_TYPEDEF_CLASSES #include "mlir/IR/BuiltinTypes.h.inc" -//===----------------------------------------------------------------------===// -// Tablegen Interface Declarations -//===----------------------------------------------------------------------===// - -#include "mlir/IR/BuiltinTypeInterfaces.h.inc" - namespace mlir { //===----------------------------------------------------------------------===// @@ -439,11 +391,6 @@ return Float128Type::get(ctx); } -inline bool ShapedType::classof(Type type) { - return type.isa(); -} - inline bool TensorType::classof(Type type) { return type.isa(); } diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -266,7 +266,7 @@ //===----------------------------------------------------------------------===// def Builtin_MemRef : Builtin_Type<"MemRef", [ - DeclareTypeInterfaceMethods + DeclareTypeInterfaceMethods, ShapedTypeInterface ], "BaseMemRefType"> { let summary = "Shaped reference to a region of memory"; let description = [{ @@ -541,6 +541,16 @@ "unsigned":$memorySpaceInd)> ]; let extraClassDeclaration = [{ + using ShapedType::Trait::clone; + using ShapedType::Trait::getElementTypeBitWidth; + using ShapedType::Trait::getRank; + using ShapedType::Trait::getNumElements; + using ShapedType::Trait::isDynamicDim; + using ShapedType::Trait::hasStaticShape; + using ShapedType::Trait::getNumDynamicDims; + using ShapedType::Trait::getDimSize; + using ShapedType::Trait::getDynamicDimIndex; + /// This is a builder type that keeps local references to arguments. /// Arguments that are passed into the builder must outlive the builder. class Builder; @@ -620,7 +630,7 @@ //===----------------------------------------------------------------------===// def Builtin_RankedTensor : Builtin_Type<"RankedTensor", [ - DeclareTypeInterfaceMethods + DeclareTypeInterfaceMethods, ShapedTypeInterface ], "TensorType"> { let summary = "Multi-dimensional array with a fixed number of dimensions"; let description = [{ @@ -702,6 +712,16 @@ }]> ]; let extraClassDeclaration = [{ + using ShapedType::Trait::clone; + using ShapedType::Trait::getElementTypeBitWidth; + using ShapedType::Trait::getRank; + using ShapedType::Trait::getNumElements; + using ShapedType::Trait::isDynamicDim; + using ShapedType::Trait::hasStaticShape; + using ShapedType::Trait::getNumDynamicDims; + using ShapedType::Trait::getDimSize; + using ShapedType::Trait::getDynamicDimIndex; + /// This is a builder type that keeps local references to arguments. /// Arguments that are passed into the builder must outlive the builder. class Builder; @@ -784,7 +804,7 @@ //===----------------------------------------------------------------------===// def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", [ - DeclareTypeInterfaceMethods + DeclareTypeInterfaceMethods, ShapedTypeInterface ], "BaseMemRefType"> { let summary = "Shaped reference, with unknown rank, to a region of memory"; let description = [{ @@ -831,6 +851,16 @@ }]> ]; let extraClassDeclaration = [{ + using ShapedType::Trait::clone; + using ShapedType::Trait::getElementTypeBitWidth; + using ShapedType::Trait::getRank; + using ShapedType::Trait::getNumElements; + using ShapedType::Trait::isDynamicDim; + using ShapedType::Trait::hasStaticShape; + using ShapedType::Trait::getNumDynamicDims; + using ShapedType::Trait::getDimSize; + using ShapedType::Trait::getDynamicDimIndex; + ArrayRef getShape() const { return llvm::None; } /// [deprecated] Returns the memory space in old raw integer representation. @@ -846,7 +876,7 @@ //===----------------------------------------------------------------------===// def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", [ - DeclareTypeInterfaceMethods + DeclareTypeInterfaceMethods, ShapedTypeInterface ], "TensorType"> { let summary = "Multi-dimensional array with unknown dimensions"; let description = [{ @@ -874,6 +904,16 @@ }]> ]; let extraClassDeclaration = [{ + using ShapedType::Trait::clone; + using ShapedType::Trait::getElementTypeBitWidth; + using ShapedType::Trait::getRank; + using ShapedType::Trait::getNumElements; + using ShapedType::Trait::isDynamicDim; + using ShapedType::Trait::hasStaticShape; + using ShapedType::Trait::getNumDynamicDims; + using ShapedType::Trait::getDimSize; + using ShapedType::Trait::getDynamicDimIndex; + ArrayRef getShape() const { return llvm::None; } }]; let skipDefaultBuilders = 1; @@ -885,8 +925,8 @@ //===----------------------------------------------------------------------===// def Builtin_Vector : Builtin_Type<"Vector", [ - DeclareTypeInterfaceMethods - ], "ShapedType"> { + DeclareTypeInterfaceMethods, ShapedTypeInterface + ], "Type"> { let summary = "Multi-dimensional SIMD vector type"; let description = [{ Syntax: @@ -966,6 +1006,14 @@ /// element type of bitwidth scaled by `scale`. /// Return null if the scaled element type cannot be represented. VectorType scaleElementBitwidth(unsigned scale); + + /// Returns if this type is ranked (always true). + bool hasRank() const { return true; } + + /// Clone this vector type with the given shape and element type. If the + /// provided shape is `None`, the current shape of the type is used. + VectorType cloneWith(Optional> shape, + Type elementType); }]; let skipDefaultBuilders = 1; let genVerifyDecl = 1; diff --git a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp --- a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp +++ b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp @@ -51,10 +51,10 @@ auto result = getStridesAndOffset(type, strides, offset); (void)result; assert(succeeded(result) && "unexpected failure in stride computation"); - assert(!MemRefType::isDynamicStrideOrOffset(offset) && + assert(!ShapedType::isDynamicStrideOrOffset(offset) && "expected static offset"); assert(!llvm::any_of(strides, [](int64_t stride) { - return MemRefType::isDynamicStrideOrOffset(stride); + return ShapedType::isDynamicStrideOrOffset(stride); }) && "expected static strides"); auto convertedType = typeConverter.convertType(type); diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp --- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp @@ -79,14 +79,14 @@ Value index; if (offset != 0) // Skip if offset is zero. - index = MemRefType::isDynamicStrideOrOffset(offset) + index = ShapedType::isDynamicStrideOrOffset(offset) ? memRefDescriptor.offset(rewriter, loc) : createIndexConstant(rewriter, loc, offset); for (int i = 0, e = indices.size(); i < e; ++i) { Value increment = indices[i]; if (strides[i] != 1) { // Skip if stride is 1. - Value stride = MemRefType::isDynamicStrideOrOffset(strides[i]) + Value stride = ShapedType::isDynamicStrideOrOffset(strides[i]) ? memRefDescriptor.stride(rewriter, loc, i) : createIndexConstant(rewriter, loc, strides[i]); increment = rewriter.create(loc, increment, stride); diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -106,7 +106,7 @@ Operation *op) const { uint64_t sizeDivisor = getMemRefEltSizeInBytes(type, op); for (unsigned i = 0, e = type.getRank(); i < e; i++) { - if (type.isDynamic(type.getDimSize(i))) + if (ShapedType::isDynamic(type.getDimSize(i))) continue; sizeDivisor = sizeDivisor * type.getDimSize(i); } @@ -1467,7 +1467,7 @@ ArrayRef strides, Value nextSize, Value runningStride, unsigned idx) const { assert(idx < strides.size()); - if (!MemRefType::isDynamicStrideOrOffset(strides[idx])) + if (!ShapedType::isDynamicStrideOrOffset(strides[idx])) return createIndexConstant(rewriter, loc, strides[idx]); if (nextSize) return runningStride diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -342,22 +342,22 @@ for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) { auto ss = std::get<0>(it), st = std::get<1>(it); if (ss != st) - if (MemRefType::isDynamic(ss) && !MemRefType::isDynamic(st)) + if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st)) return false; } // If cast is towards more static offset along any dimension, don't fold. if (sourceOffset != resultOffset) - if (MemRefType::isDynamicStrideOrOffset(sourceOffset) && - !MemRefType::isDynamicStrideOrOffset(resultOffset)) + if (ShapedType::isDynamicStrideOrOffset(sourceOffset) && + !ShapedType::isDynamicStrideOrOffset(resultOffset)) return false; // If cast is towards more static strides along any dimension, don't fold. for (auto it : llvm::zip(sourceStrides, resultStrides)) { auto ss = std::get<0>(it), st = std::get<1>(it); if (ss != st) - if (MemRefType::isDynamicStrideOrOffset(ss) && - !MemRefType::isDynamicStrideOrOffset(st)) + if (ShapedType::isDynamicStrideOrOffset(ss) && + !ShapedType::isDynamicStrideOrOffset(st)) return false; } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -518,7 +518,7 @@ // Find upper bound in current dimension. unsigned p = perm(enc, d); Value up = linalg::createOrFoldDimOp(rewriter, loc, t->get(), p); - if (shape[p] == MemRefType::kDynamicSize) + if (ShapedType::isDynamic(shape[p])) args.push_back(up); assert(codegen.highs[tensor][idx] == nullptr); codegen.sizes[idx] = codegen.highs[tensor][idx] = up; diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -268,13 +268,12 @@ fromElements.getResult().getType().cast(); // The case where the type encodes the size of the dimension is handled // above. - assert(resultType.getShape()[index.getInt()] == - RankedTensorType::kDynamicSize); + assert(ShapedType::isDynamic(resultType.getShape()[index.getInt()])); // Find the operand of the fromElements that corresponds to this index. auto dynExtents = fromElements.dynamicExtents().begin(); for (auto dim : resultType.getShape().take_front(index.getInt())) - if (dim == RankedTensorType::kDynamicSize) + if (ShapedType::isDynamic(dim)) dynExtents++; return Value{*dynExtents}; @@ -523,13 +522,13 @@ auto operandsIt = tensorFromElements.dynamicExtents().begin(); for (int64_t dim : resultType.getShape()) { - if (dim != RankedTensorType::kDynamicSize) { + if (!ShapedType::isDynamic(dim)) { newShape.push_back(dim); continue; } APInt index; if (!matchPattern(*operandsIt, m_ConstantInt(&index))) { - newShape.push_back(RankedTensorType::kDynamicSize); + newShape.push_back(ShapedType::kDynamicSize); newOperands.push_back(*operandsIt++); continue; } @@ -661,7 +660,7 @@ return op.emitOpError("source and destination tensor should have the " "same number of elements"); } - if (shapeSize == TensorType::kDynamicSize) + if (ShapedType::isDynamic(shapeSize)) return op.emitOpError("cannot use shape operand with dynamic length to " "reshape to statically-ranked tensor type"); if (shapeSize != resultRankedType.getRank()) diff --git a/mlir/lib/Dialect/Vector/VectorTransferSplitRewritePatterns.cpp b/mlir/lib/Dialect/Vector/VectorTransferSplitRewritePatterns.cpp --- a/mlir/lib/Dialect/Vector/VectorTransferSplitRewritePatterns.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransferSplitRewritePatterns.cpp @@ -172,13 +172,13 @@ resStrides(bT.getRank(), 0); for (int64_t idx = 0, e = aT.getRank(); idx < e; ++idx) { resShape[idx] = - (aShape[idx] == bShape[idx]) ? aShape[idx] : MemRefType::kDynamicSize; + (aShape[idx] == bShape[idx]) ? aShape[idx] : ShapedType::kDynamicSize; resStrides[idx] = (aStrides[idx] == bStrides[idx]) ? aStrides[idx] - : MemRefType::kDynamicStrideOrOffset; + : ShapedType::kDynamicStrideOrOffset; } resOffset = - (aOffset == bOffset) ? aOffset : MemRefType::kDynamicStrideOrOffset; + (aOffset == bOffset) ? aOffset : ShapedType::kDynamicStrideOrOffset; return MemRefType::get( resShape, aT.getElementType(), makeStridedLinearLayoutMap(resStrides, resOffset, aT.getContext())); diff --git a/mlir/lib/IR/BuiltinTypeInterfaces.cpp b/mlir/lib/IR/BuiltinTypeInterfaces.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/IR/BuiltinTypeInterfaces.cpp @@ -0,0 +1,51 @@ +//===- BuiltinTypeInterfaces.cpp ------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "llvm/ADT/Sequence.h" + +using namespace mlir; +using namespace mlir::detail; + +//===----------------------------------------------------------------------===// +/// Tablegen Interface Definitions +//===----------------------------------------------------------------------===// + +#include "mlir/IR/BuiltinTypeInterfaces.cpp.inc" + +//===----------------------------------------------------------------------===// +// ShapedType +//===----------------------------------------------------------------------===// + +constexpr int64_t ShapedType::kDynamicSize; +constexpr int64_t ShapedType::kDynamicStrideOrOffset; + +int64_t ShapedType::getNumElements(ArrayRef shape) { + int64_t num = 1; + for (int64_t dim : shape) { + num *= dim; + assert(num >= 0 && "integer overflow in element count computation"); + } + return num; +} + +int64_t ShapedType::getSizeInBits() const { + assert(hasStaticShape() && + "cannot get the bit size of an aggregate with a dynamic shape"); + + auto elementType = getElementType(); + if (elementType.isIntOrFloat()) + return elementType.getIntOrFloatBitWidth() * getNumElements(); + + if (auto complexType = elementType.dyn_cast()) { + elementType = complexType.getElementType(); + return elementType.getIntOrFloatBitWidth() * getNumElements() * 2; + } + return getNumElements() * elementType.cast().getSizeInBits(); +} diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -32,12 +32,6 @@ #define GET_TYPEDEF_CLASSES #include "mlir/IR/BuiltinTypes.cpp.inc" -//===----------------------------------------------------------------------===// -/// Tablegen Interface Definitions -//===----------------------------------------------------------------------===// - -#include "mlir/IR/BuiltinTypeInterfaces.cpp.inc" - //===----------------------------------------------------------------------===// // BuiltinDialect //===----------------------------------------------------------------------===// @@ -271,171 +265,6 @@ return success(); } -//===----------------------------------------------------------------------===// -// ShapedType -//===----------------------------------------------------------------------===// -constexpr int64_t ShapedType::kDynamicSize; -constexpr int64_t ShapedType::kDynamicStrideOrOffset; - -ShapedType ShapedType::clone(ArrayRef shape, Type elementType) { - if (auto other = dyn_cast()) { - MemRefType::Builder b(other); - b.setShape(shape); - b.setElementType(elementType); - return b; - } - - if (auto other = dyn_cast()) { - MemRefType::Builder b(shape, elementType); - b.setMemorySpace(other.getMemorySpace()); - return b; - } - - if (isa()) - return RankedTensorType::get(shape, elementType); - - if (auto vecTy = dyn_cast()) - return VectorType::get(shape, elementType, vecTy.getNumScalableDims()); - - llvm_unreachable("Unhandled ShapedType clone case"); -} - -ShapedType ShapedType::clone(ArrayRef shape) { - if (auto other = dyn_cast()) { - MemRefType::Builder b(other); - b.setShape(shape); - return b; - } - - if (auto other = dyn_cast()) { - MemRefType::Builder b(shape, other.getElementType()); - b.setShape(shape); - b.setMemorySpace(other.getMemorySpace()); - return b; - } - - if (isa()) - return RankedTensorType::get(shape, getElementType()); - - if (auto vecTy = dyn_cast()) - return VectorType::get(shape, getElementType(), vecTy.getNumScalableDims()); - - llvm_unreachable("Unhandled ShapedType clone case"); -} - -ShapedType ShapedType::clone(Type elementType) { - if (auto other = dyn_cast()) { - MemRefType::Builder b(other); - b.setElementType(elementType); - return b; - } - - if (auto other = dyn_cast()) { - return UnrankedMemRefType::get(elementType, other.getMemorySpace()); - } - - if (isa()) { - if (hasRank()) - return RankedTensorType::get(getShape(), elementType); - return UnrankedTensorType::get(elementType); - } - - if (auto vecTy = dyn_cast()) - return VectorType::get(getShape(), elementType, vecTy.getNumScalableDims()); - - llvm_unreachable("Unhandled ShapedType clone hit"); -} - -Type ShapedType::getElementType() const { - return TypeSwitch(*this) - .Case([](auto ty) { return ty.getElementType(); }); -} - -unsigned ShapedType::getElementTypeBitWidth() const { - return getElementType().getIntOrFloatBitWidth(); -} - -int64_t ShapedType::getNumElements() const { - assert(hasStaticShape() && "cannot get element count of dynamic shaped type"); - auto shape = getShape(); - int64_t num = 1; - for (auto dim : shape) { - num *= dim; - assert(num >= 0 && "integer overflow in element count computation"); - } - return num; -} - -int64_t ShapedType::getRank() const { - assert(hasRank() && "cannot query rank of unranked shaped type"); - return getShape().size(); -} - -bool ShapedType::hasRank() const { - return !isa(); -} - -int64_t ShapedType::getDimSize(unsigned idx) const { - assert(idx < getRank() && "invalid index for shaped type"); - return getShape()[idx]; -} - -bool ShapedType::isDynamicDim(unsigned idx) const { - assert(idx < getRank() && "invalid index for shaped type"); - return isDynamic(getShape()[idx]); -} - -unsigned ShapedType::getDynamicDimIndex(unsigned index) const { - assert(index < getRank() && "invalid index"); - assert(ShapedType::isDynamic(getDimSize(index)) && "invalid index"); - return llvm::count_if(getShape().take_front(index), ShapedType::isDynamic); -} - -/// Get the number of bits require to store a value of the given shaped type. -/// Compute the value recursively since tensors are allowed to have vectors as -/// elements. -int64_t ShapedType::getSizeInBits() const { - assert(hasStaticShape() && - "cannot get the bit size of an aggregate with a dynamic shape"); - - auto elementType = getElementType(); - if (elementType.isIntOrFloat()) - return elementType.getIntOrFloatBitWidth() * getNumElements(); - - if (auto complexType = elementType.dyn_cast()) { - elementType = complexType.getElementType(); - return elementType.getIntOrFloatBitWidth() * getNumElements() * 2; - } - - // Tensors can have vectors and other tensors as elements, other shaped types - // cannot. - assert(isa() && "unsupported element type"); - assert((elementType.isa()) && - "unsupported tensor element type"); - return getNumElements() * elementType.cast().getSizeInBits(); -} - -ArrayRef ShapedType::getShape() const { - if (auto vectorType = dyn_cast()) - return vectorType.getShape(); - if (auto tensorType = dyn_cast()) - return tensorType.getShape(); - return cast().getShape(); -} - -int64_t ShapedType::getNumDynamicDims() const { - return llvm::count_if(getShape(), isDynamic); -} - -bool ShapedType::hasStaticShape() const { - return hasRank() && llvm::none_of(getShape(), isDynamic); -} - -bool ShapedType::hasStaticShape(ArrayRef shape) const { - return hasStaticShape() && getShape() == shape; -} - //===----------------------------------------------------------------------===// // VectorType //===----------------------------------------------------------------------===// @@ -474,10 +303,44 @@ walkTypesFn(getElementType()); } +VectorType VectorType::cloneWith(Optional> shape, + Type elementType) { + return VectorType::get(shape.getValueOr(getShape()), elementType, + getNumScalableDims()); +} + //===----------------------------------------------------------------------===// // TensorType //===----------------------------------------------------------------------===// +Type TensorType::getElementType() const { + return llvm::TypeSwitch(*this) + .Case( + [](auto type) { return type.getElementType(); }); +} + +bool TensorType::hasRank() const { return !isa(); } + +ArrayRef TensorType::getShape() const { + return cast().getShape(); +} + +TensorType TensorType::cloneWith(Optional> shape, + Type elementType) const { + if (auto unrankedTy = dyn_cast()) { + if (shape) + return RankedTensorType::get(*shape, elementType); + return UnrankedTensorType::get(elementType); + } + + auto rankedTy = cast(); + if (!shape) + return RankedTensorType::get(rankedTy.getShape(), elementType, + rankedTy.getEncoding()); + return RankedTensorType::get(shape.getValueOr(rankedTy.getShape()), + elementType, rankedTy.getEncoding()); +} + // Check if "elementType" can be an element type of a tensor. static LogicalResult checkTensorElementType(function_ref emitError, @@ -542,6 +405,35 @@ // BaseMemRefType //===----------------------------------------------------------------------===// +Type BaseMemRefType::getElementType() const { + return llvm::TypeSwitch(*this) + .Case( + [](auto type) { return type.getElementType(); }); +} + +bool BaseMemRefType::hasRank() const { return !isa(); } + +ArrayRef BaseMemRefType::getShape() const { + return cast().getShape(); +} + +BaseMemRefType BaseMemRefType::cloneWith(Optional> shape, + Type elementType) const { + if (auto unrankedTy = dyn_cast()) { + if (!shape) + return UnrankedMemRefType::get(elementType, getMemorySpace()); + MemRefType::Builder builder(*shape, elementType); + builder.setMemorySpace(getMemorySpace()); + return builder; + } + + MemRefType::Builder builder(cast()); + if (shape) + builder.setShape(*shape); + builder.setElementType(elementType); + return builder; +} + Attribute BaseMemRefType::getMemorySpace() const { if (auto rankedMemRefTy = dyn_cast()) return rankedMemRefTy.getMemorySpace(); diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt --- a/mlir/lib/IR/CMakeLists.txt +++ b/mlir/lib/IR/CMakeLists.txt @@ -9,6 +9,7 @@ BuiltinAttributes.cpp BuiltinDialect.cpp BuiltinTypes.cpp + BuiltinTypeInterfaces.cpp Diagnostics.cpp Dialect.cpp Dominance.cpp diff --git a/mlir/unittests/IR/ShapedTypeTest.cpp b/mlir/unittests/IR/ShapedTypeTest.cpp --- a/mlir/unittests/IR/ShapedTypeTest.cpp +++ b/mlir/unittests/IR/ShapedTypeTest.cpp @@ -30,14 +30,14 @@ AffineMap map = makeStridedLinearLayoutMap({2, 3}, 5, &context); ShapedType memrefType = - MemRefType::Builder(memrefOriginalShape, memrefOriginalType) + (ShapedType)MemRefType::Builder(memrefOriginalShape, memrefOriginalType) .setMemorySpace(memSpace) .setLayout(AffineMapAttr::get(map)); // Update shape. llvm::SmallVector memrefNewShape({30, 40}); ASSERT_NE(memrefOriginalShape, memrefNewShape); ASSERT_EQ(memrefType.clone(memrefNewShape), - (MemRefType)MemRefType::Builder(memrefNewShape, memrefOriginalType) + (ShapedType)MemRefType::Builder(memrefNewShape, memrefOriginalType) .setMemorySpace(memSpace) .setLayout(AffineMapAttr::get(map))); // Update type. @@ -81,25 +81,29 @@ // Update shape. llvm::SmallVector tensorNewShape({30, 40}); ASSERT_NE(tensorOriginalShape, tensorNewShape); - ASSERT_EQ(tensorType.clone(tensorNewShape), - RankedTensorType::get(tensorNewShape, tensorOriginalType)); + ASSERT_EQ( + tensorType.clone(tensorNewShape), + (ShapedType)RankedTensorType::get(tensorNewShape, tensorOriginalType)); // Update type. Type tensorNewType = f32; ASSERT_NE(tensorOriginalType, tensorNewType); - ASSERT_EQ(tensorType.clone(tensorNewType), - RankedTensorType::get(tensorOriginalShape, tensorNewType)); + ASSERT_EQ( + tensorType.clone(tensorNewType), + (ShapedType)RankedTensorType::get(tensorOriginalShape, tensorNewType)); // Update both. ASSERT_EQ(tensorType.clone(tensorNewShape, tensorNewType), - RankedTensorType::get(tensorNewShape, tensorNewType)); + (ShapedType)RankedTensorType::get(tensorNewShape, tensorNewType)); // Test unranked tensor cloning. ShapedType unrankedTensorType = UnrankedTensorType::get(tensorOriginalType); - ASSERT_EQ(unrankedTensorType.clone(tensorNewShape), - RankedTensorType::get(tensorNewShape, tensorOriginalType)); + ASSERT_EQ( + unrankedTensorType.clone(tensorNewShape), + (ShapedType)RankedTensorType::get(tensorNewShape, tensorOriginalType)); ASSERT_EQ(unrankedTensorType.clone(tensorNewType), - UnrankedTensorType::get(tensorNewType)); - ASSERT_EQ(unrankedTensorType.clone(tensorNewShape), - RankedTensorType::get(tensorNewShape, tensorOriginalType)); + (ShapedType)UnrankedTensorType::get(tensorNewType)); + ASSERT_EQ( + unrankedTensorType.clone(tensorNewShape), + (ShapedType)RankedTensorType::get(tensorNewShape, tensorOriginalType)); } TEST(ShapedTypeTest, CloneVector) {