diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -10,6 +10,8 @@ // //===----------------------------------------------------------------------===// +#include "CodegenUtils.h" + #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" @@ -94,12 +96,49 @@ return false; } +// Helper to detect direct yield of a zero value. +static bool isZeroYield(GenericOp op) { + auto yieldOp = cast(op.region().front().getTerminator()); + if (auto arg = yieldOp.getOperand(0).dyn_cast()) { + if (arg.getOwner()->getParentOp() == op) { + OpOperand *t = op.getInputAndOutputOperands()[arg.getArgNumber()]; + return matchPattern(t->get(), m_Zero()) || + matchPattern(t->get(), m_AnyZeroFloat()); + } + } + return false; +} + //===---------------------------------------------------------------------===// // The actual sparse tensor rewriting rules. //===---------------------------------------------------------------------===// namespace { +/// Rewriting rule that convert direct yield of zero with intial allocation. +struct FoldInvariantYield : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GenericOp op, + PatternRewriter &rewriter) const override { + if (!op.hasTensorSemantics() || op.getNumInputs() != 1 || + op.getNumResults() != 1 || + !isAlloc(op.getOutputOperand(0), /*isZero=*/false) || !isZeroYield(op)) + return failure(); + auto outputType = op.getResult(0).getType().cast(); + if (!outputType.hasStaticShape() || getSparseTensorEncoding(outputType)) + return failure(); + // Incorporate zero value into allocation copy. + AllocTensorOp a = + op.getOutputOperand(0)->get().getDefiningOp(); + Value zero = constantZero(rewriter, op.getLoc(), op.getResult(0).getType()); + a.getCopyMutable().assign(zero); + rewriter.replaceOp(op, op.getOutputOperand(0)->get()); + return success(); + } +}; + /// Rewriting rule that converts two kernels: /// /// T(i,j) = SUM(k, A(i,j,k) * B(i,j,k) * ... ) @@ -253,7 +292,7 @@ //===---------------------------------------------------------------------===// void mlir::populateSparseTensorRewriting(RewritePatternSet &patterns) { - patterns - .add, - ReshapeRewriter>(patterns.getContext()); + patterns.add, + ReshapeRewriter>(patterns.getContext()); } diff --git a/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir b/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir @@ -20,6 +20,25 @@ iterator_types = ["parallel", "parallel"] } +// CHECK-LABEL: func.func @fold_yield_zero() -> tensor<1024x1024xf64> { +// CHECK: %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf64> +// CHECK: %[[VAL_1:.*]] = bufferization.alloc_tensor() copy(%[[VAL_0]]) {bufferization.escape = [false], memory_space = 0 : ui64} : tensor<1024x1024xf64> +// CHECK: return %[[VAL_1]] : tensor<1024x1024xf64> +// CHECK: } +func.func @fold_yield_zero() -> tensor<1024x1024xf64> { + %cst = arith.constant 0.000000e+00 : f64 + %0 = linalg.init_tensor [1024, 1024] : tensor<1024x1024xf64> + %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> ()>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%cst : f64) + outs(%0 : tensor<1024x1024xf64>) { + ^bb0(%arg3: f64, %arg4: f64): + linalg.yield %arg3 : f64 + } -> tensor<1024x1024xf64> + return %1 : tensor<1024x1024xf64> +} + // CHECK-LABEL: func.func @sampled_dd_unfused( // CHECK-SAME: %[[VAL_0:.*]]: tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>>, // CHECK-SAME: %[[VAL_1:.*]]: tensor<8x8xf64>,