diff --git a/mlir/include/mlir/IR/BlockAndValueMapping.h b/mlir/include/mlir/IR/BlockAndValueMapping.h --- a/mlir/include/mlir/IR/BlockAndValueMapping.h +++ b/mlir/include/mlir/IR/BlockAndValueMapping.h @@ -76,6 +76,9 @@ /// Clears all mappings held by the mapper. void clear() { valueMap.clear(); } + /// Check if the map is empty. + bool isEmpty() { return valueMap.empty(); } + /// Returns a new mapper containing the inverse mapping. BlockAndValueMapping getInverse() const { BlockAndValueMapping result; diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -685,12 +685,56 @@ return success(); } }; + +/// Canonicalization pattern to remove unused iteration arguments in a +/// scf::ForOp. If the block argument corresponding to the given iterator has no +/// use and the yielded value equals the input, we remap the input value with +/// the block argument allowing the ForOpIterArgsFolder to do the cleanup. +struct EnableIterArgsFolderOnUnusedArgs : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ForOp forOp, + PatternRewriter &rewriter) const override { + if (forOp.getNumRegionIterArgs() == 0) + return failure(); + + BlockAndValueMapping mapping; + auto yieldOp = cast(forOp.getBody()->getTerminator()); + + for (BlockArgument bbArg : forOp.getRegionIterArgs()) { + if (!bbArg.use_empty()) + continue; + + unsigned idx = bbArg.getArgNumber() - 1 /*indVar*/; + Value yieldVal = yieldOp.getOperand(idx); + Value inputVal = forOp.getOperand(idx + forOp.getNumControlOperands()); + + if (yieldVal == inputVal) + mapping.map(yieldVal, bbArg); + } + + if (mapping.isEmpty()) + return failure(); + + unsigned sizeOperands = yieldOp.getNumOperands(); + for (unsigned idx = 0; idx < sizeOperands; idx++) { + Value newOperand = mapping.lookupOrDefault(yieldOp.getOperand(idx)); + rewriter.startRootUpdate(yieldOp); + yieldOp.setOperand(idx, newOperand); + rewriter.finalizeRootUpdate(yieldOp); + } + return success(); + } +}; + } // namespace void ForOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); + results + .insert( + context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -335,6 +335,7 @@ } // ----- + func private @process(%0 : memref<128x128xf32>) func private @process_tensor(%0 : tensor<128x128xf32>) -> memref<128x128xf32> @@ -382,3 +383,22 @@ // CHECK-NEXT: return %[[R0]], %[[R1]], %[[FOR_RES]] : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32> return %0#0, %0#1, %0#2 : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32> } + +// ----- + +// CHECK-LABEL: EnableIterArgsFolderOnUnusedSecondArgInFor +// CHECK-SAME: %[[A0:[0-9a-z]*]]: i32 +func @EnableIterArgsFolderOnUnusedSecondArgInFor(%arg0 : i32, + %ub : index, %lb : index, %step : index) -> (i32, i32) { + // CHECK-NEXT: %[[C32:.*]] = constant 32 : i32 + %cst = constant 32 : i32 + // CHECK-NEXT: %[[FOR_RES:.*]] = scf.for {{.*}} iter_args({{.*}} = %[[A0]]) -> (i32) { + %0:2 = scf.for %arg1 = %lb to %ub step %step iter_args(%arg2 = %arg0, %arg3 = %cst) + -> (i32, i32) { + %1 = addi %arg2, %cst : i32 + scf.yield %1, %cst : i32, i32 + } + + // CHECK: return %[[FOR_RES]], %[[C32]] : i32, i32 + return %0#0, %0#1 : i32, i32 +}