diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -38,7 +38,9 @@ kAddF, kAddI, kSubF, - kSubI + kSubI, + kAndI, + kOrI, }; /// Children subexpressions of tensor operations. @@ -171,6 +173,11 @@ /// Returns true if any set bit corresponds to queried dim. bool hasAnyDimOf(const llvm::BitVector &bits, Dim d) const; + /// Returns true if given tensor co-iterates with conjunction only in the + /// given tensor expression. For the output tensor, this defines a "simply + /// dynamic" operation [Bik96]. For instance: a(i) *= b(i) * c(i) + bool isConjunction(unsigned t, unsigned e) const; + /// Dimension setter. void setDim(unsigned t, unsigned i, Dim d) { dims[t][i] = d; } @@ -193,17 +200,21 @@ /// Builds the iteration lattices in a bottom-up traversal given the remaining /// tensor (sub)expression and the next loop index in the iteration graph. /// Returns index of the root expression. - unsigned buildLattices(unsigned exp, unsigned idx); + unsigned buildLattices(unsigned e, unsigned i); /// Builds a tensor expression from the given Linalg operation. /// Returns index of the root expression on success. Optional buildTensorExpFromLinalg(linalg::GenericOp op); + /// Rebuilds SSA format from a tensor expression. + Value buildExp(PatternRewriter &rewriter, Location loc, unsigned e, Value v0, + Value v1); + private: bool maybeZero(unsigned e); /// Traverses the SSA tree (possibly a DAG) to build a tensor expression. - Optional buildTensorExp(linalg::GenericOp op, Value val); + Optional buildTensorExp(linalg::GenericOp op, Value v); const unsigned outTensor; const unsigned syntheticTensor; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -208,22 +208,6 @@ return true; } -/// Returns true if given tensor co-iterates with conjunction only. -/// For the output tensor, this defines a "simply dynamic" operation. -/// For instance: A(I) = A(I) * B(I) * C(I) -static unsigned isConjunction(Merger &merger, unsigned tensor, unsigned exp) { - switch (merger.exp(exp).kind) { - case Kind::kTensor: - return merger.exp(exp).tensor == tensor; - case Kind::kMulF: - case Kind::kMulI: - return isConjunction(merger, tensor, merger.exp(exp).children.e0) || - isConjunction(merger, tensor, merger.exp(exp).children.e1); - default: - return false; - } -} - /// Returns true when the tensor expression is admissable for codegen. /// Since all sparse input tensors are admissable, we just need to check /// whether the output tensor in the tensor expression codegen is admissable. @@ -250,7 +234,7 @@ // A tensor expression with a sparse output tensor that changes its values // but not its nonzero structure, an operation called "simply dynamic" in // [Bik96,Ch9], is also admissable without special codegen. - if (isConjunction(merger, tensor, exp)) + if (merger.isConjunction(tensor, exp)) return true; // Reject for now since this requires changes to the nonzero structure. // TODO: implement "workspaces" [Kjolstad2019] @@ -637,31 +621,7 @@ } Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e0); Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e1); - switch (merger.exp(exp).kind) { - case Kind::kTensor: - case Kind::kInvariant: - case Kind::kZero: - llvm_unreachable("handled above"); - case Kind::kMulF: - return rewriter.create(loc, v0, v1); - case Kind::kMulI: - return rewriter.create(loc, v0, v1); - case Kind::kDivF: - return rewriter.create(loc, v0, v1); - case Kind::kDivS: - return rewriter.create(loc, v0, v1); - case Kind::kDivU: - return rewriter.create(loc, v0, v1); - case Kind::kAddF: - return rewriter.create(loc, v0, v1); - case Kind::kAddI: - return rewriter.create(loc, v0, v1); - case Kind::kSubF: - return rewriter.create(loc, v0, v1); - case Kind::kSubI: - return rewriter.create(loc, v0, v1); - } - llvm_unreachable("unexpected expression kind"); + return merger.buildExp(rewriter, loc, exp, v0, v1); } /// Hoists loop invariant tensor loads for which indices have been exhausted. diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -190,6 +190,23 @@ return false; } +bool Merger::isConjunction(unsigned t, unsigned e) const { + switch (tensorExps[e].kind) { + case Kind::kTensor: + return tensorExps[e].tensor == t; + case Kind::kMulF: + case Kind::kMulI: + case Kind::kAndI: + case Kind::kDivF: // note: x / c only + case Kind::kDivS: + case Kind::kDivU: + return isConjunction(t, tensorExps[e].children.e0) || + isConjunction(t, tensorExps[e].children.e1); + default: + return false; + } +} + #ifndef NDEBUG // @@ -211,6 +228,10 @@ case Kind::kSubF: case Kind::kSubI: return '-'; + case Kind::kAndI: + return '&'; + case Kind::kOrI: + return '|'; default: break; } @@ -290,7 +311,7 @@ // Builder methods. // -unsigned Merger::buildLattices(unsigned e, unsigned idx) { +unsigned Merger::buildLattices(unsigned e, unsigned i) { Kind kind = tensorExps[e].kind; switch (kind) { case Kind::kTensor: @@ -301,11 +322,12 @@ // is set to a synthetic tensor with undefined indices only. unsigned s = addSet(); unsigned t = kind == Kind::kTensor ? tensorExps[e].tensor : syntheticTensor; - latSets[s].push_back(addLat(t, idx, e)); + latSets[s].push_back(addLat(t, i, e)); return s; } case Kind::kMulF: case Kind::kMulI: + case Kind::kAndI: // A multiplicative operation only needs to be performed // for the conjunction of sparse iteration spaces. // @@ -314,8 +336,8 @@ // !x | 0 | 0 | // x | 0 |x*y| return takeConj(kind, // take binary conjunction - buildLattices(tensorExps[e].children.e0, idx), - buildLattices(tensorExps[e].children.e1, idx)); + buildLattices(tensorExps[e].children.e0, i), + buildLattices(tensorExps[e].children.e1, i)); case Kind::kDivF: case Kind::kDivS: case Kind::kDivU: @@ -333,17 +355,18 @@ // rules applies (viz. x/c = x*(1/c) as far as lattice // construction is concerned). return takeConj(kind, // take binary conjunction - buildLattices(tensorExps[e].children.e0, idx), - buildLattices(tensorExps[e].children.e1, idx)); + buildLattices(tensorExps[e].children.e0, i), + buildLattices(tensorExps[e].children.e1, i)); case Kind::kSubF: case Kind::kSubI: // Special case: 0-y is -y. if (tensorExps[tensorExps[e].children.e0].kind == Kind::kZero) return mapZero(kind, // maps to 0-y with just y's lattices - buildLattices(tensorExps[e].children.e1, idx)); + buildLattices(tensorExps[e].children.e1, i)); LLVM_FALLTHROUGH; case Kind::kAddF: case Kind::kAddI: + case Kind::kOrI: // An additive operation needs to be performed // for the disjunction of sparse iteration spaces. // @@ -352,8 +375,8 @@ // !x | 0 | y | !x | 0 |-y | // x | x |x+y| x | x |x-y| return takeDisj(kind, // take binary disjunction - buildLattices(tensorExps[e].children.e0, idx), - buildLattices(tensorExps[e].children.e1, idx)); + buildLattices(tensorExps[e].children.e0, i), + buildLattices(tensorExps[e].children.e1, i)); } llvm_unreachable("unexpected expression kind"); } @@ -373,8 +396,8 @@ return true; } -Optional Merger::buildTensorExp(linalg::GenericOp op, Value val) { - if (auto arg = val.dyn_cast()) { +Optional Merger::buildTensorExp(linalg::GenericOp op, Value v) { + if (auto arg = v.dyn_cast()) { unsigned argN = arg.getArgNumber(); // Any argument of the generic op that is not marked as a scalar // argument is considered a tensor, indexed by the implicit loop @@ -383,16 +406,16 @@ OpOperand *t = op.getInputAndOutputOperands()[argN]; if (!op.isScalar(t)) return addExp(Kind::kTensor, argN); - val = t->get(); // get scalar value + v = t->get(); // get scalar value } // Any other argument (marked as scalar argument for the generic op // or belonging to an enveloping op) is considered invariant. - return addExp(Kind::kInvariant, val); + return addExp(Kind::kInvariant, v); } // Something defined outside is invariant. - Operation *def = val.getDefiningOp(); + Operation *def = v.getDefiningOp(); if (def->getBlock() != &op.region().front()) - return addExp(Kind::kInvariant, val); + return addExp(Kind::kInvariant, v); // Construct unary operations if subexpression can be built. if (def->getNumOperands() == 1) { auto x = buildTensorExp(op, def->getOperand(0)); @@ -430,11 +453,48 @@ return addExp(Kind::kSubF, e0, e1); if (isa(def)) return addExp(Kind::kSubI, e0, e1); + if (isa(def)) + return addExp(Kind::kAndI, e0, e1); + if (isa(def)) + return addExp(Kind::kOrI, e0, e1); } } // Cannot build. return None; } +Value Merger::buildExp(PatternRewriter &rewriter, Location loc, unsigned e, + Value v0, Value v1) { + switch (tensorExps[e].kind) { + case Kind::kTensor: + case Kind::kInvariant: + case Kind::kZero: + llvm_unreachable("unexpected non-op"); + case Kind::kMulF: + return rewriter.create(loc, v0, v1); + case Kind::kMulI: + return rewriter.create(loc, v0, v1); + case Kind::kDivF: + return rewriter.create(loc, v0, v1); + case Kind::kDivS: + return rewriter.create(loc, v0, v1); + case Kind::kDivU: + return rewriter.create(loc, v0, v1); + case Kind::kAddF: + return rewriter.create(loc, v0, v1); + case Kind::kAddI: + return rewriter.create(loc, v0, v1); + case Kind::kSubF: + return rewriter.create(loc, v0, v1); + case Kind::kSubI: + return rewriter.create(loc, v0, v1); + case Kind::kAndI: + return rewriter.create(loc, v0, v1); + case Kind::kOrI: + return rewriter.create(loc, v0, v1); + } + llvm_unreachable("unexpected expression kind in build"); +} + } // namespace sparse_tensor } // namespace mlir diff --git a/mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir b/mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_int_ops.mlir @@ -248,3 +248,100 @@ } -> tensor<32xi64> return %0 : tensor<32xi64> } + +// CHECK-LABEL: func @and( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<32xi64>, +// CHECK-SAME: %[[VAL_2:.*]]: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> { +// CHECK: %[[VAL_3:.*]] = constant 0 : index +// CHECK: %[[VAL_4:.*]] = constant 1 : index +// CHECK: %[[VAL_5:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_3]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK: %[[VAL_6:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_3]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK: %[[VAL_7:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK: %[[VAL_8:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32xi64> +// CHECK: %[[VAL_9:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32xi64> +// CHECK: %[[VAL_10:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_3]]] : memref +// CHECK: %[[VAL_11:.*]] = memref.load %[[VAL_5]]{{\[}}%[[VAL_4]]] : memref +// CHECK: scf.for %[[VAL_12:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_4]] { +// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_12]]] : memref +// CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_12]]] : memref +// CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_13]]] : memref<32xi64> +// CHECK: %[[VAL_16:.*]] = and %[[VAL_14]], %[[VAL_15]] : i64 +// CHECK: memref.store %[[VAL_16]], %[[VAL_9]]{{\[}}%[[VAL_13]]] : memref<32xi64> +// CHECK: } +// CHECK: %[[VAL_17:.*]] = memref.tensor_load %[[VAL_9]] : memref<32xi64> +// CHECK: return %[[VAL_17]] : tensor<32xi64> +// CHECK: } +func @and(%arga: tensor<32xi64, #SV>, + %argb: tensor<32xi64>, + %argx: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> { + %0 = linalg.generic #trait2 + ins(%arga, %argb: tensor<32xi64, #SV>, tensor<32xi64>) + outs(%argx: tensor<32xi64>) { + ^bb(%a: i64, %b: i64, %x: i64): + %0 = and %a, %b : i64 + linalg.yield %0 : i64 + } -> tensor<32xi64> + return %0 : tensor<32xi64> +} + +// CHECK-LABEL: func @or( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<32xi64>, +// CHECK-SAME: %[[VAL_2:.*]]: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> { +// CHECK: %[[VAL_3:.*]] = constant 32 : index +// CHECK: %[[VAL_4:.*]] = constant 0 : index +// CHECK: %[[VAL_5:.*]] = constant true +// CHECK: %[[VAL_6:.*]] = constant 1 : index +// CHECK: %[[VAL_7:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK: %[[VAL_8:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK: %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<32xi64, #sparse_tensor.encoding<{{{.*}}}>> to memref +// CHECK: %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_1]] : memref<32xi64> +// CHECK: %[[VAL_11:.*]] = memref.buffer_cast %[[VAL_2]] : memref<32xi64> +// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_4]]] : memref +// CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_6]]] : memref +// CHECK: %[[VAL_14:.*]]:2 = scf.while (%[[VAL_15:.*]] = %[[VAL_12]], %[[VAL_16:.*]] = %[[VAL_4]]) : (index, index) -> (index, index) { +// CHECK: %[[VAL_17:.*]] = cmpi ult, %[[VAL_15]], %[[VAL_13]] : index +// CHECK: scf.condition(%[[VAL_17]]) %[[VAL_15]], %[[VAL_16]] : index, index +// CHECK: } do { +// CHECK: ^bb0(%[[VAL_18:.*]]: index, %[[VAL_19:.*]]: index): +// CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_8]]{{\[}}%[[VAL_18]]] : memref +// CHECK: %[[VAL_21:.*]] = cmpi eq, %[[VAL_20]], %[[VAL_19]] : index +// CHECK: scf.if %[[VAL_21]] { +// CHECK: %[[VAL_22:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_18]]] : memref +// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_19]]] : memref<32xi64> +// CHECK: %[[VAL_24:.*]] = or %[[VAL_22]], %[[VAL_23]] : i64 +// CHECK: memref.store %[[VAL_24]], %[[VAL_11]]{{\[}}%[[VAL_19]]] : memref<32xi64> +// CHECK: } else { +// CHECK: scf.if %[[VAL_5]] { +// CHECK: %[[VAL_25:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_19]]] : memref<32xi64> +// CHECK: memref.store %[[VAL_25]], %[[VAL_11]]{{\[}}%[[VAL_19]]] : memref<32xi64> +// CHECK: } else { +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_26:.*]] = cmpi eq, %[[VAL_20]], %[[VAL_19]] : index +// CHECK: %[[VAL_27:.*]] = addi %[[VAL_18]], %[[VAL_6]] : index +// CHECK: %[[VAL_28:.*]] = select %[[VAL_26]], %[[VAL_27]], %[[VAL_18]] : index +// CHECK: %[[VAL_29:.*]] = addi %[[VAL_19]], %[[VAL_6]] : index +// CHECK: scf.yield %[[VAL_28]], %[[VAL_29]] : index, index +// CHECK: } +// CHECK: scf.for %[[VAL_30:.*]] = %[[VAL_31:.*]]#1 to %[[VAL_3]] step %[[VAL_6]] { +// CHECK: %[[VAL_32:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_30]]] : memref<32xi64> +// CHECK: memref.store %[[VAL_32]], %[[VAL_11]]{{\[}}%[[VAL_30]]] : memref<32xi64> +// CHECK: } +// CHECK: %[[VAL_33:.*]] = memref.tensor_load %[[VAL_11]] : memref<32xi64> +// CHECK: return %[[VAL_33]] : tensor<32xi64> +// CHECK: } +func @or(%arga: tensor<32xi64, #SV>, + %argb: tensor<32xi64>, + %argx: tensor<32xi64> {linalg.inplaceable = true}) -> tensor<32xi64> { + %0 = linalg.generic #trait2 + ins(%arga, %argb: tensor<32xi64, #SV>, tensor<32xi64>) + outs(%argx: tensor<32xi64>) { + ^bb(%a: i64, %b: i64, %x: i64): + %0 = or %a, %b : i64 + linalg.yield %0 : i64 + } -> tensor<32xi64> + return %0 : tensor<32xi64> +} +