diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -271,6 +271,10 @@ return ldx >= numNativeLoops; } + /// Return true if the expression contains a negation on output tensor. + /// I.e., `- outTensor` or `exp - outputTensor` + bool hasNegateOnOut(unsigned e) const; + /// Returns true if given tensor iterates *only* in the given tensor /// expression. For the output tensor, this defines a "simply dynamic" /// operation [Bik96]. For instance: a(i) *= 2.0 or a(i) += a(i) for @@ -348,9 +352,9 @@ void dumpBits(const BitVector &bits) const; #endif - /// Builds the iteration lattices in a bottom-up traversal given the remaining - /// tensor (sub)expression and the next loop index in the iteration graph. - /// Returns index of the root expression. + /// Builds the iteration lattices in a bottom-up traversal given the + /// remaining tensor (sub)expression and the next loop index in the + /// iteration graph. Returns index of the root expression. unsigned buildLattices(unsigned e, unsigned i); /// Builds a tensor expression from the given Linalg operation. @@ -380,7 +384,8 @@ // Map that converts pair to the corresponding dimension // level type. std::vector> dimTypes; - // Map that converts pair to the corresponding dimension. + // Map that converts pair to the corresponding + // dimension. std::vector>> loopIdxToDim; // Map that converts pair to the corresponding loop id. std::vector>> dimToLoopIdx; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -583,6 +583,19 @@ std::vector &topSort, unsigned exp, OpOperand **sparseOut, unsigned &outerParNest) { + // We reject any expression that makes a reduction from `-outTensor`, as those + // expression create dependency between the current iteration (i) and the + // previous iteration (i-1). It would then requires iterating over the whole + // coordinate space, which prevent us from exploiting sparsity for faster + // code. + for (utils::IteratorType it : op.getIteratorTypesArray()) { + if (it == utils::IteratorType::reduction) { + if (merger.hasNegateOnOut(exp)) + return false; + break; + } + } + OpOperand *lhs = op.getDpsInitOperand(0); unsigned tensor = lhs->getOperandNumber(); auto enc = getSparseTensorEncoding(lhs->get().getType()); diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -310,6 +310,91 @@ return !hasAnySparse(tmp); } +bool Merger::hasNegateOnOut(unsigned e) const { + switch (tensorExps[e].kind) { + // Leaf. + case kTensor: + case kInvariant: + case kIndex: + return false; + case kAbsF: + case kAbsC: + case kAbsI: + case kCeilF: + case kFloorF: + case kSqrtF: + case kSqrtC: + case kExpm1F: + case kExpm1C: + case kLog1pF: + case kLog1pC: + case kSinF: + case kSinC: + case kTanhF: + case kTanhC: + case kTruncF: + case kExtF: + case kCastFS: + case kCastFU: + case kCastSF: + case kCastUF: + case kCastS: + case kCastU: + case kCastIdx: + case kTruncI: + case kCIm: + case kCRe: + case kBitCast: + case kBinaryBranch: + case kUnary: + case kSelect: + return hasNegateOnOut(tensorExps[e].children.e0); + case kNegF: + case kNegC: + case kNegI: { + unsigned operand = tensorExps[e].children.e0; + if (tensorExps[operand].kind == kTensor && + tensorExps[operand].tensor == outTensor) { + return true; + } + return hasNegateOnOut(operand); + } + // Binary operations. + case kDivF: + case kDivC: + case kDivS: + case kDivU: + case kShrS: + case kShrU: + case kShlI: + case kMulF: + case kMulC: + case kMulI: + case kAndI: + case kAddF: + case kAddC: + case kAddI: + case kOrI: + case kXorI: + case kBinary: + case kReduce: + return hasNegateOnOut(tensorExps[e].children.e0) && + hasNegateOnOut(tensorExps[e].children.e1); + case kSubF: + case kSubC: + case kSubI: { + unsigned operand = tensorExps[e].children.e1; + if (tensorExps[operand].kind == kTensor && + tensorExps[operand].tensor == outTensor) { + return true; + } + return hasNegateOnOut(tensorExps[e].children.e0) && + hasNegateOnOut(tensorExps[e].children.e1); + } + } + llvm_unreachable("unexpected kind"); +} + bool Merger::isSingleCondition(unsigned t, unsigned e) const { switch (tensorExps[e].kind) { // Leaf.