diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -271,6 +271,15 @@ return ldx >= numNativeLoops; } + /// Return true if the expression contains the `t` as an operand. + bool expContainsTensor(unsigned e, unsigned t) const; + + /// Return true if the expression contains a negation on output tensor. + /// I.e., `- outTensor` or `exp - outputTensor` + /// NOTE: this is an trivial tests in that it does not handle recursive + /// negation, i.e., it returns true when the expression is `-(-tensor)`. + bool hasNegateOnOut(unsigned e) const; + /// Returns true if given tensor iterates *only* in the given tensor /// expression. For the output tensor, this defines a "simply dynamic" /// operation [Bik96]. For instance: a(i) *= 2.0 or a(i) += a(i) for @@ -348,9 +357,9 @@ void dumpBits(const BitVector &bits) const; #endif - /// Builds the iteration lattices in a bottom-up traversal given the remaining - /// tensor (sub)expression and the next loop index in the iteration graph. - /// Returns index of the root expression. + /// Builds the iteration lattices in a bottom-up traversal given the + /// remaining tensor (sub)expression and the next loop index in the + /// iteration graph. Returns index of the root expression. unsigned buildLattices(unsigned e, unsigned i); /// Builds a tensor expression from the given Linalg operation. @@ -380,7 +389,8 @@ // Map that converts pair to the corresponding dimension // level type. std::vector> dimTypes; - // Map that converts pair to the corresponding dimension. + // Map that converts pair to the corresponding + // dimension. std::vector>> loopIdxToDim; // Map that converts pair to the corresponding loop id. std::vector>> dimToLoopIdx; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -583,6 +583,19 @@ std::vector &topSort, unsigned exp, OpOperand **sparseOut, unsigned &outerParNest) { + // We reject any expression that makes a reduction from `-outTensor`, as those + // expression create dependency between the current iteration (i) and the + // previous iteration (i-1). It would then require iterating over the whole + // coordinate space, which prevent us from exploiting sparsity for faster + // code. + for (utils::IteratorType it : op.getIteratorTypesArray()) { + if (it == utils::IteratorType::reduction) { + if (merger.hasNegateOnOut(exp)) + return false; + break; + } + } + OpOperand *lhs = op.getDpsInitOperand(0); unsigned tensor = lhs->getOperandNumber(); auto enc = getSparseTensorEncoding(lhs->get().getType()); diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -18,6 +18,81 @@ namespace mlir { namespace sparse_tensor { +enum class ExpArity { + kNullary, + kUnary, + kBinary, +}; + +static ExpArity getExpArity(Kind k) { + switch (k) { + // Leaf. + case kTensor: + case kInvariant: + case kIndex: + return ExpArity::kNullary; + case kAbsF: + case kAbsC: + case kAbsI: + case kCeilF: + case kFloorF: + case kSqrtF: + case kSqrtC: + case kExpm1F: + case kExpm1C: + case kLog1pF: + case kLog1pC: + case kSinF: + case kSinC: + case kTanhF: + case kTanhC: + case kTruncF: + case kExtF: + case kCastFS: + case kCastFU: + case kCastSF: + case kCastUF: + case kCastS: + case kCastU: + case kCastIdx: + case kTruncI: + case kCIm: + case kCRe: + case kBitCast: + case kBinaryBranch: + case kUnary: + case kSelect: + case kNegF: + case kNegC: + case kNegI: + return ExpArity::kUnary; + // Binary operations. + case kDivF: + case kDivC: + case kDivS: + case kDivU: + case kShrS: + case kShrU: + case kShlI: + case kMulF: + case kMulC: + case kMulI: + case kAndI: + case kAddF: + case kAddC: + case kAddI: + case kOrI: + case kXorI: + case kBinary: + case kReduce: + case kSubF: + case kSubC: + case kSubI: + return ExpArity::kBinary; + } + llvm_unreachable("unexpected kind"); +} + //===----------------------------------------------------------------------===// // Constructors. //===----------------------------------------------------------------------===// @@ -310,6 +385,56 @@ return !hasAnySparse(tmp); } +bool Merger::expContainsTensor(unsigned e, unsigned t) const { + ExpArity arity = getExpArity(tensorExps[e].kind); + switch (arity) { + case ExpArity::kNullary: + return false; + case ExpArity::kUnary: { + unsigned op = tensorExps[e].children.e0; + if (tensorExps[op].kind == kTensor && tensorExps[op].tensor == t) + return true; + return expContainsTensor(op, t); + } + case ExpArity::kBinary: { + unsigned op1 = tensorExps[e].children.e0; + unsigned op2 = tensorExps[e].children.e1; + if ((tensorExps[op1].kind == kTensor && tensorExps[op1].tensor == t) || + (tensorExps[op2].kind == kTensor && tensorExps[op2].tensor == t)) + return true; + return expContainsTensor(op1, t) || expContainsTensor(op2, t); + } + } + llvm_unreachable("unexpected arity"); +} + +bool Merger::hasNegateOnOut(unsigned e) const { + switch (tensorExps[e].kind) { + case kNegF: + case kNegC: + case kNegI: + return expContainsTensor(tensorExps[e].children.e0, outTensor); + case kSubF: + case kSubC: + case kSubI: + return expContainsTensor(tensorExps[e].children.e1, outTensor) || + hasNegateOnOut(tensorExps[e].children.e0); + default: { + ExpArity arity = getExpArity(tensorExps[e].kind); + switch (arity) { + case ExpArity::kNullary: + return false; + case ExpArity::kUnary: + return hasNegateOnOut(tensorExps[e].children.e0); + case ExpArity::kBinary: + return hasNegateOnOut(tensorExps[e].children.e0) || + hasNegateOnOut(tensorExps[e].children.e1); + } + } + } + llvm_unreachable("unexpected kind"); +} + bool Merger::isSingleCondition(unsigned t, unsigned e) const { switch (tensorExps[e].kind) { // Leaf.