diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h @@ -125,7 +125,8 @@ // Globals are created lazily at the top of the enclosing ModuleOp with pretty // names. Duplicates are avoided. FailureOr getGlobalFor(arith::ConstantOp constantOp, - uint64_t alignment); + uint64_t alignment, + Attribute memorySpace = {}); } // namespace bufferization } // namespace mlir diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp @@ -11,6 +11,7 @@ #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Operation.h" @@ -26,10 +27,11 @@ const BufferizationOptions &options) const { auto constantOp = cast(op); - // TODO: Implement memory space for this op. E.g., by adding a memory_space - // attribute to ConstantOp. - if (options.defaultMemorySpace != Attribute()) - return op->emitError("memory space not implemented yet"); + Attribute memorySpace; + if (options.defaultMemorySpace.has_value()) + memorySpace = *options.defaultMemorySpace; + else + return constantOp->emitError("could not infer memory space"); // Only ranked tensors are supported. if (!constantOp.getType().isa()) @@ -43,7 +45,7 @@ // Create global memory segment and replace tensor with memref pointing to // that memory segment. FailureOr globalOp = - getGlobalFor(constantOp, options.bufferAlignment); + getGlobalFor(constantOp, options.bufferAlignment, memorySpace); if (failed(globalOp)) return failure(); memref::GlobalOp globalMemref = *globalOp; diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp @@ -147,7 +147,8 @@ //===----------------------------------------------------------------------===// FailureOr -bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment) { +bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment, + Attribute memorySpace) { auto type = constantOp.getType().cast(); auto moduleOp = constantOp->getParentOfType(); if (!moduleOp) @@ -184,10 +185,13 @@ : IntegerAttr(); BufferizeTypeConverter typeConverter; + auto memrefType = typeConverter.convertType(type).cast(); + if (memorySpace) + memrefType = MemRefType::Builder(memrefType).setMemorySpace(memorySpace); auto global = globalBuilder.create( constantOp.getLoc(), (Twine("__constant_") + os.str()).str(), /*sym_visibility=*/globalBuilder.getStringAttr("private"), - /*type=*/typeConverter.convertType(type).cast(), + /*type=*/memrefType, /*initial_value=*/constantOp.getValue().cast(), /*constant=*/true, /*alignment=*/memrefAlignment); diff --git a/mlir/test/Dialect/Arith/one-shot-bufferize-memory-space-invalid.mlir b/mlir/test/Dialect/Arith/one-shot-bufferize-memory-space-invalid.mlir --- a/mlir/test/Dialect/Arith/one-shot-bufferize-memory-space-invalid.mlir +++ b/mlir/test/Dialect/Arith/one-shot-bufferize-memory-space-invalid.mlir @@ -13,10 +13,10 @@ // ----- -func.func @constant_memory_space(%idx: index, %v: i32) -> tensor<3xi32> { - // expected-error @+2 {{memory space not implemented yet}} +func.func @unknown_memory_space(%idx: index, %v: i32) -> tensor<3xi32> { + // expected-error @+2 {{could not infer memory space}} // expected-error @+1 {{failed to bufferize op}} %cst = arith.constant dense<[5, 1000, 20]> : tensor<3xi32> %0 = tensor.insert %v into %cst[%idx] : tensor<3xi32> return %0 : tensor<3xi32> } \ No newline at end of file