diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -701,12 +701,35 @@ return success(); } }; + +/// Propagate constants in scf::For to trigger later simplifications. +struct PropagateConstantsInLoopBody : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ForOp forOp, + PatternRewriter &rewriter) const override { + if (!forOp.hasIterOperands()) + return failure(); + + Block &block = forOp.getRegion().front(); + auto yieldOp = cast(block.getTerminator()); + + for (auto it : llvm::zip(forOp.getIterOperands(), forOp.getRegionIterArgs(), + yieldOp.getOperands())) { + ConstantOp cst = std::get<0>(it).getDefiningOp(); + if (cst && (std::get<0>(it) == std::get<2>(it))) + std::get<1>(it).replaceAllUsesWith(cst.getResult()); + } + return success(); + } +}; } // namespace void ForOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add(context); + LastTensorLoadCanonicalization, PropagateConstantsInLoopBody>( + context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -516,3 +516,23 @@ // CHECK: return %[[FOR_RES]] : i32 return %0#0 : i32 } + +// ----- + +// CHECK-LABEL: constant_prop_in_for +// CHECK-SAME: %[[A0:[0-9a-z]*]]: i32 +func @constant_prop_in_for(%arg0 : i32, + %ub : index, %lb : index, %step : index) -> (i32) { + // CHECK-NEXT: %[[C32:.*]] = constant 32 : i32 + %cst = constant 32 : i32 + // CHECK-NEXT: %[[FOR_RES:.*]] = scf.for {{.*}} iter_args({{.*}} = %[[A0]]) -> (i32) { + %0:2 = scf.for %arg1 = %lb to %ub step %step iter_args(%arg2 = %arg0, %arg3 = %cst) + -> (i32, i32) { + // CHECK-NEXT: %{{.*}} = addi %{{.*}}, %[[C32]] : i32 + %1 = addi %arg2, %arg3 : i32 + %2 = addi %1, %cst : i32 + scf.yield %2, %cst : i32, i32 + } + // CHECK: return %[[FOR_RES]] : i32 + return %0#0 : i32 +}