diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -11,6 +11,7 @@ #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Traits.h" +#include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" @@ -268,12 +269,63 @@ return success(); } }; + +struct AssumingOpRemoveUnusedResults : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(AssumingOp op, + PatternRewriter &rewriter) const override { + Block *body = op.getBody(); + auto yieldOp = llvm::cast(body->getTerminator()); + + // Find used values. + SmallVector usedYieldValues; + Value opResult, yieldOperand; + for (auto it : llvm::zip(op.getResults(), yieldOp.operands())) { + std::tie(opResult, yieldOperand) = it; + if (!opResult.getUses().empty()) + usedYieldValues.push_back(yieldOperand); + } + + // Rewrite only if redundant results exist. + if (usedYieldValues.size() == yieldOp->getNumOperands()) + return failure(); + + auto newOp = rewriter.create( + op.getLoc(), op.witness(), [&](OpBuilder &b, Location) { + // Copy body. + BlockAndValueMapping mapping; + for (Operation &nested : body->without_terminator()) { + b.clone(nested, mapping); + } + + // Yield only the later used values. + SmallVector newUsedYieldValues; + for (Value v : usedYieldValues) { + newUsedYieldValues.push_back(mapping.lookupOrDefault(v)); + } + return newUsedYieldValues; + }); + + // Use the new results to replace the previously used ones. + SmallVector replacementValues; + auto src = newOp.getResults().begin(); + for (auto it : op.getResults()) { + if (it.getUses().empty()) + replacementValues.push_back(nullptr); + else + replacementValues.push_back(*src++); + } + + rewriter.replaceOp(op, replacementValues); + return success(); + } +}; } // namespace void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { - // If taking a passing witness, inline region. - patterns.add(context); + patterns.add(context); } // See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -543,6 +543,27 @@ return } +// ----- + +// Remove unused results from assuming ops. +// CHECK-LABEL: func @unused_assuming_results +func @unused_assuming_results() { + // CHECK: %[[ASSUMING_RESULT:.*]] = shape.assuming %0 -> (f32) { + // CHECK: %{{.*}} = "produce.redundant" + // CHECK: %[[MEANINGFUL:.*]] = "produce.meaningful" + // CHECK: shape.assuming_yield %[[MEANINGFUL]] : f32 + // CHECK: } + // CHECK: "use"(%[[ASSUMING_RESULT]]) + %0 = "test.source"() : () -> (!shape.witness) + %1:2 = shape.assuming %0 -> (f32, f32) { + %2 = "produce.redundant"() : () -> (f32) + %3 = "produce.meaningful"() : () -> (f32) + shape.assuming_yield %2, %3 : f32, f32 + } + "use"(%1#1) : (f32) -> () + return +} + // ----- // Broadcastable with broadcastable constant shapes can be removed. // CHECK-LABEL: func @f