diff --git a/mlir/docs/LangRef.md b/mlir/docs/LangRef.md --- a/mlir/docs/LangRef.md +++ b/mlir/docs/LangRef.md @@ -1003,14 +1003,9 @@ For the moment, a `memref` supports parsing a strided form which is converted to a semi-affine map automatically. -The memory space of a memref is specified by a target-specific integer index. If -no memory space is specified, then the default memory space (0) is used. The -default space is target specific but always at index 0. - -TODO: MLIR will eventually have target-dialects which allow symbolic use of -memory hierarchy names (e.g. L3, L2, L1, ...) but we have not spec'd the details -of that mechanism yet. Until then, this document pretends that it is valid to -refer to these memories by `bare-id`. +The memory space of a memref is specified by a target-specific attribute. +It might be an integer value, string, dictionary or custom dialect attribute. +The empty memory space (attribute is None) is target specific. The notionally dynamic value of a memref value includes the address of the buffer allocated, as well as the symbols referred to by the shape, layout map, diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h --- a/mlir/include/mlir-c/BuiltinTypes.h +++ b/mlir/include/mlir-c/BuiltinTypes.h @@ -224,38 +224,38 @@ /// same context as element type. The type is owned by the context. MLIR_CAPI_EXPORTED MlirType mlirMemRefTypeGet( MlirType elementType, intptr_t rank, const int64_t *shape, intptr_t numMaps, - MlirAffineMap const *affineMaps, unsigned memorySpace); + MlirAffineMap const *affineMaps, MlirAttribute memorySpace); /// Same as "mlirMemRefTypeGet" but returns a nullptr-wrapping MlirType o /// illegal arguments, emitting appropriate diagnostics. MLIR_CAPI_EXPORTED MlirType mlirMemRefTypeGetChecked( MlirLocation loc, MlirType elementType, intptr_t rank, const int64_t *shape, - intptr_t numMaps, MlirAffineMap const *affineMaps, unsigned memorySpace); + intptr_t numMaps, MlirAffineMap const *affineMaps, + MlirAttribute memorySpace); /// Creates a MemRef type with the given rank, shape, memory space and element /// type in the same context as the element type. The type has no affine maps, /// i.e. represents a default row-major contiguous memref. The type is owned by /// the context. -MLIR_CAPI_EXPORTED MlirType mlirMemRefTypeContiguousGet(MlirType elementType, - intptr_t rank, - const int64_t *shape, - unsigned memorySpace); +MLIR_CAPI_EXPORTED MlirType +mlirMemRefTypeContiguousGet(MlirType elementType, intptr_t rank, + const int64_t *shape, MlirAttribute memorySpace); /// Same as "mlirMemRefTypeContiguousGet" but returns a nullptr wrapping /// MlirType on illegal arguments, emitting appropriate diagnostics. MLIR_CAPI_EXPORTED MlirType mlirMemRefTypeContiguousGetChecked( MlirLocation loc, MlirType elementType, intptr_t rank, const int64_t *shape, - unsigned memorySpace); + MlirAttribute memorySpace); /// Creates an Unranked MemRef type with the given element type and in the given /// memory space. The type is owned by the context of element type. -MLIR_CAPI_EXPORTED MlirType mlirUnrankedMemRefTypeGet(MlirType elementType, - unsigned memorySpace); +MLIR_CAPI_EXPORTED MlirType +mlirUnrankedMemRefTypeGet(MlirType elementType, MlirAttribute memorySpace); /// Same as "mlirUnrankedMemRefTypeGet" but returns a nullptr wrapping /// MlirType on illegal arguments, emitting appropriate diagnostics. MLIR_CAPI_EXPORTED MlirType mlirUnrankedMemRefTypeGetChecked( - MlirLocation loc, MlirType elementType, unsigned memorySpace); + MlirLocation loc, MlirType elementType, MlirAttribute memorySpace); /// Returns the number of affine layout maps in the given MemRef type. MLIR_CAPI_EXPORTED intptr_t mlirMemRefTypeGetNumAffineMaps(MlirType type); @@ -265,10 +265,11 @@ intptr_t pos); /// Returns the memory space of the given MemRef type. -MLIR_CAPI_EXPORTED unsigned mlirMemRefTypeGetMemorySpace(MlirType type); +MLIR_CAPI_EXPORTED MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type); /// Returns the memory spcae of the given Unranked MemRef type. -MLIR_CAPI_EXPORTED unsigned mlirUnrankedMemrefGetMemorySpace(MlirType type); +MLIR_CAPI_EXPORTED MlirAttribute +mlirUnrankedMemrefGetMemorySpace(MlirType type); //===----------------------------------------------------------------------===// // Tuple type. diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -9,6 +9,7 @@ #ifndef MLIR_IR_BUILTINTYPES_H #define MLIR_IR_BUILTINTYPES_H +#include "mlir/IR/Attributes.h" #include "mlir/IR/Types.h" namespace llvm { @@ -293,6 +294,13 @@ static bool classof(Type type); /// Returns the memory space in which data referred to by this memref resides. + Attribute getMemorySpace() const; + + /// Checks if the memorySpace has supported Attribute type. + static bool isSupportedMemorySpace(Attribute memorySpace); + + /// [deprecated] Returns the memory space in old raw integer representation. + /// New `Attribute getMemorySpace()` method should be used instead. unsigned getMemorySpaceAsInt() const; }; @@ -314,12 +322,11 @@ explicit Builder(MemRefType other) : shape(other.getShape()), elementType(other.getElementType()), affineMaps(other.getAffineMaps()), - memorySpace(other.getMemorySpaceAsInt()) {} + memorySpace(other.getMemorySpace()) {} // Build from scratch. Builder(ArrayRef shape, Type elementType) - : shape(shape), elementType(elementType), affineMaps(), memorySpace(0) { - } + : shape(shape), elementType(elementType), affineMaps() {} Builder &setShape(ArrayRef newShape) { shape = newShape; @@ -336,11 +343,14 @@ return *this; } - Builder &setMemorySpace(unsigned newMemorySpace) { + Builder &setMemorySpace(Attribute newMemorySpace) { memorySpace = newMemorySpace; return *this; } + // [deprecated] `setMemorySpace(Attribute)` should be used instead. + Builder &setMemorySpace(unsigned newMemorySpace); + operator MemRefType() { return MemRefType::get(shape, elementType, affineMaps, memorySpace); } @@ -349,7 +359,7 @@ ArrayRef shape; Type elementType; ArrayRef affineMaps; - unsigned memorySpace; + Attribute memorySpace; }; using Base::Base; @@ -361,12 +371,21 @@ /// construction failures. static MemRefType get(ArrayRef shape, Type elementType, ArrayRef affineMapComposition = {}, - unsigned memorySpace = 0); + Attribute memorySpace = {}); + // [deprecated] `Attribute`-based form should be used instead. + static MemRefType get(ArrayRef shape, Type elementType, + ArrayRef affineMapComposition, + unsigned memorySpace); /// Get or create a new MemRefType based on shape, element type, affine /// map composition, and memory space. If the MemRefType defined by the /// arguments would be ill-formed, an error is emitted to `emitError` and a /// null type is returned. + static MemRefType getChecked(function_ref emitError, + ArrayRef shape, Type elementType, + ArrayRef affineMapComposition, + Attribute memorySpace); + // [deprecated] `Attribute`-based form should be used instead. static MemRefType getChecked(function_ref emitError, ArrayRef shape, Type elementType, ArrayRef affineMapComposition, @@ -390,7 +409,7 @@ /// type would be ill-formed, return nullptr. static MemRefType getImpl(ArrayRef shape, Type elementType, ArrayRef affineMapComposition, - unsigned memorySpace, + Attribute memorySpace, function_ref emitError); using Base::getImpl; }; @@ -409,6 +428,8 @@ /// Get or create a new UnrankedMemRefType of the provided element /// type and memory space + static UnrankedMemRefType get(Type elementType, Attribute memorySpace); + // [deprecated] `Attribute`-based form should be used instead. static UnrankedMemRefType get(Type elementType, unsigned memorySpace); /// Get or create a new UnrankedMemRefType of the provided element @@ -416,12 +437,16 @@ /// would be ill-formed, an error is emitted to `emitError` and a null type is /// returned. static UnrankedMemRefType + getChecked(function_ref emitError, Type elementType, + Attribute memorySpace); + // [deprecated] `Attribute`-based form should be used instead. + static UnrankedMemRefType getChecked(function_ref emitError, Type elementType, unsigned memorySpace); /// Verify the construction of a unranked memref type. static LogicalResult verify(function_ref emitError, - Type elementType, unsigned memorySpace); + Type elementType, Attribute memorySpace); ArrayRef getShape() const { return llvm::None; } }; diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -2861,16 +2861,20 @@ c.def_static( "get", [](std::vector shape, PyType &elementType, - std::vector layout, unsigned memorySpace, + std::vector layout, PyAttribute *memorySpace, DefaultingPyLocation loc) { SmallVector maps; maps.reserve(layout.size()); for (PyAffineMap &map : layout) maps.push_back(map); + MlirAttribute memSpaceAttr = {}; + if (memorySpace) + memSpaceAttr = *memorySpace; + MlirType t = mlirMemRefTypeGetChecked(loc, elementType, shape.size(), shape.data(), maps.size(), - maps.data(), memorySpace); + maps.data(), memSpaceAttr); // TODO: Rework error reporting once diagnostic engine is exposed // in C API. if (mlirTypeIsNull(t)) { @@ -2885,14 +2889,15 @@ return PyMemRefType(elementType.getContext(), t); }, py::arg("shape"), py::arg("element_type"), - py::arg("layout") = py::list(), py::arg("memory_space") = 0, + py::arg("layout") = py::list(), py::arg("memory_space") = py::none(), py::arg("loc") = py::none(), "Create a memref type") .def_property_readonly("layout", &PyMemRefType::getLayout, "The list of layout maps of the MemRef type.") .def_property_readonly( "memory_space", - [](PyMemRefType &self) -> unsigned { - return mlirMemRefTypeGetMemorySpace(self); + [](PyMemRefType &self) -> PyAttribute { + MlirAttribute a = mlirMemRefTypeGetMemorySpace(self); + return PyAttribute(self.getContext(), a); }, "Returns the memory space of the given MemRef type."); } @@ -2944,10 +2949,14 @@ static void bindDerived(ClassTy &c) { c.def_static( "get", - [](PyType &elementType, unsigned memorySpace, + [](PyType &elementType, PyAttribute *memorySpace, DefaultingPyLocation loc) { + MlirAttribute memSpaceAttr = {}; + if (memorySpace) + memSpaceAttr = *memorySpace; + MlirType t = - mlirUnrankedMemRefTypeGetChecked(loc, elementType, memorySpace); + mlirUnrankedMemRefTypeGetChecked(loc, elementType, memSpaceAttr); // TODO: Rework error reporting once diagnostic engine is exposed // in C API. if (mlirTypeIsNull(t)) { @@ -2965,8 +2974,9 @@ py::arg("loc") = py::none(), "Create a unranked memref type") .def_property_readonly( "memory_space", - [](PyUnrankedMemRefType &self) -> unsigned { - return mlirUnrankedMemrefGetMemorySpace(self); + [](PyUnrankedMemRefType &self) -> PyAttribute { + MlirAttribute a = mlirMemRefTypeGetMemorySpace(self); + return PyAttribute(self.getContext(), a); }, "Returns the memory space of the given Unranked MemRef type."); } diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -223,41 +223,41 @@ MlirType mlirMemRefTypeGet(MlirType elementType, intptr_t rank, const int64_t *shape, intptr_t numMaps, MlirAffineMap const *affineMaps, - unsigned memorySpace) { + MlirAttribute memorySpace) { SmallVector maps; (void)unwrapList(numMaps, affineMaps, maps); return wrap( MemRefType::get(llvm::makeArrayRef(shape, static_cast(rank)), - unwrap(elementType), maps, memorySpace)); + unwrap(elementType), maps, unwrap(memorySpace))); } MlirType mlirMemRefTypeGetChecked(MlirLocation loc, MlirType elementType, intptr_t rank, const int64_t *shape, intptr_t numMaps, MlirAffineMap const *affineMaps, - unsigned memorySpace) { + MlirAttribute memorySpace) { SmallVector maps; (void)unwrapList(numMaps, affineMaps, maps); return wrap(MemRefType::getChecked( unwrap(loc), llvm::makeArrayRef(shape, static_cast(rank)), - unwrap(elementType), maps, memorySpace)); + unwrap(elementType), maps, unwrap(memorySpace))); } MlirType mlirMemRefTypeContiguousGet(MlirType elementType, intptr_t rank, const int64_t *shape, - unsigned memorySpace) { + MlirAttribute memorySpace) { return wrap( MemRefType::get(llvm::makeArrayRef(shape, static_cast(rank)), - unwrap(elementType), llvm::None, memorySpace)); + unwrap(elementType), llvm::None, unwrap(memorySpace))); } MlirType mlirMemRefTypeContiguousGetChecked(MlirLocation loc, MlirType elementType, intptr_t rank, const int64_t *shape, - unsigned memorySpace) { + MlirAttribute memorySpace) { return wrap(MemRefType::getChecked( unwrap(loc), llvm::makeArrayRef(shape, static_cast(rank)), - unwrap(elementType), llvm::None, memorySpace)); + unwrap(elementType), llvm::None, unwrap(memorySpace))); } intptr_t mlirMemRefTypeGetNumAffineMaps(MlirType type) { @@ -269,27 +269,29 @@ return wrap(unwrap(type).cast().getAffineMaps()[pos]); } -unsigned mlirMemRefTypeGetMemorySpace(MlirType type) { - return unwrap(type).cast().getMemorySpaceAsInt(); +MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type) { + return wrap(unwrap(type).cast().getMemorySpace()); } bool mlirTypeIsAUnrankedMemRef(MlirType type) { return unwrap(type).isa(); } -MlirType mlirUnrankedMemRefTypeGet(MlirType elementType, unsigned memorySpace) { - return wrap(UnrankedMemRefType::get(unwrap(elementType), memorySpace)); +MlirType mlirUnrankedMemRefTypeGet(MlirType elementType, + MlirAttribute memorySpace) { + return wrap( + UnrankedMemRefType::get(unwrap(elementType), unwrap(memorySpace))); } MlirType mlirUnrankedMemRefTypeGetChecked(MlirLocation loc, MlirType elementType, - unsigned memorySpace) { + MlirAttribute memorySpace) { return wrap(UnrankedMemRefType::getChecked(unwrap(loc), unwrap(elementType), - memorySpace)); + unwrap(memorySpace))); } -unsigned mlirUnrankedMemrefGetMemorySpace(MlirType type) { - return unwrap(type).cast().getMemorySpaceAsInt(); +MlirAttribute mlirUnrankedMemrefGetMemorySpace(MlirType type) { + return wrap(unwrap(type).cast().getMemorySpace()); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1882,16 +1882,20 @@ printAttribute(AffineMapAttr::get(map)); } // Only print the memory space if it is the non-default one. - if (memrefTy.getMemorySpaceAsInt()) - os << ", " << memrefTy.getMemorySpaceAsInt(); + if (memrefTy.getMemorySpace()) { + os << ", "; + printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May); + } os << '>'; }) .Case([&](UnrankedMemRefType memrefTy) { os << "memref<*x"; printType(memrefTy.getElementType()); // Only print the memory space if it is the non-default one. - if (memrefTy.getMemorySpaceAsInt()) - os << ", " << memrefTy.getMemorySpaceAsInt(); + if (memrefTy.getMemorySpace()) { + os << ", "; + printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May); + } os << '>'; }) .Case([&](ComplexType complexTy) { diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -10,6 +10,8 @@ #include "TypeDetail.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dialect.h" #include "llvm/ADT/APFloat.h" @@ -206,7 +208,7 @@ if (auto other = dyn_cast()) { MemRefType::Builder b(shape, elementType); - b.setMemorySpace(other.getMemorySpaceAsInt()); + b.setMemorySpace(other.getMemorySpace()); return b; } @@ -229,7 +231,7 @@ if (auto other = dyn_cast()) { MemRefType::Builder b(shape, other.getElementType()); b.setShape(shape); - b.setMemorySpace(other.getMemorySpaceAsInt()); + b.setMemorySpace(other.getMemorySpace()); return b; } @@ -250,7 +252,7 @@ } if (auto other = dyn_cast()) { - return UnrankedMemRefType::get(elementType, other.getMemorySpaceAsInt()); + return UnrankedMemRefType::get(elementType, other.getMemorySpace()); } if (isa()) { @@ -473,20 +475,73 @@ //===----------------------------------------------------------------------===// unsigned BaseMemRefType::getMemorySpaceAsInt() const { + Attribute attr = getMemorySpace(); + + if (!attr) + return 0; + + assert(attr.isa() && + "Using `getMemorySpaceInteger` with non-Integer attribute"); + + return static_cast(attr.cast().getInt()); +} + +Attribute BaseMemRefType::getMemorySpace() const { return static_cast(impl)->memorySpace; } +bool BaseMemRefType::isSupportedMemorySpace(Attribute memorySpace) { + // Empty attribute is allowed as default memory space. + if (!memorySpace) + return true; + + // Supported built-in attributes. + if (memorySpace.isa()) + return true; + + // Allow custom dialect attributes. + if (!::mlir::isa(memorySpace.getDialect())) + return true; + + return false; +} + //===----------------------------------------------------------------------===// // MemRefType //===----------------------------------------------------------------------===// +/// Wraps deprecated integer memory space to the new Attribute form. +static Attribute wrapIntegerMemorySpace(unsigned memorySpace, + MLIRContext *ctx) { + if (memorySpace == 0) + return nullptr; + + return IntegerAttr::get(IntegerType::get(ctx, 64), memorySpace); +} + +/// Replaces default memorySpace (integer == `0`) with empty Attribute. +static Attribute skipDefaultMemorySpace(Attribute memorySpace) { + IntegerAttr intMemorySpace = memorySpace.dyn_cast_or_null(); + if (intMemorySpace && intMemorySpace.getValue() == 0) + return nullptr; + + return memorySpace; +} + +MemRefType::Builder & +MemRefType::Builder::setMemorySpace(unsigned newMemorySpace) { + memorySpace = + wrapIntegerMemorySpace(newMemorySpace, elementType.getContext()); + return *this; +} + /// Get or create a new MemRefType based on shape, element type, affine /// map composition, and memory space. Assumes the arguments define a /// well-formed MemRef type. Use getChecked to gracefully handle MemRefType /// construction failures. MemRefType MemRefType::get(ArrayRef shape, Type elementType, ArrayRef affineMapComposition, - unsigned memorySpace) { + Attribute memorySpace) { auto result = getImpl(shape, elementType, affineMapComposition, memorySpace, [=] { return emitError(UnknownLoc::get(elementType.getContext())); @@ -494,6 +549,13 @@ assert(result && "Failed to construct instance of MemRefType."); return result; } +MemRefType MemRefType::get(ArrayRef shape, Type elementType, + ArrayRef affineMapComposition, + unsigned memorySpace) { + Attribute memorySpaceAttr = + wrapIntegerMemorySpace(memorySpace, elementType.getContext()); + return get(shape, elementType, affineMapComposition, memorySpaceAttr); +} /// Get or create a new MemRefType based on shape, element type, affine /// map composition, and memory space declared at the given location. @@ -504,10 +566,19 @@ MemRefType MemRefType::getChecked(function_ref emitError, ArrayRef shape, Type elementType, ArrayRef affineMapComposition, - unsigned memorySpace) { + Attribute memorySpace) { return getImpl(shape, elementType, affineMapComposition, memorySpace, emitError); } +MemRefType MemRefType::getChecked(function_ref emitError, + ArrayRef shape, Type elementType, + ArrayRef affineMapComposition, + unsigned memorySpace) { + Attribute memorySpaceAttr = + wrapIntegerMemorySpace(memorySpace, elementType.getContext()); + return getChecked(emitError, shape, elementType, affineMapComposition, + memorySpaceAttr); +} /// Get or create a new MemRefType defined by the arguments. If the resulting /// type would be ill-formed, return nullptr. If the location is provided, @@ -515,7 +586,7 @@ /// pass in an instance of UnknownLoc. MemRefType MemRefType::getImpl(ArrayRef shape, Type elementType, ArrayRef affineMapComposition, - unsigned memorySpace, + Attribute memorySpace, function_ref emitError) { auto *context = elementType.getContext(); @@ -556,8 +627,13 @@ cleanedAffineMapComposition.push_back(map); } + if (!isSupportedMemorySpace(memorySpace)) { + emitError() << "unsupported memory space Attribute"; + return nullptr; + } + return Base::get(context, shape, elementType, cleanedAffineMapComposition, - memorySpace); + skipDefaultMemorySpace(memorySpace)); } ArrayRef MemRefType::getShape() const { return getImpl()->getShape(); } @@ -570,23 +646,41 @@ // UnrankedMemRefType //===----------------------------------------------------------------------===// +UnrankedMemRefType UnrankedMemRefType::get(Type elementType, + Attribute memorySpace) { + return Base::get(elementType.getContext(), elementType, + skipDefaultMemorySpace(memorySpace)); +} UnrankedMemRefType UnrankedMemRefType::get(Type elementType, unsigned memorySpace) { - return Base::get(elementType.getContext(), elementType, memorySpace); + Attribute memorySpaceAttr = + wrapIntegerMemorySpace(memorySpace, elementType.getContext()); + return get(elementType, memorySpaceAttr); } UnrankedMemRefType UnrankedMemRefType::getChecked(function_ref emitError, - Type elementType, unsigned memorySpace) { + Type elementType, Attribute memorySpace) { return Base::getChecked(emitError, elementType.getContext(), elementType, - memorySpace); + skipDefaultMemorySpace(memorySpace)); +} +UnrankedMemRefType +UnrankedMemRefType::getChecked(function_ref emitError, + Type elementType, unsigned memorySpace) { + Attribute memorySpaceAttr = + wrapIntegerMemorySpace(memorySpace, elementType.getContext()); + return getChecked(emitError, elementType, memorySpaceAttr); } LogicalResult UnrankedMemRefType::verify(function_ref emitError, - Type elementType, unsigned memorySpace) { + Type elementType, Attribute memorySpace) { if (!BaseMemRefType::isValidElementType(elementType)) return emitError() << "invalid memref element type"; + + if (!isSupportedMemorySpace(memorySpace)) + return emitError() << "unsupported memory space Attribute"; + return success(); } diff --git a/mlir/lib/IR/TypeDetail.h b/mlir/lib/IR/TypeDetail.h --- a/mlir/lib/IR/TypeDetail.h +++ b/mlir/lib/IR/TypeDetail.h @@ -183,17 +183,17 @@ }; struct BaseMemRefTypeStorage : public ShapedTypeStorage { - BaseMemRefTypeStorage(Type elementType, unsigned memorySpace) + BaseMemRefTypeStorage(Type elementType, Attribute memorySpace) : ShapedTypeStorage(elementType), memorySpace(memorySpace) {} /// Memory space in which data referenced by memref resides. - const unsigned memorySpace; + Attribute memorySpace; }; struct MemRefTypeStorage : public BaseMemRefTypeStorage { MemRefTypeStorage(unsigned shapeSize, Type elementType, const int64_t *shapeElements, const unsigned numAffineMaps, - AffineMap const *affineMapList, const unsigned memorySpace) + AffineMap const *affineMapList, Attribute memorySpace) : BaseMemRefTypeStorage(elementType, memorySpace), shapeElements(shapeElements), shapeSize(shapeSize), numAffineMaps(numAffineMaps), affineMapList(affineMapList) {} @@ -202,7 +202,7 @@ // MemRefs are uniqued based on their shape, element type, affine map // composition, and memory space. using KeyTy = - std::tuple, Type, ArrayRef, unsigned>; + std::tuple, Type, ArrayRef, Attribute>; bool operator==(const KeyTy &key) const { return key == KeyTy(getShape(), elementType, getAffineMaps(), memorySpace); } @@ -246,11 +246,11 @@ /// Only element type and memory space are known struct UnrankedMemRefTypeStorage : public BaseMemRefTypeStorage { - UnrankedMemRefTypeStorage(Type elementTy, const unsigned memorySpace) + UnrankedMemRefTypeStorage(Type elementTy, Attribute memorySpace) : BaseMemRefTypeStorage(elementTy, memorySpace) {} /// The hash key used for uniquing. - using KeyTy = std::tuple; + using KeyTy = std::tuple; bool operator==(const KeyTy &key) const { return key == KeyTy(elementType, memorySpace); } diff --git a/mlir/lib/Parser/TypeParser.cpp b/mlir/lib/Parser/TypeParser.cpp --- a/mlir/lib/Parser/TypeParser.cpp +++ b/mlir/lib/Parser/TypeParser.cpp @@ -225,27 +225,14 @@ // Parse semi-affine-map-composition. SmallVector affineMapComposition; - Optional memorySpace; + Attribute memorySpace; unsigned numDims = dimensions.size(); auto parseElt = [&]() -> ParseResult { - // Check for the memory space. - if (getToken().is(Token::integer)) { - if (memorySpace) - return emitError("multiple memory spaces specified in memref type"); - memorySpace = getToken().getUnsignedIntegerValue(); - if (!memorySpace.hasValue()) - return emitError("invalid memory space in memref type"); - consumeToken(Token::integer); - return success(); - } - if (isUnranked) - return emitError("cannot have affine map for unranked memref type"); - if (memorySpace) - return emitError("expected memory space to be last in memref type"); - AffineMap map; llvm::SMLoc mapLoc = getToken().getLoc(); + + // Check for AffineMap as offset/strides. if (getToken().is(Token::kw_offset)) { int64_t offset; SmallVector strides; @@ -254,16 +241,26 @@ // Construct strided affine map. map = makeStridedLinearLayoutMap(strides, offset, state.context); } else { - // Parse an affine map attribute. - auto affineMap = parseAttribute(); - if (!affineMap) + // Either it is AffineMapAttr or memory space attribute. + Attribute attr = parseAttribute(); + if (!attr) return failure(); - auto affineMapAttr = affineMap.dyn_cast(); - if (!affineMapAttr) - return emitError("expected affine map in memref type"); - map = affineMapAttr.getValue(); + + if (AffineMapAttr affineMapAttr = attr.dyn_cast()) { + map = affineMapAttr.getValue(); + } else if (memorySpace) { + return emitError("multiple memory spaces specified in memref type"); + } else { + memorySpace = attr; + return success(); + } } + if (isUnranked) + return emitError("cannot have affine map for unranked memref type"); + if (memorySpace) + return emitError("expected memory space to be last in memref type"); + if (map.getNumDims() != numDims) { size_t i = affineMapComposition.size(); return emitError(mapLoc, "memref affine map dimension mismatch between ") @@ -286,11 +283,15 @@ } } - if (isUnranked) - return UnrankedMemRefType::get(elementType, memorySpace.getValueOr(0)); + if (isUnranked) { + return UnrankedMemRefType::getChecked( + [&]() -> InFlightDiagnostic { return emitError(); }, elementType, + memorySpace); + } - return MemRefType::get(dimensions, elementType, affineMapComposition, - memorySpace.getValueOr(0)); + return MemRefType::getChecked( + [&]() -> InFlightDiagnostic { return emitError(); }, dimensions, + elementType, affineMapComposition, memorySpace); } /// Parse any type except the function type. diff --git a/mlir/test/Bindings/Python/ir_types.py b/mlir/test/Bindings/Python/ir_types.py --- a/mlir/test/Bindings/Python/ir_types.py +++ b/mlir/test/Bindings/Python/ir_types.py @@ -326,7 +326,7 @@ f32 = F32Type.get() shape = [2, 3] loc = Location.unknown() - memref = MemRefType.get(shape, f32, memory_space=2) + memref = MemRefType.get(shape, f32, memory_space=Attribute.parse("2")) # CHECK: memref type: memref<2x3xf32, 2> print("memref type:", memref) # CHECK: number of affine layout maps: 0 @@ -341,7 +341,7 @@ assert len(memref_layout.layout) == 1 # CHECK: memref layout: (d0, d1) -> (d1, d0) print("memref layout:", memref_layout.layout[0]) - # CHECK: memory space: 0 + # CHECK: memory space: <> print("memory space:", memref_layout.memory_space) none = NoneType.get() @@ -361,7 +361,7 @@ with Context(), Location.unknown(): f32 = F32Type.get() loc = Location.unknown() - unranked_memref = UnrankedMemRefType.get(f32, 2) + unranked_memref = UnrankedMemRefType.get(f32, Attribute.parse("2")) # CHECK: unranked memref type: memref<*xf32, 2> print("unranked memref type:", unranked_memref) try: @@ -388,7 +388,7 @@ none = NoneType.get() try: - memref_invalid = UnrankedMemRefType.get(none, 2) + memref_invalid = UnrankedMemRefType.get(none, Attribute.parse("2")) except ValueError as e: # CHECK: invalid 'Type(none)' and expected floating point, integer, vector # CHECK: or complex type. diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -707,21 +707,24 @@ // CHECK: tensor<*xf32> // MemRef type. + MlirAttribute memSpace2 = mlirIntegerAttrGet(mlirIntegerTypeGet(ctx, 64), 2); MlirType memRef = mlirMemRefTypeContiguousGet( - f32, sizeof(shape) / sizeof(int64_t), shape, 2); + f32, sizeof(shape) / sizeof(int64_t), shape, memSpace2); if (!mlirTypeIsAMemRef(memRef) || mlirMemRefTypeGetNumAffineMaps(memRef) != 0 || - mlirMemRefTypeGetMemorySpace(memRef) != 2) + !mlirAttributeEqual(mlirMemRefTypeGetMemorySpace(memRef), memSpace2)) return 18; mlirTypeDump(memRef); fprintf(stderr, "\n"); // CHECK: memref<2x3xf32, 2> // Unranked MemRef type. - MlirType unrankedMemRef = mlirUnrankedMemRefTypeGet(f32, 4); + MlirAttribute memSpace4 = mlirIntegerAttrGet(mlirIntegerTypeGet(ctx, 64), 4); + MlirType unrankedMemRef = mlirUnrankedMemRefTypeGet(f32, memSpace4); if (!mlirTypeIsAUnrankedMemRef(unrankedMemRef) || mlirTypeIsAMemRef(unrankedMemRef) || - mlirUnrankedMemrefGetMemorySpace(unrankedMemRef) != 4) + !mlirAttributeEqual(mlirUnrankedMemrefGetMemorySpace(unrankedMemRef), + memSpace4)) return 19; mlirTypeDump(unrankedMemRef); fprintf(stderr, "\n"); diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -36,8 +36,8 @@ func @memrefs(memref<2x4xi8, #map7>) // expected-error {{undefined symbol alias id 'map7'}} // ----- -// Test non affine map in memref type. -func @memrefs(memref<2x4xi8, i8>) // expected-error {{expected affine map in memref type}} +// Test unsupported memory space. +func @memrefs(memref<2x4xi8, i8>) // expected-error {{unsupported memory space Attribute}} // ----- // Test non-existent map in map composition of memref type. diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -137,11 +137,35 @@ func private @memrefs_compose_with_id(memref<2x2xi8, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>>) +// Test memref with custom memory space + +// CHECK: func private @memrefs_nomap_nospace(memref<5x6x7xf32>) +func private @memrefs_nomap_nospace(memref<5x6x7xf32>) + +// CHECK: func private @memrefs_map_nospace(memref<5x6x7xf32, #map{{[0-9]+}}>) +func private @memrefs_map_nospace(memref<5x6x7xf32, #map3>) + +// CHECK: func private @memrefs_nomap_intspace(memref<5x6x7xf32, 3>) +func private @memrefs_nomap_intspace(memref<5x6x7xf32, 3>) + +// CHECK: func private @memrefs_map_intspace(memref<5x6x7xf32, #map{{[0-9]+}}, 5>) +func private @memrefs_map_intspace(memref<5x6x7xf32, #map3, 5>) + +// CHECK: func private @memrefs_nomap_strspace(memref<5x6x7xf32, "local">) +func private @memrefs_nomap_strspace(memref<5x6x7xf32, "local">) + +// CHECK: func private @memrefs_map_strspace(memref<5x6x7xf32, #map{{[0-9]+}}, "private">) +func private @memrefs_map_strspace(memref<5x6x7xf32, #map3, "private">) + +// CHECK: func private @memrefs_nomap_dictspace(memref<5x6x7xf32, {memSpace = "special", subIndex = 1 : i64}>) +func private @memrefs_nomap_dictspace(memref<5x6x7xf32, {memSpace = "special", subIndex = 1}>) + +// CHECK: func private @memrefs_map_dictspace(memref<5x6x7xf32, #map{{[0-9]+}}, {memSpace = "special", subIndex = 3 : i64}>) +func private @memrefs_map_dictspace(memref<5x6x7xf32, #map3, {memSpace = "special", subIndex = 3}>) // CHECK: func private @complex_types(complex) -> complex func private @complex_types(complex) -> complex - // CHECK: func private @memref_with_index_elems(memref<1x?xindex>) func private @memref_with_index_elems(memref<1x?xindex>) diff --git a/mlir/unittests/IR/ShapedTypeTest.cpp b/mlir/unittests/IR/ShapedTypeTest.cpp --- a/mlir/unittests/IR/ShapedTypeTest.cpp +++ b/mlir/unittests/IR/ShapedTypeTest.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/IR/AffineMap.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/DialectInterface.h" @@ -23,7 +24,7 @@ Type i32 = IntegerType::get(&context, 32); Type f32 = FloatType::getF32(&context); - int memSpace = 7; + Attribute memSpace = IntegerAttr::get(IntegerType::get(&context, 64), 7); Type memrefOriginalType = i32; llvm::SmallVector memrefOriginalShape({10, 20}); AffineMap map = makeStridedLinearLayoutMap({2, 3}, 5, &context);