diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -32,6 +32,9 @@ // Operation. kMulF, kMulI, + kDivF, + kDivS, // signed + kDivU, // unsigned kAddF, kAddI, kSubF, @@ -197,6 +200,8 @@ Optional buildTensorExpFromLinalg(linalg::GenericOp op); private: + bool maybeZero(unsigned e); + /// Traverses the SSA tree (possibly a DAG) to build a tensor expression. Optional buildTensorExp(linalg::GenericOp op, Value val); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -646,6 +646,12 @@ return rewriter.create(loc, v0, v1); case Kind::kMulI: return rewriter.create(loc, v0, v1); + case Kind::kDivF: + return rewriter.create(loc, v0, v1); + case Kind::kDivS: + return rewriter.create(loc, v0, v1); + case Kind::kDivU: + return rewriter.create(loc, v0, v1); case Kind::kAddF: return rewriter.create(loc, v0, v1); case Kind::kAddI: diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -201,6 +201,10 @@ case Kind::kMulF: case Kind::kMulI: return '*'; + case Kind::kDivF: + case Kind::kDivS: + case Kind::kDivU: + return '/'; case Kind::kAddF: case Kind::kAddI: return '+'; @@ -302,17 +306,51 @@ } case Kind::kMulF: case Kind::kMulI: + // A multiplicative operation only needs to be performed + // for the conjunction of sparse iteration spaces. + // + // x*y|!y | y | + // ---+---+---+ + // !x | 0 | 0 | + // x | 0 |x*y| return takeConj(kind, // take binary conjunction buildLattices(tensorExps[e].children.e0, idx), buildLattices(tensorExps[e].children.e1, idx)); + case Kind::kDivF: + case Kind::kDivS: + case Kind::kDivU: { + // A division is tricky, since 0/0, 0/c, c/0 all have + // specific outcomes for floating-point and integers. + // Thus, we need to traverse the full iteration space. + // + // x/y|!y | y | + // ---+---+---+ + // !x |0/0|0/y| FP: 0/0=NaN,c/0=Inf,0/c=0 with c true nonzero + // x |x/0|x/y| INT: x/0=exception for any x + // + // TODO: for now we "fixed" this by only accepting x/c cases + // during expression building, so that the conjunction + // rules applies (viz. x/c = x*(1/c)). + return takeConj(kind, // take binary conjunction + buildLattices(tensorExps[e].children.e0, idx), + buildLattices(tensorExps[e].children.e1, idx)); + } case Kind::kSubF: case Kind::kSubI: + // Special case: 0-y is -y. if (tensorExps[tensorExps[e].children.e0].kind == Kind::kZero) return mapZero(kind, // maps to 0-y with just y's lattices buildLattices(tensorExps[e].children.e1, idx)); LLVM_FALLTHROUGH; case Kind::kAddF: case Kind::kAddI: + // An additive operation needs to be performed + // for the disjunction of sparse iteration spaces. + // + // x+y|!y | y | x-y|!y | y | + // ---+---+---+ ---+---+---+ + // !x | 0 | y | !x | 0 |-y | + // x | x |x+y| x | x |x-y| return takeDisj(kind, // take binary disjunction buildLattices(tensorExps[e].children.e0, idx), buildLattices(tensorExps[e].children.e1, idx)); @@ -325,6 +363,16 @@ return buildTensorExp(op, yield->getOperand(0)); } +bool Merger::maybeZero(unsigned e) { + if (tensorExps[e].kind == Kind::kInvariant) { + if (auto c = tensorExps[e].val.getDefiningOp()) + return c.getValue() == 0; + if (auto c = tensorExps[e].val.getDefiningOp()) + return c.getValue().isZero(); + } + return true; +} + Optional Merger::buildTensorExp(linalg::GenericOp op, Value val) { if (auto arg = val.dyn_cast()) { unsigned argN = arg.getArgNumber(); @@ -357,6 +405,7 @@ } } // Construct binary operations if subexpressions can be built. + // TODO: see buildLattices() for an explanation of rejecting certain divisions if (def->getNumOperands() == 2) { auto x = buildTensorExp(op, def->getOperand(0)); auto y = buildTensorExp(op, def->getOperand(1)); @@ -367,6 +416,21 @@ return addExp(Kind::kMulF, e0, e1); if (isa(def)) return addExp(Kind::kMulI, e0, e1); + if (isa(def)) { + if (maybeZero(e1)) + return None; + return addExp(Kind::kDivF, e0, e1); + } + if (isa(def)) { + if (maybeZero(e1)) + return None; + return addExp(Kind::kDivS, e0, e1); + } + if (isa(def)) { + if (maybeZero(e1)) + return None; + return addExp(Kind::kDivU, e0, e1); + } if (isa(def)) return addExp(Kind::kAddF, e0, e1); if (isa(def)) diff --git a/mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir b/mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_fp_ops.mlir @@ -22,6 +22,15 @@ doc = "x(i) = a(i) OP b(i)" } +#traitc = { + indexing_maps = [ + affine_map<(i) -> (i)>, // a + affine_map<(i) -> (i)> // x (out) + ], + iterator_types = ["parallel"], + doc = "x(i) = a(i) OP c" +} + // CHECK-LABEL: func @neg( // CHECK-SAME: %[[VAL_0:.*]]: tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>, // CHECK-SAME: %[[VAL_1:.*]]: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> { @@ -213,3 +222,38 @@ } -> tensor<32xf64> return %0 : tensor<32xf64> } + +// CHECK-LABEL: func @divbyc( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> { +// CHECK: %[[VAL_2:.*]] = constant 2.000000e+00 : f64 +// CHECK: %[[VAL_3:.*]] = constant 0 : index +// CHECK: %[[VAL_4:.*]] = constant 1 : index +// CHECK: %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> +// CHECK: %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> +// CHECK: %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xf64, #sparse_tensor.encoding<{{{.*}}}>> +// CHECK: %[[VAL_8:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32xf64> +// CHECK: %[[VAL_9:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref +// CHECK: %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref +// CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_9]] to %[[VAL_10]] step %[[VAL_4]] { +// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref +// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_11]]] : memref +// CHECK: %[[VAL_14:.*]] = divf %[[VAL_13]], %[[VAL_2]] : f64 +// CHECK: memref.store %[[VAL_14]], %[[VAL_8]]{{\[}}%[[VAL_12]]] : memref<32xf64> +// CHECK: } +// CHECK: %[[VAL_15:.*]] = memref.tensor_load %[[VAL_8]] : memref<32xf64> +// CHECK: return %[[VAL_15]] : tensor<32xf64> +// CHECK: } +func @divbyc(%arga: tensor<32xf64, #SV>, + %argx: tensor<32xf64> {linalg.inplaceable = true}) -> tensor<32xf64> { + %c = constant 2.0 : f64 + %0 = linalg.generic #traitc + ins(%arga: tensor<32xf64, #SV>) + outs(%argx: tensor<32xf64>) { + ^bb(%a: f64, %x: f64): + %0 = divf %a, %c : f64 + linalg.yield %0 : f64 + } -> tensor<32xf64> + return %0 : tensor<32xf64> +} + diff --git a/mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir b/mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir @@ -13,6 +13,15 @@ doc = "x(i) = a(i) OP b(i)" } +#traitc = { + indexing_maps = [ + affine_map<(i) -> (i)>, // a + affine_map<(i) -> (i)> // x (out) + ], + iterator_types = ["parallel"], + doc = "x(i) = a(i) OP c" +} + // CHECK-LABEL: func @add( // CHECK-SAME: %[[VAL_0:.*]]: tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>>, // CHECK-SAME: %[[VAL_1:.*]]: tensor<32xi64>, @@ -171,3 +180,71 @@ } -> tensor<32xi64> return %0 : tensor<32xi64> } + +// CHECK-LABEL: func @divsbyc( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> { +// CHECK: %[[VAL_2:.*]] = constant 2 : i64 +// CHECK: %[[VAL_3:.*]] = constant 0 : index +// CHECK: %[[VAL_4:.*]] = constant 1 : index +// CHECK: %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>> +// CHECK: %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>> +// CHECK: %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>> +// CHECK: %[[VAL_8:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32xi64> +// CHECK: %[[VAL_9:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref +// CHECK: %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref +// CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_9]] to %[[VAL_10]] step %[[VAL_4]] { +// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref +// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_11]]] : memref +// CHECK: %[[VAL_14:.*]] = divi_signed %[[VAL_13]], %[[VAL_2]] : i64 +// CHECK: memref.store %[[VAL_14]], %[[VAL_8]]{{\[}}%[[VAL_12]]] : memref<32xi64> +// CHECK: } +// CHECK: %[[VAL_15:.*]] = memref.tensor_load %[[VAL_8]] : memref<32xi64> +// CHECK: return %[[VAL_15]] : tensor<32xi64> +// CHECK: } +func @divsbyc(%arga: tensor<32xi64, #SV>, + %argx: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> { + %c = constant 2 : i64 + %0 = linalg.generic #traitc + ins(%arga: tensor<32xi64, #SV>) + outs(%argx: tensor<32xi64>) { + ^bb(%a: i64, %x: i64): + %0 = divi_signed %a, %c : i64 + linalg.yield %0 : i64 + } -> tensor<32xi64> + return %0 : tensor<32xi64> +} + +// CHECK-LABEL: func @divubyc( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> { +// CHECK: %[[VAL_2:.*]] = constant 2 : i64 +// CHECK: %[[VAL_3:.*]] = constant 0 : index +// CHECK: %[[VAL_4:.*]] = constant 1 : index +// CHECK: %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32xi64, #sparse_tensor.encoding<{{.*}}}>> +// CHECK: %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>> +// CHECK: %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>> +// CHECK: %[[VAL_8:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32xi64> +// CHECK: %[[VAL_9:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref +// CHECK: %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref +// CHECK: scf.for %[[VAL_11:.*]] = %[[VAL_9]] to %[[VAL_10]] step %[[VAL_4]] { +// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_11]]] : memref +// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_11]]] : memref +// CHECK: %[[VAL_14:.*]] = divi_unsigned %[[VAL_13]], %[[VAL_2]] : i64 +// CHECK: memref.store %[[VAL_14]], %[[VAL_8]]{{\[}}%[[VAL_12]]] : memref<32xi64> +// CHECK: } +// CHECK: %[[VAL_15:.*]] = memref.tensor_load %[[VAL_8]] : memref<32xi64> +// CHECK: return %[[VAL_15]] : tensor<32xi64> +// CHECK: } +func @divubyc(%arga: tensor<32xi64, #SV>, + %argx: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> { + %c = constant 2 : i64 + %0 = linalg.generic #traitc + ins(%arga: tensor<32xi64, #SV>) + outs(%argx: tensor<32xi64>) { + ^bb(%a: i64, %x: i64): + %0 = divi_unsigned %a, %c : i64 + linalg.yield %0 : i64 + } -> tensor<32xi64> + return %0 : tensor<32xi64> +}