diff --git a/mlir/docs/LangRef.md b/mlir/docs/LangRef.md --- a/mlir/docs/LangRef.md +++ b/mlir/docs/LangRef.md @@ -1003,14 +1003,9 @@ For the moment, a `memref` supports parsing a strided form which is converted to a semi-affine map automatically. -The memory space of a memref is specified by a target-specific integer index. If -no memory space is specified, then the default memory space (0) is used. The -default space is target specific but always at index 0. - -TODO: MLIR will eventually have target-dialects which allow symbolic use of -memory hierarchy names (e.g. L3, L2, L1, ...) but we have not spec'd the details -of that mechanism yet. Until then, this document pretends that it is valid to -refer to these memories by `bare-id`. +The memory space of a memref is specified by a target-specific attribute. +It might be an integer value, string, dictionary or custom dialect attribute. +The empty memory space (attribute is None) is target specific. The notionally dynamic value of a memref value includes the address of the buffer allocated, as well as the symbols referred to by the shape, layout map, diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h --- a/mlir/include/mlir-c/BuiltinTypes.h +++ b/mlir/include/mlir-c/BuiltinTypes.h @@ -224,38 +224,38 @@ /// same context as element type. The type is owned by the context. MLIR_CAPI_EXPORTED MlirType mlirMemRefTypeGet( MlirType elementType, intptr_t rank, const int64_t *shape, intptr_t numMaps, - MlirAffineMap const *affineMaps, unsigned memorySpace); + MlirAffineMap const *affineMaps, MlirAttribute memorySpace); /// Same as "mlirMemRefTypeGet" but returns a nullptr-wrapping MlirType o /// illegal arguments, emitting appropriate diagnostics. MLIR_CAPI_EXPORTED MlirType mlirMemRefTypeGetChecked( MlirLocation loc, MlirType elementType, intptr_t rank, const int64_t *shape, - intptr_t numMaps, MlirAffineMap const *affineMaps, unsigned memorySpace); + intptr_t numMaps, MlirAffineMap const *affineMaps, + MlirAttribute memorySpace); /// Creates a MemRef type with the given rank, shape, memory space and element /// type in the same context as the element type. The type has no affine maps, /// i.e. represents a default row-major contiguous memref. The type is owned by /// the context. -MLIR_CAPI_EXPORTED MlirType mlirMemRefTypeContiguousGet(MlirType elementType, - intptr_t rank, - const int64_t *shape, - unsigned memorySpace); +MLIR_CAPI_EXPORTED MlirType +mlirMemRefTypeContiguousGet(MlirType elementType, intptr_t rank, + const int64_t *shape, MlirAttribute memorySpace); /// Same as "mlirMemRefTypeContiguousGet" but returns a nullptr wrapping /// MlirType on illegal arguments, emitting appropriate diagnostics. MLIR_CAPI_EXPORTED MlirType mlirMemRefTypeContiguousGetChecked( MlirLocation loc, MlirType elementType, intptr_t rank, const int64_t *shape, - unsigned memorySpace); + MlirAttribute memorySpace); /// Creates an Unranked MemRef type with the given element type and in the given /// memory space. The type is owned by the context of element type. -MLIR_CAPI_EXPORTED MlirType mlirUnrankedMemRefTypeGet(MlirType elementType, - unsigned memorySpace); +MLIR_CAPI_EXPORTED MlirType +mlirUnrankedMemRefTypeGet(MlirType elementType, MlirAttribute memorySpace); /// Same as "mlirUnrankedMemRefTypeGet" but returns a nullptr wrapping /// MlirType on illegal arguments, emitting appropriate diagnostics. MLIR_CAPI_EXPORTED MlirType mlirUnrankedMemRefTypeGetChecked( - MlirLocation loc, MlirType elementType, unsigned memorySpace); + MlirLocation loc, MlirType elementType, MlirAttribute memorySpace); /// Returns the number of affine layout maps in the given MemRef type. MLIR_CAPI_EXPORTED intptr_t mlirMemRefTypeGetNumAffineMaps(MlirType type); @@ -265,10 +265,11 @@ intptr_t pos); /// Returns the memory space of the given MemRef type. -MLIR_CAPI_EXPORTED unsigned mlirMemRefTypeGetMemorySpace(MlirType type); +MLIR_CAPI_EXPORTED MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type); /// Returns the memory spcae of the given Unranked MemRef type. -MLIR_CAPI_EXPORTED unsigned mlirUnrankedMemrefGetMemorySpace(MlirType type); +MLIR_CAPI_EXPORTED MlirAttribute +mlirUnrankedMemrefGetMemorySpace(MlirType type); //===----------------------------------------------------------------------===// // Tuple type. diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -9,6 +9,7 @@ #ifndef MLIR_IR_BUILTINTYPES_H #define MLIR_IR_BUILTINTYPES_H +#include "mlir/IR/Attributes.h" #include "mlir/IR/Types.h" namespace llvm { @@ -175,6 +176,10 @@ static bool classof(Type type); /// Returns the memory space in which data referred to by this memref resides. + Attribute getMemorySpace() const; + + /// [deprecated] Returns the memory space in old raw integer representation. + /// New `Attribute getMemorySpace()` method should be used instead. unsigned getMemorySpaceAsInt() const; }; @@ -199,12 +204,12 @@ // Build from another MemRefType. explicit Builder(MemRefType other) : shape(other.getShape()), elementType(other.getElementType()), - affineMaps(other.getAffineMaps()), - memorySpace(other.getMemorySpaceAsInt()) {} + affineMaps(other.getAffineMaps()), memorySpace(other.getMemorySpace()) { + } // Build from scratch. Builder(ArrayRef shape, Type elementType) - : shape(shape), elementType(elementType), affineMaps(), memorySpace(0) {} + : shape(shape), elementType(elementType), affineMaps() {} Builder &setShape(ArrayRef newShape) { shape = newShape; @@ -221,11 +226,14 @@ return *this; } - Builder &setMemorySpace(unsigned newMemorySpace) { + Builder &setMemorySpace(Attribute newMemorySpace) { memorySpace = newMemorySpace; return *this; } + // [deprecated] `setMemorySpace(Attribute)` should be used instead. + Builder &setMemorySpace(unsigned newMemorySpace); + operator MemRefType() { return MemRefType::get(shape, elementType, affineMaps, memorySpace); } @@ -234,7 +242,7 @@ ArrayRef shape; Type elementType; ArrayRef affineMaps; - unsigned memorySpace; + Attribute memorySpace; }; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -268,7 +268,7 @@ strided-layout ::= `offset:` dimension `,` `strides: ` stride-list semi-affine-map-composition ::= (semi-affine-map `,` )* semi-affine-map layout-specification ::= semi-affine-map-composition | strided-layout - memory-space ::= integer-literal /* | TODO: address-space-id */ + memory-space ::= attribute-value ``` A `memref` type is a reference to a region of memory (similar to a buffer @@ -335,14 +335,9 @@ intuitive to read. For the moment, a `memref` supports parsing a strided form which is converted to a semi-affine map automatically. - The memory space of a memref is specified by a target-specific integer - index. If no memory space is specified, then the default memory space (0) - is used. The default space is target specific but always at index 0. - - TODO: MLIR will eventually have target-dialects which allow symbolic use of - memory hierarchy names (e.g. L3, L2, L1, ...) but we have not spec'd the - details of that mechanism yet. Until then, this document pretends that it - is valid to refer to these memories by `bare-id`. + The memory space of a memref is specified by a target-specific attribute. + It might be an integer value, string, dictionary or custom dialect attribute. + The empty memory space (attribute is None) is target specific. The notionally dynamic value of a memref value includes the address of the buffer allocated, as well as the symbols referred to by the shape, layout @@ -527,22 +522,34 @@ ArrayRefParameter<"int64_t">:$shape, "Type":$elementType, ArrayRefParameter<"AffineMap">:$affineMaps, - "unsigned":$memorySpaceAsInt + "Attribute":$memorySpace ); - let builders = [ TypeBuilderWithInferredContext<(ins "ArrayRef":$shape, "Type":$elementType, CArg<"ArrayRef", "{}">:$affineMaps, - CArg<"unsigned", "0">:$memorySpace + CArg<"Attribute", "{}">:$memorySpace ), [{ // Drop identity maps from the composition. This may lead to the // composition becoming empty, which is interpreted as an implicit // identity. auto nonIdentityMaps = llvm::make_filter_range(affineMaps, [](AffineMap map) { return !map.isIdentity(); }); + // Drop default memory space value and replace it with empty attribute. + Attribute nonDefaultMemorySpace = skipDefaultMemorySpace(memorySpace); return $_get(elementType.getContext(), shape, elementType, - llvm::to_vector<4>(nonIdentityMaps), memorySpace); + llvm::to_vector<4>(nonIdentityMaps), nonDefaultMemorySpace); + }]>, + /// [deprecated] `Attribute`-based form should be used instead. + TypeBuilderWithInferredContext<(ins + "ArrayRef":$shape, "Type":$elementType, + "ArrayRef":$affineMaps, + "unsigned":$memorySpace + ), [{ + // Convert deprecated integer-like memory space to Attribute. + Attribute memorySpaceAttr = + wrapIntegerMemorySpace(memorySpace, elementType.getContext()); + return MemRefType::get(shape, elementType, affineMaps, memorySpaceAttr); }]> ]; let extraClassDeclaration = [{ @@ -550,6 +557,10 @@ /// Arguments that are passed into the builder must out-live the builder. class Builder; + /// [deprecated] Returns the memory space in old raw integer representation. + /// New `Attribute getMemorySpace()` method should be used instead. + unsigned getMemorySpaceAsInt() const; + // TODO: merge these two special values in a single one used everywhere. // Unfortunately, uses of `-1` have crept deep into the codebase now and are // hard to track. @@ -767,7 +778,7 @@ ``` unranked-memref-type ::= `memref` `<*x` type (`,` memory-space)? `>` - memory-space ::= integer-literal /* | TODO: address-space-id */ + memory-space ::= attribute-value ``` A `memref` type with an unknown rank (e.g. `memref<*xf32>`). The purpose of @@ -787,16 +798,30 @@ memref<*f32, 10> ``` }]; - let parameters = (ins "Type":$elementType, "unsigned":$memorySpaceAsInt); + let parameters = (ins "Type":$elementType, "Attribute":$memorySpace); let builders = [ + TypeBuilderWithInferredContext<(ins "Type":$elementType, + "Attribute":$memorySpace), [{ + // Drop default memory space value and replace it with empty attribute. + Attribute nonDefaultMemorySpace = skipDefaultMemorySpace(memorySpace); + return $_get(elementType.getContext(), elementType, nonDefaultMemorySpace); + }]>, + /// [deprecated] `Attribute`-based form should be used instead. TypeBuilderWithInferredContext<(ins "Type":$elementType, "unsigned":$memorySpace), [{ - return $_get(elementType.getContext(), elementType, memorySpace); + // Convert deprecated integer-like memory space to Attribute. + Attribute memorySpaceAttr = + wrapIntegerMemorySpace(memorySpace, elementType.getContext()); + return UnrankedMemRefType::get(elementType, memorySpaceAttr); }]> ]; let extraClassDeclaration = [{ ArrayRef getShape() const { return llvm::None; } + + /// [deprecated] Returns the memory space in old raw integer representation. + /// New `Attribute getMemorySpace()` method should be used instead. + unsigned getMemorySpaceAsInt() const; }]; let skipDefaultBuilders = 1; let genVerifyDecl = 1; diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -2861,16 +2861,20 @@ c.def_static( "get", [](std::vector shape, PyType &elementType, - std::vector layout, unsigned memorySpace, + std::vector layout, PyAttribute *memorySpace, DefaultingPyLocation loc) { SmallVector maps; maps.reserve(layout.size()); for (PyAffineMap &map : layout) maps.push_back(map); + MlirAttribute memSpaceAttr = {}; + if (memorySpace) + memSpaceAttr = *memorySpace; + MlirType t = mlirMemRefTypeGetChecked(loc, elementType, shape.size(), shape.data(), maps.size(), - maps.data(), memorySpace); + maps.data(), memSpaceAttr); // TODO: Rework error reporting once diagnostic engine is exposed // in C API. if (mlirTypeIsNull(t)) { @@ -2885,14 +2889,15 @@ return PyMemRefType(elementType.getContext(), t); }, py::arg("shape"), py::arg("element_type"), - py::arg("layout") = py::list(), py::arg("memory_space") = 0, + py::arg("layout") = py::list(), py::arg("memory_space") = py::none(), py::arg("loc") = py::none(), "Create a memref type") .def_property_readonly("layout", &PyMemRefType::getLayout, "The list of layout maps of the MemRef type.") .def_property_readonly( "memory_space", - [](PyMemRefType &self) -> unsigned { - return mlirMemRefTypeGetMemorySpace(self); + [](PyMemRefType &self) -> PyAttribute { + MlirAttribute a = mlirMemRefTypeGetMemorySpace(self); + return PyAttribute(self.getContext(), a); }, "Returns the memory space of the given MemRef type."); } @@ -2944,10 +2949,14 @@ static void bindDerived(ClassTy &c) { c.def_static( "get", - [](PyType &elementType, unsigned memorySpace, + [](PyType &elementType, PyAttribute *memorySpace, DefaultingPyLocation loc) { + MlirAttribute memSpaceAttr = {}; + if (memorySpace) + memSpaceAttr = *memorySpace; + MlirType t = - mlirUnrankedMemRefTypeGetChecked(loc, elementType, memorySpace); + mlirUnrankedMemRefTypeGetChecked(loc, elementType, memSpaceAttr); // TODO: Rework error reporting once diagnostic engine is exposed // in C API. if (mlirTypeIsNull(t)) { @@ -2965,8 +2974,9 @@ py::arg("loc") = py::none(), "Create a unranked memref type") .def_property_readonly( "memory_space", - [](PyUnrankedMemRefType &self) -> unsigned { - return mlirUnrankedMemrefGetMemorySpace(self); + [](PyUnrankedMemRefType &self) -> PyAttribute { + MlirAttribute a = mlirMemRefTypeGetMemorySpace(self); + return PyAttribute(self.getContext(), a); }, "Returns the memory space of the given Unranked MemRef type."); } diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -223,41 +223,41 @@ MlirType mlirMemRefTypeGet(MlirType elementType, intptr_t rank, const int64_t *shape, intptr_t numMaps, MlirAffineMap const *affineMaps, - unsigned memorySpace) { + MlirAttribute memorySpace) { SmallVector maps; (void)unwrapList(numMaps, affineMaps, maps); return wrap( MemRefType::get(llvm::makeArrayRef(shape, static_cast(rank)), - unwrap(elementType), maps, memorySpace)); + unwrap(elementType), maps, unwrap(memorySpace))); } MlirType mlirMemRefTypeGetChecked(MlirLocation loc, MlirType elementType, intptr_t rank, const int64_t *shape, intptr_t numMaps, MlirAffineMap const *affineMaps, - unsigned memorySpace) { + MlirAttribute memorySpace) { SmallVector maps; (void)unwrapList(numMaps, affineMaps, maps); return wrap(MemRefType::getChecked( unwrap(loc), llvm::makeArrayRef(shape, static_cast(rank)), - unwrap(elementType), maps, memorySpace)); + unwrap(elementType), maps, unwrap(memorySpace))); } MlirType mlirMemRefTypeContiguousGet(MlirType elementType, intptr_t rank, const int64_t *shape, - unsigned memorySpace) { + MlirAttribute memorySpace) { return wrap( MemRefType::get(llvm::makeArrayRef(shape, static_cast(rank)), - unwrap(elementType), llvm::None, memorySpace)); + unwrap(elementType), llvm::None, unwrap(memorySpace))); } MlirType mlirMemRefTypeContiguousGetChecked(MlirLocation loc, MlirType elementType, intptr_t rank, const int64_t *shape, - unsigned memorySpace) { + MlirAttribute memorySpace) { return wrap(MemRefType::getChecked( unwrap(loc), llvm::makeArrayRef(shape, static_cast(rank)), - unwrap(elementType), llvm::None, memorySpace)); + unwrap(elementType), llvm::None, unwrap(memorySpace))); } intptr_t mlirMemRefTypeGetNumAffineMaps(MlirType type) { @@ -269,27 +269,29 @@ return wrap(unwrap(type).cast().getAffineMaps()[pos]); } -unsigned mlirMemRefTypeGetMemorySpace(MlirType type) { - return unwrap(type).cast().getMemorySpaceAsInt(); +MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type) { + return wrap(unwrap(type).cast().getMemorySpace()); } bool mlirTypeIsAUnrankedMemRef(MlirType type) { return unwrap(type).isa(); } -MlirType mlirUnrankedMemRefTypeGet(MlirType elementType, unsigned memorySpace) { - return wrap(UnrankedMemRefType::get(unwrap(elementType), memorySpace)); +MlirType mlirUnrankedMemRefTypeGet(MlirType elementType, + MlirAttribute memorySpace) { + return wrap( + UnrankedMemRefType::get(unwrap(elementType), unwrap(memorySpace))); } MlirType mlirUnrankedMemRefTypeGetChecked(MlirLocation loc, MlirType elementType, - unsigned memorySpace) { + MlirAttribute memorySpace) { return wrap(UnrankedMemRefType::getChecked(unwrap(loc), unwrap(elementType), - memorySpace)); + unwrap(memorySpace))); } -unsigned mlirUnrankedMemrefGetMemorySpace(MlirType type) { - return unwrap(type).cast().getMemorySpaceAsInt(); +MlirAttribute mlirUnrankedMemrefGetMemorySpace(MlirType type) { + return wrap(unwrap(type).cast().getMemorySpace()); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1887,16 +1887,20 @@ printAttribute(AffineMapAttr::get(map)); } // Only print the memory space if it is the non-default one. - if (memrefTy.getMemorySpaceAsInt()) - os << ", " << memrefTy.getMemorySpaceAsInt(); + if (memrefTy.getMemorySpace()) { + os << ", "; + printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May); + } os << '>'; }) .Case([&](UnrankedMemRefType memrefTy) { os << "memref<*x"; printType(memrefTy.getElementType()); // Only print the memory space if it is the non-default one. - if (memrefTy.getMemorySpaceAsInt()) - os << ", " << memrefTy.getMemorySpaceAsInt(); + if (memrefTy.getMemorySpace()) { + os << ", "; + printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May); + } os << '>'; }) .Case([&](ComplexType complexTy) { diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -10,6 +10,8 @@ #include "TypeDetail.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dialect.h" #include "llvm/ADT/APFloat.h" @@ -207,7 +209,7 @@ if (auto other = dyn_cast()) { MemRefType::Builder b(shape, elementType); - b.setMemorySpace(other.getMemorySpaceAsInt()); + b.setMemorySpace(other.getMemorySpace()); return b; } @@ -230,7 +232,7 @@ if (auto other = dyn_cast()) { MemRefType::Builder b(shape, other.getElementType()); b.setShape(shape); - b.setMemorySpace(other.getMemorySpaceAsInt()); + b.setMemorySpace(other.getMemorySpace()); return b; } @@ -251,7 +253,7 @@ } if (auto other = dyn_cast()) { - return UnrankedMemRefType::get(elementType, other.getMemorySpaceAsInt()); + return UnrankedMemRefType::get(elementType, other.getMemorySpace()); } if (isa()) { @@ -436,6 +438,12 @@ // BaseMemRefType //===----------------------------------------------------------------------===// +Attribute BaseMemRefType::getMemorySpace() const { + if (auto rankedMemRefTy = dyn_cast()) + return rankedMemRefTy.getMemorySpace(); + return cast().getMemorySpace(); +} + unsigned BaseMemRefType::getMemorySpaceAsInt() const { if (auto rankedMemRefTy = dyn_cast()) return rankedMemRefTy.getMemorySpaceAsInt(); @@ -446,10 +454,63 @@ // MemRefType //===----------------------------------------------------------------------===// +bool mlir::detail::isSupportedMemorySpace(Attribute memorySpace) { + // Empty attribute is allowed as default memory space. + if (!memorySpace) + return true; + + // Supported built-in attributes. + if (memorySpace.isa()) + return true; + + // Allow custom dialect attributes. + if (!::mlir::isa(memorySpace.getDialect())) + return true; + + return false; +} + +Attribute mlir::detail::wrapIntegerMemorySpace(unsigned memorySpace, + MLIRContext *ctx) { + if (memorySpace == 0) + return nullptr; + + return IntegerAttr::get(IntegerType::get(ctx, 64), memorySpace); +} + +Attribute mlir::detail::skipDefaultMemorySpace(Attribute memorySpace) { + IntegerAttr intMemorySpace = memorySpace.dyn_cast_or_null(); + if (intMemorySpace && intMemorySpace.getValue() == 0) + return nullptr; + + return memorySpace; +} + +unsigned mlir::detail::getMemorySpaceAsInt(Attribute memorySpace) { + if (!memorySpace) + return 0; + + assert(memorySpace.isa() && + "Using `getMemorySpaceInteger` with non-Integer attribute"); + + return static_cast(memorySpace.cast().getInt()); +} + +MemRefType::Builder & +MemRefType::Builder::setMemorySpace(unsigned newMemorySpace) { + memorySpace = + wrapIntegerMemorySpace(newMemorySpace, elementType.getContext()); + return *this; +} + +unsigned MemRefType::getMemorySpaceAsInt() const { + return detail::getMemorySpaceAsInt(getMemorySpace()); +} + LogicalResult MemRefType::verify(function_ref emitError, ArrayRef shape, Type elementType, ArrayRef affineMapComposition, - unsigned memorySpace) { + Attribute memorySpace) { if (!BaseMemRefType::isValidElementType(elementType)) return emitError() << "invalid memref element type"; @@ -474,6 +535,11 @@ << " and affine map" << it.index() + 1 << ": " << dim << " != " << map.getNumDims(); } + + if (!isSupportedMemorySpace(memorySpace)) { + return emitError() << "unsupported memory space Attribute"; + } + return success(); } @@ -481,11 +547,19 @@ // UnrankedMemRefType //===----------------------------------------------------------------------===// +unsigned UnrankedMemRefType::getMemorySpaceAsInt() const { + return detail::getMemorySpaceAsInt(getMemorySpace()); +} + LogicalResult UnrankedMemRefType::verify(function_ref emitError, - Type elementType, unsigned memorySpace) { + Type elementType, Attribute memorySpace) { if (!BaseMemRefType::isValidElementType(elementType)) return emitError() << "invalid memref element type"; + + if (!isSupportedMemorySpace(memorySpace)) + return emitError() << "unsupported memory space Attribute"; + return success(); } diff --git a/mlir/lib/IR/TypeDetail.h b/mlir/lib/IR/TypeDetail.h --- a/mlir/lib/IR/TypeDetail.h +++ b/mlir/lib/IR/TypeDetail.h @@ -130,6 +130,19 @@ unsigned numElements; }; +/// Checks if the memorySpace has supported Attribute type. +bool isSupportedMemorySpace(Attribute memorySpace); + +/// Wraps deprecated integer memory space to the new Attribute form. +Attribute wrapIntegerMemorySpace(unsigned memorySpace, MLIRContext *ctx); + +/// Replaces default memorySpace (integer == `0`) with empty Attribute. +Attribute skipDefaultMemorySpace(Attribute memorySpace); + +/// [deprecated] Returns the memory space in old raw integer representation. +/// New `Attribute getMemorySpace()` method should be used instead. +unsigned getMemorySpaceAsInt(Attribute memorySpace); + } // namespace detail } // namespace mlir diff --git a/mlir/lib/Parser/TypeParser.cpp b/mlir/lib/Parser/TypeParser.cpp --- a/mlir/lib/Parser/TypeParser.cpp +++ b/mlir/lib/Parser/TypeParser.cpp @@ -224,27 +224,14 @@ // Parse semi-affine-map-composition. SmallVector affineMapComposition; - Optional memorySpace; + Attribute memorySpace; unsigned numDims = dimensions.size(); auto parseElt = [&]() -> ParseResult { - // Check for the memory space. - if (getToken().is(Token::integer)) { - if (memorySpace) - return emitError("multiple memory spaces specified in memref type"); - memorySpace = getToken().getUnsignedIntegerValue(); - if (!memorySpace.hasValue()) - return emitError("invalid memory space in memref type"); - consumeToken(Token::integer); - return success(); - } - if (isUnranked) - return emitError("cannot have affine map for unranked memref type"); - if (memorySpace) - return emitError("expected memory space to be last in memref type"); - AffineMap map; llvm::SMLoc mapLoc = getToken().getLoc(); + + // Check for AffineMap as offset/strides. if (getToken().is(Token::kw_offset)) { int64_t offset; SmallVector strides; @@ -253,16 +240,26 @@ // Construct strided affine map. map = makeStridedLinearLayoutMap(strides, offset, state.context); } else { - // Parse an affine map attribute. - auto affineMap = parseAttribute(); - if (!affineMap) + // Either it is AffineMapAttr or memory space attribute. + Attribute attr = parseAttribute(); + if (!attr) return failure(); - auto affineMapAttr = affineMap.dyn_cast(); - if (!affineMapAttr) - return emitError("expected affine map in memref type"); - map = affineMapAttr.getValue(); + + if (AffineMapAttr affineMapAttr = attr.dyn_cast()) { + map = affineMapAttr.getValue(); + } else if (memorySpace) { + return emitError("multiple memory spaces specified in memref type"); + } else { + memorySpace = attr; + return success(); + } } + if (isUnranked) + return emitError("cannot have affine map for unranked memref type"); + if (memorySpace) + return emitError("expected memory space to be last in memref type"); + if (map.getNumDims() != numDims) { size_t i = affineMapComposition.size(); return emitError(mapLoc, "memref affine map dimension mismatch between ") @@ -285,11 +282,15 @@ } } - if (isUnranked) - return UnrankedMemRefType::get(elementType, memorySpace.getValueOr(0)); + if (isUnranked) { + return UnrankedMemRefType::getChecked( + [&]() -> InFlightDiagnostic { return emitError(); }, elementType, + memorySpace); + } - return MemRefType::get(dimensions, elementType, affineMapComposition, - memorySpace.getValueOr(0)); + return MemRefType::getChecked( + [&]() -> InFlightDiagnostic { return emitError(); }, dimensions, + elementType, affineMapComposition, memorySpace); } /// Parse any type except the function type. diff --git a/mlir/test/Bindings/Python/ir_types.py b/mlir/test/Bindings/Python/ir_types.py --- a/mlir/test/Bindings/Python/ir_types.py +++ b/mlir/test/Bindings/Python/ir_types.py @@ -326,7 +326,7 @@ f32 = F32Type.get() shape = [2, 3] loc = Location.unknown() - memref = MemRefType.get(shape, f32, memory_space=2) + memref = MemRefType.get(shape, f32, memory_space=Attribute.parse("2")) # CHECK: memref type: memref<2x3xf32, 2> print("memref type:", memref) # CHECK: number of affine layout maps: 0 @@ -341,7 +341,7 @@ assert len(memref_layout.layout) == 1 # CHECK: memref layout: (d0, d1) -> (d1, d0) print("memref layout:", memref_layout.layout[0]) - # CHECK: memory space: 0 + # CHECK: memory space: <> print("memory space:", memref_layout.memory_space) none = NoneType.get() @@ -361,7 +361,7 @@ with Context(), Location.unknown(): f32 = F32Type.get() loc = Location.unknown() - unranked_memref = UnrankedMemRefType.get(f32, 2) + unranked_memref = UnrankedMemRefType.get(f32, Attribute.parse("2")) # CHECK: unranked memref type: memref<*xf32, 2> print("unranked memref type:", unranked_memref) try: @@ -388,7 +388,7 @@ none = NoneType.get() try: - memref_invalid = UnrankedMemRefType.get(none, 2) + memref_invalid = UnrankedMemRefType.get(none, Attribute.parse("2")) except ValueError as e: # CHECK: invalid 'Type(none)' and expected floating point, integer, vector # CHECK: or complex type. diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -707,21 +707,24 @@ // CHECK: tensor<*xf32> // MemRef type. + MlirAttribute memSpace2 = mlirIntegerAttrGet(mlirIntegerTypeGet(ctx, 64), 2); MlirType memRef = mlirMemRefTypeContiguousGet( - f32, sizeof(shape) / sizeof(int64_t), shape, 2); + f32, sizeof(shape) / sizeof(int64_t), shape, memSpace2); if (!mlirTypeIsAMemRef(memRef) || mlirMemRefTypeGetNumAffineMaps(memRef) != 0 || - mlirMemRefTypeGetMemorySpace(memRef) != 2) + !mlirAttributeEqual(mlirMemRefTypeGetMemorySpace(memRef), memSpace2)) return 18; mlirTypeDump(memRef); fprintf(stderr, "\n"); // CHECK: memref<2x3xf32, 2> // Unranked MemRef type. - MlirType unrankedMemRef = mlirUnrankedMemRefTypeGet(f32, 4); + MlirAttribute memSpace4 = mlirIntegerAttrGet(mlirIntegerTypeGet(ctx, 64), 4); + MlirType unrankedMemRef = mlirUnrankedMemRefTypeGet(f32, memSpace4); if (!mlirTypeIsAUnrankedMemRef(unrankedMemRef) || mlirTypeIsAMemRef(unrankedMemRef) || - mlirUnrankedMemrefGetMemorySpace(unrankedMemRef) != 4) + !mlirAttributeEqual(mlirUnrankedMemrefGetMemorySpace(unrankedMemRef), + memSpace4)) return 19; mlirTypeDump(unrankedMemRef); fprintf(stderr, "\n"); diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -36,8 +36,8 @@ func @memrefs(memref<2x4xi8, #map7>) // expected-error {{undefined symbol alias id 'map7'}} // ----- -// Test non affine map in memref type. -func @memrefs(memref<2x4xi8, i8>) // expected-error {{expected affine map in memref type}} +// Test unsupported memory space. +func @memrefs(memref<2x4xi8, i8>) // expected-error {{unsupported memory space Attribute}} // ----- // Test non-existent map in map composition of memref type. diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -137,11 +137,35 @@ func private @memrefs_compose_with_id(memref<2x2xi8, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>>) +// Test memref with custom memory space + +// CHECK: func private @memrefs_nomap_nospace(memref<5x6x7xf32>) +func private @memrefs_nomap_nospace(memref<5x6x7xf32>) + +// CHECK: func private @memrefs_map_nospace(memref<5x6x7xf32, #map{{[0-9]+}}>) +func private @memrefs_map_nospace(memref<5x6x7xf32, #map3>) + +// CHECK: func private @memrefs_nomap_intspace(memref<5x6x7xf32, 3>) +func private @memrefs_nomap_intspace(memref<5x6x7xf32, 3>) + +// CHECK: func private @memrefs_map_intspace(memref<5x6x7xf32, #map{{[0-9]+}}, 5>) +func private @memrefs_map_intspace(memref<5x6x7xf32, #map3, 5>) + +// CHECK: func private @memrefs_nomap_strspace(memref<5x6x7xf32, "local">) +func private @memrefs_nomap_strspace(memref<5x6x7xf32, "local">) + +// CHECK: func private @memrefs_map_strspace(memref<5x6x7xf32, #map{{[0-9]+}}, "private">) +func private @memrefs_map_strspace(memref<5x6x7xf32, #map3, "private">) + +// CHECK: func private @memrefs_nomap_dictspace(memref<5x6x7xf32, {memSpace = "special", subIndex = 1 : i64}>) +func private @memrefs_nomap_dictspace(memref<5x6x7xf32, {memSpace = "special", subIndex = 1}>) + +// CHECK: func private @memrefs_map_dictspace(memref<5x6x7xf32, #map{{[0-9]+}}, {memSpace = "special", subIndex = 3 : i64}>) +func private @memrefs_map_dictspace(memref<5x6x7xf32, #map3, {memSpace = "special", subIndex = 3}>) // CHECK: func private @complex_types(complex) -> complex func private @complex_types(complex) -> complex - // CHECK: func private @memref_with_index_elems(memref<1x?xindex>) func private @memref_with_index_elems(memref<1x?xindex>) diff --git a/mlir/unittests/IR/ShapedTypeTest.cpp b/mlir/unittests/IR/ShapedTypeTest.cpp --- a/mlir/unittests/IR/ShapedTypeTest.cpp +++ b/mlir/unittests/IR/ShapedTypeTest.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/IR/AffineMap.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/DialectInterface.h" @@ -23,7 +24,7 @@ Type i32 = IntegerType::get(&context, 32); Type f32 = FloatType::getF32(&context); - int memSpace = 7; + Attribute memSpace = IntegerAttr::get(IntegerType::get(&context, 64), 7); Type memrefOriginalType = i32; llvm::SmallVector memrefOriginalShape({10, 20}); AffineMap map = makeStridedLinearLayoutMap({2, 3}, 5, &context);