diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -629,6 +629,29 @@ // sparse tensor codegen to progressively lower sparse tensors. //===----------------------------------------------------------------------===// +def SparseTensor_StorageNewOp : SparseTensor_Op<"storage", []>, + Arguments<(ins Variadic:$inputs)>, + Results<(outs AnyTuple:$result)> { + let summary = "Pack a list of value into one sparse tensor storage value"; + let description = [{ + Pack a list of value into one sparse tensor storage value (represented as + a tuple) at the given index. + + The result tuple elements' type should match the corresponding type in the + input array. + + Example: + + ```mlir + %0 = sparse_tensor.storage(%1, %2): memref, memref + to tuple, memref> + ``` + }]; + + let assemblyFormat = " attr-dict `(` $inputs `)``:` type($inputs) `to` type($result)"; + let hasVerifier = 1; +} + def SparseTensor_StorageGetOp : SparseTensor_Op<"storage_get", []>, Arguments<(ins AnyTuple:$storage, IndexAttr:$idx)>, diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -486,6 +486,23 @@ // Sparse Tensor Storage Operation. //===----------------------------------------------------------------------===// +LogicalResult StorageNewOp::verify() { + auto retTypes = getResult().getType().getTypes(); + if (retTypes.size() != getInputs().size()) + return emitError("The number of inputs is inconsistent with output tuple"); + + for (auto pair : llvm::zip(getInputs(), retTypes)) { + auto input = std::get<0>(pair); + auto retTy = std::get<1>(pair); + + if (input.getType() != retTy) + return emitError(llvm::formatv("Type mismatch between input (type={0}) " + "and output tuple element (type={1})", + input.getType(), retTy)); + } + return success(); +} + LogicalResult StorageGetOp::verify() { uint64_t extractIdx = getIdx().getZExtValue(); auto innerTypeArray = getStorage().getType().getTypes(); diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir --- a/mlir/test/Dialect/SparseTensor/invalid.mlir +++ b/mlir/test/Dialect/SparseTensor/invalid.mlir @@ -445,6 +445,26 @@ // ----- +func.func @sparse_storage_new(%arg0: memref, %arg1: memref, %arg2: f64) -> + tuple, memref> { + // expected-error@+1{{The number of inputs is inconsistent with output}} + %0 = sparse_tensor.storage(%arg0, %arg1, %arg2) + : memref, memref, f64 to tuple, memref> + return %0 : tuple, memref> +} + +// ----- + +func.func @sparse_storage_new(%arg0: memref, %arg1: memref, %arg2: f64) -> + tuple, memref, f64> { + // expected-error@+1{{Type mismatch between}} + %0 = sparse_tensor.storage(%arg0, %arg1, %arg2) + : memref, memref, f64 to tuple, memref, f64> + return %0 : tuple, memref, f64> +} + +// ----- + func.func @sparse_storage_get(%arg0: tuple, memref, f64>) -> memref { // expected-error@+1{{Out-of-bound access}} %0 = sparse_tensor.storage_get %arg0[3] diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir --- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir +++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir @@ -317,6 +317,22 @@ // ----- + +// CHECK: func @sparse_storage_new( +// CHECK-SAME: %[[A0:.*0]]: memref, +// CHECK-SAME: %[[A1:.*1]]: memref, +// CHECK-SAME: %[[A2:.*]]: f64 +// CHECK: %[[TMP_0:.*]] = sparse_tensor.storage(%[[A0]], %[[A1]], %[[A2]]) +// CHECK: return %[[TMP_0]] : tuple, memref, f64> +func.func @sparse_storage_new(%arg0: memref, %arg1: memref, %arg2: f64) -> + tuple, memref, f64> { + %0 = sparse_tensor.storage(%arg0, %arg1, %arg2) + : memref, memref, f64 to tuple, memref, f64> + return %0 : tuple, memref, f64> +} + +// ----- + // CHECK-LABEL: func @sparse_storage_get( // CHECK-SAME: %[[A0:.*]]: tuple, memref, f64> // CHECK: %[[TMP0:.*]] = sparse_tensor.storage_get %[[A0]][0] : @@ -329,7 +345,7 @@ return %0 : memref } -// ---- +// ----- // CHECK-LABEL: func @sparse_storage_set( // CHECK-SAME: %[[A0:.*]]: tuple, memref, f64>,