diff --git a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h --- a/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h +++ b/mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h @@ -70,6 +70,9 @@ /// and kSelect, this holds the original operation with all regions. For /// kBinaryBranch, this holds the YieldOp for the left or right half /// to be merged into a nested scf loop. + /// + /// Or the actual operation that we can not sparsify but having all dense + /// operands for kDenseOp. Operation *op; /// An optional attribute that is required to determine the semantics of the @@ -157,8 +160,9 @@ kShrS, // signed kShrU, // unsigned kShlI, - kBinary, // semiring binary op - kReduce, // semiring reduction op + kBinary, // semiring binary op + kReduce, // semiring reduction op + kDenseOp, // special category of operations requiring all dense operands }; //===----------------------------------------------------------------------===// @@ -645,7 +649,11 @@ Type inferType(ExprId e, Value src) const; /// Traverses the SSA tree (possibly a DAG) to build a tensor expression. - std::optional buildTensorExp(linalg::GenericOp op, Value v); + /// The boolean value returned indicates whether the result of the current + /// operation being built depends on any value that is loaded from a sparse + /// tensor. + std::pair, bool> buildTensorExp(linalg::GenericOp op, + Value v); /// Merger data structures. const TensorId outTensor; diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -92,6 +92,7 @@ case TensorExp::Kind::kSubI: case TensorExp::Kind::kCmpF: case TensorExp::Kind::kCmpI: + case TensorExp::Kind::kDenseOp: // kDenseOp can *at most* have two operands return ExpArity::kBinary; } llvm_unreachable("unexpected kind"); @@ -210,6 +211,11 @@ children.e0 = x; children.e1 = y; return; + case TensorExp::Kind::kDenseOp: + assert(x != detail::kInvalidId && !v && o); + children.e0 = x; + children.e1 = y; + return; } llvm_unreachable("unexpected kind"); } @@ -393,7 +399,8 @@ LatSetId Merger::mapSet(TensorExp::Kind kind, LatSetId s0, Value v, Operation *op) { - assert(TensorExp::Kind::kAbsF <= kind && kind <= TensorExp::Kind::kSelect); + assert((TensorExp::Kind::kAbsF <= kind && kind <= TensorExp::Kind::kSelect) || + TensorExp::Kind::kDenseOp == kind); const LatSetId sNew = addSet(); auto &setNew = latSets[sNew]; for (const LatPointId p : set(s0)) { @@ -546,6 +553,12 @@ case TensorExp::Kind::kSubI: return expContainsTensor(expr.children.e1, outTensor) || hasNegateOnOut(expr.children.e0); + case TensorExp::Kind::kDenseOp: { + bool lhsNeg = hasNegateOnOut(expr.children.e0); + if (!lhsNeg && expr.children.e1 != detail::kInvalidId) + return hasNegateOnOut(expr.children.e1); + return lhsNeg; + } default: { switch (getExpArity(expr.kind)) { case ExpArity::kNullary: @@ -646,6 +659,10 @@ case TensorExp::Kind::kCmpI: case TensorExp::Kind::kBinary: return false; + case TensorExp::Kind::kDenseOp: + // Since Merger guarantees all the operands of the kDenseOp to be dense, the + // operation must be single-condition. + return true; } llvm_unreachable("unexpected kind"); } @@ -771,6 +788,8 @@ return "binary"; case TensorExp::Kind::kReduce: return "reduce"; + case TensorExp::Kind::kDenseOp: + return "dense"; } llvm_unreachable("unexpected kind for symbol"); } @@ -857,14 +876,19 @@ case TensorExp::Kind::kCmpI: case TensorExp::Kind::kBinary: case TensorExp::Kind::kReduce: + case TensorExp::Kind::kDenseOp: llvm::dbgs() << "("; dumpExp(expr.children.e0); llvm::dbgs() << " " << kindToOpSymbol(expr.kind); if (expr.attr) llvm::dbgs() << "{" << expr.attr << "}"; - llvm::dbgs() << " "; - dumpExp(expr.children.e1); - llvm::dbgs() << ")"; + if (expr.children.e1 != detail::kInvalidId) { + llvm::dbgs() << " "; + dumpExp(expr.children.e1); + llvm::dbgs() << ")"; + } else { + assert(expr.kind == TensorExp::Kind::kDenseOp); + } break; } } @@ -1142,6 +1166,21 @@ Operation *const op = expr.op; return conjSet(e, buildLattices(e0, i), buildLattices(e1, i), op); } + case TensorExp::Kind::kDenseOp: { + // It does not really matter whether we use conjunctive/disjunctive set + // here, as all the operands of kDenseOp must be dense, the disjunctive set + // will be optimized into conjunctive set eventually. + if (expr.children.e1 == detail::kInvalidId) { + const ExprId e0 = expr.children.e0; + Operation *const op = expr.op; + return mapSet(kind, buildLattices(e0, i), Value(), op); + } + + const ExprId e0 = expr.children.e0; + const ExprId e1 = expr.children.e1; + Operation *const op = expr.op; + return conjSet(e, buildLattices(e0, i), buildLattices(e1, i), op); + } } llvm_unreachable("unexpected expression kind"); } @@ -1150,7 +1189,7 @@ // Build the linalg semantics backward from yield. Operation *yield = op.getRegion().front().getTerminator(); assert(isa(yield)); - return buildTensorExp(op, yield->getOperand(0)); + return buildTensorExp(op, yield->getOperand(0)).first; } /// Only returns false if we are certain this is a nonzero. @@ -1210,7 +1249,9 @@ return isAdmissibleBranchExp(op, ®ion.front(), yield->getOperand(0)); } -std::optional Merger::buildTensorExp(linalg::GenericOp op, Value v) { +std::pair, bool> +Merger::buildTensorExp(linalg::GenericOp op, Value v) { + // Recursion leaves. if (auto arg = dyn_cast(v)) { const TensorId tid = makeTensorId(arg.getArgNumber()); // Any argument of the generic op that is not marked as a scalar @@ -1218,96 +1259,98 @@ // bounds. This includes rank-0 tensor arguments. if (arg.getOwner()->getParentOp() == op) { OpOperand &t = op->getOpOperand(tid); + bool hasSpDep = getSparseTensorEncoding(t.get().getType()) != nullptr; if (!op.isScalar(&t)) - return addTensorExp(tid); + return {addTensorExp(tid), hasSpDep}; v = t.get(); // get scalar value } // Any other argument (marked as scalar argument for the generic op // or belonging to an enveloping op) is considered invariant. - return addInvariantExp(v); + return {addInvariantExp(v), /*hasSpDep=*/false}; } // Something defined outside is invariant. Operation *def = v.getDefiningOp(); if (def->getBlock() != &op.getRegion().front()) - return addInvariantExp(v); + return {addInvariantExp(v), /*hasSpDep=*/false}; // Construct index operations. if (def->getNumOperands() == 0) { if (auto indexOp = dyn_cast(def)) - return addLoopVarExp(makeLoopId(indexOp.getDim())); + return {addLoopVarExp(makeLoopId(indexOp.getDim())), /*hasSpDep=*/false}; } + // Construct unary operations if subexpression can be built. if (def->getNumOperands() == 1) { - const auto x = buildTensorExp(op, def->getOperand(0)); + const auto [x, hasSpDep] = buildTensorExp(op, def->getOperand(0)); if (x.has_value()) { const ExprId e = *x; if (isa(def)) - return addExp(TensorExp::Kind::kAbsF, e); + return {addExp(TensorExp::Kind::kAbsF, e), hasSpDep}; if (isa(def)) - return addExp(TensorExp::Kind::kAbsC, e); + return {addExp(TensorExp::Kind::kAbsC, e), hasSpDep}; if (isa(def)) - return addExp(TensorExp::Kind::kAbsI, e); + return {addExp(TensorExp::Kind::kAbsI, e), hasSpDep}; if (isa(def)) - return addExp(TensorExp::Kind::kCeilF, e); + return {addExp(TensorExp::Kind::kCeilF, e), hasSpDep}; if (isa(def)) - return addExp(TensorExp::Kind::kFloorF, e); + return {addExp(TensorExp::Kind::kFloorF, e), hasSpDep}; if (isa(def)) - return addExp(TensorExp::Kind::kSqrtF, e); + return {addExp(TensorExp::Kind::kSqrtF, e), hasSpDep}; if (isa(def)) - return addExp(TensorExp::Kind::kSqrtC, e); + return {addExp(TensorExp::Kind::kSqrtC, e), hasSpDep}; if (isa(def)) - return addExp(TensorExp::Kind::kExpm1F, e); + return {addExp(TensorExp::Kind::kExpm1F, e), hasSpDep}; if (isa(def)) - return addExp(TensorExp::Kind::kExpm1C, e); + return {addExp(TensorExp::Kind::kExpm1C, e), hasSpDep}; if (isa(def)) - return addExp(TensorExp::Kind::kLog1pF, e); + return {addExp(TensorExp::Kind::kLog1pF, e), hasSpDep}; if (isa(def)) - return addExp(TensorExp::Kind::kLog1pC, e); + return {addExp(TensorExp::Kind::kLog1pC, e), hasSpDep}; if (isa(def)) - return addExp(TensorExp::Kind::kSinF, e); + return {addExp(TensorExp::Kind::kSinF, e), hasSpDep}; if (isa(def)) - return addExp(TensorExp::Kind::kSinC, e); + return {addExp(TensorExp::Kind::kSinC, e), hasSpDep}; if (isa(def)) - return addExp(TensorExp::Kind::kTanhF, e); + return {addExp(TensorExp::Kind::kTanhF, e), hasSpDep}; if (isa(def)) - return addExp(TensorExp::Kind::kTanhC, e); + return {addExp(TensorExp::Kind::kTanhC, e), hasSpDep}; if (isa(def)) - return addExp(TensorExp::Kind::kNegF, e); // no negi in std + return {addExp(TensorExp::Kind::kNegF, e), hasSpDep}; // no negi in std if (isa(def)) - return addExp(TensorExp::Kind::kNegC, e); + return {addExp(TensorExp::Kind::kNegC, e), hasSpDep}; if (isa(def)) - return addExp(TensorExp::Kind::kTruncF, e, v); + return {addExp(TensorExp::Kind::kTruncF, e, v), hasSpDep}; if (isa(def)) - return addExp(TensorExp::Kind::kExtF, e, v); + return {addExp(TensorExp::Kind::kExtF, e, v), hasSpDep}; if (isa(def)) - return addExp(TensorExp::Kind::kCastFS, e, v); + return {addExp(TensorExp::Kind::kCastFS, e, v), hasSpDep}; if (isa(def)) - return addExp(TensorExp::Kind::kCastFU, e, v); + return {addExp(TensorExp::Kind::kCastFU, e, v), hasSpDep}; if (isa(def)) - return addExp(TensorExp::Kind::kCastSF, e, v); + return {addExp(TensorExp::Kind::kCastSF, e, v), hasSpDep}; if (isa(def)) - return addExp(TensorExp::Kind::kCastUF, e, v); + return {addExp(TensorExp::Kind::kCastUF, e, v), hasSpDep}; if (isa(def)) - return addExp(TensorExp::Kind::kCastS, e, v); + return {addExp(TensorExp::Kind::kCastS, e, v), hasSpDep}; if (isa(def)) - return addExp(TensorExp::Kind::kCastU, e, v); + return {addExp(TensorExp::Kind::kCastU, e, v), hasSpDep}; if (isa(def)) - return addExp(TensorExp::Kind::kCastIdx, e, v); + return {addExp(TensorExp::Kind::kCastIdx, e, v), hasSpDep}; if (isa(def)) - return addExp(TensorExp::Kind::kTruncI, e, v); + return {addExp(TensorExp::Kind::kTruncI, e, v), hasSpDep}; if (isa(def)) - return addExp(TensorExp::Kind::kCIm, e); + return {addExp(TensorExp::Kind::kCIm, e), hasSpDep}; if (isa(def)) - return addExp(TensorExp::Kind::kCRe, e); + return {addExp(TensorExp::Kind::kCRe, e), hasSpDep}; if (isa(def)) - return addExp(TensorExp::Kind::kBitCast, e, v); + return {addExp(TensorExp::Kind::kBitCast, e, v), hasSpDep}; if (auto unop = dyn_cast(def)) { if (isAdmissibleBranch(unop, unop.getPresentRegion()) && isAdmissibleBranch(unop, unop.getAbsentRegion())) - return addExp(TensorExp::Kind::kUnary, e, Value(), def); + return {addExp(TensorExp::Kind::kUnary, e, Value(), def), hasSpDep}; } if (auto selop = dyn_cast(def)) { if (isAdmissibleBranch(selop, selop.getRegion())) - return addExp(TensorExp::Kind::kSelect, e, Value(), def); + return {addExp(TensorExp::Kind::kSelect, e, Value(), def), hasSpDep}; } } } @@ -1315,49 +1358,50 @@ // See buildLattices() for an explanation of rejecting certain // division and shift operations. if (def->getNumOperands() == 2) { - const auto x = buildTensorExp(op, def->getOperand(0)); - const auto y = buildTensorExp(op, def->getOperand(1)); + const auto [x, xDepSp] = buildTensorExp(op, def->getOperand(0)); + const auto [y, yDepSp] = buildTensorExp(op, def->getOperand(1)); + bool hasSpDep = xDepSp || yDepSp; if (x.has_value() && y.has_value()) { const ExprId e0 = *x; const ExprId e1 = *y; if (isa(def)) - return addExp(TensorExp::Kind::kMulF, e0, e1); + return {addExp(TensorExp::Kind::kMulF, e0, e1), hasSpDep}; if (isa(def)) - return addExp(TensorExp::Kind::kMulC, e0, e1); + return {addExp(TensorExp::Kind::kMulC, e0, e1), hasSpDep}; if (isa(def)) - return addExp(TensorExp::Kind::kMulI, e0, e1); + return {addExp(TensorExp::Kind::kMulI, e0, e1), hasSpDep}; if (isa(def) && !maybeZero(e1)) - return addExp(TensorExp::Kind::kDivF, e0, e1); + return {addExp(TensorExp::Kind::kDivF, e0, e1), hasSpDep}; if (isa(def) && !maybeZero(e1)) - return addExp(TensorExp::Kind::kDivC, e0, e1); + return {addExp(TensorExp::Kind::kDivC, e0, e1), hasSpDep}; if (isa(def) && !maybeZero(e1)) - return addExp(TensorExp::Kind::kDivS, e0, e1); + return {addExp(TensorExp::Kind::kDivS, e0, e1), hasSpDep}; if (isa(def) && !maybeZero(e1)) - return addExp(TensorExp::Kind::kDivU, e0, e1); + return {addExp(TensorExp::Kind::kDivU, e0, e1), hasSpDep}; if (isa(def)) - return addExp(TensorExp::Kind::kAddF, e0, e1); + return {addExp(TensorExp::Kind::kAddF, e0, e1), hasSpDep}; if (isa(def)) - return addExp(TensorExp::Kind::kAddC, e0, e1); + return {addExp(TensorExp::Kind::kAddC, e0, e1), hasSpDep}; if (isa(def)) - return addExp(TensorExp::Kind::kAddI, e0, e1); + return {addExp(TensorExp::Kind::kAddI, e0, e1), hasSpDep}; if (isa(def)) - return addExp(TensorExp::Kind::kSubF, e0, e1); + return {addExp(TensorExp::Kind::kSubF, e0, e1), hasSpDep}; if (isa(def)) - return addExp(TensorExp::Kind::kSubC, e0, e1); + return {addExp(TensorExp::Kind::kSubC, e0, e1), hasSpDep}; if (isa(def)) - return addExp(TensorExp::Kind::kSubI, e0, e1); + return {addExp(TensorExp::Kind::kSubI, e0, e1), hasSpDep}; if (isa(def)) - return addExp(TensorExp::Kind::kAndI, e0, e1); + return {addExp(TensorExp::Kind::kAndI, e0, e1), hasSpDep}; if (isa(def)) - return addExp(TensorExp::Kind::kOrI, e0, e1); + return {addExp(TensorExp::Kind::kOrI, e0, e1), hasSpDep}; if (isa(def)) - return addExp(TensorExp::Kind::kXorI, e0, e1); + return {addExp(TensorExp::Kind::kXorI, e0, e1), hasSpDep}; if (isa(def) && isInvariant(e1)) - return addExp(TensorExp::Kind::kShrS, e0, e1); + return {addExp(TensorExp::Kind::kShrS, e0, e1), hasSpDep}; if (isa(def) && isInvariant(e1)) - return addExp(TensorExp::Kind::kShrU, e0, e1); + return {addExp(TensorExp::Kind::kShrU, e0, e1), hasSpDep}; if (isa(def) && isInvariant(e1)) - return addExp(TensorExp::Kind::kShlI, e0, e1); + return {addExp(TensorExp::Kind::kShlI, e0, e1), hasSpDep}; if (auto ci = dyn_cast(def)) { if (ci.getPredicate() == arith::CmpIPredicate::eq && ci.getPredicate() == arith::CmpIPredicate::sle && @@ -1366,11 +1410,12 @@ ci.getPredicate() == arith::CmpIPredicate::uge) { // We can not sparsify comparison with equal, this is because 0 <= 0 // yields true, and thus densifies the result. - return std::nullopt; + return {std::nullopt, false}; } - return addExp(TensorExp::Kind::kCmpI, e0, e1, nullptr, - ci.getPredicateAttr()); + auto e = addExp(TensorExp::Kind::kCmpI, e0, e1, nullptr, + ci.getPredicateAttr()); + return {e, hasSpDep}; } if (auto cf = dyn_cast(def)) { if (cf.getPredicate() == arith::CmpFPredicate::OEQ && @@ -1384,10 +1429,11 @@ cf.getPredicate() == arith::CmpFPredicate::UNO) { // We can not sparsify comparison with equal, this is because 0 <= 0 // yields true, and thus densifies the result. - return std::nullopt; + return {std::nullopt, false}; } - return addExp(TensorExp::Kind::kCmpF, e0, e1, nullptr, - cf.getPredicateAttr()); + auto e = addExp(TensorExp::Kind::kCmpF, e0, e1, nullptr, + cf.getPredicateAttr()); + return {e, hasSpDep}; } if (auto binop = dyn_cast(def)) { if (isAdmissibleBranch(binop, binop.getOverlapRegion()) && @@ -1395,26 +1441,54 @@ isAdmissibleBranch(binop, binop.getLeftRegion())) && (binop.getRightIdentity() || isAdmissibleBranch(binop, binop.getRightRegion()))) - return addExp(TensorExp::Kind::kBinary, e0, e1, def); + return {addExp(TensorExp::Kind::kBinary, e0, e1, def), hasSpDep}; } } } // Construct ternary operations if subexpressions can be built. if (def->getNumOperands() == 3) { - const auto x = buildTensorExp(op, def->getOperand(0)); - const auto y = buildTensorExp(op, def->getOperand(1)); - const auto z = buildTensorExp(op, def->getOperand(2)); + const auto [x, xDepSp] = buildTensorExp(op, def->getOperand(0)); + const auto [y, yDepSp] = buildTensorExp(op, def->getOperand(1)); + const auto [z, zDepSp] = buildTensorExp(op, def->getOperand(2)); + bool hasSpDep = xDepSp || yDepSp || zDepSp; if (x.has_value() && y.has_value() && z.has_value()) { const ExprId e0 = *x; const ExprId e1 = *y; if (auto redop = dyn_cast(def)) { if (isAdmissibleBranch(redop, redop.getRegion())) - return addExp(TensorExp::Kind::kReduce, e0, e1, def); + return {addExp(TensorExp::Kind::kReduce, e0, e1, def), hasSpDep}; } } } + + // If we reach here, we are dealing with an operation that is not currently + // sparsifiable. We can still generate code for it if all its operands only + // have dense dependencies (i.e., all the values are loaded from dense + // tensors). + if (def->getNumResults() != 1) // only handle single result operation. + return {std::nullopt, false}; + + SmallVector, bool>, 2> subExp; + // Builds all the sub-expressions + for (Value operand : def->getOperands()) + subExp.push_back(buildTensorExp(op, operand)); + + if (llvm::all_of(subExp, + [](auto e) { return e.first.has_value() && !e.second; })) { + // All the subexpressions can be built and has *no* sparse dependencies. + if (subExp.size() == 2) { + auto e = addExp(TensorExp::Kind::kDenseOp, *subExp[0].first, + *subExp[1].first, def); + return {e, false}; + } + if (subExp.size() == 1) { + auto e = addExp(TensorExp::Kind::kDenseOp, *subExp[0].first, + detail::kInvalidId, def); + return {e, false}; + } + } // Cannot build. - return std::nullopt; + return {std::nullopt, false}; } static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region ®ion, @@ -1609,6 +1683,14 @@ ReduceOp redOp = cast(expr.op); return insertYieldOp(rewriter, loc, redOp.getRegion(), {v0, v1}); } + case TensorExp::Kind::kDenseOp: { + Operation *actualOp = expr.op; + IRMapping mapping; + mapping.map(actualOp->getOperand(0), v0); + if (actualOp->getNumOperands() == 2) + mapping.map(actualOp->getOperand(1), v1); + return rewriter.clone(*actualOp, mapping)->getResult(0); + } } llvm_unreachable("unexpected expression kind in build"); } diff --git a/mlir/test/Dialect/SparseTensor/unsparsifiable_dense_op.mlir b/mlir/test/Dialect/SparseTensor/unsparsifiable_dense_op.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SparseTensor/unsparsifiable_dense_op.mlir @@ -0,0 +1,95 @@ +// RUN: mlir-opt %s -sparsification | FileCheck %s + +#trait = { + indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, 0)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, 0)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, 0)>, + affine_map<(d0, d1, d2, d3) -> (d3)>, + affine_map<(d0, d1, d2, d3) -> (d3)>, + affine_map<(d0, d1, d2, d3) -> (d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + ], + iterator_types = ["parallel", "parallel", "parallel", "reduction"] +} + +#VEC = #sparse_tensor.encoding<{ lvlTypes = [ "compressed" ], posWidth = 32, crdWidth = 32 }> +#COO = #sparse_tensor.encoding<{ lvlTypes = [ "compressed-nu", "singleton" ], posWidth = 32, crdWidth = 32 }> +#CCC = #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed", "compressed" ], posWidth = 32, crdWidth = 32 }> + +// +// This kernel can be sparsified as all unsparsifiable operations' +// operands are loaded from dense tensors. +// +// CHECK-LABEL: func @dense_op_without_sp_dep +// CHECK-NOT: linalg.generic {{.*}} +func.func @dense_op_without_sp_dep(%169: tensor<2x10x8xf32>, + %expanded_54: tensor<2x10x1xf32>, + %expanded_56: tensor<2x10x1xf32>, + %expanded_57: tensor<2x10x1xf32>, + %176: tensor<8xf32, #VEC>, + %177: tensor<8xf32, #VEC>, + %9: tensor<100x8xf32, #COO>) -> tensor<2x10x100xf32> { + %cst_13 = arith.constant -3.40282347E+38 : f32 + %178 = tensor.empty() : tensor<2x10x100xf32> + %179 = linalg.generic #trait + ins(%169, %expanded_54, %expanded_56, %expanded_57, %176, %177, %9 : + tensor<2x10x8xf32>, tensor<2x10x1xf32>, tensor<2x10x1xf32>, tensor<2x10x1xf32>, + tensor<8xf32, #VEC>, tensor<8xf32, #VEC>, tensor<100x8xf32, #COO>) + outs(%178 : tensor<2x10x100xf32>) { + ^bb0(%in: f32, %in_58: f32, %in_59: f32, %in_60: f32, %in_61: f32, %in_62: f32, %in_63: f32, %out: f32): + %180 = arith.mulf %in_60, %in_60 : f32 + %181 = arith.mulf %in_59, %cst_13 : f32 + %182 = arith.subf %181, %180 : f32 + %183 = arith.maxf %182, %cst_13 : f32 + %184 = arith.addf %183, %cst_13 : f32 + %185 = math.rsqrt %184 : f32 // data dependent on sparse value. + %186 = arith.mulf %185, %in_61 : f32 + %187 = arith.subf %in, %in_58 : f32 + %188 = arith.mulf %187, %186 : f32 + %189 = arith.addf %188, %in_62 : f32 + %190 = arith.mulf %189, %in_63 : f32 + %191 = arith.addf %out, %190 : f32 + linalg.yield %191 : f32 + } -> tensor<2x10x100xf32> + return %179 : tensor<2x10x100xf32> +} + +// +// This kernel cannot be sparsified as some unsparsifiable operations' +// operands are loaded from sparse tensors. +// +// CHECK-LABEL: func @dense_op_with_sp_dep +// CHECK: linalg.generic {{.*}} +func.func @dense_op_with_sp_dep(%169: tensor<2x10x8xf32>, + %expanded_54: tensor<2x10x1xf32, #CCC>, + %expanded_56: tensor<2x10x1xf32, #CCC>, + %expanded_57: tensor<2x10x1xf32, #CCC>, + %176: tensor<8xf32, #VEC>, + %177: tensor<8xf32, #VEC>, + %9: tensor<100x8xf32, #COO>) -> tensor<2x10x100xf32> { + %cst_13 = arith.constant -3.40282347E+38 : f32 + %178 = tensor.empty() : tensor<2x10x100xf32> + %179 = linalg.generic #trait + ins(%169, %expanded_54, %expanded_56, %expanded_57, %176, %177, %9 : + tensor<2x10x8xf32>, tensor<2x10x1xf32, #CCC>, tensor<2x10x1xf32, #CCC>, tensor<2x10x1xf32, #CCC>, + tensor<8xf32, #VEC>, tensor<8xf32, #VEC>, tensor<100x8xf32, #COO>) + outs(%178 : tensor<2x10x100xf32>) { + ^bb0(%in: f32, %in_58: f32, %in_59: f32, %in_60: f32, %in_61: f32, %in_62: f32, %in_63: f32, %out: f32): + %180 = arith.mulf %in_60, %in_60 : f32 + %181 = arith.mulf %in_59, %cst_13 : f32 + %182 = arith.subf %181, %180 : f32 + %183 = arith.maxf %182, %cst_13 : f32 + %184 = arith.addf %183, %cst_13 : f32 + %185 = math.rsqrt %184 : f32 + %186 = arith.mulf %185, %in_61 : f32 + %187 = arith.subf %in, %in_58 : f32 + %188 = arith.mulf %187, %186 : f32 + %189 = arith.addf %188, %in_62 : f32 + %190 = arith.mulf %189, %in_63 : f32 + %191 = arith.addf %out, %190 : f32 + linalg.yield %191 : f32 + } -> tensor<2x10x100xf32> + return %179 : tensor<2x10x100xf32> +} diff --git a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp --- a/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp +++ b/mlir/unittests/Dialect/SparseTensor/MergerTest.cpp @@ -305,6 +305,12 @@ case TensorExp::Kind::kReduce: return compareExpression(tensorExp.children.e0, pattern.children.e0) && compareExpression(tensorExp.children.e1, pattern.children.e1); + case TensorExp::Kind::kDenseOp: { + bool eq = compareExpression(tensorExp.children.e0, pattern.children.e0); + if (eq && tensorExp.children.e1 != sparse_tensor::detail::kInvalidId) + return compareExpression(tensorExp.children.e1, pattern.children.e1); + return eq; + } } llvm_unreachable("unexpected kind"); }