diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -909,6 +909,16 @@ Value nnz = createFuncCall(rewriter, loc, "getSparseTensorReaderNNZ", {indexTp}, {reader}, EmitCInterface::Off) .getResult(0); + Value symmetric; + // We assume only rank 2 tensors may have the isSymmetric flag set. + if (rank == 2) { + symmetric = + createFuncCall(rewriter, loc, "getSparseTensorReaderIsSymmetric", + {rewriter.getI1Type()}, {reader}, EmitCInterface::Off) + .getResult(0); + } else { + symmetric = Value(); + } Type eltTp = dstTp.getElementType(); Value value = genAllocaScalar(rewriter, loc, eltTp); scf::ForOp forOp = rewriter.create(loc, c0, nnz, c1, @@ -927,8 +937,23 @@ loc, indices, constantIndex(rewriter, loc, i))); } Value v = rewriter.create(loc, value); - auto t = rewriter.create(loc, v, forOp.getRegionIterArg(0), - indicesArray); + Value t = rewriter.create(loc, v, forOp.getRegionIterArg(0), + indicesArray); + if (symmetric) { + Value eq = rewriter.create( + loc, arith::CmpIPredicate::ne, indicesArray[0], indicesArray[1]); + Value cond = rewriter.create(loc, symmetric, eq); + scf::IfOp ifOp = + rewriter.create(loc, t.getType(), cond, /*else*/ true); + rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); + rewriter.create( + loc, Value(rewriter.create( + loc, v, t, ValueRange{indicesArray[1], indicesArray[0]}))); + rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); + rewriter.create(loc, t); + t = ifOp.getResult(0); + rewriter.setInsertionPointAfter(ifOp); + } rewriter.create(loc, ArrayRef(t)); rewriter.setInsertionPointAfter(forOp); // Link SSA chain. diff --git a/mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir b/mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir --- a/mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir +++ b/mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -sparse-tensor-rewrite="enable-runtime-library=false enable-convert=false" |\ +// RUN: mlir-opt %s -sparse-tensor-rewrite="enable-runtime-library=false enable-convert=false" | \ // RUN: FileCheck %s #CSR = #sparse_tensor.encoding<{ @@ -17,6 +17,7 @@ // CHECK: %[[D1:.*]] = memref.load %[[DS]]{{\[}}%[[C1]]] // CHECK: %[[T:.*]] = bufferization.alloc_tensor(%[[D0]], %[[D1]]) // CHECK: %[[N:.*]] = call @getSparseTensorReaderNNZ(%[[R]]) +// CHECK: %[[S:.*]] = call @getSparseTensorReaderIsSymmetric(%[[R]]) // CHECK: %[[VB:.*]] = memref.alloca() // CHECK: %[[T2:.*]] = scf.for %{{.*}} = %[[C0]] to %[[N]] step %[[C1]] iter_args(%[[A2:.*]] = %[[T]]) // CHECK: func.call @getSparseTensorReaderNextF32(%[[R]], %[[DS]], %[[VB]]) @@ -24,12 +25,19 @@ // CHECK: %[[E1:.*]] = memref.load %[[DS]]{{\[}}%[[C1]]] // CHECK: %[[V:.*]] = memref.load %[[VB]][] // CHECK: %[[T1:.*]] = sparse_tensor.insert %[[V]] into %[[A2]]{{\[}}%[[E0]], %[[E1]]] -// CHECK: scf.yield %[[T1]] +// CHECK: %[[NE:.*]] = arith.cmpi ne, %[[E0]], %[[E1]] +// CHECK: %[[COND:.*]] = arith.andi %[[S]], %[[NE]] +// CHECK: %[[T3:.*]] = scf.if %[[COND]] +// CHECK: %[[T4:.*]] = sparse_tensor.insert %[[V]] into %[[T1]]{{\[}}%[[E1]], %[[E0]]] +// CHECK: scf.yield %[[T4]] +// CHECK: else +// CHECK: scf.yield %[[T1]] +// CHECK: scf.yield %[[T3]] // CHECK: } // CHECK: call @delSparseTensorReader(%[[R]]) -// CHECK: %[[T3:.*]] = sparse_tensor.load %[[T2]] hasInserts -// CHECK: %[[R:.*]] = sparse_tensor.convert %[[T3]] -// CHECK: bufferization.dealloc_tensor %[[T3]] +// CHECK: %[[T5:.*]] = sparse_tensor.load %[[T2]] hasInserts +// CHECK: %[[R:.*]] = sparse_tensor.convert %[[T5]] +// CHECK: bufferization.dealloc_tensor %[[T5]] // CHECK: return %[[R]] func.func @sparse_new(%arg0: !llvm.ptr) -> tensor { %0 = sparse_tensor.new %arg0 : !llvm.ptr to tensor diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sum.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sum.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sum.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sum.mlir @@ -1,9 +1,16 @@ -// RUN: mlir-opt %s --sparse-compiler | \ -// RUN: TENSOR0="%mlir_src_dir/test/Integration/data/test_symmetric.mtx" \ -// RUN: mlir-cpu-runner \ -// RUN: -e entry -entry-point-result=void \ -// RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \ -// RUN: FileCheck %s +// DEFINE: %{option} = enable-runtime-library=true +// DEFINE: %{command} = mlir-opt %s --sparse-compiler=%{option} | \ +// DEFINE: TENSOR0="%mlir_src_dir/test/Integration/data/test_symmetric.mtx" \ +// DEFINE: mlir-cpu-runner \ +// DEFINE: -e entry -entry-point-result=void \ +// DEFINE: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \ +// DEFINE: FileCheck %s +// +// RUN: %{command} +// +// Do the same run, but now with direct IR generation. +// REDEFINE: %{option} = enable-runtime-library=false +// RUN: %{command} !Filename = !llvm.ptr