Index: mlir/include/mlir/Dialect/SCF/SCFOps.td =================================================================== --- mlir/include/mlir/Dialect/SCF/SCFOps.td +++ mlir/include/mlir/Dialect/SCF/SCFOps.td @@ -262,6 +262,8 @@ : OpBuilder::atBlockEnd(body); } }]; + + let hasCanonicalizer = 1; } def ParallelOp : SCF_Op<"parallel", Index: mlir/lib/Dialect/SCF/SCF.cpp =================================================================== --- mlir/lib/Dialect/SCF/SCF.cpp +++ mlir/lib/Dialect/SCF/SCF.cpp @@ -508,6 +508,68 @@ regions.push_back(RegionSuccessor(condition ? &thenRegion() : elseRegion)); } +namespace { +// Pattern to remove unused if op results. +struct RemoveUnusedResults : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + void transferBody(Block *source, Block *dest, ArrayRef usedResults, + PatternRewriter &rewriter) const { + // Move all operations to the destination block. + rewriter.mergeBlocks(source, dest); + // Replace the yield op by one that returns only the used values. + auto yieldOp = cast(dest->getTerminator()); + SmallVector usedOperands; + llvm::transform(usedResults, std::back_inserter(usedOperands), + [&](OpResult result) { + return yieldOp.getOperand(result.getResultNumber()); + }); + rewriter.setInsertionPoint(yieldOp); + rewriter.create(yieldOp.getLoc(), usedOperands); + rewriter.eraseOp(yieldOp); + } + + LogicalResult matchAndRewrite(IfOp op, + PatternRewriter &rewriter) const override { + // Compute the list of used results. + SmallVector usedResults; + llvm::copy_if(op.getResults(), std::back_inserter(usedResults), + [](OpResult result) { return !result.use_empty(); }); + + // Replace the operation if only a subset of its results have uses. + if (usedResults.size() < op.getNumResults()) { + // Compute the result types of the replacement operation. + SmallVector newTypes; + llvm::transform(usedResults, std::back_inserter(newTypes), + [](OpResult result) { return result.getType(); }); + + // Create a replacement operation with empty then and else regions. + auto emptyBuilder = [](OpBuilder &, Location) {}; + auto newOp = rewriter.create(op.getLoc(), newTypes, op.condition(), + emptyBuilder, emptyBuilder); + + // Move the bodies and replace the terminators (note there is a then and + // an else region since the operation returns results). + transferBody(op.getBody(0), newOp.getBody(0), usedResults, rewriter); + transferBody(op.getBody(1), newOp.getBody(1), usedResults, rewriter); + + // Replace the operation by the new one. + SmallVector repResults(op.getNumResults()); + for (auto en : llvm::enumerate(usedResults)) + repResults[en.value().getResultNumber()] = newOp.getResult(en.index()); + rewriter.replaceOp(op, repResults); + return success(); + } + return failure(); + } +}; +} // namespace + +void IfOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // ParallelOp //===----------------------------------------------------------------------===// Index: mlir/test/Dialect/SCF/canonicalize.mlir =================================================================== --- mlir/test/Dialect/SCF/canonicalize.mlir +++ mlir/test/Dialect/SCF/canonicalize.mlir @@ -53,3 +53,87 @@ // CHECK: scf.yield // CHECK: } // CHECK: return + +// ----- + +func @one_unused() -> (index) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %true = constant true + %0, %1 = scf.if %true -> (index, index) { + scf.yield %c0, %c1 : index, index + } else { + scf.yield %c0, %c1 : index, index + } + return %1 : index +} + +// CHECK-LABEL: func @one_unused +// CHECK: [[C0:%.*]] = constant 1 : index +// CHECK: [[C1:%.*]] = constant true +// CHECK: [[V0:%.*]] = scf.if [[C1]] -> (index) { +// CHECK: scf.yield [[C0]] : index +// CHECK: } else +// CHECK: scf.yield [[C0]] : index +// CHECK: } +// CHECK: return [[V0]] : index + +// ----- + +func @nested_unused() -> (index) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %true = constant true + %0, %1 = scf.if %true -> (index, index) { + %2, %3 = scf.if %true -> (index, index) { + scf.yield %c0, %c1 : index, index + } else { + scf.yield %c0, %c1 : index, index + } + scf.yield %2, %3 : index, index + } else { + scf.yield %c0, %c1 : index, index + } + return %1 : index +} + +// CHECK-LABEL: func @nested_unused +// CHECK: [[C0:%.*]] = constant 1 : index +// CHECK: [[C1:%.*]] = constant true +// CHECK: [[V0:%.*]] = scf.if [[C1]] -> (index) { +// CHECK: [[V1:%.*]] = scf.if [[C1]] -> (index) { +// CHECK: scf.yield [[C0]] : index +// CHECK: } else +// CHECK: scf.yield [[C0]] : index +// CHECK: } +// CHECK: scf.yield [[V1]] : index +// CHECK: } else +// CHECK: scf.yield [[C0]] : index +// CHECK: } +// CHECK: return [[V0]] : index + +// ----- + +func @side_effect() {} +func @all_unused() { + %c0 = constant 0 : index + %c1 = constant 1 : index + %true = constant true + %0, %1 = scf.if %true -> (index, index) { + call @side_effect() : () -> () + scf.yield %c0, %c1 : index, index + } else { + call @side_effect() : () -> () + scf.yield %c0, %c1 : index, index + } + return +} + +// CHECK-LABEL: func @all_unused +// CHECK: [[C1:%.*]] = constant true +// CHECK: scf.if [[C1]] { +// CHECK: call @side_effect() : () -> () +// CHECK: } else +// CHECK: call @side_effect() : () -> () +// CHECK: } +// CHECK: return