diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -517,19 +517,20 @@ /// canonicalizations are currently not implemented. BaseMemRefType getMemRefType(TensorType tensorType, const BufferizationOptions &options, - MemRefLayoutAttrInterface layout = {}, - Attribute memorySpace = {}); + Optional layout = None, + Optional memorySpace = None); /// Return a MemRef type with fully dynamic layout. If the given tensor type /// is unranked, return an unranked MemRef type. -BaseMemRefType getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, - Attribute memorySpace = {}); +BaseMemRefType +getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, + Optional memorySpace = None); /// Return a MemRef type with a static identity layout (i.e., no layout map). If /// the given tensor type is unranked, return an unranked MemRef type. BaseMemRefType getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, - Attribute memorySpace = {}); + Optional memorySpace = None); } // namespace bufferization } // namespace mlir diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -595,21 +595,27 @@ BaseMemRefType bufferization::getMemRefType(TensorType tensorType, const BufferizationOptions &options, - MemRefLayoutAttrInterface layout, - Attribute memorySpace) { + Optional layout, + Optional memorySpace) { // Case 1: Unranked memref type. if (auto unrankedTensorType = tensorType.dyn_cast()) { - assert(!layout && "UnrankedTensorType cannot have a layout map"); + assert(!layout.hasValue() && "UnrankedTensorType cannot have a layout map"); + if (memorySpace.hasValue()) + return UnrankedMemRefType::get(unrankedTensorType.getElementType(), + *memorySpace); return UnrankedMemRefType::get(unrankedTensorType.getElementType(), - memorySpace); + /*memorySpace=*/{}); } // Case 2: Ranked memref type with specified layout. auto rankedTensorType = tensorType.cast(); - if (layout) { + if (layout.hasValue()) { + if (memorySpace.hasValue()) + return MemRefType::get(rankedTensorType.getShape(), + rankedTensorType.getElementType(), *layout, + *memorySpace); return MemRefType::get(rankedTensorType.getShape(), - rankedTensorType.getElementType(), layout, - memorySpace); + rankedTensorType.getElementType(), *layout); } // Case 3: Configured with "fully dynamic layout maps". @@ -625,13 +631,15 @@ llvm_unreachable("InferLayoutMap is an invalid option"); } -BaseMemRefType -bufferization::getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, - Attribute memorySpace) { +BaseMemRefType bufferization::getMemRefTypeWithFullyDynamicLayout( + TensorType tensorType, Optional memorySpace) { // Case 1: Unranked memref type. if (auto unrankedTensorType = tensorType.dyn_cast()) { + if (memorySpace.hasValue()) + return UnrankedMemRefType::get(unrankedTensorType.getElementType(), + *memorySpace); return UnrankedMemRefType::get(unrankedTensorType.getElementType(), - memorySpace); + /*memorySpace=*/{}); } // Case 2: Ranked memref type. @@ -641,26 +649,33 @@ ShapedType::kDynamicStrideOrOffset); AffineMap stridedLayout = makeStridedLinearLayoutMap( dynamicStrides, dynamicOffset, rankedTensorType.getContext()); + if (memorySpace.hasValue()) + return MemRefType::get(rankedTensorType.getShape(), + rankedTensorType.getElementType(), stridedLayout, + *memorySpace); return MemRefType::get(rankedTensorType.getShape(), - rankedTensorType.getElementType(), stridedLayout, - memorySpace); + rankedTensorType.getElementType(), stridedLayout); } /// Return a MemRef type with a static identity layout (i.e., no layout map). If /// the given tensor type is unranked, return an unranked MemRef type. -BaseMemRefType -bufferization::getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, - Attribute memorySpace) { +BaseMemRefType bufferization::getMemRefTypeWithStaticIdentityLayout( + TensorType tensorType, Optional memorySpace) { // Case 1: Unranked memref type. if (auto unrankedTensorType = tensorType.dyn_cast()) { + if (memorySpace.hasValue()) + return UnrankedMemRefType::get(unrankedTensorType.getElementType(), + *memorySpace); return UnrankedMemRefType::get(unrankedTensorType.getElementType(), - memorySpace); + /*memorySpace=*/{}); } // Case 2: Ranked memref type. auto rankedTensorType = tensorType.cast(); - MemRefLayoutAttrInterface layout = {}; + if (memorySpace.hasValue()) + return MemRefType::get(rankedTensorType.getShape(), + rankedTensorType.getElementType(), AffineMap(), + *memorySpace); return MemRefType::get(rankedTensorType.getShape(), - rankedTensorType.getElementType(), layout, - memorySpace); + rankedTensorType.getElementType()); } diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -54,18 +54,18 @@ // The result buffer still has the old (pre-cast) type. Value resultBuffer = getBuffer(rewriter, castOp.getSource(), options); auto sourceMemRefType = resultBuffer.getType().cast(); - Attribute memorySpace = sourceMemRefType.getMemorySpace(); TensorType resultTensorType = castOp.getResult().getType().cast(); - MemRefLayoutAttrInterface layout; + Optional layout = None; if (auto rankedMemRefType = sourceMemRefType.dyn_cast()) if (resultTensorType.isa()) - layout = rankedMemRefType.getLayout(); + layout = rankedMemRefType.getLayout().getAffineMap(); // Compute the new memref type. Type resultMemRefType = - getMemRefType(resultTensorType, options, layout, memorySpace); + getMemRefType(resultTensorType, options, layout, + sourceMemRefType.getMemorySpaceAsInt()); // Replace the op with a memref.cast. assert(memref::CastOp::areCastCompatible(resultBuffer.getType(),