diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -538,9 +538,9 @@ case TensorExp::Kind::kCIm: case TensorExp::Kind::kCRe: case TensorExp::Kind::kBitCast: + case TensorExp::Kind::kUnary: return isSingleCondition(t, expr.children.e0); case TensorExp::Kind::kBinaryBranch: - case TensorExp::Kind::kUnary: case TensorExp::Kind::kSelect: return false; // Binary operations. @@ -559,6 +559,7 @@ case TensorExp::Kind::kMulC: case TensorExp::Kind::kMulI: case TensorExp::Kind::kAndI: + case TensorExp::Kind::kReduce: if (isSingleCondition(t, expr.children.e0)) return isSingleCondition(t, expr.children.e1) || isInvariant(expr.children.e1); @@ -576,7 +577,6 @@ case TensorExp::Kind::kOrI: case TensorExp::Kind::kXorI: case TensorExp::Kind::kBinary: - case TensorExp::Kind::kReduce: return false; } llvm_unreachable("unexpected kind"); @@ -783,6 +783,7 @@ llvm::dbgs() << " " << kindToOpSymbol(expr.kind) << " "; dumpExp(expr.children.e1); llvm::dbgs() << ")"; + break; } } @@ -917,11 +918,11 @@ UnaryOp unop = cast(expr.op); const LatSetId child0 = buildLattices(e0, i); Region &absentRegion = unop.getAbsentRegion(); - if (absentRegion.empty()) { // Simple mapping over existing values. return mapSet(kind, child0, Value(), unop); - } // Use a disjunction with `unop` on the left and the absent value as an + } + // Use a disjunction with `unop` on the left and the absent value as an // invariant on the right. Block &absentBlock = absentRegion.front(); YieldOp absentYield = cast(absentBlock.getTerminator()); diff --git a/mlir/test/Dialect/SparseTensor/semi_ring.mlir b/mlir/test/Dialect/SparseTensor/semi_ring.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SparseTensor/semi_ring.mlir @@ -0,0 +1,58 @@ +// RUN: mlir-opt %s -sparsification | FileCheck %s + +#SM = #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ] }> + +#trait = { + indexing_maps = [ + affine_map<(i,j) -> (i,j)> // A + ], + iterator_types = ["parallel", "parallel"], + doc = "A(i,j) += 2.0 where A(i,j) != 0" +} + +module { + // Example of a semi-ring operation that only adds a + // constant at stored values (something that would + // typically not sparsify since it would densify the + // implicit zeros in the normal case). The sparse + // compiler should see that this is a "simply dynamic" + // operation, and the values can be change "in-place". + // + // CHECK-LABEL: func.func @add_only_where_nonzero( + // CHECK-SAME: %[[VAL_0:.*]]: tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> { + // CHECK-DAG: %[[VAL_1:.*]] = arith.constant 8 : index + // CHECK-DAG: %[[VAL_2:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1 : index + // CHECK-DAG: %[[VAL_4:.*]] = arith.constant 2.000000e+00 : f64 + // CHECK-DAG: %[[VAL_5:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref + // CHECK-DAG: %[[VAL_6:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref + // CHECK: scf.for %[[VAL_7:.*]] = %[[VAL_2]] to %[[VAL_1]] step %[[VAL_3]] { + // CHECK: %[[VAL_8:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_7]]] : memref + // CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_7]], %[[VAL_3]] : index + // CHECK: %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_9]]] : memref + // CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_8]] to %[[VAL_10]] step %[[VAL_3]] { + // CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref + // CHECK: %[[VAL_13:.*]] = arith.addf %[[VAL_12]], %[[VAL_4]] : f64 + // CHECK: memref.store %[[VAL_13]], %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref + // CHECK: } {"Emitted from" = "linalg.generic"} + // CHECK: } {"Emitted from" = "linalg.generic"} + // CHECK: %[[VAL_14:.*]] = sparse_tensor.load %[[VAL_0]] : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> + // CHECK: return %[[VAL_14]] : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> + // CHECK: } + func.func @add_only_where_nonzero(%argA: tensor<8x8xf64, #SM>) -> tensor<8x8xf64, #SM> { + %c = arith.constant 2.0 : f64 + %result = linalg.generic #trait + outs(%argA: tensor<8x8xf64, #SM>) { + ^bb(%a: f64): + %u = sparse_tensor.unary %a : f64 to f64 + present={ + ^bb0(%p: f64): + %add = arith.addf %p, %c : f64 + sparse_tensor.yield %add : f64 + } + absent={} + linalg.yield %u : f64 + } -> tensor<8x8xf64, #SM> + return %result : tensor<8x8xf64, #SM> + } +}