diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -26,7 +26,7 @@ //===----------------------------------------------------------------------===// def SparseTensor_NewOp : SparseTensor_Op<"new", [Pure]>, - Arguments<(ins AnyType:$source)>, + Arguments<(ins AnyType:$source, UnitAttr:$expandSymmetry)>, Results<(outs AnySparseTensor:$result)> { string summary = "Materializes a new sparse tensor from given source"; string description = [{ @@ -39,13 +39,22 @@ code. The operation is provided as an anchor that materializes a properly typed sparse tensor with inital contents into a computation. + An optional attribute `expandSymmetry` can be used to extend this operation + to make symmetry in external formats explicit in the storage. That is, when + the attribute presents and a non-zero value is discovered at (i, j) where + i!=j, we add the same value to (j, i). This claims more storage than a pure + symmetric storage, and thus may cause a bad performance hit. True symmetric + storage is planned for the future. + Example: ```mlir sparse_tensor.new %source : !Source to tensor<1024x1024xf64, #CSR> ``` }]; - let assemblyFormat = "$source attr-dict `:` type($source) `to` type($result)"; + let assemblyFormat = "(`expand_symmetry` $expandSymmetry^)? $source attr-dict" + "`:` type($source) `to` type($result)"; + let hasVerifier = 1; } def SparseTensor_ConvertOp : SparseTensor_Op<"convert", diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -334,6 +334,13 @@ return failure(); } +LogicalResult NewOp::verify() { + if (getExpandSymmetry() && + getResult().getType().cast().getRank() != 2) + return emitOpError("expand_symmetry can only be used for 2D tensors"); + return success(); +} + LogicalResult ConvertOp::verify() { if (auto tp1 = getSource().getType().dyn_cast()) { if (auto tp2 = getDest().getType().dyn_cast()) { diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -915,8 +915,8 @@ {indexTp}, {reader}, EmitCInterface::Off) .getResult(0); Value symmetric; - // We assume only rank 2 tensors may have the isSymmetric flag set. - if (rank == 2) { + // The verifier ensures only 2D tensors can have the expandSymmetry flag. + if (rank == 2 && op.getExpandSymmetry()) { symmetric = createFuncCall(rewriter, loc, "getSparseTensorReaderIsSymmetric", {rewriter.getI1Type()}, {reader}, EmitCInterface::Off) diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir --- a/mlir/test/Dialect/SparseTensor/invalid.mlir +++ b/mlir/test/Dialect/SparseTensor/invalid.mlir @@ -180,6 +180,16 @@ // ----- +#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}> + +func.func @sparse_new(%arg0: !llvm.ptr) { + // expected-error@+1 {{expand_symmetry can only be used for 2D tensors}} + %0 = sparse_tensor.new expand_symmetry %arg0 : !llvm.ptr to tensor<128xf64, #SparseVector> + return +} + +// ----- + func.func @sparse_convert_unranked(%arg0: tensor<*xf32>) -> tensor<10xf32> { // expected-error@+1 {{unexpected type in convert}} %0 = sparse_tensor.convert %arg0 : tensor<*xf32> to tensor<10xf32> diff --git a/mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir b/mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir --- a/mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir +++ b/mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir @@ -5,7 +5,7 @@ dimLevelType = ["dense", "compressed"] }> -// CHECK-LABEL: func.func @sparse_new( +// CHECK-LABEL: func.func @sparse_new_symmetry( // CHECK-SAME: %[[A:.*]]: !llvm.ptr) -> tensor> { // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index @@ -39,6 +39,37 @@ // CHECK: %[[R:.*]] = sparse_tensor.convert %[[T5]] // CHECK: bufferization.dealloc_tensor %[[T5]] // CHECK: return %[[R]] +func.func @sparse_new_symmetry(%arg0: !llvm.ptr) -> tensor { + %0 = sparse_tensor.new expand_symmetry %arg0 : !llvm.ptr to tensor + return %0 : tensor +} + +// CHECK-LABEL: func.func @sparse_new( +// CHECK-SAME: %[[A:.*]]: !llvm.ptr) -> tensor> { +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[R:.*]] = call @createSparseTensorReader(%[[A]]) +// CHECK: %[[DS:.*]] = memref.alloca(%[[C2]]) : memref +// CHECK: call @getSparseTensorReaderDimSizes(%[[R]], %[[DS]]) +// CHECK: %[[D0:.*]] = memref.load %[[DS]]{{\[}}%[[C0]]] +// CHECK: %[[D1:.*]] = memref.load %[[DS]]{{\[}}%[[C1]]] +// CHECK: %[[T:.*]] = bufferization.alloc_tensor(%[[D0]], %[[D1]]) +// CHECK: %[[N:.*]] = call @getSparseTensorReaderNNZ(%[[R]]) +// CHECK: %[[VB:.*]] = memref.alloca() +// CHECK: %[[T2:.*]] = scf.for %{{.*}} = %[[C0]] to %[[N]] step %[[C1]] iter_args(%[[A2:.*]] = %[[T]]) +// CHECK: func.call @getSparseTensorReaderNextF32(%[[R]], %[[DS]], %[[VB]]) +// CHECK: %[[E0:.*]] = memref.load %[[DS]]{{\[}}%[[C0]]] +// CHECK: %[[E1:.*]] = memref.load %[[DS]]{{\[}}%[[C1]]] +// CHECK: %[[V:.*]] = memref.load %[[VB]][] +// CHECK: %[[T1:.*]] = sparse_tensor.insert %[[V]] into %[[A2]]{{\[}}%[[E0]], %[[E1]]] +// CHECK: scf.yield %[[T1]] +// CHECK: } +// CHECK: call @delSparseTensorReader(%[[R]]) +// CHECK: %[[T4:.*]] = sparse_tensor.load %[[T2]] hasInserts +// CHECK: %[[R:.*]] = sparse_tensor.convert %[[T4]] +// CHECK: bufferization.dealloc_tensor %[[T4]] +// CHECK: return %[[R]] func.func @sparse_new(%arg0: !llvm.ptr) -> tensor { %0 = sparse_tensor.new %arg0 : !llvm.ptr to tensor return %0 : tensor diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir --- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir +++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir @@ -13,6 +13,19 @@ // ----- +#SparseMatrix = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}> + +// CHECK-LABEL: func @sparse_new_symmetry( +// CHECK-SAME: %[[A:.*]]: !llvm.ptr) +// CHECK: %[[T:.*]] = sparse_tensor.new expand_symmetry %[[A]] : !llvm.ptr to tensor +// CHECK: return %[[T]] : tensor +func.func @sparse_new_symmetry(%arg0: !llvm.ptr) -> tensor { + %0 = sparse_tensor.new expand_symmetry %arg0 : !llvm.ptr to tensor + return %0 : tensor +} + +// ----- + #SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}> // CHECK-LABEL: func @sparse_dealloc( diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sum.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sum.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sum.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sum.mlir @@ -63,7 +63,7 @@ // Read the sparse matrix from file, construct sparse storage. %fileName = call @getTensorFilename(%c0) : (index) -> (!Filename) - %a = sparse_tensor.new %fileName : !Filename to tensor + %a = sparse_tensor.new expand_symmetry %fileName : !Filename to tensor // Call the kernel. %0 = call @kernel_sum_reduce(%a, %x) diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sum_c32.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sum_c32.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sum_c32.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sum_c32.mlir @@ -66,7 +66,7 @@ // Read the sparse matrix from file, construct sparse storage. %fileName = call @getTensorFilename(%c0) : (index) -> (!Filename) - %a = sparse_tensor.new %fileName : !Filename to tensor, #SparseMatrix> + %a = sparse_tensor.new expand_symmetry %fileName : !Filename to tensor, #SparseMatrix> // Call the kernel. %0 = call @kernel_sum_reduce(%a, %x)