diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -49,7 +49,7 @@ AffineMap invProducerResultIndexMap = inversePermutation(producerResultIndexMap); assert(invProducerResultIndexMap && - "expected producer result indexig map to be invertible"); + "expected producer result indexing map to be invertible"); LinalgOp producer = cast(producerOpOperand->getOwner()); // argMap is a map from producer loop -> producer arg tensor index. @@ -2217,6 +2217,186 @@ }; } // namespace +//===---------------------------------------------------------------------===// +// Patterns that help fusion in the context of sparse tensors. +//===---------------------------------------------------------------------===// + +// Helper to detect a sparse tensor type operand. +static bool isSparseTensor(OpOperand *op) { + if (auto enc = sparse_tensor::getSparseTensorEncoding(op->get().getType())) { + ArrayRef dimTypes = + enc.getDimLevelType(); + for (unsigned i = 0, e = dimTypes.size(); i < e; i++) + if (dimTypes[i] == + sparse_tensor::SparseTensorEncodingAttr::DimLevelType::Compressed) + return true; // at least one compressed + } + return false; +} + +// Helper method to find zero or empty initialization. +static bool isEmptyInit(OpOperand *op) { + Value val = op->get(); + if (matchPattern(val, m_Zero())) + return true; + if (matchPattern(val, m_AnyZeroFloat())) + return true; + if (val.getDefiningOp()) + return true; + if (val.getDefiningOp()) + return true; + return false; +} + +// Helper to detect sampling operation. +static bool isSampling(GenericOp op) { + if (auto yieldOp = + dyn_cast(op.region().front().getTerminator())) { + if (auto def = yieldOp.getOperand(0).getDefiningOp()) { + if (isa(def) || isa(def)) { + // Both scalar input arguments used exactly once. + Value s1 = op.getBlock()->getArgument(0); + Value s2 = op.getBlock()->getArgument(1); + return (def->getOperand(0) == s1 && def->getOperand(1) == s2) || + (def->getOperand(1) == s1 && def->getOperand(0) == s2); + } + } + } + return false; +} + +// Helper to detect chain of multiplications that do not involve x. +static bool isMulChain(Value val, Value x) { + if (auto arg = val.dyn_cast()) + return arg != x; + if (auto def = val.getDefiningOp()) { + if (isa(def) || isa(def)) + return isMulChain(def->getOperand(0), x) && + isMulChain(def->getOperand(1), x); + } + return false; +} + +// Helper to detect x = x + . +static bool isSumOfMul(GenericOp op) { + if (auto yieldOp = + dyn_cast(op.region().front().getTerminator())) { + if (auto def = yieldOp.getOperand(0).getDefiningOp()) { + if (isa(def) || isa(def)) { + Value x = op.getBlock()->getArguments().back(); + return (def->getOperand(0) == x && isMulChain(def->getOperand(1), x)) || + (def->getOperand(1) == x && isMulChain(def->getOperand(0), x)); + } + } + } + return false; +} + +namespace { +/// Rewriting rule that converts: +/// +/// T(i,j) = SUM(k, A(i,j,k) * B(i,j,k) * ... ) +/// X(i,j) = S(i,j) * T(i,j) +/// +/// into the following form, using distributive law: +/// +/// X(i,j) = SUM(k, S(i,j) * A(i,j,k) * B(i,j,k) * ... ) +/// +/// This kind of fusion would be undesirable in the dense case, since we +/// bring the multiplication into the reduction loop. However, for sparse +/// sampling tensor S, this fusion may actually reduce the asymptotic +/// complexity of the kernel, since intermediate results may be nullified! +// +// TODO: lift all rewriting methods to reuable fusion methods? +// +struct FuseSparseSamplingOps : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GenericOp op, + PatternRewriter &rewriter) const override { + // Check consumer. + if (!op.hasTensorSemantics() || op.getNumInputs() != 2 || + op.getNumResults() != 1) + return failure(); + if (op.getNumParallelLoops() != op.getNumLoops()) + return failure(); + if (!op.getTiedIndexingMap(op.getOutputOperand(0)).isIdentity() || + !op.getTiedIndexingMap(op.getInputOperand(0)).isIdentity() || + !op.getTiedIndexingMap(op.getInputOperand(1)).isIdentity()) + return failure(); + // Find consuming OP2(sparse, other) or OP2(other, sparse). The other + // operand can be sparse or dense, since the point of this rewriting rule + // is detecting a situation in which *more* sparsity is introduced into + // a computation, be it already sparse or still dense. + unsigned other = 0; + if (isSparseTensor(op.getInputOperand(0))) + other = 1; + else if (!isSparseTensor(op.getInputOperand(1))) + return failure(); + // Check producer. + auto prod = dyn_cast_or_null( + op.getInputOperand(other)->get().getDefiningOp()); + if (!prod || !prod.hasTensorSemantics() || prod.getNumResults() != 1) + return failure(); + if (!prod.getResult(0).hasOneUse()) + return failure(); + // Sampling consumer and sum of multiplication chain producer. + if (isEmptyInit(op.getOutputOperand(0)) && + isEmptyInit(prod.getOutputOperand(0)) && isSampling(op) && + isSumOfMul(prod)) { + // Modify operand structure of producer and consumer. + Location loc = prod.getLoc(); + SmallVector inputOps = prod.getInputOperands(); + SmallVector outputOps = op.getOutputOperands(); + SmallVector fusedIndexMaps = prod.getIndexingMaps(); + inputOps.push_back(op.getInputOperand(1 - other)->get()); + fusedIndexMaps.push_back(fusedIndexMaps.back()); // mimic other + // Fuse producer and consumer into a new generic op. + auto fusedOp = rewriter.create( + loc, op.getResult(0).getType(), inputOps, outputOps, + rewriter.getAffineMapArrayAttr(fusedIndexMaps), prod.iterator_types(), + /*doc=*/nullptr, + /*library_call=*/nullptr); + Block &prodBlock = prod.region().front(); + Block &consBlock = op.region().front(); + BlockAndValueMapping mapper; + Block *fusedBlock = new Block(); + fusedOp.region().push_back(fusedBlock); + unsigned num = prodBlock.getNumArguments(); + for (unsigned i = 0; i < num - 1; i++) + addArg(mapper, fusedBlock, prodBlock.getArgument(i)); + addArg(mapper, fusedBlock, consBlock.getArgument(1 - other)); + addArg(mapper, fusedBlock, prodBlock.getArgument(num - 1)); + // Clone bodies of the producer and consumer in new evaluation order. + auto acc = prodBlock.getTerminator()->getOperand(0).getDefiningOp(); + auto sampler = consBlock.getTerminator()->getOperand(0).getDefiningOp(); + rewriter.setInsertionPointToStart(fusedBlock); + Value last; + for (auto &op : prodBlock.without_terminator()) + if (&op != acc) { + last = op.getResult(0); + rewriter.clone(op, mapper); + } + mapper.map(consBlock.getArgument(other), fusedBlock->back().getResult(0)); + mapper.map(last, rewriter.clone(*sampler, mapper)->getResult(0)); + last = rewriter.clone(*acc, mapper)->getResult(0); + rewriter.create(loc, last); + // Replace consumer with fused operation. Old producer + // and consumer ops will be removed by DCE. + rewriter.replaceOp(op, fusedOp->getResults()); + return success(); + } + return failure(); + } + +private: + // Helper to add argument and record the mapping. + static void addArg(BlockAndValueMapping &mapper, Block *b, BlockArgument a) { + mapper.map(a, b->addArgument(a.getType(), a.getLoc())); + } +}; +} // namespace + //===---------------------------------------------------------------------===// // Methods that add patterns described in this file to a pattern list. //===---------------------------------------------------------------------===// @@ -2263,6 +2443,7 @@ patterns.add(context, options.controlElementwiseOpsFusionFn); + patterns.add(context); patterns.add(context); populateFoldReshapeOpsByExpansionPatterns(patterns, options.controlFoldingReshapesFn); diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sampled_mm_fusion.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sampled_mm_fusion.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sampled_mm_fusion.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sampled_mm_fusion.mlir @@ -5,7 +5,7 @@ // // Do the same run, but now with SIMDization as well. This should not change the outcome. // -// RUN: mlir-opt %s -sparse-compiler="vectorization-strategy=2 vl=8" | \ +// RUN: mlir-opt %s --sparse-compiler="vectorization-strategy=2 vl=8" | \ // RUN: mlir-cpu-runner -e entry -entry-point-result=void \ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ // RUN: FileCheck %s @@ -46,7 +46,8 @@ // module { // - // A kernel that computes a direct sampled matrix matrix multiplication. + // A kernel that computes a direct sampled matrix matrix multiplication + // (with dense result). // func @sampled_dd(%args: tensor<8x8xf64, #SM>, %arga: tensor<8x8xf64>, @@ -66,11 +67,13 @@ } // - // A kernel that computes an unfused sampled matrix matrix multiplication. + // A kernel that computes an unfused sampled matrix matrix multiplication + // (with dense result). // func @sampled_dd_unfused(%args: tensor<8x8xf64, #SM>, %arga: tensor<8x8xf64>, - %argb: tensor<8x8xf64>) -> (tensor<8x8xf64>, tensor<8x8xf64>) { + %argb: tensor<8x8xf64>) -> tensor<8x8xf64> { + // Perform dense-dense matrix matrix multiplication. %1 = arith.constant dense<0.0> : tensor<8x8xf64> %2 = linalg.generic #trait_matmul ins(%arga, %argb : tensor<8x8xf64>, tensor<8x8xf64>) @@ -80,17 +83,68 @@ %q = arith.addf %x, %p : f64 linalg.yield %q : f64 } -> tensor<8x8xf64> - - %3 = arith.constant dense<0.0> : tensor<8x8xf64> - %4 = linalg.generic #trait_scale + // Sample the result with elements-wise multiplication with sparse matrix. + %3 = linalg.generic #trait_scale ins(%2, %args : tensor<8x8xf64>, tensor<8x8xf64, #SM>) - outs(%3 : tensor<8x8xf64>) { + outs(%1 : tensor<8x8xf64>) { ^bb0(%t: f64, %s: f64, %x: f64): %r = arith.mulf %t, %s : f64 linalg.yield %r : f64 } -> tensor<8x8xf64> + return %3 : tensor<8x8xf64> + } - return %4, %2 : tensor<8x8xf64>, tensor<8x8xf64> + // + // A kernel that computes a direct sampled matrix matrix multiplication + // (with sparse result). + // + func @sparse_sampled_dd(%args: tensor<8x8xf64, #SM>, + %arga: tensor<8x8xf64>, + %argb: tensor<8x8xf64>) -> tensor<8x8xf64, #SM> { + %c8 = arith.constant 8 : index + %1 = sparse_tensor.init [%c8, %c8] : tensor<8x8xf64, #SM> + %2 = linalg.generic #trait_sampled_dense_dense + ins(%args, %arga, %argb: tensor<8x8xf64, #SM>, + tensor<8x8xf64>, tensor<8x8xf64>) + outs(%1: tensor<8x8xf64, #SM>) { + ^bb(%s: f64, %a: f64, %b: f64, %x: f64): + %p = arith.mulf %a, %b : f64 + %q = arith.mulf %s, %p : f64 + %r = arith.addf %x, %q : f64 + linalg.yield %r : f64 + } -> tensor<8x8xf64, #SM> + return %2 : tensor<8x8xf64, #SM> + } + + // + // A kernel that computes an unfused sampled matrix matrix multiplication + // (with sparse result). + // + func @sparse_sampled_dd_unfused( + %args: tensor<8x8xf64, #SM>, + %arga: tensor<8x8xf64>, + %argb: tensor<8x8xf64>) -> tensor<8x8xf64, #SM> { + // Perform dense-dense matrix matrix multiplication. + %1 = arith.constant dense<0.0> : tensor<8x8xf64> + %2 = linalg.generic #trait_matmul + ins(%arga, %argb : tensor<8x8xf64>, tensor<8x8xf64>) + outs(%1 : tensor<8x8xf64>) { + ^bb0(%a: f64, %b: f64, %x: f64): + %p = arith.mulf %a, %b : f64 + %q = arith.addf %x, %p : f64 + linalg.yield %q : f64 + } -> tensor<8x8xf64> + // Sample the result with elements-wise multiplication with sparse matrix. + %c8 = arith.constant 8 : index + %3 = sparse_tensor.init [%c8, %c8] : tensor<8x8xf64, #SM> + %4 = linalg.generic #trait_scale + ins(%2, %args : tensor<8x8xf64>, tensor<8x8xf64, #SM>) + outs(%3 : tensor<8x8xf64, #SM>) { + ^bb0(%t: f64, %s: f64, %x: f64): + %r = arith.mulf %t, %s : f64 + linalg.yield %r : f64 + } -> tensor<8x8xf64, #SM> + return %4 : tensor<8x8xf64, #SM> } // @@ -112,9 +166,15 @@ %0 = call @sampled_dd(%s, %a, %b) : (tensor<8x8xf64, #SM>, tensor<8x8xf64>, tensor<8x8xf64>) -> tensor<8x8xf64> - %1, %2 = call @sampled_dd_unfused(%s, %a, %b) + %1 = call @sampled_dd_unfused(%s, %a, %b) : (tensor<8x8xf64, #SM>, - tensor<8x8xf64>, tensor<8x8xf64>) -> (tensor<8x8xf64>, tensor<8x8xf64>) + tensor<8x8xf64>, tensor<8x8xf64>) -> tensor<8x8xf64> + %2 = call @sparse_sampled_dd(%s, %a, %b) + : (tensor<8x8xf64, #SM>, + tensor<8x8xf64>, tensor<8x8xf64>) -> tensor<8x8xf64, #SM> + %3 = call @sparse_sampled_dd_unfused(%s, %a, %b) + : (tensor<8x8xf64, #SM>, + tensor<8x8xf64>, tensor<8x8xf64>) -> tensor<8x8xf64, #SM> // Verify the outputs. // @@ -128,21 +188,31 @@ // CHECK-SAME: ( 0, 0, 0, 0, 0, 0, 0, 0 ), ( 0, 0, 0, 0, 0, 0, 0, 0 ), // CHECK-SAME: ( 0, 0, 0, 0, 0, 0, 0, 0 ), ( 0, 0, 0, 0, 0, 0, 0, 192 ) ) // + // CHECK-NEXT: ( 96, 192, 0, 0 ) + // + // CHECK-NEXT: ( 96, 192, 0, 0 ) + // %m0 = bufferization.to_memref %0 : memref<8x8xf64> %m1 = bufferization.to_memref %1 : memref<8x8xf64> - %m2 = bufferization.to_memref %2 : memref<8x8xf64> + %m2 = sparse_tensor.values %2 : tensor<8x8xf64, #SM> to memref + %m3 = sparse_tensor.values %3 : tensor<8x8xf64, #SM> to memref %v0 = vector.transfer_read %m0[%c0, %c0], %d0 : memref<8x8xf64>, vector<8x8xf64> %v1 = vector.transfer_read %m1[%c0, %c0], %d0 : memref<8x8xf64>, vector<8x8xf64> + %v2 = vector.transfer_read %m2[%c0], %d0 : memref, vector<4xf64> + %v3 = vector.transfer_read %m3[%c0], %d0 : memref, vector<4xf64> vector.print %v0 : vector<8x8xf64> vector.print %v1 : vector<8x8xf64> + vector.print %v2 : vector<4xf64> + vector.print %v3 : vector<4xf64> // Release the resources. sparse_tensor.release %s : tensor<8x8xf64, #SM> memref.dealloc %m0 : memref<8x8xf64> memref.dealloc %m1 : memref<8x8xf64> - memref.dealloc %m2 : memref<8x8xf64> + sparse_tensor.release %2 : tensor<8x8xf64, #SM> + sparse_tensor.release %3 : tensor<8x8xf64, #SM> return } diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/test_SDDMM.py b/mlir/test/Integration/Dialect/SparseTensor/taco/test_SDDMM.py --- a/mlir/test/Integration/Dialect/SparseTensor/taco/test_SDDMM.py +++ b/mlir/test/Integration/Dialect/SparseTensor/taco/test_SDDMM.py @@ -33,8 +33,9 @@ # Alternative way to define SDDMM kernel. Since this performs the reduction as # sum(k, A[i, k] * B[k, j]) * S[i, j] -# the MLIR lowering results in two separate tensor index expressions that -# need to be fused properly to guarantee proper asymptotic complexity. +# the MLIR lowering results in two separate tensor index expressions that are +# fused prior to running the sparse compiler in order to guarantee proper +# asymptotic complexity. Y[i, j] = A[i, k] * B[k, j] * S[i, j] expected = """; extended FROSTT format