diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -41,12 +41,17 @@ return false; } -// Helper method to find zero or empty initialization. -static bool isEmptyInit(OpOperand *op) { +// Helper method to find zero/uninitialized allocation. +static bool isAlloc(OpOperand *op, bool isZero) { Value val = op->get(); - return matchPattern(val, m_Zero()) || matchPattern(val, m_AnyZeroFloat()) || - val.getDefiningOp() || - val.getDefiningOp(); + if (auto alloc = val.getDefiningOp()) { + Value copy = alloc.getCopy(); + if (isZero) + return copy && (matchPattern(copy, m_Zero()) || + matchPattern(copy, m_AnyZeroFloat())); + return !copy; + } + return false; } // Helper to detect sampling operation. @@ -140,9 +145,9 @@ !prod.getResult(0).hasOneUse()) return failure(); // Sampling consumer and sum of multiplication chain producer. - if (!isEmptyInit(op.getOutputOperand(0)) || - !isEmptyInit(prod.getOutputOperand(0)) || !isSampling(op) || - !isSumOfMul(prod)) + if (!isAlloc(op.getOutputOperand(0), /*isZero=*/false) || + !isAlloc(prod.getOutputOperand(0), /*isZero=*/true) || + !isSampling(op) || !isSumOfMul(prod)) return failure(); // Modify operand structure of producer and consumer. Location loc = prod.getLoc(); @@ -180,6 +185,14 @@ mapper.map(last, rewriter.clone(*sampler, mapper)->getResult(0)); last = rewriter.clone(*acc, mapper)->getResult(0); rewriter.create(loc, last); + // Force initial value on merged allocation for dense outputs. + if (!getSparseTensorEncoding(op.getResult(0).getType())) { + AllocTensorOp a1 = + prod.getOutputOperand(0)->get().getDefiningOp(); + AllocTensorOp a2 = + op.getOutputOperand(0)->get().getDefiningOp(); + a2.getCopyMutable().assign(a1.getCopy()); + } // Replace consumer with fused operation. Old producer // and consumer ops will be removed by DCE. rewriter.replaceOp(op, fusedOp->getResults()); @@ -240,7 +253,7 @@ //===---------------------------------------------------------------------===// void mlir::populateSparseTensorRewriting(RewritePatternSet &patterns) { - // TODO(springerm): enable FuseSparseMultiplyOverAdd - patterns.add, - ReshapeRewriter>(patterns.getContext()); + patterns + .add, + ReshapeRewriter>(patterns.getContext()); } diff --git a/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir b/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir @@ -0,0 +1,162 @@ +// RUN: mlir-opt %s --tensor-copy-insertion --sparsification --cse | FileCheck %s + +#SM = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }> + +#trait_matmul = { + indexing_maps = [ + affine_map<(d0, d1, d2) -> (d1, d0)>, + affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d1, d2)> + ], + iterator_types = ["reduction", "parallel", "parallel"] +} + +#trait_scale = { + indexing_maps = [ + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)> + ], + iterator_types = ["parallel", "parallel"] +} + +// CHECK-LABEL: func.func @sampled_dd_unfused( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<8x8xf64>, +// CHECK-SAME: %[[VAL_2:.*]]: tensor<8x8xf64>) -> tensor<8x8xf64> { +// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 8 : index +// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[VAL_6:.*]] = arith.constant dense<0.000000e+00> : tensor<8x8xf64> +// CHECK: %[[VAL_7:.*]] = bufferization.alloc_tensor() copy(%[[VAL_6]]) {bufferization.escape = [false]} : tensor<8x8xf64> +// CHECK: %[[VAL_8:.*]] = bufferization.alloc_tensor() copy(%[[VAL_6]]) {bufferization.escape = [false], memory_space = 0 : ui64} : tensor<8x8xf64> +// CHECK: %[[VAL_9:.*]] = bufferization.to_memref %[[VAL_1]] : memref<8x8xf64> +// CHECK: %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_2]] : memref<8x8xf64> +// CHECK: %[[VAL_11:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> to memref +// CHECK: %[[VAL_12:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> to memref +// CHECK: %[[VAL_13:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_5]] : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> to memref +// CHECK: %[[VAL_14:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_5]] : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> to memref +// CHECK: %[[VAL_15:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> to memref +// CHECK: %[[VAL_16:.*]] = bufferization.to_memref %[[VAL_8]] : memref<8x8xf64> +// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_4]]] : memref +// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_5]]] : memref +// CHECK: scf.for %[[VAL_19:.*]] = %[[VAL_17]] to %[[VAL_18]] step %[[VAL_5]] { +// CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_19]]] : memref +// CHECK: %[[VAL_21:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_19]]] : memref +// CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_19]], %[[VAL_5]] : index +// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_22]]] : memref +// CHECK: scf.for %[[VAL_24:.*]] = %[[VAL_21]] to %[[VAL_23]] step %[[VAL_5]] { +// CHECK: %[[VAL_25:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_24]]] : memref +// CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_20]], %[[VAL_25]]] : memref<8x8xf64> +// CHECK: %[[VAL_27:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_24]]] : memref +// CHECK: %[[VAL_28:.*]] = scf.for %[[VAL_29:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] iter_args(%[[VAL_30:.*]] = %[[VAL_26]]) -> (f64) { +// CHECK: %[[VAL_31:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_20]], %[[VAL_29]]] : memref<8x8xf64> +// CHECK: %[[VAL_32:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_29]], %[[VAL_25]]] : memref<8x8xf64> +// CHECK: %[[VAL_33:.*]] = arith.mulf %[[VAL_31]], %[[VAL_32]] : f64 +// CHECK: %[[VAL_34:.*]] = arith.mulf %[[VAL_33]], %[[VAL_27]] : f64 +// CHECK: %[[VAL_35:.*]] = arith.addf %[[VAL_30]], %[[VAL_34]] : f64 +// CHECK: scf.yield %[[VAL_35]] : f64 +// CHECK: } +// CHECK: memref.store %[[VAL_24:.*]], %[[VAL_16]]{{\[}}%[[VAL_20]], %[[VAL_25]]] : memref<8x8xf64> +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_37:.*]] = bufferization.to_tensor %[[VAL_16]] : memref<8x8xf64> +// CHECK: return %[[VAL_37]] : tensor<8x8xf64> +// CHECK: } +func.func @sampled_dd_unfused(%args: tensor<8x8xf64, #SM>, + %arga: tensor<8x8xf64>, + %argb: tensor<8x8xf64>) -> tensor<8x8xf64> { + // Perform dense-dense matrix matrix multiplication. + %1 = arith.constant dense<0.0> : tensor<8x8xf64> + %2 = linalg.generic #trait_matmul + ins(%arga, %argb : tensor<8x8xf64>, tensor<8x8xf64>) + outs(%1 : tensor<8x8xf64>) { + ^bb0(%a: f64, %b: f64, %x: f64): + %p = arith.mulf %a, %b : f64 + %q = arith.addf %x, %p : f64 + linalg.yield %q : f64 + } -> tensor<8x8xf64> + // Sample the result with elements-wise multiplication with sparse matrix. + %3 = linalg.generic #trait_scale + ins(%2, %args : tensor<8x8xf64>, tensor<8x8xf64, #SM>) + outs(%1 : tensor<8x8xf64>) { + ^bb0(%t: f64, %s: f64, %x: f64): + %r = arith.mulf %t, %s : f64 + linalg.yield %r : f64 + } -> tensor<8x8xf64> + return %3 : tensor<8x8xf64> +} + +// CHECK-LABEL: func.func @sparse_sampled_dd_unfused( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<8x8xf64>, +// CHECK-SAME: %[[VAL_2:.*]]: tensor<8x8xf64>) -> tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> { +// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 8 : index +// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 0.000000e+00 : f64 +// CHECK-DAG: %[[VAL_8:.*]] = arith.constant dense<0.000000e+00> : tensor<8x8xf64> +// CHECK: %[[VAL_9:.*]] = bufferization.alloc_tensor() copy(%[[VAL_8]]) {bufferization.escape = [false]} : tensor<8x8xf64> +// CHECK: %[[VAL_10:.*]] = bufferization.alloc_tensor() {bufferization.escape = [false]} : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> +// CHECK: %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_1]] : memref<8x8xf64> +// CHECK: %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_2]] : memref<8x8xf64> +// CHECK: %[[VAL_13:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> to memref +// CHECK: %[[VAL_14:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> to memref +// CHECK: %[[VAL_15:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_5]] : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> to memref +// CHECK: %[[VAL_16:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_5]] : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> to memref +// CHECK: %[[VAL_17:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> to memref +// CHECK: %[[VAL_18:.*]] = memref.alloca(%[[VAL_6]]) : memref +// CHECK: %[[VAL_19:.*]] = memref.alloca() : memref +// CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_4]]] : memref +// CHECK: %[[VAL_21:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_5]]] : memref +// CHECK: scf.for %[[VAL_22:.*]] = %[[VAL_20]] to %[[VAL_21]] step %[[VAL_5]] { +// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_22]]] : memref +// CHECK: memref.store %[[VAL_23]], %[[VAL_18]]{{\[}}%[[VAL_4]]] : memref +// CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_22]]] : memref +// CHECK: %[[VAL_25:.*]] = arith.addi %[[VAL_22]], %[[VAL_5]] : index +// CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_25]]] : memref +// CHECK: scf.for %[[VAL_27:.*]] = %[[VAL_24]] to %[[VAL_26]] step %[[VAL_5]] { +// CHECK: %[[VAL_28:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_27]]] : memref +// CHECK: memref.store %[[VAL_28]], %[[VAL_18]]{{\[}}%[[VAL_5]]] : memref +// CHECK: %[[VAL_29:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_27]]] : memref +// CHECK: %[[VAL_30:.*]] = scf.for %[[VAL_31:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] iter_args(%[[VAL_32:.*]] = %[[VAL_7]]) -> (f64) { +// CHECK: memref.store %[[VAL_31]], %[[VAL_18]]{{\[}}%[[VAL_6]]] : memref +// CHECK: %[[VAL_33:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_23]], %[[VAL_31]]] : memref<8x8xf64> +// CHECK: %[[VAL_34:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_31]], %[[VAL_28]]] : memref<8x8xf64> +// CHECK: %[[VAL_35:.*]] = arith.mulf %[[VAL_33]], %[[VAL_34]] : f64 +// CHECK: %[[VAL_36:.*]] = arith.mulf %[[VAL_35]], %[[VAL_29]] : f64 +// CHECK: %[[VAL_37:.*]] = arith.addf %[[VAL_32]], %[[VAL_36]] : f64 +// CHECK: scf.yield %[[VAL_37]] : f64 +// CHECK: } +// CHECK: memref.store %[[VAL_30:.*]], %[[VAL_19]][] : memref +// CHECK: sparse_tensor.lex_insert %[[VAL_10]], %[[VAL_18]], %[[VAL_19]] : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>>, memref, memref +// CHECK: } +// CHECK: } +// CHECK: %[[VAL_39:.*]] = sparse_tensor.load %[[VAL_10]] hasInserts : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> +// CHECK: return %[[VAL_39]] : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> +// CHECK: } +func.func @sparse_sampled_dd_unfused(%args: tensor<8x8xf64, #SM>, + %arga: tensor<8x8xf64>, + %argb: tensor<8x8xf64>) -> tensor<8x8xf64, #SM> { + // Perform dense-dense matrix matrix multiplication. + %1 = arith.constant dense<0.0> : tensor<8x8xf64> + %2 = linalg.generic #trait_matmul + ins(%arga, %argb : tensor<8x8xf64>, tensor<8x8xf64>) + outs(%1 : tensor<8x8xf64>) { + ^bb0(%a: f64, %b: f64, %x: f64): + %p = arith.mulf %a, %b : f64 + %q = arith.addf %x, %p : f64 + linalg.yield %q : f64 + } -> tensor<8x8xf64> + // Sample the result with elements-wise multiplication with sparse matrix. + %3 = bufferization.alloc_tensor() : tensor<8x8xf64, #SM> + %4 = linalg.generic #trait_scale + ins(%2, %args : tensor<8x8xf64>, tensor<8x8xf64, #SM>) + outs(%3 : tensor<8x8xf64, #SM>) { + ^bb0(%t: f64, %s: f64, %x: f64): + %r = arith.mulf %t, %s : f64 + linalg.yield %r : f64 + } -> tensor<8x8xf64, #SM> + return %4 : tensor<8x8xf64, #SM> +} diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sampled_mm_fusion.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sampled_mm_fusion.mlir --- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sampled_mm_fusion.mlir +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sampled_mm_fusion.mlir @@ -50,8 +50,8 @@ // (with dense result). // func.func @sampled_dd(%args: tensor<8x8xf64, #SM>, - %arga: tensor<8x8xf64>, - %argb: tensor<8x8xf64>) -> tensor<8x8xf64> { + %arga: tensor<8x8xf64>, + %argb: tensor<8x8xf64>) -> tensor<8x8xf64> { %1 = arith.constant dense<0.0> : tensor<8x8xf64> %2 = linalg.generic #trait_sampled_dense_dense ins(%args, %arga, %argb: tensor<8x8xf64, #SM>, @@ -71,8 +71,8 @@ // (with dense result). // func.func @sampled_dd_unfused(%args: tensor<8x8xf64, #SM>, - %arga: tensor<8x8xf64>, - %argb: tensor<8x8xf64>) -> tensor<8x8xf64> { + %arga: tensor<8x8xf64>, + %argb: tensor<8x8xf64>) -> tensor<8x8xf64> { // Perform dense-dense matrix matrix multiplication. %1 = arith.constant dense<0.0> : tensor<8x8xf64> %2 = linalg.generic #trait_matmul @@ -99,8 +99,8 @@ // (with sparse result). // func.func @sparse_sampled_dd(%args: tensor<8x8xf64, #SM>, - %arga: tensor<8x8xf64>, - %argb: tensor<8x8xf64>) -> tensor<8x8xf64, #SM> { + %arga: tensor<8x8xf64>, + %argb: tensor<8x8xf64>) -> tensor<8x8xf64, #SM> { %1 = bufferization.alloc_tensor() : tensor<8x8xf64, #SM> %2 = linalg.generic #trait_sampled_dense_dense ins(%args, %arga, %argb: tensor<8x8xf64, #SM>,