diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td @@ -49,13 +49,17 @@ Both dense and sparse tensor types are supported. The result of a `bufferization.alloc_tensor` is a tensor value that can be used like any other tensor value. In practice, it is often used as the "out" operand of - another op. E.g.: + another op. Sparse tensor allocations should always be used in a local + construction operation and never escape the function boundary directly. + + Example: ```mlir %c = bufferization.alloc_tensor [%d1, %d2] : tensor %0 = linalg.matmul ins(%a, %b: tensor, tensor) outs(%c: tensor) -> tensor + return %0 : tensor ``` }]; diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -9,8 +9,10 @@ #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Matchers.h" @@ -250,6 +252,16 @@ << getType().getNumDynamicDims() << " dynamic sizes"; if (getCopy() && getCopy().getType() != getType()) return emitError("expected that `copy` and return type match"); + + // For sparse tensor allocation, we require that none of its + // uses escapes the function boundary directly. + if (sparse_tensor::getSparseTensorEncoding(getType())) { + for (auto &use : getOperation()->getUses()) + if (isa( + use.getOwner())) + return emitError("sparse tensor allocation should not escape function"); + } + return success(); } diff --git a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt @@ -16,6 +16,7 @@ MLIRDialect MLIRFuncDialect MLIRIR + MLIRSparseTensorDialect MLIRTensorDialect MLIRMemRefDialect ) diff --git a/mlir/test/Dialect/Bufferization/invalid.mlir b/mlir/test/Dialect/Bufferization/invalid.mlir --- a/mlir/test/Dialect/Bufferization/invalid.mlir +++ b/mlir/test/Dialect/Bufferization/invalid.mlir @@ -54,4 +54,28 @@ // expected-error @+1{{'bufferization.escape' only valid on bufferizable ops}} %0 = memref.cast %m0 {bufferization.escape = [true]} : memref to memref<10xf32> return -} \ No newline at end of file +} + +// ----- + +#DCSR = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }> + +func.func @sparse_alloc_direct_return() -> tensor<20x40xf32, #DCSR> { + // expected-error @+1{{sparse tensor allocation should not escape function}} + %0 = bufferization.alloc_tensor() : tensor<20x40xf32, #DCSR> + return %0 : tensor<20x40xf32, #DCSR> +} + +// ----- + +#DCSR = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }> + +func.func private @foo(tensor<20x40xf32, #DCSR>) -> () + +func.func @sparse_alloc_call() { + // expected-error @+1{{sparse tensor allocation should not escape function}} + %0 = bufferization.alloc_tensor() : tensor<20x40xf32, #DCSR> + call @foo(%0) : (tensor<20x40xf32, #DCSR>) -> () + return +} + diff --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir --- a/mlir/test/Dialect/SparseTensor/conversion.mlir +++ b/mlir/test/Dialect/SparseTensor/conversion.mlir @@ -136,7 +136,8 @@ // CHECK: return %[[T]] : !llvm.ptr func.func @sparse_init(%arg0: index, %arg1: index) -> tensor { %0 = bufferization.alloc_tensor(%arg0, %arg1) : tensor - return %0 : tensor + %1 = sparse_tensor.load %0 : tensor + return %1 : tensor } // CHECK-LABEL: func @sparse_release( @@ -580,6 +581,7 @@ func.func @sparse_and_dense_init(%arg0: index, %arg1: index) -> (tensor, tensor) { %0 = bufferization.alloc_tensor(%arg0, %arg1) : tensor - %1 = bufferization.alloc_tensor(%arg0, %arg1) : tensor - return %0, %1 : tensor, tensor + %1 = sparse_tensor.load %0 : tensor + %2 = bufferization.alloc_tensor(%arg0, %arg1) : tensor + return %1, %2 : tensor, tensor } diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -8958,6 +8958,7 @@ ":IR", ":InferTypeOpInterface", ":MemRefDialect", + ":SparseTensorDialect", ":Support", ":TensorDialect", "//llvm:Support",