diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -521,6 +521,13 @@ LogicalResult matchAndRewrite(ForOp op, PatternRewriter &rewriter) const override { + // If the upper bound is the same as the lower bound, the loop does not + // iterate, just remove it. + if (op.lowerBound() == op.upperBound()) { + rewriter.replaceOp(op, op.getIterOperands()); + return success(); + } + auto lb = op.lowerBound().getDefiningOp(); auto ub = op.upperBound().getDefiningOp(); if (!lb || !ub) @@ -1066,11 +1073,30 @@ return success(); } }; + +/// Removes parallel loops in which at least one lower/upper bound pair consists +/// of the same values - such loops have an empty iteration domain. +struct RemoveEmptyParallelLoops : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ParallelOp op, + PatternRewriter &rewriter) const override { + for (auto dim : llvm::zip(op.lowerBound(), op.upperBound())) { + if (std::get<0>(dim) == std::get<1>(dim)) { + rewriter.replaceOp(op, op.initVals()); + return success(); + } + } + return failure(); + } +}; + } // namespace void ParallelOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); + results.insert( + context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -32,30 +32,6 @@ // ----- -func @no_iteration(%A: memref) { - %c0 = constant 0 : index - %c1 = constant 1 : index - scf.parallel (%i0, %i1) = (%c0, %c0) to (%c1, %c0) step (%c1, %c1) { - %c42 = constant 42 : i32 - store %c42, %A[%i0, %i1] : memref - scf.yield - } - return -} - -// CHECK-LABEL: func @no_iteration( -// CHECK-SAME: [[ARG0:%.*]]: memref) { -// CHECK: [[C0:%.*]] = constant 0 : index -// CHECK: [[C1:%.*]] = constant 1 : index -// CHECK: [[C42:%.*]] = constant 42 : i32 -// CHECK: scf.parallel ([[V1:%.*]]) = ([[C0]]) to ([[C0]]) step ([[C1]]) { -// CHECK: store [[C42]], [[ARG0]]{{\[}}[[C0]], [[V1]]] : memref -// CHECK: scf.yield -// CHECK: } -// CHECK: return - -// ----- - func @one_unused(%cond: i1) -> (index) { %c0 = constant 0 : index %c1 = constant 1 : index @@ -241,6 +217,22 @@ return } +// CHECK-LABEL: @remove_zero_iteration_loop_vals +func @remove_zero_iteration_loop_vals(%arg0: index) { + %c2 = constant 2 : index + // CHECK: %[[INIT:.*]] = "test.init" + %init = "test.init"() : () -> i32 + // CHECK-NOT: scf.for + // CHECK-NOT: test.op + %0 = scf.for %i = %arg0 to %arg0 step %c2 iter_args(%arg = %init) -> (i32) { + %1 = "test.op"(%i, %arg) : (index, i32) -> i32 + scf.yield %1 : i32 + } + // CHECK: "test.consume"(%[[INIT]]) + "test.consume"(%0) : (i32) -> () + return +} + // CHECK-LABEL: @replace_single_iteration_loop func @replace_single_iteration_loop() { // CHECK: %[[LB:.*]] = constant 42 @@ -278,3 +270,24 @@ "test.consume"(%0) : (i32) -> () return } + +// CHECK-LABEL: @remove_empty_parallel_loop +func @remove_empty_parallel_loop(%lb: index, %ub: index, %s: index) { + // CHECK: %[[INIT:.*]] = "test.init" + %init = "test.init"() : () -> f32 + // CHECK-NOT: scf.parallel + // CHECK-NOT: test.produce + // CHECK-NOT: test.transform + %0 = scf.parallel (%i, %j, %k) = (%lb, %ub, %lb) to (%ub, %ub, %ub) step (%s, %s, %s) init(%init) -> f32 { + %1 = "test.produce"() : () -> f32 + scf.reduce(%1) : f32 { + ^bb0(%lhs: f32, %rhs: f32): + %2 = "test.transform"(%lhs, %rhs) : (f32, f32) -> f32 + scf.reduce.return %2 : f32 + } + scf.yield + } + // CHECK: "test.consume"(%[[INIT]]) + "test.consume"(%0) : (f32) -> () + return +}