diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -271,6 +271,15 @@ return ldx >= numNativeLoops; } + /// Returns true if the expression contains the `t` as an operand. + bool expContainsTensor(unsigned e, unsigned t) const; + + /// Returns true if the expression contains a negation on output tensor. + /// I.e., `- outTensor` or `exp - outputTensor` + /// NOTE: this is an trivial tests in that it does not handle recursive + /// negation, i.e., it returns true when the expression is `-(-tensor)`. + bool hasNegateOnOut(unsigned e) const; + /// Returns true if given tensor iterates *only* in the given tensor /// expression. For the output tensor, this defines a "simply dynamic" /// operation [Bik96]. For instance: a(i) *= 2.0 or a(i) += a(i) for @@ -348,9 +357,9 @@ void dumpBits(const BitVector &bits) const; #endif - /// Builds the iteration lattices in a bottom-up traversal given the remaining - /// tensor (sub)expression and the next loop index in the iteration graph. - /// Returns index of the root expression. + /// Builds the iteration lattices in a bottom-up traversal given the + /// remaining tensor (sub)expression and the next loop index in the + /// iteration graph. Returns index of the root expression. unsigned buildLattices(unsigned e, unsigned i); /// Builds a tensor expression from the given Linalg operation. @@ -380,7 +389,8 @@ // Map that converts pair to the corresponding dimension // level type. std::vector> dimTypes; - // Map that converts pair to the corresponding dimension. + // Map that converts pair to the corresponding + // dimension. std::vector>> loopIdxToDim; // Map that converts pair to the corresponding loop id. std::vector>> dimToLoopIdx; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -583,6 +583,19 @@ std::vector &topSort, unsigned exp, OpOperand **sparseOut, unsigned &outerParNest) { + // We reject any expression that makes a reduction from `-outTensor`, as those + // expression create dependency between the current iteration (i) and the + // previous iteration (i-1). It would then require iterating over the whole + // coordinate space, which prevent us from exploiting sparsity for faster + // code. + for (utils::IteratorType it : op.getIteratorTypesArray()) { + if (it == utils::IteratorType::reduction) { + if (merger.hasNegateOnOut(exp)) + return false; + break; + } + } + OpOperand *lhs = op.getDpsInitOperand(0); unsigned tensor = lhs->getOperandNumber(); auto enc = getSparseTensorEncoding(lhs->get().getType()); diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -18,6 +18,81 @@ namespace mlir { namespace sparse_tensor { +enum class ExpArity { + kNullary, + kUnary, + kBinary, +}; + +static ExpArity getExpArity(Kind k) { + switch (k) { + // Leaf. + case kTensor: + case kInvariant: + case kIndex: + return ExpArity::kNullary; + case kAbsF: + case kAbsC: + case kAbsI: + case kCeilF: + case kFloorF: + case kSqrtF: + case kSqrtC: + case kExpm1F: + case kExpm1C: + case kLog1pF: + case kLog1pC: + case kSinF: + case kSinC: + case kTanhF: + case kTanhC: + case kTruncF: + case kExtF: + case kCastFS: + case kCastFU: + case kCastSF: + case kCastUF: + case kCastS: + case kCastU: + case kCastIdx: + case kTruncI: + case kCIm: + case kCRe: + case kBitCast: + case kBinaryBranch: + case kUnary: + case kSelect: + case kNegF: + case kNegC: + case kNegI: + return ExpArity::kUnary; + // Binary operations. + case kDivF: + case kDivC: + case kDivS: + case kDivU: + case kShrS: + case kShrU: + case kShlI: + case kMulF: + case kMulC: + case kMulI: + case kAndI: + case kAddF: + case kAddC: + case kAddI: + case kOrI: + case kXorI: + case kBinary: + case kReduce: + case kSubF: + case kSubC: + case kSubI: + return ExpArity::kBinary; + } + llvm_unreachable("unexpected kind"); +} + //===----------------------------------------------------------------------===// // Constructors. //===----------------------------------------------------------------------===// @@ -310,6 +385,57 @@ return !hasAnySparse(tmp); } +bool Merger::expContainsTensor(unsigned e, unsigned t) const { + if (tensorExps[e].kind == kTensor) + return tensorExps[e].tensor == t; + + switch (getExpArity(tensorExps[e].kind)) { + case ExpArity::kNullary: + return false; + case ExpArity::kUnary: { + unsigned op = tensorExps[e].children.e0; + if (tensorExps[op].kind == kTensor && tensorExps[op].tensor == t) + return true; + return expContainsTensor(op, t); + } + case ExpArity::kBinary: { + unsigned op1 = tensorExps[e].children.e0; + unsigned op2 = tensorExps[e].children.e1; + if ((tensorExps[op1].kind == kTensor && tensorExps[op1].tensor == t) || + (tensorExps[op2].kind == kTensor && tensorExps[op2].tensor == t)) + return true; + return expContainsTensor(op1, t) || expContainsTensor(op2, t); + } + } + llvm_unreachable("unexpected arity"); +} + +bool Merger::hasNegateOnOut(unsigned e) const { + switch (tensorExps[e].kind) { + case kNegF: + case kNegC: + case kNegI: + return expContainsTensor(tensorExps[e].children.e0, outTensor); + case kSubF: + case kSubC: + case kSubI: + return expContainsTensor(tensorExps[e].children.e1, outTensor) || + hasNegateOnOut(tensorExps[e].children.e0); + default: { + switch (getExpArity(tensorExps[e].kind)) { + case ExpArity::kNullary: + return false; + case ExpArity::kUnary: + return hasNegateOnOut(tensorExps[e].children.e0); + case ExpArity::kBinary: + return hasNegateOnOut(tensorExps[e].children.e0) || + hasNegateOnOut(tensorExps[e].children.e1); + } + } + } + llvm_unreachable("unexpected kind"); +} + bool Merger::isSingleCondition(unsigned t, unsigned e) const { switch (tensorExps[e].kind) { // Leaf. diff --git a/mlir/test/Dialect/SparseTensor/rejected.mlir b/mlir/test/Dialect/SparseTensor/rejected.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SparseTensor/rejected.mlir @@ -0,0 +1,39 @@ +// RUN: mlir-opt %s -sparsification | FileCheck %s + + +// The file contains examples that will be rejected by sparse compiler +// (we expect the linalg.generic unchanged). +#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}> + +#trait = { + indexing_maps = [ + affine_map<(i) -> (i)>, // a (in) + affine_map<(i) -> ()> // x (out) + ], + iterator_types = ["reduction"] +} + +// CHECK-LABEL: func.func @sparse_reduction_subi( +// CHECK-SAME: %[[VAL_0:.*]]: tensor, +// CHECK-SAME: %[[VAL_1:.*]]: tensor>) -> tensor { +// CHECK: %[[VAL_2:.*]] = linalg.generic +// CHECK: ^bb0(%[[VAL_3:.*]]: i32, %[[VAL_4:.*]]: i32): +// CHECK: %[[VAL_5:.*]] = arith.subi %[[VAL_3]], %[[VAL_4]] : i32 +// CHECK: linalg.yield %[[VAL_5]] : i32 +// CHECK: } -> tensor +// CHECK: return %[[VAL_6:.*]] : tensor +func.func @sparse_reduction_subi(%argx: tensor, + %arga: tensor) + -> tensor { + %0 = linalg.generic #trait + ins(%arga: tensor) + outs(%argx: tensor) { + ^bb(%a: i32, %x: i32): + // NOTE: `subi %a, %x` is the reason why the program is rejected by the sparse compiler. + // It is because we do not allow `-outTensor` in reduction loops as it creates cyclic + // dependences. + %t = arith.subi %a, %x: i32 + linalg.yield %t : i32 + } -> tensor + return %0 : tensor +}