diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -911,6 +911,16 @@ Value nnz = createFuncCall(rewriter, loc, "getSparseTensorReaderNNZ", {indexTp}, {reader}, EmitCInterface::Off) .getResult(0); + Value symmetric; + // We assume only rank 2 tensors may have the isSymmetric flag set. + if (rank == 2) { + symmetric = + createFuncCall(rewriter, loc, "getSparseTensorReaderIsSymmetric", + {rewriter.getI1Type()}, {reader}, EmitCInterface::Off) + .getResult(0); + } else { + symmetric = Value(); + } Type eltTp = dstTp.getElementType(); Value value = genAllocaScalar(rewriter, loc, eltTp); scf::ForOp forOp = rewriter.create(loc, c0, nnz, c1, @@ -929,8 +939,23 @@ loc, indices, constantIndex(rewriter, loc, i))); } Value v = rewriter.create(loc, value); - auto t = rewriter.create(loc, v, forOp.getRegionIterArg(0), - indicesArray); + Value t = rewriter.create(loc, v, forOp.getRegionIterArg(0), + indicesArray); + if (symmetric) { + Value eq = rewriter.create( + loc, arith::CmpIPredicate::ne, indicesArray[0], indicesArray[1]); + Value cond = rewriter.create(loc, symmetric, eq); + scf::IfOp ifOp = + rewriter.create(loc, t.getType(), cond, /*else*/ true); + rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); + rewriter.create( + loc, Value(rewriter.create( + loc, v, t, ValueRange{indicesArray[1], indicesArray[0]}))); + rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); + rewriter.create(loc, t); + t = ifOp.getResult(0); + rewriter.setInsertionPointAfter(ifOp); + } rewriter.create(loc, ArrayRef(t)); rewriter.setInsertionPointAfter(forOp); // Link SSA chain. diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sum.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sum.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sum.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sum.mlir @@ -1,9 +1,16 @@ -// RUN: mlir-opt %s --sparse-compiler | \ -// RUN: TENSOR0="%mlir_src_dir/test/Integration/data/test_symmetric.mtx" \ -// RUN: mlir-cpu-runner \ -// RUN: -e entry -entry-point-result=void \ -// RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \ -// RUN: FileCheck %s +// DEFINE: %{option} = enable-runtime-library=true +// DEFINE: %{command} = mlir-opt %s --sparse-compiler=%{option} | \ +// DEFINE: TENSOR0="%mlir_src_dir/test/Integration/data/test_symmetric.mtx" \ +// DEFINE: mlir-cpu-runner \ +// DEFINE: -e entry -entry-point-result=void \ +// DEFINE: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \ +// DEFINE: FileCheck %s +// +// RUN: %{command} +// +// Do the same run, but now with direct IR generation. +// REDEFINE: %{option} = enable-runtime-library=false +// RUN: %{command} !Filename = !llvm.ptr