diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -623,4 +623,57 @@ let hasVerifier = 1; } +//===----------------------------------------------------------------------===// +// Sparse Tensor Storage Operation. These operations are used internally by +// sparse tensor codegen to progressively lower sparse tensors. +//===----------------------------------------------------------------------===// + +def SparseTensor_StorageGetOp : SparseTensor_Op<"storage_get", []>, + Arguments<(ins AnyTuple:$storage, + IndexAttr:$idx)>, + Results<(outs AnyType:$result)> { + let summary = "Get the data stored in the sparse tensor storage at the given index"; + let description = [{ + Get the data stored in the sparse tensor storage (represented as a tuple) + at the given index. + + The result type should match the corresponding element type in the tuple. + + Example: + + ```mlir + %0 = sparse_tensor.storage_get %arg0[0] : tuple, memref, f64> to memref + ``` + }]; + + let assemblyFormat = " $storage attr-dict `[`$idx`]` `:` type($storage) `to` type($result)"; + let hasVerifier = 1; +} + +def SparseTensor_StorageSetOp : SparseTensor_Op<"storage_set", []>, + Arguments<(ins AnyTuple:$storage, + AnyType:$value, + IndexAttr:$idx)>, + Results<(outs AnyTuple:$result)> { + let summary = "Set the data stored in the sparse tensor storage at given index"; + let description = [{ + Set the data stored in the sparse tensor storage (represented as a tuple) + at the given index. Return a new SSA value with the corresponding element + updated (others remain unchanged). + + The result type should match the original tuple type with only the updated + element type changed accordingly. + + Example: + + ```mlir + %0 = sparse_tensor.storage_set %arg0, %arg1 at 0 : tuple, memref, f64>, memref to tuple, memref, f64> + ``` + }]; + + let assemblyFormat = " $storage attr-dict `[`$idx`]``,` $value `:` type($storage) `,` type($value) `to` type($result)"; + let hasVerifier = 1; +} + + #endif // SPARSETENSOR_OPS diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -482,6 +482,48 @@ "expected parent op to be sparse_tensor unary, binary, or reduce"); } +//===----------------------------------------------------------------------===// +// Sparse Tensor Storage Operation. +//===----------------------------------------------------------------------===// + +LogicalResult StorageGetOp::verify() { + uint64_t extractIdx = getIdx().getZExtValue(); + auto innerTypeArray = getStorage().getType().getTypes(); + if (extractIdx >= innerTypeArray.size()) + return emitError(llvm::formatv( + "Out-of-bound access with index={0} on tuple with length={1}", + extractIdx, innerTypeArray.size())); + + auto expectedTy = getStorage().getType().getType(extractIdx); + auto returnTy = getResult().getType(); + if (expectedTy != returnTy) + return emitError(llvm::formatv( + "Type mismatch between the returning type (type={0}) and the " + "corresponding element type at index {1} (type={2})", + expectedTy, extractIdx, returnTy)); + return success(); +} + +LogicalResult StorageSetOp::verify() { + uint64_t setIdx = getIdx().getZExtValue(); + SmallVector expectedElemTy(getStorage().getType().getTypes()); + if (setIdx >= expectedElemTy.size()) + return emitError(llvm::formatv( + "Out-of-bound access with index = {0} on tuple with length={1}", setIdx, + expectedElemTy.size())); + + // Updates the element type after storage_set. + expectedElemTy[setIdx] = getValue().getType(); + auto expectedTy = TupleType::get(getContext(), expectedElemTy); + auto returnTy = getResult().getType(); + if (expectedTy != returnTy) + return emitError( + llvm::formatv("Type mismatch between the returning type " + "(type={0}) and the expected type (type={1})", + returnTy, expectedTy)); + return success(); +} + //===----------------------------------------------------------------------===// // TensorDialect Methods. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir --- a/mlir/test/Dialect/SparseTensor/invalid.mlir +++ b/mlir/test/Dialect/SparseTensor/invalid.mlir @@ -443,3 +443,42 @@ return %0 : tensor<9x4xf64, #DC> } +// ----- + +func.func @sparse_storage_get(%arg0: tuple, memref, f64>) -> memref { + // expected-error@+1{{Out-of-bound access}} + %0 = sparse_tensor.storage_get %arg0[3] + : tuple, memref, f64> to + memref + return %0 : memref +} + +// ----- + +func.func @sparse_storage_get(%arg0: tuple, memref, f64>) -> memref { + // expected-error@+1{{Type mismatch}} + %0 = sparse_tensor.storage_get %arg0[2] + : tuple, memref, f64> to + memref + return %0 : memref +} + +// ----- + +func.func @sparse_storage_set(%arg0: tuple, memref, f64>, %arg1: memref) -> tuple, memref, f64> { + // expected-error@+1{{Out-of-bound access}} + %0 = sparse_tensor.storage_set %arg0[3], %arg1 + : tuple, memref, f64>, memref to + tuple, memref, f64> + return %0 : tuple, memref, f64> +} + +// ----- + +func.func @sparse_storage_set(%arg0: tuple, memref, f64>, %arg1: memref) -> tuple, memref, f64> { + // expected-error@+1{{Type mismatch}} + %0 = sparse_tensor.storage_set %arg0[2], %arg1 + : tuple, memref, f64>, memref to + tuple, memref, f64> + return %0 : tuple, memref, f64> +} diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir --- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir +++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir @@ -314,3 +314,34 @@ tensor<4x4xf64, #SparseMatrix> to tensor<9x4xf64, #SparseMatrix> return %0 : tensor<9x4xf64, #SparseMatrix> } + +// ----- + +// CHECK-LABEL: func @sparse_storage_get( +// CHECK-SAME: %[[A0:.*]]: tuple, memref, f64> +// CHECK: %[[TMP0:.*]] = sparse_tensor.storage_get %[[A0]][0] : +// CHECK-SAME: tuple, memref, f64> +// CHECK-SAME: to memref +// CHECK: return %[[TMP0]] : memref +func.func @sparse_storage_get(%arg0: tuple, memref, f64>) -> memref { + %0 = sparse_tensor.storage_get %arg0[0] + : tuple, memref, f64> to memref + return %0 : memref +} + +// ---- + +// CHECK-LABEL: func @sparse_storage_set( +// CHECK-SAME: %[[A0:.*]]: tuple, memref, f64>, +// CHECK-SAME: %[[A1:.*]]: memref +// CHECK: %[[TMP0:.*]] = sparse_tensor.storage_set %[[A0]][0], %[[A1]] : +// CHECK-SAME: tuple, memref, f64>, +// CHECK-SAME: memref +// CHECK-SAME: to tuple, memref, f64> +// CHECK: return %0 : tuple, memref, f64> +func.func @sparse_storage_set(%arg0: tuple, memref, f64>, %arg1: memref) -> tuple, memref, f64> { + %0 = sparse_tensor.storage_set %arg0[0], %arg1 + : tuple, memref, f64>, memref to + tuple, memref, f64> + return %0 : tuple, memref, f64> +}