diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -9,6 +9,7 @@ #ifndef MLIR_IR_BUILTINTYPES_H #define MLIR_IR_BUILTINTYPES_H +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Types.h" namespace llvm { @@ -287,6 +288,7 @@ /// Returns the memory space in which data referred to by this memref resides. unsigned getMemorySpace() const; + Attribute getMemorySpaceAttr() const; }; //===----------------------------------------------------------------------===// @@ -307,12 +309,11 @@ explicit Builder(MemRefType other) : shape(other.getShape()), elementType(other.getElementType()), affineMaps(other.getAffineMaps()), - memorySpace(other.getMemorySpace()) {} + memorySpace(other.getMemorySpaceAttr()) {} // Build from scratch. Builder(ArrayRef shape, Type elementType) - : shape(shape), elementType(elementType), affineMaps(), memorySpace(0) { - } + : shape(shape), elementType(elementType), affineMaps() {} Builder &setShape(ArrayRef newShape) { shape = newShape; @@ -329,7 +330,8 @@ return *this; } - Builder &setMemorySpace(unsigned newMemorySpace) { + Builder &setMemorySpace(unsigned newMemorySpace); + Builder &setMemorySpace(Attribute newMemorySpace) { memorySpace = newMemorySpace; return *this; } @@ -342,7 +344,7 @@ ArrayRef shape; Type elementType; ArrayRef affineMaps; - unsigned memorySpace; + Attribute memorySpace; }; using Base::Base; @@ -354,6 +356,9 @@ static MemRefType get(ArrayRef shape, Type elementType, ArrayRef affineMapComposition = {}, unsigned memorySpace = 0); + static MemRefType get(ArrayRef shape, Type elementType, + ArrayRef affineMapComposition, + Attribute memorySpace); /// Get or create a new MemRefType based on shape, element type, affine /// map composition, and memory space declared at the given location. @@ -365,6 +370,10 @@ Type elementType, ArrayRef affineMapComposition, unsigned memorySpace); + static MemRefType getChecked(Location location, ArrayRef shape, + Type elementType, + ArrayRef affineMapComposition, + Attribute memorySpace); ArrayRef getShape() const; @@ -385,7 +394,7 @@ /// emit detailed error messages. static MemRefType getImpl(ArrayRef shape, Type elementType, ArrayRef affineMapComposition, - unsigned memorySpace, Optional location); + Attribute memorySpace, Optional location); using Base::getImpl; }; @@ -403,6 +412,7 @@ /// Get or create a new UnrankedMemRefType of the provided element /// type and memory space static UnrankedMemRefType get(Type elementType, unsigned memorySpace); + static UnrankedMemRefType get(Type elementType, Attribute memorySpace); /// Get or create a new UnrankedMemRefType of the provided element /// type and memory space declared at the given, potentially unknown, @@ -410,11 +420,13 @@ /// ill-formed, emit errors and return a nullptr-wrapping type. static UnrankedMemRefType getChecked(Location location, Type elementType, unsigned memorySpace); + static UnrankedMemRefType getChecked(Location location, Type elementType, + Attribute memorySpace); /// Verify the construction of a unranked memref type. static LogicalResult verifyConstructionInvariants(Location loc, Type elementType, - unsigned memorySpace); + Attribute memorySpace); ArrayRef getShape() const { return llvm::None; } }; diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1803,6 +1803,14 @@ printDenseElementsAttrImpl(attr.isSplat(), attr.getType(), os, printFn); } +static bool isDefaultMemorySpace(Attribute memorySpace) { + if (memorySpace == nullptr) + return true; + if (!memorySpace.isa()) + return false; + return memorySpace.cast().getValue() == 0; +} + void ModulePrinter::printType(Type type) { if (!type) { os << "<>"; @@ -1882,16 +1890,20 @@ printAttribute(AffineMapAttr::get(map)); } // Only print the memory space if it is the non-default one. - if (memrefTy.getMemorySpace()) - os << ", " << memrefTy.getMemorySpace(); + if (!isDefaultMemorySpace(memrefTy.getMemorySpaceAttr())) { + os << ", "; + printAttribute(memrefTy.getMemorySpaceAttr(), AttrTypeElision::May); + } os << '>'; }) .Case([&](UnrankedMemRefType memrefTy) { os << "memref<*x"; printType(memrefTy.getElementType()); // Only print the memory space if it is the non-default one. - if (memrefTy.getMemorySpace()) - os << ", " << memrefTy.getMemorySpace(); + if (!isDefaultMemorySpace(memrefTy.getMemorySpaceAttr())) { + os << ", "; + printAttribute(memrefTy.getMemorySpaceAttr(), AttrTypeElision::May); + } os << '>'; }) .Case([&](ComplexType complexTy) { diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -400,6 +400,15 @@ //===----------------------------------------------------------------------===// unsigned BaseMemRefType::getMemorySpace() const { + auto attr = getMemorySpaceAttr(); + if (attr == nullptr) + return 0; + assert(attr.isa() && + "Using old MemorySpace API with non-Integer attribute"); + return static_cast(attr.cast().getInt()); +} + +Attribute BaseMemRefType::getMemorySpaceAttr() const { return static_cast(impl)->memorySpace; } @@ -407,6 +416,25 @@ // MemRefType //===----------------------------------------------------------------------===// +static Attribute wrapMemorySpace(unsigned memorySpace, MLIRContext *ctx) { + if (memorySpace == 0) + return nullptr; + return IntegerAttr::get(IntegerType::get(ctx, 64), memorySpace); +} + +static Attribute normalizeMemorySpace(Attribute memorySpace) { + if (memorySpace != nullptr && memorySpace.isa() && + memorySpace.cast().getValue() == 0) + return nullptr; + return memorySpace; +} + +MemRefType::Builder & +MemRefType::Builder::setMemorySpace(unsigned newMemorySpace) { + memorySpace = wrapMemorySpace(newMemorySpace, elementType.getContext()); + return *this; +} + /// Get or create a new MemRefType based on shape, element type, affine /// map composition, and memory space. Assumes the arguments define a /// well-formed MemRef type. Use getChecked to gracefully handle MemRefType @@ -414,6 +442,12 @@ MemRefType MemRefType::get(ArrayRef shape, Type elementType, ArrayRef affineMapComposition, unsigned memorySpace) { + auto memorySpaceAttr = wrapMemorySpace(memorySpace, elementType.getContext()); + return get(shape, elementType, affineMapComposition, memorySpaceAttr); +} +MemRefType MemRefType::get(ArrayRef shape, Type elementType, + ArrayRef affineMapComposition, + Attribute memorySpace) { auto result = getImpl(shape, elementType, affineMapComposition, memorySpace, /*location=*/llvm::None); assert(result && "Failed to construct instance of MemRefType."); @@ -430,6 +464,14 @@ Type elementType, ArrayRef affineMapComposition, unsigned memorySpace) { + auto memorySpaceAttr = wrapMemorySpace(memorySpace, elementType.getContext()); + return getChecked(location, shape, elementType, affineMapComposition, + memorySpaceAttr); +} +MemRefType MemRefType::getChecked(Location location, ArrayRef shape, + Type elementType, + ArrayRef affineMapComposition, + Attribute memorySpace) { return getImpl(shape, elementType, affineMapComposition, memorySpace, location); } @@ -440,7 +482,7 @@ /// pass in an instance of UnknownLoc. MemRefType MemRefType::getImpl(ArrayRef shape, Type elementType, ArrayRef affineMapComposition, - unsigned memorySpace, + Attribute memorySpace, Optional location) { auto *context = elementType.getContext(); @@ -485,8 +527,14 @@ cleanedAffineMapComposition.push_back(map); } + if (memorySpace != nullptr && memorySpace.isa()) { + if (location) + emitError(*location, "AffineMap can't be used as memory space"); + return nullptr; + } + return Base::get(context, shape, elementType, cleanedAffineMapComposition, - memorySpace); + normalizeMemorySpace(memorySpace)); } ArrayRef MemRefType::getShape() const { return getImpl()->getShape(); } @@ -501,20 +549,38 @@ UnrankedMemRefType UnrankedMemRefType::get(Type elementType, unsigned memorySpace) { - return Base::get(elementType.getContext(), elementType, memorySpace); + auto memorySpaceAttr = wrapMemorySpace(memorySpace, elementType.getContext()); + return get(elementType, memorySpaceAttr); +} +UnrankedMemRefType UnrankedMemRefType::get(Type elementType, + Attribute memorySpace) { + return Base::get(elementType.getContext(), elementType, + normalizeMemorySpace(memorySpace)); } UnrankedMemRefType UnrankedMemRefType::getChecked(Location location, Type elementType, unsigned memorySpace) { - return Base::getChecked(location, elementType, memorySpace); + auto memorySpaceAttr = wrapMemorySpace(memorySpace, elementType.getContext()); + return getChecked(location, elementType, memorySpaceAttr); +} +UnrankedMemRefType UnrankedMemRefType::getChecked(Location location, + Type elementType, + Attribute memorySpace) { + return Base::getChecked(location, elementType, + normalizeMemorySpace(memorySpace)); } LogicalResult UnrankedMemRefType::verifyConstructionInvariants(Location loc, Type elementType, - unsigned memorySpace) { + Attribute memorySpace) { if (!BaseMemRefType::isValidElementType(elementType)) return emitError(loc, "invalid memref element type"); + + if (memorySpace != nullptr && memorySpace.isa()) { + return emitError(loc, "AffineMap can't be used as memory space"); + } + return success(); } diff --git a/mlir/lib/IR/TypeDetail.h b/mlir/lib/IR/TypeDetail.h --- a/mlir/lib/IR/TypeDetail.h +++ b/mlir/lib/IR/TypeDetail.h @@ -183,17 +183,17 @@ }; struct BaseMemRefTypeStorage : public ShapedTypeStorage { - BaseMemRefTypeStorage(Type elementType, unsigned memorySpace) + BaseMemRefTypeStorage(Type elementType, Attribute memorySpace) : ShapedTypeStorage(elementType), memorySpace(memorySpace) {} /// Memory space in which data referenced by memref resides. - const unsigned memorySpace; + Attribute memorySpace; }; struct MemRefTypeStorage : public BaseMemRefTypeStorage { MemRefTypeStorage(unsigned shapeSize, Type elementType, const int64_t *shapeElements, const unsigned numAffineMaps, - AffineMap const *affineMapList, const unsigned memorySpace) + AffineMap const *affineMapList, Attribute memorySpace) : BaseMemRefTypeStorage(elementType, memorySpace), shapeElements(shapeElements), shapeSize(shapeSize), numAffineMaps(numAffineMaps), affineMapList(affineMapList) {} @@ -202,7 +202,7 @@ // MemRefs are uniqued based on their shape, element type, affine map // composition, and memory space. using KeyTy = - std::tuple, Type, ArrayRef, unsigned>; + std::tuple, Type, ArrayRef, Attribute>; bool operator==(const KeyTy &key) const { return key == KeyTy(getShape(), elementType, getAffineMaps(), memorySpace); } @@ -246,11 +246,11 @@ /// Only element type and memory space are known struct UnrankedMemRefTypeStorage : public BaseMemRefTypeStorage { - UnrankedMemRefTypeStorage(Type elementTy, const unsigned memorySpace) + UnrankedMemRefTypeStorage(Type elementTy, Attribute memorySpace) : BaseMemRefTypeStorage(elementTy, memorySpace) {} /// The hash key used for uniquing. - using KeyTy = std::tuple; + using KeyTy = std::tuple; bool operator==(const KeyTy &key) const { return key == KeyTy(elementType, memorySpace); } diff --git a/mlir/lib/Parser/TypeParser.cpp b/mlir/lib/Parser/TypeParser.cpp --- a/mlir/lib/Parser/TypeParser.cpp +++ b/mlir/lib/Parser/TypeParser.cpp @@ -225,27 +225,14 @@ // Parse semi-affine-map-composition. SmallVector affineMapComposition; - Optional memorySpace; + Attribute memorySpace; unsigned numDims = dimensions.size(); auto parseElt = [&]() -> ParseResult { - // Check for the memory space. - if (getToken().is(Token::integer)) { - if (memorySpace) - return emitError("multiple memory spaces specified in memref type"); - memorySpace = getToken().getUnsignedIntegerValue(); - if (!memorySpace.hasValue()) - return emitError("invalid memory space in memref type"); - consumeToken(Token::integer); - return success(); - } - if (isUnranked) - return emitError("cannot have affine map for unranked memref type"); - if (memorySpace) - return emitError("expected memory space to be last in memref type"); - AffineMap map; llvm::SMLoc mapLoc = getToken().getLoc(); + + // Check for AffineMap as offset/strides if (getToken().is(Token::kw_offset)) { int64_t offset; SmallVector strides; @@ -254,16 +241,26 @@ // Construct strided affine map. map = makeStridedLinearLayoutMap(strides, offset, state.context); } else { - // Parse an affine map attribute. - auto affineMap = parseAttribute(); - if (!affineMap) + // Either it is AffineMapAttr or memory space attribute + auto attr = parseAttribute(); + if (!attr) return failure(); - auto affineMapAttr = affineMap.dyn_cast(); - if (!affineMapAttr) - return emitError("expected affine map in memref type"); - map = affineMapAttr.getValue(); + + if (auto affineMapAttr = attr.dyn_cast()) { + map = affineMapAttr.getValue(); + } else { + if (memorySpace) + return emitError("multiple memory spaces specified in memref type"); + memorySpace = attr; + return success(); + } } + if (isUnranked) + return emitError("cannot have affine map for unranked memref type"); + if (memorySpace) + return emitError("expected memory space to be last in memref type"); + if (map.getNumDims() != numDims) { size_t i = affineMapComposition.size(); return emitError(mapLoc, "memref affine map dimension mismatch between ") @@ -287,10 +284,10 @@ } if (isUnranked) - return UnrankedMemRefType::get(elementType, memorySpace.getValueOr(0)); + return UnrankedMemRefType::get(elementType, memorySpace); return MemRefType::get(dimensions, elementType, affineMapComposition, - memorySpace.getValueOr(0)); + memorySpace); } /// Parse any type except the function type. diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -35,10 +35,6 @@ // Test non-existent map in memref type. func @memrefs(memref<2x4xi8, #map7>) // expected-error {{undefined symbol alias id 'map7'}} -// ----- -// Test non affine map in memref type. -func @memrefs(memref<2x4xi8, i8>) // expected-error {{expected affine map in memref type}} - // ----- // Test non-existent map in map composition of memref type. #map0 = affine_map<(d0, d1) -> (d0, d1)> diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -137,11 +137,35 @@ func private @memrefs_compose_with_id(memref<2x2xi8, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>>) +// Test memref with custom memory space + +// CHECK: func private @memrefs_nomap_nospace(memref<5x6x7xf32>) +func private @memrefs_nomap_nospace(memref<5x6x7xf32>) + +// CHECK: func private @memrefs_map_nospace(memref<5x6x7xf32, #map{{[0-9]+}}>) +func private @memrefs_map_nospace(memref<5x6x7xf32, #map3>) + +// CHECK: func private @memrefs_nomap_intspace(memref<5x6x7xf32, 3>) +func private @memrefs_nomap_intspace(memref<5x6x7xf32, 3>) + +// CHECK: func private @memrefs_map_intspace(memref<5x6x7xf32, #map{{[0-9]+}}, 5>) +func private @memrefs_map_intspace(memref<5x6x7xf32, #map3, 5>) + +// CHECK: func private @memrefs_nomap_strspace(memref<5x6x7xf32, "local">) +func private @memrefs_nomap_strspace(memref<5x6x7xf32, "local">) + +// CHECK: func private @memrefs_map_strspace(memref<5x6x7xf32, #map{{[0-9]+}}, "private">) +func private @memrefs_map_strspace(memref<5x6x7xf32, #map3, "private">) + +// CHECK: func private @memrefs_nomap_dictspace(memref<5x6x7xf32, {memSpace = "special", subIndex = 1 : i64}>) +func private @memrefs_nomap_dictspace(memref<5x6x7xf32, {memSpace = "special", subIndex = 1}>) + +// CHECK: func private @memrefs_map_dictspace(memref<5x6x7xf32, #map{{[0-9]+}}, {memSpace = "special", subIndex = 3 : i64}>) +func private @memrefs_map_dictspace(memref<5x6x7xf32, #map3, {memSpace = "special", subIndex = 3}>) // CHECK: func private @complex_types(complex) -> complex func private @complex_types(complex) -> complex - // CHECK: func private @memref_with_index_elems(memref<1x?xindex>) func private @memref_with_index_elems(memref<1x?xindex>)