diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -10,6 +10,8 @@ // //===----------------------------------------------------------------------===// +#include "CodegenUtils.h" + #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" @@ -94,12 +96,50 @@ return false; } +// Helper to detect direct yield of a zero value. +static bool isZeroYield(GenericOp op) { + auto yieldOp = cast(op.region().front().getTerminator()); + if (auto arg = yieldOp.getOperand(0).dyn_cast()) { + if (arg.getOwner()->getParentOp() == op) { + OpOperand *t = op.getInputAndOutputOperands()[arg.getArgNumber()]; + return matchPattern(t->get(), m_Zero()) || + matchPattern(t->get(), m_AnyZeroFloat()); + } + } else if (auto *def = yieldOp.getOperand(0).getDefiningOp()) { + return matchPattern(def, m_Zero()) || matchPattern(def, m_AnyZeroFloat()); + } + return false; +} + //===---------------------------------------------------------------------===// // The actual sparse tensor rewriting rules. //===---------------------------------------------------------------------===// namespace { +/// Rewriting rule that converts direct yield of zero with initial allocation. +struct FoldInvariantYield : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GenericOp op, + PatternRewriter &rewriter) const override { + if (!op.hasTensorSemantics() || op.getNumResults() != 1 || + !isAlloc(op.getOutputOperand(0), /*isZero=*/false) || !isZeroYield(op)) + return failure(); + auto outputType = op.getResult(0).getType().cast(); + if (!outputType.hasStaticShape() || getSparseTensorEncoding(outputType)) + return failure(); + // Incorporate zero value into allocation copy. + Value zero = constantZero(rewriter, op.getLoc(), op.getResult(0).getType()); + AllocTensorOp a = + op.getOutputOperand(0)->get().getDefiningOp(); + rewriter.updateRootInPlace(a, [&]() { a.getCopyMutable().assign(zero); }); + rewriter.replaceOp(op, op.getOutputOperand(0)->get()); + return success(); + } +}; + /// Rewriting rule that converts two kernels: /// /// T(i,j) = SUM(k, A(i,j,k) * B(i,j,k) * ... ) @@ -187,11 +227,13 @@ rewriter.create(loc, last); // Force initial value on merged allocation for dense outputs. if (!getSparseTensorEncoding(op.getResult(0).getType())) { - AllocTensorOp a1 = - prod.getOutputOperand(0)->get().getDefiningOp(); - AllocTensorOp a2 = + Value init = prod.getOutputOperand(0) + ->get() + .getDefiningOp() + .getCopy(); + AllocTensorOp a = op.getOutputOperand(0)->get().getDefiningOp(); - a2.getCopyMutable().assign(a1.getCopy()); + rewriter.updateRootInPlace(a, [&]() { a.getCopyMutable().assign(init); }); } // Replace consumer with fused operation. Old producer // and consumer ops will be removed by DCE. @@ -253,7 +295,7 @@ //===---------------------------------------------------------------------===// void mlir::populateSparseTensorRewriting(RewritePatternSet &patterns) { - patterns - .add, - ReshapeRewriter>(patterns.getContext()); + patterns.add, + ReshapeRewriter>(patterns.getContext()); } diff --git a/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir b/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir old mode 100644 new mode 100755 --- a/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir @@ -20,6 +20,42 @@ iterator_types = ["parallel", "parallel"] } +// CHECK-LABEL: func.func @fold_yield_arg_zero() -> tensor<1024x1024xf64> { +// CHECK: %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf64> +// CHECK: %[[VAL_1:.*]] = bufferization.alloc_tensor() copy(%[[VAL_0]]) {bufferization.escape = [false], memory_space = 0 : ui64} : tensor<1024x1024xf64> +// CHECK: return %[[VAL_1]] : tensor<1024x1024xf64> +// CHECK: } +func.func @fold_yield_arg_zero() -> tensor<1024x1024xf64> { + %cst = arith.constant 0.000000e+00 : f64 + %0 = linalg.init_tensor [1024, 1024] : tensor<1024x1024xf64> + %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> ()>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%cst : f64) + outs(%0 : tensor<1024x1024xf64>) { + ^bb0(%a: f64, %x: f64): + linalg.yield %a : f64 + } -> tensor<1024x1024xf64> + return %1 : tensor<1024x1024xf64> +} + +// CHECK-LABEL: func.func @fold_yield_direct_zero() -> tensor<32xf64> { +// CHECK: %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : tensor<32xf64> +// CHECK: %[[VAL_1:.*]] = bufferization.alloc_tensor() copy(%[[VAL_0]]) {bufferization.escape = [false], memory_space = 0 : ui64} : tensor<32xf64> +// CHECK: return %[[VAL_1]] : tensor<32xf64> +// CHECK: } +func.func @fold_yield_direct_zero() -> tensor<32xf64> { + %cst = arith.constant 0.000000e+00 : f64 + %0 = linalg.init_tensor [32] : tensor<32xf64> + %1 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"]} + outs(%0 : tensor<32xf64>) { + ^bb0(%x: f64): + linalg.yield %cst : f64 + } -> tensor<32xf64> + return %1 : tensor<32xf64> +} + // CHECK-LABEL: func.func @sampled_dd_unfused( // CHECK-SAME: %[[VAL_0:.*]]: tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>>, // CHECK-SAME: %[[VAL_1:.*]]: tensor<8x8xf64>,