diff --git a/mlir/lib/Dialect/Linalg/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/Linalg/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/SparseTensorRewriting.cpp @@ -45,15 +45,8 @@ // Helper method to find zero or empty initialization. static bool isEmptyInit(OpOperand *op) { Value val = op->get(); - if (matchPattern(val, m_Zero())) - return true; - if (matchPattern(val, m_AnyZeroFloat())) - return true; - if (val.getDefiningOp()) - return true; - if (val.getDefiningOp()) - return true; - return false; + return matchPattern(val, m_Zero()) || matchPattern(val, m_AnyZeroFloat()) || + val.getDefiningOp() || val.getDefiningOp(); } // Helper to detect sampling operation. @@ -123,11 +116,9 @@ PatternRewriter &rewriter) const override { // Check consumer. if (!op.hasTensorSemantics() || op.getNumInputs() != 2 || - op.getNumResults() != 1) - return failure(); - if (op.getNumParallelLoops() != op.getNumLoops()) - return failure(); - if (!op.getTiedIndexingMap(op.getOutputOperand(0)).isIdentity() || + op.getNumResults() != 1 || + op.getNumParallelLoops() != op.getNumLoops() || + !op.getTiedIndexingMap(op.getOutputOperand(0)).isIdentity() || !op.getTiedIndexingMap(op.getInputOperand(0)).isIdentity() || !op.getTiedIndexingMap(op.getInputOperand(1)).isIdentity()) return failure(); @@ -143,15 +134,13 @@ // Check producer. auto prod = dyn_cast_or_null( op.getInputOperand(other)->get().getDefiningOp()); - if (!prod || !prod.hasTensorSemantics() || prod.getNumResults() != 1) - return failure(); - if (!prod.getResult(0).hasOneUse()) + if (!prod || !prod.hasTensorSemantics() || prod.getNumResults() != 1 || + !prod.getResult(0).hasOneUse()) return failure(); // Sampling consumer and sum of multiplication chain producer. if (!isEmptyInit(op.getOutputOperand(0)) || - !isEmptyInit(prod.getOutputOperand(0))) - return failure(); - if (!isSampling(op) || !isSumOfMul(prod)) + !isEmptyInit(prod.getOutputOperand(0)) || !isSampling(op) || + !isSumOfMul(prod)) return failure(); // Modify operand structure of producer and consumer. Location loc = prod.getLoc();