diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h @@ -125,7 +125,8 @@ // Globals are created lazily at the top of the enclosing ModuleOp with pretty // names. Duplicates are avoided. FailureOr getGlobalFor(arith::ConstantOp constantOp, - uint64_t alignment); + uint64_t alignment, + Attribute memorySpace); } // namespace bufferization } // namespace mlir diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp @@ -26,11 +26,6 @@ const BufferizationOptions &options) const { auto constantOp = cast(op); - // TODO: Implement memory space for this op. E.g., by adding a memory_space - // attribute to ConstantOp. - if (options.defaultMemorySpace != Attribute()) - return op->emitError("memory space not implemented yet"); - // Only ranked tensors are supported. if (!constantOp.getType().isa()) return failure(); @@ -42,8 +37,8 @@ // Create global memory segment and replace tensor with memref pointing to // that memory segment. - FailureOr globalOp = - getGlobalFor(constantOp, options.bufferAlignment); + FailureOr globalOp = getGlobalFor( + constantOp, options.bufferAlignment, *options.defaultMemorySpace); if (failed(globalOp)) return failure(); memref::GlobalOp globalMemref = *globalOp; diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp @@ -147,7 +147,8 @@ //===----------------------------------------------------------------------===// FailureOr -bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment) { +bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment, + Attribute memorySpace) { auto type = constantOp.getType().cast(); auto moduleOp = constantOp->getParentOfType(); if (!moduleOp) @@ -184,10 +185,12 @@ : IntegerAttr(); BufferizeTypeConverter typeConverter; + auto memrefType = typeConverter.convertType(type).cast(); + memrefType = MemRefType::Builder(memrefType).setMemorySpace(memorySpace); auto global = globalBuilder.create( constantOp.getLoc(), (Twine("__constant_") + os.str()).str(), /*sym_visibility=*/globalBuilder.getStringAttr("private"), - /*type=*/typeConverter.convertType(type).cast(), + /*type=*/memrefType, /*initial_value=*/constantOp.getValue().cast(), /*constant=*/true, /*alignment=*/memrefAlignment); diff --git a/mlir/test/Dialect/Arith/one-shot-bufferize-memory-space-invalid.mlir b/mlir/test/Dialect/Arith/one-shot-bufferize-memory-space-invalid.mlir --- a/mlir/test/Dialect/Arith/one-shot-bufferize-memory-space-invalid.mlir +++ b/mlir/test/Dialect/Arith/one-shot-bufferize-memory-space-invalid.mlir @@ -10,13 +10,3 @@ %r = arith.select %c, %0, %1 : tensor<10xf32> func.return %r : tensor<10xf32> } - -// ----- - -func.func @constant_memory_space(%idx: index, %v: i32) -> tensor<3xi32> { - // expected-error @+2 {{memory space not implemented yet}} - // expected-error @+1 {{failed to bufferize op}} - %cst = arith.constant dense<[5, 1000, 20]> : tensor<3xi32> - %0 = tensor.insert %v into %cst[%idx] : tensor<3xi32> - return %0 : tensor<3xi32> -} \ No newline at end of file diff --git a/mlir/test/Dialect/Arith/one-shot-bufferize-memory-space.mlir b/mlir/test/Dialect/Arith/one-shot-bufferize-memory-space.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Arith/one-shot-bufferize-memory-space.mlir @@ -0,0 +1,10 @@ +// RUN: mlir-opt %s -one-shot-bufferize="must-infer-memory-space" | FileCheck %s + +// CHECK-LABEL: memref.global "private" constant @__constant_3xi32 : memref<3xi32> = dense<[5, 1000, 20]> {alignment = 64 : i64} +// CHECK-LABEL: func.func @constant_memory_space( +// CHECK: %[[VAL_2:.*]] = memref.get_global @__constant_3xi32 : memref<3xi32> +func.func @constant_memory_space(%idx: index, %v: i32) -> tensor<3xi32> { + %cst = arith.constant dense<[5, 1000, 20]> : tensor<3xi32> + %0 = tensor.insert %v into %cst[%idx] : tensor<3xi32> + return %0 : tensor<3xi32> +} \ No newline at end of file