diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -1896,6 +1896,47 @@ return success(); } }; + +/// Removes Affine.If cond if the condition is always true or false in certain +/// trivial cases. Promotes the then/else block in the parent operation block. +struct AlwaysTrueOrFalseIf : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(AffineIfOp op, + PatternRewriter &rewriter) const override { + + // If affine.if is returning results then don't remove it. + // TODO: Similar simplication can be done when affine.if return results. + if (op.getNumResults() > 0) + return failure(); + + IntegerSet conditionSet = op.getIntegerSet(); + Block *blockToMove; + if (conditionSet.isEmptyIntegerSet()) { + // If the else region is not there, simply remove the Affine.if + // operation. + if (!op.hasElse()) { + rewriter.eraseOp(op); + return success(); + } + blockToMove = op.getElseBlock(); + } else if (conditionSet.getNumEqualities() == 1 && + conditionSet.getNumInequalities() == 0 && + conditionSet.getConstraint(0) == 0) { + // Condition to check for trivially true condition (0==0). + blockToMove = op.getThenBlock(); + } else { + return failure(); + } + // Remove the terminator from the block as it already exists in parent + // block. + Operation *blockTerminator = blockToMove->getTerminator(); + rewriter.eraseOp(blockTerminator); + rewriter.mergeBlockBefore(blockToMove, op); + rewriter.eraseOp(op); + return success(); + } +}; } // end anonymous namespace. static LogicalResult verify(AffineIfOp op) { @@ -2059,7 +2100,7 @@ void AffineIfOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Affine/loop-unswitch.mlir b/mlir/test/Dialect/Affine/loop-unswitch.mlir --- a/mlir/test/Dialect/Affine/loop-unswitch.mlir +++ b/mlir/test/Dialect/Affine/loop-unswitch.mlir @@ -245,9 +245,7 @@ } return } -// CHECK: affine.if -// CHECK-NEXT: call -// CHECK-NEXT: } +// CHECK: call // CHECK-NEXT: affine.if // CHECK-NEXT: affine.for // CHECK-NEXT: call diff --git a/mlir/test/Dialect/Affine/simplify-affine-structures.mlir b/mlir/test/Dialect/Affine/simplify-affine-structures.mlir --- a/mlir/test/Dialect/Affine/simplify-affine-structures.mlir +++ b/mlir/test/Dialect/Affine/simplify-affine-structures.mlir @@ -1,6 +1,5 @@ // RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file -simplify-affine-structures | FileCheck %s -// CHECK-DAG: #[[$SET_EMPTY:.*]] = affine_set<() : (1 == 0)> // CHECK-DAG: #[[$SET_2D:.*]] = affine_set<(d0, d1) : (d0 - 100 == 0, d1 - 10 == 0, -d0 + 100 >= 0, d1 >= 0)> // CHECK-DAG: #[[$SET_7_11:.*]] = affine_set<(d0, d1) : (d0 * 7 + d1 * 5 + 88 == 0, d0 * 5 - d1 * 11 + 60 == 0, d0 * 11 + d1 * 7 - 24 == 0, d0 * 7 + d1 * 5 + 88 == 0)> @@ -11,7 +10,7 @@ func @test_gaussian_elimination_empty_set0() { affine.for %arg0 = 1 to 10 { affine.for %arg1 = 1 to 100 { - // CHECK: affine.if #[[$SET_EMPTY]]() + // CHECK-NOT: affine.if affine.if affine_set<(d0, d1) : (2 == 0)>(%arg0, %arg1) { call @external() : () -> () } @@ -24,7 +23,7 @@ func @test_gaussian_elimination_empty_set1() { affine.for %arg0 = 1 to 10 { affine.for %arg1 = 1 to 100 { - // CHECK: affine.if #[[$SET_EMPTY]]() + // CHECK-NOT: affine.if affine.if affine_set<(d0, d1) : (1 >= 0, -1 >= 0)> (%arg0, %arg1) { call @external() : () -> () } @@ -52,7 +51,7 @@ %c11 = constant 11 : index affine.for %arg0 = 1 to 10 { affine.for %arg1 = 1 to 100 { - // CHECK: #[[$SET_EMPTY]]() + // CHECK-NOT: affine.if affine.if affine_set<(d0, d1)[s0, s1] : (d0 - s0 == 0, d0 + s0 == 0, s0 - 1 == 0)>(%arg0, %arg1)[%c7, %c11] { call @external() : () -> () } @@ -95,7 +94,7 @@ %c11 = constant 11 : index affine.for %arg0 = 1 to 10 { affine.for %arg1 = 1 to 100 { - // CHECK: #[[$SET_EMPTY]]() + // CHECK-NOT: affine.if affine.if #set_2d_empty(%arg0, %arg1)[%c7, %c11] { call @external() : () -> () } @@ -162,33 +161,33 @@ func @test_empty_set(%N : index) { affine.for %i = 0 to 10 { affine.for %j = 0 to 10 { - // CHECK: affine.if #[[$SET_EMPTY]]() + // CHECK-NOT: affine.if affine.if affine_set<(d0, d1) : (d0 - d1 >= 0, d1 - d0 - 1 >= 0)>(%i, %j) { "foo"() : () -> () } - // CHECK: affine.if #[[$SET_EMPTY]]() + // CHECK-NOT: affine.if affine.if affine_set<(d0) : (d0 >= 0, -d0 - 1 >= 0)>(%i) { "bar"() : () -> () } - // CHECK: affine.if #[[$SET_EMPTY]]() + // CHECK-NOT: affine.if affine.if affine_set<(d0) : (d0 >= 0, -d0 - 1 >= 0)>(%i) { "foo"() : () -> () } - // CHECK: affine.if #[[$SET_EMPTY]]() + // CHECK-NOT: affine.if affine.if affine_set<(d0)[s0, s1] : (d0 >= 0, -d0 + s0 - 1 >= 0, -s0 >= 0)>(%i)[%N, %N] { "bar"() : () -> () } - // CHECK: affine.if #[[$SET_EMPTY]]() + // CHECK-NOT: affine.if // The set below implies d0 = d1; so d1 >= d0, but d0 >= d1 + 1. affine.if affine_set<(d0, d1, d2) : (d0 - d1 == 0, d2 - d0 >= 0, d0 - d1 - 1 >= 0)>(%i, %j, %N) { "foo"() : () -> () } - // CHECK: affine.if #[[$SET_EMPTY]]() + // CHECK-NOT: affine.if // The set below has rational solutions but no integer solutions; GCD test catches it. affine.if affine_set<(d0, d1) : (d0*2 -d1*2 - 1 == 0, d0 >= 0, -d0 + 100 >= 0, d1 >= 0, -d1 + 100 >= 0)>(%i, %j) { "foo"() : () -> () } - // CHECK: affine.if #[[$SET_EMPTY]]() + // CHECK-NOT: affine.if affine.if affine_set<(d0, d1) : (d1 == 0, d0 - 1 >= 0, - d0 - 1 >= 0)>(%i, %j) { "foo"() : () -> () } @@ -198,12 +197,12 @@ affine.for %k = 0 to 10 { affine.for %l = 0 to 10 { // Empty because no multiple of 8 lies between 4 and 7. - // CHECK: affine.if #[[$SET_EMPTY]]() + // CHECK-NOT: affine.if affine.if affine_set<(d0) : (8*d0 - 4 >= 0, -8*d0 + 7 >= 0)>(%k) { "foo"() : () -> () } // Same as above but with equalities and inequalities. - // CHECK: affine.if #[[$SET_EMPTY]]() + // CHECK-NOT: affine.if affine.if affine_set<(d0, d1) : (d0 - 4*d1 == 0, 4*d1 - 5 >= 0, -4*d1 + 7 >= 0)>(%k, %l) { "foo"() : () -> () } @@ -211,12 +210,12 @@ // 8*d1 here is a multiple of 4, and so can't lie between 9 and 11. GCD // tightening will tighten constraints to 4*d0 + 8*d1 >= 12 and 4*d0 + // 8*d1 <= 8; hence infeasible. - // CHECK: affine.if #[[$SET_EMPTY]]() + // CHECK-NOT: affine.if affine.if affine_set<(d0, d1) : (4*d0 + 8*d1 - 9 >= 0, -4*d0 - 8*d1 + 11 >= 0)>(%k, %l) { "foo"() : () -> () } // Same as above but with equalities added into the mix. - // CHECK: affine.if #[[$SET_EMPTY]]() + // CHECK-NOT: affine.if affine.if affine_set<(d0, d1, d2) : (d0 - 4*d2 == 0, d0 + 8*d1 - 9 >= 0, -d0 - 8*d1 + 11 >= 0)>(%k, %k, %l) { "foo"() : () -> () } @@ -224,7 +223,7 @@ } affine.for %m = 0 to 10 { - // CHECK: affine.if #[[$SET_EMPTY]]() + // CHECK-NOT: affine.if affine.if affine_set<(d0) : (d0 mod 2 - 3 == 0)> (%m) { "foo"() : () -> () } @@ -239,8 +238,6 @@ func private @external() -> () // CHECK-DAG: #[[$SET:.*]] = affine_set<()[s0] : (s0 >= 0, -s0 + 50 >= 0) -// CHECK-DAG: #[[$EMPTY_SET:.*]] = affine_set<() : (1 == 0) -// CHECK-DAG: #[[$UNIV_SET:.*]] = affine_set<() : (0 == 0) // CHECK-LABEL: func @simplify_set func @simplify_set(%a : index, %b : index) { @@ -248,11 +245,11 @@ affine.if affine_set<(d0, d1) : (d0 - d1 + d1 + d0 >= 0, 2 >= 0, d0 >= 0, -d0 + 50 >= 0, -d0 + 100 >= 0)>(%a, %b) { call @external() : () -> () } - // CHECK: affine.if #[[$EMPTY_SET]] + // CHECK-NOT: affine.if affine.if affine_set<(d0, d1) : (d0 mod 2 - 1 == 0, d0 - 2 * (d0 floordiv 2) == 0)>(%a, %b) { call @external() : () -> () } - // CHECK: affine.if #[[$UNIV_SET]] + // CHECK-NOT: affine.if affine.if affine_set<(d0, d1) : (1 >= 0, 3 >= 0)>(%a, %b) { call @external() : () -> () } @@ -325,3 +322,79 @@ // CHECK: %[[CST:.*]] = constant 0 return %a : index } + +// ----- + +// Two external functions that we will use in bodies to avoid DCE. +func private @external() -> () +func private @external1() -> () + +// CHECK-LABEL: func @test_always_true_if_elimination() { +func @test_always_true_if_elimination() { + affine.for %arg0 = 1 to 10 { + affine.for %arg1 = 1 to 100 { + affine.if affine_set<(d0, d1) : (1 >= 0)> (%arg0, %arg1) { + call @external() : () -> () + } else { + call @external1() : () -> () + } + } + } + return +} + +// CHECK: affine.for +// CHECK-NEXT: affine.for +// CHECK-NEXT: call @external() +// CHECK-NEXT: } +// CHECK-NEXT: } + +// CHECK-LABEL: func @test_always_false_if_elimination() { +func @test_always_false_if_elimination() { + // CHECK: affine.for + affine.for %arg0 = 1 to 10 { + // CHECK: affine.for + affine.for %arg1 = 1 to 100 { + // CHECK: call @external1() + // CHECK-NOT: affine.if + affine.if affine_set<(d0, d1) : (-1 >= 0)> (%arg0, %arg1) { + call @external() : () -> () + } else { + call @external1() : () -> () + } + } + } + return +} + + +// Testing: Affine.If is not trivially true or false, nothing happens. +// CHECK-LABEL: func @test_dimensional_if_elimination() { +func @test_dimensional_if_elimination() { + affine.for %arg0 = 1 to 10 { + affine.for %arg1 = 1 to 100 { + // CHECK: affine.if + // CHECK: } else { + affine.if affine_set<(d0, d1) : (d0-1 == 0)> (%arg0, %arg1) { + call @external() : () -> () + } else { + call @external() : () -> () + } + } + } + return +} + +// Testing: Affine.If don't get removed if it is returning results. +// CHECK-LABEL: func @test_num_results_if_elimination +func @test_num_results_if_elimination() -> f32 { + %zero = constant 0.0 : f32 + // CHECK: affine.if + %0 = affine.if affine_set<() : ()> () -> f32 { + affine.yield %zero : f32 + // CHECK: else { + } else { + affine.yield %zero : f32 + } + return %0 : f32 +}