diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -513,18 +513,17 @@ BaseMemRefType getMemRefType(TensorType tensorType, const BufferizationOptions &options, MemRefLayoutAttrInterface layout = {}, - Attribute memorySpace = {}); + unsigned memorySpace = 0); /// Return a MemRef type with fully dynamic layout. If the given tensor type /// is unranked, return an unranked MemRef type. BaseMemRefType getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, - Attribute memorySpace = {}); + unsigned memorySpace = 0); /// Return a MemRef type with a static identity layout (i.e., no layout map). If /// the given tensor type is unranked, return an unranked MemRef type. -BaseMemRefType -getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, - Attribute memorySpace = {}); +BaseMemRefType getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, + unsigned memorySpace = 0); } // namespace bufferization } // namespace mlir diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -596,12 +596,15 @@ BaseMemRefType bufferization::getMemRefType(TensorType tensorType, const BufferizationOptions &options, MemRefLayoutAttrInterface layout, - Attribute memorySpace) { + unsigned memorySpace) { + auto memorySpaceAttr = IntegerAttr::get( + IntegerType::get(tensorType.getContext(), 64), memorySpace); + // Case 1: Unranked memref type. if (auto unrankedTensorType = tensorType.dyn_cast()) { assert(!layout && "UnrankedTensorType cannot have a layout map"); return UnrankedMemRefType::get(unrankedTensorType.getElementType(), - memorySpace); + memorySpaceAttr); } // Case 2: Ranked memref type with specified layout. @@ -609,7 +612,7 @@ if (layout) { return MemRefType::get(rankedTensorType.getShape(), rankedTensorType.getElementType(), layout, - memorySpace); + memorySpaceAttr); } // Case 3: Configured with "fully dynamic layout maps". @@ -627,7 +630,7 @@ BaseMemRefType bufferization::getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, - Attribute memorySpace) { + unsigned memorySpace) { // Case 1: Unranked memref type. if (auto unrankedTensorType = tensorType.dyn_cast()) { return UnrankedMemRefType::get(unrankedTensorType.getElementType(), @@ -635,6 +638,8 @@ } // Case 2: Ranked memref type. + auto memorySpaceAttr = IntegerAttr::get( + IntegerType::get(tensorType.getContext(), 64), memorySpace); auto rankedTensorType = tensorType.cast(); int64_t dynamicOffset = ShapedType::kDynamicStrideOrOffset; SmallVector dynamicStrides(rankedTensorType.getRank(), @@ -643,14 +648,14 @@ dynamicStrides, dynamicOffset, rankedTensorType.getContext()); return MemRefType::get(rankedTensorType.getShape(), rankedTensorType.getElementType(), stridedLayout, - memorySpace); + memorySpaceAttr); } /// Return a MemRef type with a static identity layout (i.e., no layout map). If /// the given tensor type is unranked, return an unranked MemRef type. BaseMemRefType bufferization::getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, - Attribute memorySpace) { + unsigned memorySpace) { // Case 1: Unranked memref type. if (auto unrankedTensorType = tensorType.dyn_cast()) { return UnrankedMemRefType::get(unrankedTensorType.getElementType(), @@ -659,8 +664,10 @@ // Case 2: Ranked memref type. auto rankedTensorType = tensorType.cast(); + auto memorySpaceAttr = IntegerAttr::get( + IntegerType::get(tensorType.getContext(), 64), memorySpace); MemRefLayoutAttrInterface layout = {}; return MemRefType::get(rankedTensorType.getShape(), rankedTensorType.getElementType(), layout, - memorySpace); + memorySpaceAttr); } diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -54,7 +54,6 @@ // The result buffer still has the old (pre-cast) type. Value resultBuffer = getBuffer(rewriter, castOp.getSource(), options); auto sourceMemRefType = resultBuffer.getType().cast(); - Attribute memorySpace = sourceMemRefType.getMemorySpace(); TensorType resultTensorType = castOp.getResult().getType().cast(); MemRefLayoutAttrInterface layout; @@ -65,7 +64,8 @@ // Compute the new memref type. Type resultMemRefType = - getMemRefType(resultTensorType, options, layout, memorySpace); + getMemRefType(resultTensorType, options, layout, + sourceMemRefType.getMemorySpaceAsInt()); // Replace the op with a memref.cast. assert(memref::CastOp::areCastCompatible(resultBuffer.getType(),