diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h @@ -120,7 +120,7 @@ /// Returns the memory space of the src memref. unsigned getSrcMemorySpace() { - return getSrcMemRef().getType().cast().getMemorySpace(); + return getSrcMemRef().getType().cast().getMemorySpaceAsInt(); } /// Returns the operand index of the dst memref. @@ -141,7 +141,7 @@ /// Returns the memory space of the src memref. unsigned getDstMemorySpace() { - return getDstMemRef().getType().cast().getMemorySpace(); + return getDstMemRef().getType().cast().getMemorySpaceAsInt(); } /// Returns the affine map used to access the dst memref. diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h @@ -177,10 +177,10 @@ return getDstMemRef().getType().cast().getRank(); } unsigned getSrcMemorySpace() { - return getSrcMemRef().getType().cast().getMemorySpace(); + return getSrcMemRef().getType().cast().getMemorySpaceAsInt(); } unsigned getDstMemorySpace() { - return getDstMemRef().getType().cast().getMemorySpace(); + return getDstMemRef().getType().cast().getMemorySpaceAsInt(); } // Returns the destination memref indices for this DMA operation. diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -293,7 +293,7 @@ static bool classof(Type type); /// Returns the memory space in which data referred to by this memref resides. - unsigned getMemorySpace() const; + unsigned getMemorySpaceAsInt() const; }; //===----------------------------------------------------------------------===// @@ -314,7 +314,7 @@ explicit Builder(MemRefType other) : shape(other.getShape()), elementType(other.getElementType()), affineMaps(other.getAffineMaps()), - memorySpace(other.getMemorySpace()) {} + memorySpace(other.getMemorySpaceAsInt()) {} // Build from scratch. Builder(ArrayRef shape, Type elementType) diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -270,7 +270,7 @@ } unsigned mlirMemRefTypeGetMemorySpace(MlirType type) { - return unwrap(type).cast().getMemorySpace(); + return unwrap(type).cast().getMemorySpaceAsInt(); } bool mlirTypeIsAUnrankedMemRef(MlirType type) { @@ -289,7 +289,7 @@ } unsigned mlirUnrankedMemrefGetMemorySpace(MlirType type) { - return unwrap(type).cast().getMemorySpace(); + return unwrap(type).cast().getMemorySpaceAsInt(); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -118,7 +118,8 @@ /// converter drops the private memory space to support the use case above. LLVMTypeConverter converter(m.getContext(), options); converter.addConversion([&](MemRefType type) -> Optional { - if (type.getMemorySpace() != gpu::GPUDialect::getPrivateAddressSpace()) + if (type.getMemorySpaceAsInt() != + gpu::GPUDialect::getPrivateAddressSpace()) return llvm::None; return converter.convertType(MemRefType::Builder(type).setMemorySpace(0)); }); diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -316,7 +316,8 @@ Type elementType = unwrap(convertType(type.getElementType())); if (!elementType) return {}; - auto ptrTy = LLVM::LLVMPointerType::get(elementType, type.getMemorySpace()); + auto ptrTy = + LLVM::LLVMPointerType::get(elementType, type.getMemorySpaceAsInt()); auto indexTy = getIndexType(); SmallVector results = {ptrTy, ptrTy, indexTy}; @@ -388,7 +389,7 @@ Type elementType = unwrap(convertType(type.getElementType())); if (!elementType) return {}; - return LLVM::LLVMPointerType::get(elementType, type.getMemorySpace()); + return LLVM::LLVMPointerType::get(elementType, type.getMemorySpaceAsInt()); } /// Convert an n-D vector type to an LLVM vector type via (n-1)-D array type @@ -1081,7 +1082,8 @@ Type ConvertToLLVMPattern::getElementPtrType(MemRefType type) const { auto elementType = type.getElementType(); auto structElementType = unwrap(typeConverter->convertType(elementType)); - return LLVM::LLVMPointerType::get(structElementType, type.getMemorySpace()); + return LLVM::LLVMPointerType::get(structElementType, + type.getMemorySpaceAsInt()); } void ConvertToLLVMPattern::getMemRefDescriptorSizes( @@ -1899,7 +1901,7 @@ Value alignedPtr = allocatedPtr; if (alignment) { - auto intPtrType = getIntPtrType(memRefType.getMemorySpace()); + auto intPtrType = getIntPtrType(memRefType.getMemorySpaceAsInt()); // Compute the aligned type pointer. Value allocatedInt = rewriter.create(loc, intPtrType, allocatedPtr); @@ -2247,7 +2249,7 @@ rewriter.replaceOpWithNewOp( global, arrayTy, global.constant(), linkage, global.sym_name(), - initialValue, type.getMemorySpace()); + initialValue, type.getMemorySpaceAsInt()); return success(); } }; @@ -2266,7 +2268,7 @@ Operation *op) const override { auto getGlobalOp = cast(op); MemRefType type = getGlobalOp.result().getType().cast(); - unsigned memSpace = type.getMemorySpace(); + unsigned memSpace = type.getMemorySpaceAsInt(); Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter()); auto addressOf = rewriter.create( @@ -2462,7 +2464,7 @@ } unsigned memorySpace = - operandType.cast().getMemorySpace(); + operandType.cast().getMemorySpaceAsInt(); Type elementType = operandType.cast().getElementType(); Type llvmElementType = unwrap(typeConverter.convertType(elementType)); Type elementPtrPtrType = LLVM::LLVMPointerType::get( @@ -2591,7 +2593,7 @@ // Extract address space and element type. auto targetType = reshapeOp.getResult().getType().cast(); - unsigned addressSpace = targetType.getMemorySpace(); + unsigned addressSpace = targetType.getMemorySpaceAsInt(); Type elementType = targetType.getElementType(); // Create the unranked memref descriptor that holds the ranked one. The @@ -2751,7 +2753,7 @@ auto unrankedMemRefType = operandType.cast(); auto scalarMemRefType = MemRefType::get({}, unrankedMemRefType.getElementType()); - unsigned addressSpace = unrankedMemRefType.getMemorySpace(); + unsigned addressSpace = unrankedMemRefType.getMemorySpaceAsInt(); // Extract pointer to the underlying ranked descriptor and bitcast it to a // memref descriptor pointer to minimize the number of GEP @@ -3265,7 +3267,7 @@ Value bitcastPtr = rewriter.create( loc, LLVM::LLVMPointerType::get(targetElementTy, - viewMemRefType.getMemorySpace()), + viewMemRefType.getMemorySpaceAsInt()), extracted); targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); @@ -3274,7 +3276,7 @@ bitcastPtr = rewriter.create( loc, LLVM::LLVMPointerType::get(targetElementTy, - viewMemRefType.getMemorySpace()), + viewMemRefType.getMemorySpaceAsInt()), extracted); targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); @@ -3491,7 +3493,7 @@ Value bitcastPtr = rewriter.create( loc, LLVM::LLVMPointerType::get(targetElementTy, - srcMemRefType.getMemorySpace()), + srcMemRefType.getMemorySpaceAsInt()), allocatedPtr); targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); @@ -3502,7 +3504,7 @@ bitcastPtr = rewriter.create( loc, LLVM::LLVMPointerType::get(targetElementTy, - srcMemRefType.getMemorySpace()), + srcMemRefType.getMemorySpaceAsInt()), alignedPtr); targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp @@ -194,7 +194,7 @@ // shape and int or float or vector of int or float element type. if (!(t.hasStaticShape() && SPIRVTypeConverter::getMemorySpaceForStorageClass( - spirv::StorageClass::Workgroup) == t.getMemorySpace())) + spirv::StorageClass::Workgroup) == t.getMemorySpaceAsInt())) return false; Type elementType = t.getElementType(); if (auto vecType = elementType.dyn_cast()) @@ -207,7 +207,8 @@ /// type. Returns None on failure. static Optional getAtomicOpScope(MemRefType t) { Optional storageClass = - SPIRVTypeConverter::getStorageClassForMemorySpace(t.getMemorySpace()); + SPIRVTypeConverter::getStorageClassForMemorySpace( + t.getMemorySpaceAsInt()); if (!storageClass) return {}; switch (*storageClass) { diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -189,7 +189,7 @@ SmallVector strides; auto successStrides = getStridesAndOffset(memRefType, strides, offset); if (failed(successStrides) || strides.size() != 1 || strides[0] != 1 || - offset != 0 || memRefType.getMemorySpace() != 0) + offset != 0 || memRefType.getMemorySpaceAsInt() != 0) return failure(); base = MemRefDescriptor(memref).alignedPtr(rewriter, loc); return success(); @@ -213,7 +213,7 @@ // will be in the same address space as the incoming memref type. static Value castDataPtr(ConversionPatternRewriter &rewriter, Location loc, Value ptr, MemRefType memRefType, Type vt) { - auto pType = LLVM::LLVMPointerType::get(vt, memRefType.getMemorySpace()); + auto pType = LLVM::LLVMPointerType::get(vt, memRefType.getMemorySpaceAsInt()); return rewriter.create(loc, pType, ptr); } diff --git a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp --- a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp +++ b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp @@ -94,8 +94,8 @@ // MUBUF instruction operate only on addresspace 0(unified) or 1(global) // In case of 3(LDS): fall back to vector->llvm pass // In case of 5(VGPR): wrong - if ((memRefType.getMemorySpace() != 0) && - (memRefType.getMemorySpace() != 1)) + if ((memRefType.getMemorySpaceAsInt() != 0) && + (memRefType.getMemorySpaceAsInt() != 1)) return failure(); // Note that the dataPtr starts at the offset address specified by diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -115,9 +115,11 @@ VectorType::get(vectorType.getShape().take_back(minorRank), vectorType.getElementType()); /// Memref of minor vector type is used for individual transfers. - memRefMinorVectorType = MemRefType::get( - majorVectorType.getShape(), minorVectorType, {}, - xferOp.getShapedType().template cast().getMemorySpace()); + memRefMinorVectorType = + MemRefType::get(majorVectorType.getShape(), minorVectorType, {}, + xferOp.getShapedType() + .template cast() + .getMemorySpaceAsInt()); } LogicalResult doReplace(); diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -727,7 +727,7 @@ if (!type) return op->emitOpError() << "expected memref type in attribution"; - if (type.getMemorySpace() != memorySpace) { + if (type.getMemorySpaceAsInt() != memorySpace) { return op->emitOpError() << "expected memory space " << memorySpace << " in attribution"; } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -1345,7 +1345,7 @@ if (!memrefType.hasStaticShape()) return op->emitOpError( "unexpected bare pointer for dynamically shaped memref"); - if (memrefType.getMemorySpace() != ptrType.getAddressSpace()) + if (memrefType.getMemorySpaceAsInt() != ptrType.getAddressSpace()) return op->emitError("invalid conversion between memref and pointer in " "different memory spaces"); @@ -1369,7 +1369,7 @@ // The first two elements are pointers to the element type. auto allocatedPtr = structType.getBody()[0].dyn_cast(); if (!allocatedPtr || - allocatedPtr.getAddressSpace() != memrefType.getMemorySpace()) + allocatedPtr.getAddressSpace() != memrefType.getMemorySpaceAsInt()) return op->emitOpError("expected first element of a memref descriptor to " "be a pointer in the address space of the memref"); if (failed(verifyCast(op, allocatedPtr.getElementType(), @@ -1378,7 +1378,7 @@ auto alignedPtr = structType.getBody()[1].dyn_cast(); if (!alignedPtr || - alignedPtr.getAddressSpace() != memrefType.getMemorySpace()) + alignedPtr.getAddressSpace() != memrefType.getMemorySpaceAsInt()) return op->emitOpError( "expected second element of a memref descriptor to " "be a pointer in the address space of the memref"); diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -344,7 +344,8 @@ static Optional convertMemrefType(const spirv::TargetEnv &targetEnv, MemRefType type) { Optional storageClass = - SPIRVTypeConverter::getStorageClassForMemorySpace(type.getMemorySpace()); + SPIRVTypeConverter::getStorageClassForMemorySpace( + type.getMemorySpaceAsInt()); if (!storageClass) { LLVM_DEBUG(llvm::dbgs() << type << " illegal: cannot convert memory space\n"); diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -2096,7 +2096,7 @@ if (!checkCompatible(aStride.value(), bStrides[aStride.index()])) return false; } - if (aT.getMemorySpace() != bT.getMemorySpace()) + if (aT.getMemorySpaceAsInt() != bT.getMemorySpaceAsInt()) return false; // They must have the same rank, and any specified dimensions must match. @@ -2123,8 +2123,10 @@ if (aEltType != bEltType) return false; - auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace(); - auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace(); + auto aMemSpace = + (aT) ? aT.getMemorySpaceAsInt() : uaT.getMemorySpaceAsInt(); + auto bMemSpace = + (bT) ? bT.getMemorySpaceAsInt() : ubT.getMemorySpaceAsInt(); if (aMemSpace != bMemSpace) return false; @@ -2201,7 +2203,7 @@ // The source and result memrefs should be in the same memory space. auto srcType = op.source().getType().cast(); auto resultType = op.getType().cast(); - if (srcType.getMemorySpace() != resultType.getMemorySpace()) + if (srcType.getMemorySpaceAsInt() != resultType.getMemorySpaceAsInt()) return op.emitError("different memory spaces specified for source type ") << srcType << " and result memref type " << resultType; if (srcType.getElementType() != resultType.getElementType()) @@ -2875,7 +2877,7 @@ staticSizes, sourceMemRefType.getElementType(), makeStridedLinearLayoutMap(targetStrides, targetOffset, sourceMemRefType.getContext()), - sourceMemRefType.getMemorySpace()); + sourceMemRefType.getMemorySpaceAsInt()); } Type SubViewOp::inferResultType(MemRefType sourceMemRefType, @@ -2932,7 +2934,7 @@ map = getProjectedMap(maps.front(), dimsToProject); inferredType = MemRefType::get(projectedShape, inferredType.getElementType(), map, - inferredType.getMemorySpace()); + inferredType.getMemorySpaceAsInt()); } return inferredType; } @@ -3154,7 +3156,7 @@ // Strided layout logic is relevant for MemRefType only. MemRefType original = originalType.cast(); MemRefType candidateReduced = candidateReducedType.cast(); - if (original.getMemorySpace() != candidateReduced.getMemorySpace()) + if (original.getMemorySpaceAsInt() != candidateReduced.getMemorySpaceAsInt()) return SubViewVerificationResult::MemSpaceMismatch; llvm::SmallDenseSet unusedDims = optionalUnusedDimsMask.getValue(); @@ -3228,7 +3230,7 @@ MemRefType subViewType = op.getType(); // The base memref and the view memref should be in the same memory space. - if (baseType.getMemorySpace() != subViewType.getMemorySpace()) + if (baseType.getMemorySpaceAsInt() != subViewType.getMemorySpaceAsInt()) return op.emitError("different memory spaces specified for base memref " "type ") << baseType << " and subview memref type " << subViewType; @@ -4090,7 +4092,7 @@ return op.emitError("unsupported map for result memref type ") << viewType; // The base memref and the view memref should be in the same memory space. - if (baseType.getMemorySpace() != viewType.getMemorySpace()) + if (baseType.getMemorySpaceAsInt() != viewType.getMemorySpaceAsInt()) return op.emitError("different memory spaces specified for base memref " "type ") << baseType << " and view memref type " << viewType; diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -3170,7 +3170,7 @@ VectorType::get(extractShape(memRefType), getElementTypeOrSelf(getElementTypeOrSelf(memRefType))); result.addTypes( - MemRefType::get({}, vectorType, {}, memRefType.getMemorySpace())); + MemRefType::get({}, vectorType, {}, memRefType.getMemorySpaceAsInt())); } static LogicalResult verify(TypeCastOp op) { @@ -3179,8 +3179,8 @@ return op.emitOpError("expects operand to be a memref with no layout"); if (!op.getResultMemRefType().getAffineMaps().empty()) return op.emitOpError("expects result to be a memref with no layout"); - if (op.getResultMemRefType().getMemorySpace() != - op.getMemRefType().getMemorySpace()) + if (op.getResultMemRefType().getMemorySpaceAsInt() != + op.getMemRefType().getMemorySpaceAsInt()) return op.emitOpError("expects result in same memory space"); auto sourceType = op.getMemRefType(); diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1882,16 +1882,16 @@ printAttribute(AffineMapAttr::get(map)); } // Only print the memory space if it is the non-default one. - if (memrefTy.getMemorySpace()) - os << ", " << memrefTy.getMemorySpace(); + if (memrefTy.getMemorySpaceAsInt()) + os << ", " << memrefTy.getMemorySpaceAsInt(); os << '>'; }) .Case([&](UnrankedMemRefType memrefTy) { os << "memref<*x"; printType(memrefTy.getElementType()); // Only print the memory space if it is the non-default one. - if (memrefTy.getMemorySpace()) - os << ", " << memrefTy.getMemorySpace(); + if (memrefTy.getMemorySpaceAsInt()) + os << ", " << memrefTy.getMemorySpaceAsInt(); os << '>'; }) .Case([&](ComplexType complexTy) { diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -206,7 +206,7 @@ if (auto other = dyn_cast()) { MemRefType::Builder b(shape, elementType); - b.setMemorySpace(other.getMemorySpace()); + b.setMemorySpace(other.getMemorySpaceAsInt()); return b; } @@ -229,7 +229,7 @@ if (auto other = dyn_cast()) { MemRefType::Builder b(shape, other.getElementType()); b.setShape(shape); - b.setMemorySpace(other.getMemorySpace()); + b.setMemorySpace(other.getMemorySpaceAsInt()); return b; } @@ -250,7 +250,7 @@ } if (auto other = dyn_cast()) { - return UnrankedMemRefType::get(elementType, other.getMemorySpace()); + return UnrankedMemRefType::get(elementType, other.getMemorySpaceAsInt()); } if (isa()) { @@ -472,7 +472,7 @@ // BaseMemRefType //===----------------------------------------------------------------------===// -unsigned BaseMemRefType::getMemorySpace() const { +unsigned BaseMemRefType::getMemorySpaceAsInt() const { return static_cast(impl)->memorySpace; } diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -947,7 +947,7 @@ if (bufSize <= localBufSizeThreshold && fastMemorySpace.hasValue()) { newMemSpace = fastMemorySpace.getValue(); } else { - newMemSpace = oldMemRefType.getMemorySpace(); + newMemSpace = oldMemRefType.getMemorySpaceAsInt(); } auto newMemRefType = MemRefType::get(newShape, oldMemRefType.getElementType(), {}, newMemSpace); diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -2725,12 +2725,12 @@ // Gather regions to allocate to buffers in faster memory space. if (auto loadOp = dyn_cast(opInst)) { if ((filterMemRef.hasValue() && filterMemRef != loadOp.getMemRef()) || - (loadOp.getMemRefType().getMemorySpace() != + (loadOp.getMemRefType().getMemorySpaceAsInt() != copyOptions.slowMemorySpace)) return; } else if (auto storeOp = dyn_cast(opInst)) { if ((filterMemRef.hasValue() && filterMemRef != storeOp.getMemRef()) || - storeOp.getMemRefType().getMemorySpace() != + storeOp.getMemRefType().getMemorySpaceAsInt() != copyOptions.slowMemorySpace) return; } else {