diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -521,6 +521,13 @@ LogicalResult matchAndRewrite(ForOp op, PatternRewriter &rewriter) const override { + // If the upper bound is the same as the lower bound, the loop does not + // iterate, just remove it. + if (op.lowerBound() == op.upperBound()) { + rewriter.replaceOp(op, op.getIterOperands()); + return success(); + } + auto lb = op.lowerBound().getDefiningOp(); auto ub = op.upperBound().getDefiningOp(); if (!lb || !ub) @@ -1066,11 +1073,30 @@ return success(); } }; + +/// Removes parallel loops in which at least one lower/upper bound pair consists +/// of the same values - such loops have an empty iteration domain. +struct RemoveEmptyParallelLoops : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ParallelOp op, + PatternRewriter &rewriter) const override { + for (auto dim : llvm::zip(op.lowerBound(), op.upperBound())) { + if (std::get<0>(dim) == std::get<1>(dim)) { + rewriter.replaceOp(op, op.initVals()); + return success(); + } + } + return failure(); + } +}; + } // namespace void ParallelOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); + results.insert( + context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -32,10 +32,10 @@ // ----- -func @no_iteration(%A: memref) { +func @no_iteration(%A: memref, %ub: index) { %c0 = constant 0 : index %c1 = constant 1 : index - scf.parallel (%i0, %i1) = (%c0, %c0) to (%c1, %c0) step (%c1, %c1) { + scf.parallel (%i0, %i1) = (%c0, %c0) to (%c1, %ub) step (%c1, %c1) { %c42 = constant 42 : i32 store %c42, %A[%i0, %i1] : memref scf.yield @@ -44,11 +44,11 @@ } // CHECK-LABEL: func @no_iteration( -// CHECK-SAME: [[ARG0:%.*]]: memref) { +// CHECK-SAME: [[ARG0:%.*]]: memref, [[UB:%.*]]: index) { // CHECK: [[C0:%.*]] = constant 0 : index // CHECK: [[C1:%.*]] = constant 1 : index // CHECK: [[C42:%.*]] = constant 42 : i32 -// CHECK: scf.parallel ([[V1:%.*]]) = ([[C0]]) to ([[C0]]) step ([[C1]]) { +// CHECK: scf.parallel ([[V1:%.*]]) = ([[C0]]) to ([[UB]]) step ([[C1]]) { // CHECK: store [[C42]], [[ARG0]]{{\[}}[[C0]], [[V1]]] : memref // CHECK: scf.yield // CHECK: } @@ -241,6 +241,22 @@ return } +// CHECK-LABEL: @remove_zero_iteration_loop_vals +func @remove_zero_iteration_loop_vals(%arg0: index) { + %c2 = constant 2 : index + // CHECK: %[[INIT:.*]] = "test.init" + %init = "test.init"() : () -> i32 + // CHECK-NOT: scf.for + // CHECK-NOT: test.op + %0 = scf.for %i = %arg0 to %arg0 step %c2 iter_args(%arg = %init) -> (i32) { + %1 = "test.op"(%i, %arg) : (index, i32) -> i32 + scf.yield %1 : i32 + } + // CHECK: "test.consume"(%[[INIT]]) + "test.consume"(%0) : (i32) -> () + return +} + // CHECK-LABEL: @replace_single_iteration_loop func @replace_single_iteration_loop() { // CHECK: %[[LB:.*]] = constant 42 @@ -278,3 +294,24 @@ "test.consume"(%0) : (i32) -> () return } + +// CHECK-LABEL: @remove_empty_parallel_loop +func @remove_empty_parallel_loop(%lb: index, %ub: index, %s: index) { + // CHECK: %[[INIT:.*]] = "test.init" + %init = "test.init"() : () -> f32 + // CHECK-NOT: scf.parallel + // CHECK-NOT: test.produce + // CHECK-NOT: test.transform + %0 = scf.parallel (%i, %j, %k) = (%lb, %ub, %lb) to (%ub, %ub, %ub) step (%s, %s, %s) init(%init) -> f32 { + %1 = "test.produce"() : () -> f32 + scf.reduce(%1) : f32 { + ^bb0(%lhs: f32, %rhs: f32): + %2 = "test.transform"(%lhs, %rhs) : (f32, f32) -> f32 + scf.reduce.return %2 : f32 + } + scf.yield + } + // CHECK: "test.consume"(%[[INIT]]) + "test.consume"(%0) : (f32) -> () + return +}