diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h --- a/mlir/include/mlir-c/BuiltinTypes.h +++ b/mlir/include/mlir-c/BuiltinTypes.h @@ -225,38 +225,38 @@ /// same context as element type. The type is owned by the context. MLIR_CAPI_EXPORTED MlirType mlirMemRefTypeGet( MlirType elementType, intptr_t rank, const int64_t *shape, intptr_t numMaps, - MlirAffineMap const *affineMaps, unsigned memorySpace); + MlirAffineMap const *affineMaps, MlirAttribute memorySpace); /// Same as "mlirMemRefTypeGet" but returns a nullptr-wrapping MlirType o /// illegal arguments, emitting appropriate diagnostics. MLIR_CAPI_EXPORTED MlirType mlirMemRefTypeGetChecked( MlirType elementType, intptr_t rank, const int64_t *shape, intptr_t numMaps, - MlirAffineMap const *affineMaps, unsigned memorySpace, MlirLocation loc); + MlirAffineMap const *affineMaps, MlirAttribute memorySpace, + MlirLocation loc); /// Creates a MemRef type with the given rank, shape, memory space and element /// type in the same context as the element type. The type has no affine maps, /// i.e. represents a default row-major contiguous memref. The type is owned by /// the context. -MLIR_CAPI_EXPORTED MlirType mlirMemRefTypeContiguousGet(MlirType elementType, - intptr_t rank, - const int64_t *shape, - unsigned memorySpace); +MLIR_CAPI_EXPORTED MlirType +mlirMemRefTypeContiguousGet(MlirType elementType, intptr_t rank, + const int64_t *shape, MlirAttribute memorySpace); /// Same as "mlirMemRefTypeContiguousGet" but returns a nullptr wrapping /// MlirType on illegal arguments, emitting appropriate diagnostics. MLIR_CAPI_EXPORTED MlirType mlirMemRefTypeContiguousGetChecked( - MlirType elementType, intptr_t rank, const int64_t *shape, - unsigned memorySpace, MlirLocation loc); + MlirType elementType, intptr_t rank, int64_t *shape, + MlirAttribute memorySpace, MlirLocation loc); /// Creates an Unranked MemRef type with the given element type and in the given /// memory space. The type is owned by the context of element type. -MLIR_CAPI_EXPORTED MlirType mlirUnrankedMemRefTypeGet(MlirType elementType, - unsigned memorySpace); +MLIR_CAPI_EXPORTED MlirType +mlirUnrankedMemRefTypeGet(MlirType elementType, MlirAttribute memorySpace); /// Same as "mlirUnrankedMemRefTypeGet" but returns a nullptr wrapping /// MlirType on illegal arguments, emitting appropriate diagnostics. MLIR_CAPI_EXPORTED MlirType mlirUnrankedMemRefTypeGetChecked( - MlirType elementType, unsigned memorySpace, MlirLocation loc); + MlirType elementType, MlirAttribute memorySpace, MlirLocation loc); /// Returns the number of affine layout maps in the given MemRef type. MLIR_CAPI_EXPORTED intptr_t mlirMemRefTypeGetNumAffineMaps(MlirType type); @@ -266,10 +266,11 @@ intptr_t pos); /// Returns the memory space of the given MemRef type. -MLIR_CAPI_EXPORTED unsigned mlirMemRefTypeGetMemorySpace(MlirType type); +MLIR_CAPI_EXPORTED MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type); /// Returns the memory spcae of the given Unranked MemRef type. -MLIR_CAPI_EXPORTED unsigned mlirUnrankedMemrefGetMemorySpace(MlirType type); +MLIR_CAPI_EXPORTED MlirAttribute +mlirUnrankedMemrefGetMemorySpace(MlirType type); //===----------------------------------------------------------------------===// // Tuple type. diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h @@ -117,11 +117,6 @@ getSrcMap().getNumInputs()}; } - /// Returns the memory space of the src memref. - unsigned getSrcMemorySpace() { - return getSrcMemRef().getType().cast().getMemorySpace(); - } - /// Returns the operand index of the dst memref. unsigned getDstMemRefOperandIndex() { return getSrcMemRefOperandIndex() + 1 + getSrcMap().getNumInputs(); @@ -139,9 +134,9 @@ } /// Returns the memory space of the src memref. - unsigned getDstMemorySpace() { - return getDstMemRef().getType().cast().getMemorySpace(); - } + unsigned getSrcMemorySpace(); + /// Returns the operand index of the dst memref. + unsigned getDstMemorySpace(); /// Returns the affine map used to access the dst memref. AffineMap getDstMap() { return getDstMapAttr().getValue(); } diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h @@ -176,12 +176,11 @@ unsigned getDstMemRefRank() { return getDstMemRef().getType().cast().getRank(); } - unsigned getSrcMemorySpace() { - return getSrcMemRef().getType().cast().getMemorySpace(); - } - unsigned getDstMemorySpace() { - return getDstMemRef().getType().cast().getMemorySpace(); - } + + /// Returns the memory space of the src memref. + unsigned getSrcMemorySpace(); + /// Returns the operand index of the dst memref. + unsigned getDstMemorySpace(); // Returns the destination memref indices for this DMA operation. operand_range getDstIndices() { diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -9,6 +9,7 @@ #ifndef MLIR_IR_BUILTINTYPES_H #define MLIR_IR_BUILTINTYPES_H +#include "mlir/IR/Attributes.h" #include "mlir/IR/Types.h" namespace llvm { @@ -291,7 +292,10 @@ static bool classof(Type type); /// Returns the memory space in which data referred to by this memref resides. - unsigned getMemorySpace() const; + Attribute getMemorySpace() const; + + /// Checks if the memorySpace has supported Attribute type. + static bool isSupportedMemorySpace(Attribute memorySpace); }; //===----------------------------------------------------------------------===// @@ -316,8 +320,7 @@ // Build from scratch. Builder(ArrayRef shape, Type elementType) - : shape(shape), elementType(elementType), affineMaps(), memorySpace(0) { - } + : shape(shape), elementType(elementType), affineMaps() {} Builder &setShape(ArrayRef newShape) { shape = newShape; @@ -334,7 +337,7 @@ return *this; } - Builder &setMemorySpace(unsigned newMemorySpace) { + Builder &setMemorySpace(Attribute newMemorySpace) { memorySpace = newMemorySpace; return *this; } @@ -347,7 +350,7 @@ ArrayRef shape; Type elementType; ArrayRef affineMaps; - unsigned memorySpace; + Attribute memorySpace; }; using Base::Base; @@ -358,7 +361,7 @@ /// construction failures. static MemRefType get(ArrayRef shape, Type elementType, ArrayRef affineMapComposition = {}, - unsigned memorySpace = 0); + Attribute memorySpace = {}); /// Get or create a new MemRefType based on shape, element type, affine /// map composition, and memory space declared at the given location. @@ -369,7 +372,7 @@ static MemRefType getChecked(Location location, ArrayRef shape, Type elementType, ArrayRef affineMapComposition, - unsigned memorySpace); + Attribute memorySpace); ArrayRef getShape() const; @@ -390,7 +393,7 @@ /// emit detailed error messages. static MemRefType getImpl(ArrayRef shape, Type elementType, ArrayRef affineMapComposition, - unsigned memorySpace, Optional location); + Attribute memorySpace, Optional location); using Base::getImpl; }; @@ -407,19 +410,19 @@ /// Get or create a new UnrankedMemRefType of the provided element /// type and memory space - static UnrankedMemRefType get(Type elementType, unsigned memorySpace); + static UnrankedMemRefType get(Type elementType, Attribute memorySpace); /// Get or create a new UnrankedMemRefType of the provided element /// type and memory space declared at the given, potentially unknown, /// location. If the UnrankedMemRefType defined by the arguments would be /// ill-formed, emit errors and return a nullptr-wrapping type. static UnrankedMemRefType getChecked(Location location, Type elementType, - unsigned memorySpace); + Attribute memorySpace); /// Verify the construction of a unranked memref type. static LogicalResult verifyConstructionInvariants(Location loc, Type elementType, - unsigned memorySpace); + Attribute memorySpace); ArrayRef getShape() const { return llvm::None; } }; diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -2845,16 +2845,20 @@ c.def_static( "get", [](std::vector shape, PyType &elementType, - std::vector layout, unsigned memorySpace, + std::vector layout, PyAttribute *memorySpace, DefaultingPyLocation loc) { SmallVector maps; maps.reserve(layout.size()); for (PyAffineMap &map : layout) maps.push_back(map); - MlirType t = mlirMemRefTypeGetChecked(elementType, shape.size(), - shape.data(), maps.size(), - maps.data(), memorySpace, loc); + MlirAttribute memSpaceAttr = {}; + if (memorySpace) + memSpaceAttr = *memorySpace; + + MlirType t = mlirMemRefTypeGetChecked( + elementType, shape.size(), shape.data(), maps.size(), + maps.data(), memSpaceAttr, loc); // TODO: Rework error reporting once diagnostic engine is exposed // in C API. if (mlirTypeIsNull(t)) { @@ -2869,14 +2873,15 @@ return PyMemRefType(elementType.getContext(), t); }, py::arg("shape"), py::arg("element_type"), - py::arg("layout") = py::list(), py::arg("memory_space") = 0, + py::arg("layout") = py::list(), py::arg("memory_space") = py::none(), py::arg("loc") = py::none(), "Create a memref type") .def_property_readonly("layout", &PyMemRefType::getLayout, "The list of layout maps of the MemRef type.") .def_property_readonly( "memory_space", - [](PyMemRefType &self) -> unsigned { - return mlirMemRefTypeGetMemorySpace(self); + [](PyMemRefType &self) -> PyAttribute { + MlirAttribute a = mlirMemRefTypeGetMemorySpace(self); + return PyAttribute(self.getContext(), a); }, "Returns the memory space of the given MemRef type."); } @@ -2928,7 +2933,7 @@ static void bindDerived(ClassTy &c) { c.def_static( "get", - [](PyType &elementType, unsigned memorySpace, + [](PyType &elementType, PyAttribute &memorySpace, DefaultingPyLocation loc) { MlirType t = mlirUnrankedMemRefTypeGetChecked(elementType, memorySpace, loc); @@ -2949,8 +2954,9 @@ py::arg("loc") = py::none(), "Create a unranked memref type") .def_property_readonly( "memory_space", - [](PyUnrankedMemRefType &self) -> unsigned { - return mlirUnrankedMemrefGetMemorySpace(self); + [](PyUnrankedMemRefType &self) -> PyAttribute { + MlirAttribute a = mlirMemRefTypeGetMemorySpace(self); + return PyAttribute(self.getContext(), a); }, "Returns the memory space of the given Unranked MemRef type."); } diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -223,40 +223,40 @@ MlirType mlirMemRefTypeGet(MlirType elementType, intptr_t rank, const int64_t *shape, intptr_t numMaps, MlirAffineMap const *affineMaps, - unsigned memorySpace) { + MlirAttribute memorySpace) { SmallVector maps; (void)unwrapList(numMaps, affineMaps, maps); return wrap( MemRefType::get(llvm::makeArrayRef(shape, static_cast(rank)), - unwrap(elementType), maps, memorySpace)); + unwrap(elementType), maps, unwrap(memorySpace))); } MlirType mlirMemRefTypeGetChecked(MlirType elementType, intptr_t rank, const int64_t *shape, intptr_t numMaps, MlirAffineMap const *affineMaps, - unsigned memorySpace, MlirLocation loc) { + MlirAttribute memorySpace, MlirLocation loc) { SmallVector maps; (void)unwrapList(numMaps, affineMaps, maps); return wrap(MemRefType::getChecked( unwrap(loc), llvm::makeArrayRef(shape, static_cast(rank)), - unwrap(elementType), maps, memorySpace)); + unwrap(elementType), maps, unwrap(memorySpace))); } MlirType mlirMemRefTypeContiguousGet(MlirType elementType, intptr_t rank, const int64_t *shape, - unsigned memorySpace) { + MlirAttribute memorySpace) { return wrap( MemRefType::get(llvm::makeArrayRef(shape, static_cast(rank)), - unwrap(elementType), llvm::None, memorySpace)); + unwrap(elementType), llvm::None, unwrap(memorySpace))); } MlirType mlirMemRefTypeContiguousGetChecked(MlirType elementType, intptr_t rank, const int64_t *shape, - unsigned memorySpace, + MlirAttribute memorySpace, MlirLocation loc) { return wrap(MemRefType::getChecked( unwrap(loc), llvm::makeArrayRef(shape, static_cast(rank)), - unwrap(elementType), llvm::None, memorySpace)); + unwrap(elementType), llvm::None, unwrap(memorySpace))); } intptr_t mlirMemRefTypeGetNumAffineMaps(MlirType type) { @@ -268,27 +268,29 @@ return wrap(unwrap(type).cast().getAffineMaps()[pos]); } -unsigned mlirMemRefTypeGetMemorySpace(MlirType type) { - return unwrap(type).cast().getMemorySpace(); +MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type) { + return wrap(unwrap(type).cast().getMemorySpace()); } bool mlirTypeIsAUnrankedMemRef(MlirType type) { return unwrap(type).isa(); } -MlirType mlirUnrankedMemRefTypeGet(MlirType elementType, unsigned memorySpace) { - return wrap(UnrankedMemRefType::get(unwrap(elementType), memorySpace)); +MlirType mlirUnrankedMemRefTypeGet(MlirType elementType, + MlirAttribute memorySpace) { + return wrap( + UnrankedMemRefType::get(unwrap(elementType), unwrap(memorySpace))); } MlirType mlirUnrankedMemRefTypeGetChecked(MlirType elementType, - unsigned memorySpace, + MlirAttribute memorySpace, MlirLocation loc) { return wrap(UnrankedMemRefType::getChecked(unwrap(loc), unwrap(elementType), - memorySpace)); + unwrap(memorySpace))); } -unsigned mlirUnrankedMemrefGetMemorySpace(MlirType type) { - return unwrap(type).cast().getMemorySpace(); +MlirAttribute mlirUnrankedMemrefGetMemorySpace(MlirType type) { + return wrap(unwrap(type).cast().getMemorySpace()); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -118,9 +118,13 @@ /// converter drops the private memory space to support the use case above. LLVMTypeConverter converter(m.getContext(), options); converter.addConversion([&](MemRefType type) -> Optional { - if (type.getMemorySpace() != gpu::GPUDialect::getPrivateAddressSpace()) + unsigned memSpaceInd = 0; + if (Attribute memSpace = type.getMemorySpace()) + memSpaceInd = memSpace.cast().getInt(); + if (memSpaceInd != gpu::GPUDialect::getPrivateAddressSpace()) return llvm::None; - return converter.convertType(MemRefType::Builder(type).setMemorySpace(0)); + return converter.convertType( + MemRefType::Builder(type).setMemorySpace(nullptr)); }); OwningRewritePatternList patterns, llvmPatterns; diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -316,7 +316,10 @@ Type elementType = unwrap(convertType(type.getElementType())); if (!elementType) return {}; - auto ptrTy = LLVM::LLVMPointerType::get(elementType, type.getMemorySpace()); + unsigned memSpaceInd = 0; + if (Attribute memSpace = type.getMemorySpace()) + memSpaceInd = memSpace.cast().getInt(); + auto ptrTy = LLVM::LLVMPointerType::get(elementType, memSpaceInd); auto indexTy = getIndexType(); SmallVector results = {ptrTy, ptrTy, indexTy}; @@ -388,7 +391,10 @@ Type elementType = unwrap(convertType(type.getElementType())); if (!elementType) return {}; - return LLVM::LLVMPointerType::get(elementType, type.getMemorySpace()); + unsigned memSpaceInd = 0; + if (Attribute memSpace = type.getMemorySpace()) + memSpaceInd = memSpace.cast().getInt(); + return LLVM::LLVMPointerType::get(elementType, memSpaceInd); } /// Convert an n-D vector type to an LLVM vector type via (n-1)-D array type @@ -1081,7 +1087,10 @@ Type ConvertToLLVMPattern::getElementPtrType(MemRefType type) const { auto elementType = type.getElementType(); auto structElementType = unwrap(typeConverter->convertType(elementType)); - return LLVM::LLVMPointerType::get(structElementType, type.getMemorySpace()); + unsigned memSpaceInd = 0; + if (Attribute memSpace = type.getMemorySpace()) + memSpaceInd = memSpace.cast().getInt(); + return LLVM::LLVMPointerType::get(structElementType, memSpaceInd); } void ConvertToLLVMPattern::getMemRefDescriptorSizes( @@ -1896,7 +1905,10 @@ Value alignedPtr = allocatedPtr; if (alignment) { - auto intPtrType = getIntPtrType(memRefType.getMemorySpace()); + unsigned memSpaceInd = 0; + if (Attribute memSpace = memRefType.getMemorySpace()) + memSpaceInd = memSpace.cast().getInt(); + auto intPtrType = getIntPtrType(memSpaceInd); // Compute the aligned type pointer. Value allocatedInt = rewriter.create(loc, intPtrType, allocatedPtr); @@ -2242,9 +2254,13 @@ initialValue = elementsAttr.getValue({}); } + unsigned memSpaceInd = 0; + if (Attribute memSpace = type.getMemorySpace()) + memSpaceInd = memSpace.cast().getInt(); + rewriter.replaceOpWithNewOp( global, arrayTy, global.constant(), linkage, global.sym_name(), - initialValue, type.getMemorySpace()); + initialValue, memSpaceInd); return success(); } }; @@ -2263,17 +2279,20 @@ Operation *op) const override { auto getGlobalOp = cast(op); MemRefType type = getGlobalOp.result().getType().cast(); - unsigned memSpace = type.getMemorySpace(); + unsigned memSpaceInd = 0; + if (Attribute memSpace = type.getMemorySpace()) + memSpaceInd = memSpace.cast().getInt(); Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter()); auto addressOf = rewriter.create( - loc, LLVM::LLVMPointerType::get(arrayTy, memSpace), getGlobalOp.name()); + loc, LLVM::LLVMPointerType::get(arrayTy, memSpaceInd), + getGlobalOp.name()); // Get the address of the first element in the array by creating a GEP with // the address of the GV as the base, and (rank + 1) number of 0 indices. Type elementType = unwrap(typeConverter->convertType(type.getElementType())); - Type elementPtrType = LLVM::LLVMPointerType::get(elementType, memSpace); + Type elementPtrType = LLVM::LLVMPointerType::get(elementType, memSpaceInd); SmallVector operands = {addressOf}; operands.insert(operands.end(), type.getRank() + 1, @@ -2283,7 +2302,7 @@ // We do not expect the memref obtained using `get_global_memref` to be // ever deallocated. Set the allocated pointer to be known bad value to // help debug if that ever happens. - auto intPtrType = getIntPtrType(memSpace); + auto intPtrType = getIntPtrType(memSpaceInd); Value deadBeefConst = createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef); auto deadBeefPtr = @@ -2458,12 +2477,14 @@ return; } - unsigned memorySpace = - operandType.cast().getMemorySpace(); + unsigned memSpaceInd = 0; + if (Attribute memSpace = + operandType.cast().getMemorySpace()) + memSpaceInd = memSpace.cast().getInt(); Type elementType = operandType.cast().getElementType(); Type llvmElementType = unwrap(typeConverter.convertType(elementType)); Type elementPtrPtrType = LLVM::LLVMPointerType::get( - LLVM::LLVMPointerType::get(llvmElementType, memorySpace)); + LLVM::LLVMPointerType::get(llvmElementType, memSpaceInd)); // Extract pointer to the underlying ranked memref descriptor and cast it to // ElemType**. @@ -2588,7 +2609,9 @@ // Extract address space and element type. auto targetType = reshapeOp.getResult().getType().cast(); - unsigned addressSpace = targetType.getMemorySpace(); + unsigned memSpaceInd = 0; + if (Attribute memSpace = targetType.getMemorySpace()) + memSpaceInd = memSpace.cast().getInt(); Type elementType = targetType.getElementType(); // Create the unranked memref descriptor that holds the ranked one. The @@ -2612,7 +2635,7 @@ // Set pointers and offset. Type llvmElementType = unwrap(typeConverter->convertType(elementType)); auto elementPtrPtrType = LLVM::LLVMPointerType::get( - LLVM::LLVMPointerType::get(llvmElementType, addressSpace)); + LLVM::LLVMPointerType::get(llvmElementType, memSpaceInd)); UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr, elementPtrPtrType, allocatedPtr); UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, *getTypeConverter(), @@ -2748,7 +2771,9 @@ auto unrankedMemRefType = operandType.cast(); auto scalarMemRefType = MemRefType::get({}, unrankedMemRefType.getElementType()); - unsigned addressSpace = unrankedMemRefType.getMemorySpace(); + unsigned memSpaceInd = 0; + if (Attribute memSpace = unrankedMemRefType.getMemorySpace()) + memSpaceInd = memSpace.cast().getInt(); // Extract pointer to the underlying ranked descriptor and bitcast it to a // memref descriptor pointer to minimize the number of GEP @@ -2758,12 +2783,12 @@ Value scalarMemRefDescPtr = rewriter.create( loc, LLVM::LLVMPointerType::get(typeConverter->convertType(scalarMemRefType), - addressSpace), + memSpaceInd), underlyingRankedDesc); // Get pointer to offset field of memref descriptor. Type indexPtrTy = LLVM::LLVMPointerType::get( - getTypeConverter()->getIndexType(), addressSpace); + getTypeConverter()->getIndexType(), memSpaceInd); Value two = rewriter.create( loc, typeConverter->convertType(rewriter.getI32Type()), rewriter.getI32IntegerAttr(2)); @@ -3257,21 +3282,21 @@ MemRefDescriptor sourceMemRef(operands.front()); auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); + unsigned memSpaceInd = 0; + if (Attribute memSpace = viewMemRefType.getMemorySpace()) + memSpaceInd = memSpace.cast().getInt(); + // Copy the buffer pointer from the old descriptor to the new one. Value extracted = sourceMemRef.allocatedPtr(rewriter, loc); Value bitcastPtr = rewriter.create( - loc, - LLVM::LLVMPointerType::get(targetElementTy, - viewMemRefType.getMemorySpace()), + loc, LLVM::LLVMPointerType::get(targetElementTy, memSpaceInd), extracted); targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); // Copy the aligned pointer from the old descriptor to the new one. extracted = sourceMemRef.alignedPtr(rewriter, loc); bitcastPtr = rewriter.create( - loc, - LLVM::LLVMPointerType::get(targetElementTy, - viewMemRefType.getMemorySpace()), + loc, LLVM::LLVMPointerType::get(targetElementTy, memSpaceInd), extracted); targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); @@ -3485,10 +3510,11 @@ // Field 1: Copy the allocated pointer, used for malloc/free. Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc); auto srcMemRefType = viewOp.source().getType().cast(); + unsigned memSpaceInd = 0; + if (Attribute memSpace = srcMemRefType.getMemorySpace()) + memSpaceInd = memSpace.cast().getInt(); Value bitcastPtr = rewriter.create( - loc, - LLVM::LLVMPointerType::get(targetElementTy, - srcMemRefType.getMemorySpace()), + loc, LLVM::LLVMPointerType::get(targetElementTy, memSpaceInd), allocatedPtr); targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); @@ -3497,9 +3523,7 @@ alignedPtr = rewriter.create(loc, alignedPtr.getType(), alignedPtr, adaptor.byte_shift()); bitcastPtr = rewriter.create( - loc, - LLVM::LLVMPointerType::get(targetElementTy, - srcMemRefType.getMemorySpace()), + loc, LLVM::LLVMPointerType::get(targetElementTy, memSpaceInd), alignedPtr); targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp @@ -192,10 +192,16 @@ static bool isAllocationSupported(MemRefType t) { // Currently only support workgroup local memory allocations with static // shape and int or float or vector of int or float element type. - if (!(t.hasStaticShape() && - SPIRVTypeConverter::getMemorySpaceForStorageClass( - spirv::StorageClass::Workgroup) == t.getMemorySpace())) + if (!t.hasStaticShape()) return false; + + unsigned memSpaceInd = 0; + if (Attribute memSpace = t.getMemorySpace()) + memSpaceInd = memSpace.cast().getInt(); + if (memSpaceInd != SPIRVTypeConverter::getMemorySpaceForStorageClass( + spirv::StorageClass::Workgroup)) + return false; + Type elementType = t.getElementType(); if (auto vecType = elementType.dyn_cast()) elementType = vecType.getElementType(); @@ -206,8 +212,11 @@ /// operations of unsupported integer bitwidths, based on the memref /// type. Returns None on failure. static Optional getAtomicOpScope(MemRefType t) { + unsigned memSpaceInd = 0; + if (Attribute memSpace = t.getMemorySpace()) + memSpaceInd = memSpace.cast().getInt(); Optional storageClass = - SPIRVTypeConverter::getStorageClassForMemorySpace(t.getMemorySpace()); + SPIRVTypeConverter::getStorageClassForMemorySpace(memSpaceInd); if (!storageClass) return {}; switch (*storageClass) { diff --git a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp --- a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp +++ b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp @@ -94,8 +94,10 @@ // MUBUF instruction operate only on addresspace 0(unified) or 1(global) // In case of 3(LDS): fall back to vector->llvm pass // In case of 5(VGPR): wrong - if ((memRefType.getMemorySpace() != 0) && - (memRefType.getMemorySpace() != 1)) + unsigned memSpaceInd = 0; + if (Attribute memSpace = memRefType.getMemorySpace()) + memSpaceInd = memSpace.cast().getInt(); + if ((memSpaceInd != 0) && (memSpaceInd != 1)) return failure(); // Note that the dataPtr starts at the offset address specified by diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -956,6 +956,18 @@ } } +unsigned AffineDmaStartOp::getSrcMemorySpace() { + Attribute memSpace = + getSrcMemRef().getType().cast().getMemorySpace(); + return memSpace ? memSpace.cast().getInt() : 0; +} + +unsigned AffineDmaStartOp::getDstMemorySpace() { + Attribute memSpace = + getDstMemRef().getType().cast().getMemorySpace(); + return memSpace ? memSpace.cast().getInt() : 0; +} + void AffineDmaStartOp::print(OpAsmPrinter &p) { p << "affine.dma_start " << getSrcMemRef() << '['; p.printAffineMapOfSSAIds(getSrcMapAttr(), getSrcIndices()); @@ -1059,6 +1071,17 @@ if (!getOperand(getTagMemRefOperandIndex()).getType().isa()) return emitOpError("expected DMA tag to be of memref type"); + if (Attribute memSpace = + getSrcMemRef().getType().cast().getMemorySpace()) { + if (!memSpace.isa() || !memSpace.getType().isSignlessInteger()) + return emitOpError("expected memory space to be integer value"); + } + if (Attribute memSpace = + getDstMemRef().getType().cast().getMemorySpace()) { + if (!memSpace.isa() || !memSpace.getType().isSignlessInteger()) + return emitOpError("expected memory space to be integer value"); + } + // DMAs from different memory spaces supported. if (getSrcMemorySpace() == getDstMemorySpace()) { return emitOpError("DMA should be between different memory spaces"); @@ -2137,6 +2160,10 @@ auto memrefType = op.getMemRefType(); if (op.getType() != memrefType.getElementType()) return op.emitOpError("result type must match element type of memref"); + if (Attribute memSpace = memrefType.getMemorySpace()) { + if (!memSpace.isa() || !memSpace.getType().isSignlessInteger()) + return op.emitOpError("expected memory space to be integer value"); + } if (failed(verifyMemoryOpIndexing( op.getOperation(), @@ -2226,6 +2253,10 @@ if (op.getValueToStore().getType() != memrefType.getElementType()) return op.emitOpError( "first operand must have same type memref element type"); + if (Attribute memSpace = memrefType.getMemorySpace()) { + if (!memSpace.isa() || !memSpace.getType().isSignlessInteger()) + return op.emitOpError("expected memory space to be integer value"); + } if (failed(verifyMemoryOpIndexing( op.getOperation(), diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -727,7 +727,15 @@ if (!type) return op->emitOpError() << "expected memref type in attribution"; - if (type.getMemorySpace() != memorySpace) { + unsigned memSpaceInd = 0; + if (Attribute memSpace = type.getMemorySpace()) { + if (!memSpace.isa()) { + return op->emitOpError() << "expected integer memory space"; + } + memSpaceInd = memSpace.cast().getInt(); + } + + if (memSpaceInd != memorySpace) { return op->emitOpError() << "expected memory space " << memorySpace << " in attribution"; } diff --git a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp --- a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp @@ -153,9 +153,10 @@ Value createWorkgroupBuffer() { int workgroupMemoryAddressSpace = gpu::GPUDialect::getWorkgroupAddressSpace(); + Builder b(valueType.getContext()); auto bufferType = MemRefType::get({kSubgroupSize}, valueType, ArrayRef{}, - workgroupMemoryAddressSpace); + b.getI64IntegerAttr(workgroupMemoryAddressSpace)); return funcOp.addWorkgroupAttribution(bufferType); } diff --git a/mlir/lib/Dialect/GPU/Transforms/MemoryPromotion.cpp b/mlir/lib/Dialect/GPU/Transforms/MemoryPromotion.cpp --- a/mlir/lib/Dialect/GPU/Transforms/MemoryPromotion.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/MemoryPromotion.cpp @@ -163,8 +163,10 @@ // Get the type of the buffer in the workgroup memory. int workgroupMemoryAddressSpace = gpu::GPUDialect::getWorkgroupAddressSpace(); - auto bufferType = MemRefType::get(type.getShape(), type.getElementType(), {}, - workgroupMemoryAddressSpace); + Builder b(op.getContext()); + auto bufferType = + MemRefType::get(type.getShape(), type.getElementType(), {}, + b.getI64IntegerAttr(workgroupMemoryAddressSpace)); Value attribution = op.addWorkgroupAttribution(bufferType); diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -1310,13 +1310,20 @@ } if (auto memrefType = type.dyn_cast()) { + unsigned memSpaceInd = 0; + if (Attribute memSpace = memrefType.getMemorySpace()) { + if (!memSpace.isa()) + return op->emitOpError("unsupported memory space attribute"); + memSpaceInd = memSpace.cast().getInt(); + } + // Bare pointer convention: statically-shaped memref is compatible with an // LLVM pointer to the element type. if (auto ptrType = llvmType.dyn_cast()) { if (!memrefType.hasStaticShape()) return op->emitOpError( "unexpected bare pointer for dynamically shaped memref"); - if (memrefType.getMemorySpace() != ptrType.getAddressSpace()) + if (memSpaceInd != ptrType.getAddressSpace()) return op->emitError("invalid conversion between memref and pointer in " "different memory spaces"); @@ -1339,8 +1346,7 @@ // The first two elements are pointers to the element type. auto allocatedPtr = structType.getBody()[0].dyn_cast(); - if (!allocatedPtr || - allocatedPtr.getAddressSpace() != memrefType.getMemorySpace()) + if (!allocatedPtr || allocatedPtr.getAddressSpace() != memSpaceInd) return op->emitOpError("expected first element of a memref descriptor to " "be a pointer in the address space of the memref"); if (failed(verifyCast(op, allocatedPtr.getElementType(), @@ -1348,8 +1354,7 @@ return failure(); auto alignedPtr = structType.getBody()[1].dyn_cast(); - if (!alignedPtr || - alignedPtr.getAddressSpace() != memrefType.getMemorySpace()) + if (!alignedPtr || alignedPtr.getAddressSpace() != memSpaceInd) return op->emitOpError( "expected second element of a memref descriptor to " "be a pointer in the address space of the memref"); diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -343,8 +343,11 @@ static Optional convertMemrefType(const spirv::TargetEnv &targetEnv, MemRefType type) { + unsigned memSpaceInd = 0; + if (Attribute memSpace = type.getMemorySpace()) + memSpaceInd = memSpace.cast().getInt(); Optional storageClass = - SPIRVTypeConverter::getStorageClassForMemorySpace(type.getMemorySpace()); + SPIRVTypeConverter::getStorageClassForMemorySpace(memSpaceInd); if (!storageClass) { LLVM_DEBUG(llvm::dbgs() << type << " illegal: cannot convert memory space\n"); diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -1545,6 +1545,18 @@ result.addOperands({stride, elementsPerStride}); } +unsigned DmaStartOp::getSrcMemorySpace() { + Attribute memSpace = + getSrcMemRef().getType().cast().getMemorySpace(); + return memSpace ? memSpace.cast().getInt() : 0; +} + +unsigned DmaStartOp::getDstMemorySpace() { + Attribute memSpace = + getDstMemRef().getType().cast().getMemorySpace(); + return memSpace ? memSpace.cast().getInt() : 0; +} + void DmaStartOp::print(OpAsmPrinter &p) { p << "dma_start " << getSrcMemRef() << '[' << getSrcIndices() << "], " << getDstMemRef() << '[' << getDstIndices() << "], " << getNumElements() @@ -1638,6 +1650,11 @@ // 1. Source memref. if (!getSrcMemRef().getType().isa()) return emitOpError("expected source to be of memref type"); + if (Attribute memSpace = + getSrcMemRef().getType().cast().getMemorySpace()) { + if (!memSpace.isa() || !memSpace.getType().isSignlessInteger()) + return emitOpError("expected memory space to be integer value"); + } if (numOperands < getSrcMemRefRank() + 4) return emitOpError() << "expected at least " << getSrcMemRefRank() + 4 << " operands"; @@ -1649,6 +1666,11 @@ // 2. Destination memref. if (!getDstMemRef().getType().isa()) return emitOpError("expected destination to be of memref type"); + if (Attribute memSpace = + getDstMemRef().getType().cast().getMemorySpace()) { + if (!memSpace.isa() || !memSpace.getType().isSignlessInteger()) + return emitOpError("expected memory space to be integer value"); + } unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4; if (numOperands < numExpectedOperands) return emitOpError() << "expected at least " << numExpectedOperands diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1882,16 +1882,20 @@ printAttribute(AffineMapAttr::get(map)); } // Only print the memory space if it is the non-default one. - if (memrefTy.getMemorySpace()) - os << ", " << memrefTy.getMemorySpace(); + if (memrefTy.getMemorySpace()) { + os << ", "; + printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May); + } os << '>'; }) .Case([&](UnrankedMemRefType memrefTy) { os << "memref<*x"; printType(memrefTy.getElementType()); // Only print the memory space if it is the non-default one. - if (memrefTy.getMemorySpace()) - os << ", " << memrefTy.getMemorySpace(); + if (memrefTy.getMemorySpace()) { + os << ", "; + printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May); + } os << '>'; }) .Case([&](ComplexType complexTy) { diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -10,6 +10,7 @@ #include "TypeDetail.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dialect.h" #include "llvm/ADT/APFloat.h" @@ -470,21 +471,42 @@ // BaseMemRefType //===----------------------------------------------------------------------===// -unsigned BaseMemRefType::getMemorySpace() const { +Attribute BaseMemRefType::getMemorySpace() const { return static_cast(impl)->memorySpace; } +bool BaseMemRefType::isSupportedMemorySpace(Attribute memorySpace) { + if (!memorySpace) + return true; + + if (memorySpace.isa()) + return true; + + if (!memorySpace.getDialect().getNamespace().empty()) + return true; + + return false; +} + //===----------------------------------------------------------------------===// // MemRefType //===----------------------------------------------------------------------===// +// Replaces integer `0` memorySpace with `nullptr` Attribute. +static Attribute skipDefaultIntMemorySpace(Attribute memorySpace) { + if (memorySpace && memorySpace.isa() && + memorySpace.cast().getValue() == 0) + return nullptr; + return memorySpace; +} + /// Get or create a new MemRefType based on shape, element type, affine /// map composition, and memory space. Assumes the arguments define a /// well-formed MemRef type. Use getChecked to gracefully handle MemRefType /// construction failures. MemRefType MemRefType::get(ArrayRef shape, Type elementType, ArrayRef affineMapComposition, - unsigned memorySpace) { + Attribute memorySpace) { auto result = getImpl(shape, elementType, affineMapComposition, memorySpace, /*location=*/llvm::None); assert(result && "Failed to construct instance of MemRefType."); @@ -500,7 +522,7 @@ MemRefType MemRefType::getChecked(Location location, ArrayRef shape, Type elementType, ArrayRef affineMapComposition, - unsigned memorySpace) { + Attribute memorySpace) { return getImpl(shape, elementType, affineMapComposition, memorySpace, location); } @@ -511,7 +533,7 @@ /// pass in an instance of UnknownLoc. MemRefType MemRefType::getImpl(ArrayRef shape, Type elementType, ArrayRef affineMapComposition, - unsigned memorySpace, + Attribute memorySpace, Optional location) { auto *context = elementType.getContext(); @@ -556,8 +578,13 @@ cleanedAffineMapComposition.push_back(map); } + if (!isSupportedMemorySpace(memorySpace)) { + (void)emitOptionalError(location, "unsupported memory space Attribute"); + return nullptr; + } + return Base::get(context, shape, elementType, cleanedAffineMapComposition, - memorySpace); + skipDefaultIntMemorySpace(memorySpace)); } ArrayRef MemRefType::getShape() const { return getImpl()->getShape(); } @@ -571,21 +598,28 @@ //===----------------------------------------------------------------------===// UnrankedMemRefType UnrankedMemRefType::get(Type elementType, - unsigned memorySpace) { - return Base::get(elementType.getContext(), elementType, memorySpace); + Attribute memorySpace) { + return Base::get(elementType.getContext(), elementType, + skipDefaultIntMemorySpace(memorySpace)); } UnrankedMemRefType UnrankedMemRefType::getChecked(Location location, Type elementType, - unsigned memorySpace) { - return Base::getChecked(location, elementType, memorySpace); + Attribute memorySpace) { + return Base::getChecked(location, elementType, + skipDefaultIntMemorySpace(memorySpace)); } LogicalResult UnrankedMemRefType::verifyConstructionInvariants(Location loc, Type elementType, - unsigned memorySpace) { + Attribute memorySpace) { if (!BaseMemRefType::isValidElementType(elementType)) return emitError(loc, "invalid memref element type"); + + if (!isSupportedMemorySpace(memorySpace)) { + return emitError(loc, "unsupported memory space Attribute"); + } + return success(); } diff --git a/mlir/lib/IR/TypeDetail.h b/mlir/lib/IR/TypeDetail.h --- a/mlir/lib/IR/TypeDetail.h +++ b/mlir/lib/IR/TypeDetail.h @@ -183,17 +183,17 @@ }; struct BaseMemRefTypeStorage : public ShapedTypeStorage { - BaseMemRefTypeStorage(Type elementType, unsigned memorySpace) + BaseMemRefTypeStorage(Type elementType, Attribute memorySpace) : ShapedTypeStorage(elementType), memorySpace(memorySpace) {} /// Memory space in which data referenced by memref resides. - const unsigned memorySpace; + Attribute memorySpace; }; struct MemRefTypeStorage : public BaseMemRefTypeStorage { MemRefTypeStorage(unsigned shapeSize, Type elementType, const int64_t *shapeElements, const unsigned numAffineMaps, - AffineMap const *affineMapList, const unsigned memorySpace) + AffineMap const *affineMapList, Attribute memorySpace) : BaseMemRefTypeStorage(elementType, memorySpace), shapeElements(shapeElements), shapeSize(shapeSize), numAffineMaps(numAffineMaps), affineMapList(affineMapList) {} @@ -202,7 +202,7 @@ // MemRefs are uniqued based on their shape, element type, affine map // composition, and memory space. using KeyTy = - std::tuple, Type, ArrayRef, unsigned>; + std::tuple, Type, ArrayRef, Attribute>; bool operator==(const KeyTy &key) const { return key == KeyTy(getShape(), elementType, getAffineMaps(), memorySpace); } @@ -246,11 +246,11 @@ /// Only element type and memory space are known struct UnrankedMemRefTypeStorage : public BaseMemRefTypeStorage { - UnrankedMemRefTypeStorage(Type elementTy, const unsigned memorySpace) + UnrankedMemRefTypeStorage(Type elementTy, Attribute memorySpace) : BaseMemRefTypeStorage(elementTy, memorySpace) {} /// The hash key used for uniquing. - using KeyTy = std::tuple; + using KeyTy = std::tuple; bool operator==(const KeyTy &key) const { return key == KeyTy(elementType, memorySpace); } diff --git a/mlir/lib/Parser/TypeParser.cpp b/mlir/lib/Parser/TypeParser.cpp --- a/mlir/lib/Parser/TypeParser.cpp +++ b/mlir/lib/Parser/TypeParser.cpp @@ -225,27 +225,14 @@ // Parse semi-affine-map-composition. SmallVector affineMapComposition; - Optional memorySpace; + Attribute memorySpace; unsigned numDims = dimensions.size(); auto parseElt = [&]() -> ParseResult { - // Check for the memory space. - if (getToken().is(Token::integer)) { - if (memorySpace) - return emitError("multiple memory spaces specified in memref type"); - memorySpace = getToken().getUnsignedIntegerValue(); - if (!memorySpace.hasValue()) - return emitError("invalid memory space in memref type"); - consumeToken(Token::integer); - return success(); - } - if (isUnranked) - return emitError("cannot have affine map for unranked memref type"); - if (memorySpace) - return emitError("expected memory space to be last in memref type"); - AffineMap map; llvm::SMLoc mapLoc = getToken().getLoc(); + + // Check for AffineMap as offset/strides. if (getToken().is(Token::kw_offset)) { int64_t offset; SmallVector strides; @@ -254,16 +241,26 @@ // Construct strided affine map. map = makeStridedLinearLayoutMap(strides, offset, state.context); } else { - // Parse an affine map attribute. - auto affineMap = parseAttribute(); - if (!affineMap) + // Either it is AffineMapAttr or memory space attribute. + Attribute attr = parseAttribute(); + if (!attr) return failure(); - auto affineMapAttr = affineMap.dyn_cast(); - if (!affineMapAttr) - return emitError("expected affine map in memref type"); - map = affineMapAttr.getValue(); + + if (AffineMapAttr affineMapAttr = attr.dyn_cast()) { + map = affineMapAttr.getValue(); + } else if (memorySpace) { + return emitError("multiple memory spaces specified in memref type"); + } else { + memorySpace = attr; + return success(); + } } + if (isUnranked) + return emitError("cannot have affine map for unranked memref type"); + if (memorySpace) + return emitError("expected memory space to be last in memref type"); + if (map.getNumDims() != numDims) { size_t i = affineMapComposition.size(); return emitError(mapLoc, "memref affine map dimension mismatch between ") @@ -286,11 +283,13 @@ } } + Location loc = getEncodedSourceLocation(typeLoc); + if (isUnranked) - return UnrankedMemRefType::get(elementType, memorySpace.getValueOr(0)); + return UnrankedMemRefType::getChecked(loc, elementType, memorySpace); - return MemRefType::get(dimensions, elementType, affineMapComposition, - memorySpace.getValueOr(0)); + return MemRefType::getChecked(loc, dimensions, elementType, + affineMapComposition, memorySpace); } /// Parse any type except the function type. diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -916,9 +916,9 @@ // by 'srcStoreOpInst'. uint64_t bufSize = getMemRefEltSizeInBytes(oldMemRefType) * numElements.getValue(); - unsigned newMemSpace; + Attribute newMemSpace; if (bufSize <= localBufSizeThreshold && fastMemorySpace.hasValue()) { - newMemSpace = fastMemorySpace.getValue(); + newMemSpace = top.getI64IntegerAttr(fastMemorySpace.getValue()); } else { newMemSpace = oldMemRefType.getMemorySpace(); } diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -2464,9 +2464,9 @@ bool existingBuf = fastBufferMap.count(memref) > 0; if (!existingBuf) { AffineMap fastBufferLayout = b.getMultiDimIdentityMap(rank); - auto fastMemRefType = - MemRefType::get(fastBufferShape, memRefType.getElementType(), - fastBufferLayout, copyOptions.fastMemorySpace); + auto fastMemRefType = MemRefType::get( + fastBufferShape, memRefType.getElementType(), fastBufferLayout, + top.getI64IntegerAttr(copyOptions.fastMemorySpace)); // Create the fast memory space buffer just before the 'affine.for' // operation. @@ -2538,8 +2538,9 @@ } else { // DMA generation. // Create a tag (single element 1-d memref) for the DMA. - auto tagMemRefType = MemRefType::get({1}, top.getIntegerType(32), {}, - copyOptions.tagMemorySpace); + auto tagMemRefType = + MemRefType::get({1}, top.getIntegerType(32), {}, + top.getI64IntegerAttr(copyOptions.tagMemorySpace)); auto tagMemRef = prologue.create(loc, tagMemRefType); SmallVector tagIndices({zeroIndex}); @@ -2718,14 +2719,18 @@ block->walk(begin, end, [&](Operation *opInst) { // Gather regions to allocate to buffers in faster memory space. if (auto loadOp = dyn_cast(opInst)) { + unsigned memSpaceInd = 0; + if (Attribute memSpace = loadOp.getMemRefType().getMemorySpace()) + memSpaceInd = memSpace.cast().getInt(); if ((filterMemRef.hasValue() && filterMemRef != loadOp.getMemRef()) || - (loadOp.getMemRefType().getMemorySpace() != - copyOptions.slowMemorySpace)) + (memSpaceInd != copyOptions.slowMemorySpace)) return; } else if (auto storeOp = dyn_cast(opInst)) { + unsigned memSpaceInd = 0; + if (Attribute memSpace = storeOp.getMemRefType().getMemorySpace()) + memSpaceInd = memSpace.cast().getInt(); if ((filterMemRef.hasValue() && filterMemRef != storeOp.getMemRef()) || - storeOp.getMemRefType().getMemorySpace() != - copyOptions.slowMemorySpace) + memSpaceInd != copyOptions.slowMemorySpace) return; } else { // Neither load nor a store op. diff --git a/mlir/test/Bindings/Python/dialects/linalg.py b/mlir/test/Bindings/Python/dialects/linalg.py --- a/mlir/test/Bindings/Python/dialects/linalg.py +++ b/mlir/test/Bindings/Python/dialects/linalg.py @@ -39,7 +39,7 @@ with Context() as ctx, Location.unknown(): module = Module.create() f32 = F32Type.get() - memref_type = MemRefType.get((2, 3, 4), f32) + memref_type = MemRefType.get([2, 3, 4], f32) with InsertionPoint.at_block_terminator(module.body): func = builtin.FuncOp(name="matmul_test", type=FunctionType.get( diff --git a/mlir/test/Bindings/Python/ir_types.py b/mlir/test/Bindings/Python/ir_types.py --- a/mlir/test/Bindings/Python/ir_types.py +++ b/mlir/test/Bindings/Python/ir_types.py @@ -326,7 +326,7 @@ f32 = F32Type.get() shape = [2, 3] loc = Location.unknown() - memref = MemRefType.get(shape, f32, memory_space=2) + memref = MemRefType.get(shape, f32, memory_space=Attribute.parse("2")) # CHECK: memref type: memref<2x3xf32, 2> print("memref type:", memref) # CHECK: number of affine layout maps: 0 @@ -341,7 +341,7 @@ assert len(memref_layout.layout) == 1 # CHECK: memref layout: (d0, d1) -> (d1, d0) print("memref layout:", memref_layout.layout[0]) - # CHECK: memory space: 0 + # CHECK: memory space: <> print("memory space:", memref_layout.memory_space) none = NoneType.get() @@ -361,7 +361,7 @@ with Context(), Location.unknown(): f32 = F32Type.get() loc = Location.unknown() - unranked_memref = UnrankedMemRefType.get(f32, 2) + unranked_memref = UnrankedMemRefType.get(f32, Attribute.parse("2")) # CHECK: unranked memref type: memref<*xf32, 2> print("unranked memref type:", unranked_memref) try: @@ -388,7 +388,7 @@ none = NoneType.get() try: - memref_invalid = UnrankedMemRefType.get(none, 2) + memref_invalid = UnrankedMemRefType.get(none, Attribute.parse("2")) except ValueError as e: # CHECK: invalid 'Type(none)' and expected floating point, integer, vector # CHECK: or complex type. diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -707,21 +707,24 @@ // CHECK: tensor<*xf32> // MemRef type. + MlirAttribute memSpace2 = mlirIntegerAttrGet(mlirIntegerTypeGet(ctx, 64), 2); MlirType memRef = mlirMemRefTypeContiguousGet( - f32, sizeof(shape) / sizeof(int64_t), shape, 2); + f32, sizeof(shape) / sizeof(int64_t), shape, memSpace2); if (!mlirTypeIsAMemRef(memRef) || mlirMemRefTypeGetNumAffineMaps(memRef) != 0 || - mlirMemRefTypeGetMemorySpace(memRef) != 2) + !mlirAttributeEqual(mlirMemRefTypeGetMemorySpace(memRef), memSpace2)) return 18; mlirTypeDump(memRef); fprintf(stderr, "\n"); // CHECK: memref<2x3xf32, 2> // Unranked MemRef type. - MlirType unrankedMemRef = mlirUnrankedMemRefTypeGet(f32, 4); + MlirAttribute memSpace4 = mlirIntegerAttrGet(mlirIntegerTypeGet(ctx, 64), 4); + MlirType unrankedMemRef = mlirUnrankedMemRefTypeGet(f32, memSpace4); if (!mlirTypeIsAUnrankedMemRef(unrankedMemRef) || mlirTypeIsAMemRef(unrankedMemRef) || - mlirUnrankedMemrefGetMemorySpace(unrankedMemRef) != 4) + !mlirAttributeEqual(mlirUnrankedMemrefGetMemorySpace(unrankedMemRef), + memSpace4)) return 19; mlirTypeDump(unrankedMemRef); fprintf(stderr, "\n"); diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -36,8 +36,8 @@ func @memrefs(memref<2x4xi8, #map7>) // expected-error {{undefined symbol alias id 'map7'}} // ----- -// Test non affine map in memref type. -func @memrefs(memref<2x4xi8, i8>) // expected-error {{expected affine map in memref type}} +// Test unsupported memory space. +func @memrefs(memref<2x4xi8, i8>) // expected-error {{unsupported memory space Attribute}} // ----- // Test non-existent map in map composition of memref type. diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -137,11 +137,35 @@ func private @memrefs_compose_with_id(memref<2x2xi8, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>>) +// Test memref with custom memory space + +// CHECK: func private @memrefs_nomap_nospace(memref<5x6x7xf32>) +func private @memrefs_nomap_nospace(memref<5x6x7xf32>) + +// CHECK: func private @memrefs_map_nospace(memref<5x6x7xf32, #map{{[0-9]+}}>) +func private @memrefs_map_nospace(memref<5x6x7xf32, #map3>) + +// CHECK: func private @memrefs_nomap_intspace(memref<5x6x7xf32, 3>) +func private @memrefs_nomap_intspace(memref<5x6x7xf32, 3>) + +// CHECK: func private @memrefs_map_intspace(memref<5x6x7xf32, #map{{[0-9]+}}, 5>) +func private @memrefs_map_intspace(memref<5x6x7xf32, #map3, 5>) + +// CHECK: func private @memrefs_nomap_strspace(memref<5x6x7xf32, "local">) +func private @memrefs_nomap_strspace(memref<5x6x7xf32, "local">) + +// CHECK: func private @memrefs_map_strspace(memref<5x6x7xf32, #map{{[0-9]+}}, "private">) +func private @memrefs_map_strspace(memref<5x6x7xf32, #map3, "private">) + +// CHECK: func private @memrefs_nomap_dictspace(memref<5x6x7xf32, {memSpace = "special", subIndex = 1 : i64}>) +func private @memrefs_nomap_dictspace(memref<5x6x7xf32, {memSpace = "special", subIndex = 1}>) + +// CHECK: func private @memrefs_map_dictspace(memref<5x6x7xf32, #map{{[0-9]+}}, {memSpace = "special", subIndex = 3 : i64}>) +func private @memrefs_map_dictspace(memref<5x6x7xf32, #map3, {memSpace = "special", subIndex = 3}>) // CHECK: func private @complex_types(complex) -> complex func private @complex_types(complex) -> complex - // CHECK: func private @memref_with_index_elems(memref<1x?xindex>) func private @memref_with_index_elems(memref<1x?xindex>) diff --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp --- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp @@ -265,11 +265,12 @@ ArrayRef boundingSubViewSize, OperationFolder *folder) { SmallVector shape(boundingSubViewSize.size(), -1); + IntegerAttr memSpace = b.getI64IntegerAttr(3); return b .create(subView.getLoc(), MemRefType::get(shape, subView.getType().getElementType(), - /*affineMapComposition =*/{}, 3), + /*affineMapComposition =*/{}, memSpace), boundingSubViewSize) .getResult(); }