diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h @@ -140,15 +140,23 @@ RankedTensorType getCOOFromType(RankedTensorType src, bool ordered); -/// Returns true iff MLIR operand has any sparse operand or result. -inline bool hasAnySparseOperandOrResult(Operation *op) { - bool anySparseIn = llvm::any_of(op->getOperands().getTypes(), [](Type t) { +/// Returns true iff MLIR operand has any sparse operand. +inline bool hasAnySparseOperand(Operation *op) { + return llvm::any_of(op->getOperands().getTypes(), [](Type t) { return getSparseTensorEncoding(t) != nullptr; }); - bool anySparseOut = llvm::any_of(op->getResults().getTypes(), [](Type t) { +} + +/// Returns true iff MLIR operand has any sparse result. +inline bool hasAnySparseResult(Operation *op) { + return llvm::any_of(op->getResults().getTypes(), [](Type t) { return getSparseTensorEncoding(t) != nullptr; }); - return anySparseIn || anySparseOut; +} + +/// Returns true iff MLIR operand has any sparse operand or result. +inline bool hasAnySparseOperandOrResult(Operation *op) { + return hasAnySparseOperand(op) || hasAnySparseResult(op); } // diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -422,11 +422,20 @@ if (!controlFn(&opOperand)) continue; + // Find the producer of the operand. FailureOr fusionResult = fuseElementwiseOps(rewriter, &opOperand); if (failed(fusionResult)) return rewriter.notifyMatchFailure(genericOp, "fusion failed"); Operation *producer = opOperand.get().getDefiningOp(); + + // Do not fuse a sparse-in/dense-out operation, as the + // result is too often not sparsifiable anymore. + if (sparse_tensor::hasAnySparseOperand(producer) && + !sparse_tensor::hasAnySparseResult(producer)) + return failure(); + + // Perform the fusion. for (auto [origVal, replacement] : fusionResult->replacements) { rewriter.replaceUsesWithIf(origVal, replacement, [&](OpOperand &use) { // Only replace consumer uses. diff --git a/mlir/test/Dialect/SparseTensor/sparse_fusion.mlir b/mlir/test/Dialect/SparseTensor/sparse_fusion.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SparseTensor/sparse_fusion.mlir @@ -0,0 +1,59 @@ +// RUN: mlir-opt %s --linalg-fuse-elementwise-ops | FileCheck %s + +#SV = #sparse_tensor.encoding<{ lvlTypes = ["compressed"] }> + +#trait = { + indexing_maps = [ + affine_map<(i) -> (i)>, // A + affine_map<(i) -> (i)> // B (out) + ], + iterator_types = ["parallel"], + doc = "B(i) = OP A(i)" +} + +// CHECK-LABEL: func @sparse_fusion +// CHECK: linalg.generic +// CHECK: arith.addf +// CHECK: linalg.generic +// CHECK: math.exp +// CHECK: arith.maxf +// CHECK-NOT: linalg.generic +// CHECK: return +func.func @sparse_fusion(%argA: tensor<100xf64, #SV>) -> tensor<100xf64> { + %c1 = arith.constant 1.0 : f64 + %c100 = arith.constant 100.0 : f64 + + // + // Densifying op. + // Should not be fused with subsequent dense ops. + // + %t0 = tensor.empty() : tensor<100xf64> + %l0 = linalg.generic #trait + ins(%argA: tensor<100xf64, #SV>) outs(%t0: tensor<100xf64>) { + ^bb0(%in0: f64, %out0: f64): + %b0 = arith.addf %in0, %c1 : f64 + linalg.yield %b0 : f64 + } -> tensor<100xf64> + + + // + // Two following dense ops. + // Should be fused, but not with above. + // + %t1 = tensor.empty() : tensor<100xf64> + %l1 = linalg.generic #trait + ins(%l0: tensor<100xf64>) outs(%t1: tensor<100xf64>) { + ^bb0(%in1: f64, %out1: f64): + %b1 = math.exp %in1 : f64 + linalg.yield %b1 : f64 + } -> tensor<100xf64> + %t2 = tensor.empty() : tensor<100xf64> + %l2 = linalg.generic #trait + ins(%l1: tensor<100xf64>) outs(%t2: tensor<100xf64>) { + ^bb0(%in2: f64, %out2: f64): + %b2 = arith.maxf %in2, %c100 : f64 + linalg.yield %b2 : f64 + } -> tensor<100xf64> + + return %l2 : tensor<100xf64> +}