diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -268,12 +268,57 @@ return success(); } }; + +struct AssumingOpRemoveUnusedResults : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(AssumingOp op, + PatternRewriter &rewriter) const override { + Block *body = op.getBody(); + auto yieldOp = llvm::cast(body->getTerminator()); + + // Find used values. + SmallVector newYieldOperands; + Value opResult, yieldOperand; + for (auto it : llvm::zip(op.getResults(), yieldOp.operands())) { + std::tie(opResult, yieldOperand) = it; + if (!opResult.getUses().empty()) { + newYieldOperands.push_back(yieldOperand); + } + } + + // Rewrite only if redundant results exist. + if (newYieldOperands.size() == yieldOp->getNumOperands()) + return failure(); + + // Replace yield op in the old assuming op's body and move the entire region + // to the new assuming op. + rewriter.setInsertionPointToEnd(body); + auto newYieldOp = + rewriter.replaceOpWithNewOp(yieldOp, newYieldOperands); + rewriter.setInsertionPoint(op); + auto newOp = rewriter.create( + op.getLoc(), newYieldOp->getOperandTypes(), op.witness()); + newOp.doRegion().takeBody(op.doRegion()); + + // Use the new results to replace the previously used ones. + SmallVector replacementValues; + auto src = newOp.getResults().begin(); + for (auto it : op.getResults()) { + if (it.getUses().empty()) + replacementValues.push_back(nullptr); + else + replacementValues.push_back(*src++); + } + rewriter.replaceOp(op, replacementValues); + return success(); + } +}; } // namespace void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { - // If taking a passing witness, inline region. - patterns.add(context); + patterns.add(context); } // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -543,6 +543,27 @@ return } +// ----- + +// Remove unused results from assuming ops. +// CHECK-LABEL: func @unused_assuming_results +func @unused_assuming_results() { + // CHECK: %[[ASSUMING_RESULT:.*]] = shape.assuming %0 -> (f32) { + // CHECK: %{{.*}} = "produce.redundant" + // CHECK: %[[MEANINGFUL:.*]] = "produce.meaningful" + // CHECK: shape.assuming_yield %[[MEANINGFUL]] : f32 + // CHECK: } + // CHECK: "use"(%[[ASSUMING_RESULT]]) + %0 = "test.source"() : () -> (!shape.witness) + %1:2 = shape.assuming %0 -> (f32, f32) { + %2 = "produce.redundant"() : () -> (f32) + %3 = "produce.meaningful"() : () -> (f32) + shape.assuming_yield %2, %3 : f32, f32 + } + "use"(%1#1) : (f32) -> () + return +} + // ----- // Broadcastable with broadcastable constant shapes can be removed. // CHECK-LABEL: func @f