diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -1905,7 +1905,7 @@ } }; -/// Removes Affine.If cond if the condition is always true or false in certain +/// Removes affine.if cond if the condition is always true or false in certain /// trivial cases. Promotes the then/else block in the parent operation block. struct AlwaysTrueOrFalseIf : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -1913,35 +1913,48 @@ LogicalResult matchAndRewrite(AffineIfOp op, PatternRewriter &rewriter) const override { - // If affine.if is returning results then don't remove it. - // TODO: Similar simplication can be done when affine.if return results. - if (op.getNumResults() > 0) - return failure(); + auto isTriviallyFalse = [](IntegerSet iSet) { + return iSet.isEmptyIntegerSet(); + }; - IntegerSet conditionSet = op.getIntegerSet(); + auto isTriviallyTrue = [](IntegerSet iSet) { + return (iSet.getNumEqualities() == 1 && iSet.getNumInequalities() == 0 && + iSet.getConstraint(0) == 0); + }; + + IntegerSet affineIfConditions = op.getIntegerSet(); Block *blockToMove; - if (conditionSet.isEmptyIntegerSet()) { - // If the else region is not there, simply remove the Affine.if - // operation. - if (!op.hasElse()) { + if (isTriviallyFalse(affineIfConditions)) { + // The absence, or equivalently, the emptiness of the else region need not + // be checked when affine.if is returning results because if an affine.if + // operation is returning results, it always has a non-empty else region. + if (op.getNumResults() == 0 && !op.hasElse()) { + // If the else region is absent, or equivalently, empty, remove the + // affine.if operation (which is not returning any results). rewriter.eraseOp(op); return success(); } blockToMove = op.getElseBlock(); - } else if (conditionSet.getNumEqualities() == 1 && - conditionSet.getNumInequalities() == 0 && - conditionSet.getConstraint(0) == 0) { - // Condition to check for trivially true condition (0==0). + } else if (isTriviallyTrue(affineIfConditions)) { blockToMove = op.getThenBlock(); } else { return failure(); } - // Remove the terminator from the block as it already exists in parent - // block. - Operation *blockTerminator = blockToMove->getTerminator(); - rewriter.eraseOp(blockTerminator); + Operation *blockToMoveTerminator = blockToMove->getTerminator(); + // Promote the "blockToMove" block to the parent operation block between the + // prologue and epilogue of "op". rewriter.mergeBlockBefore(blockToMove, op); - rewriter.eraseOp(op); + // Replace the "op" operation with the operands of the + // "blockToMoveTerminator" operation. Note that "blockToMoveTerminator" is + // the affine.yield operation present in the "blockToMove" block. It has no + // operands when affine.if is not returning results and therefore, in that + // case, replaceOp just erases "op". When affine.if is not returning + // results, the affine.yield operation can be omitted. It gets inserted + // implicitly. + rewriter.replaceOp(op, blockToMoveTerminator->getOperands()); + // Erase the "blockToMoveTerminator" operation since it is now in the parent + // operation block, which already has its own terminator. + rewriter.eraseOp(blockToMoveTerminator); return success(); } }; @@ -2051,6 +2064,7 @@ ->getAttrOfType(getConditionAttrName()) .getValue(); } + void AffineIfOp::setIntegerSet(IntegerSet newSet) { (*this)->setAttr(getConditionAttrName(), IntegerSetAttr::get(newSet)); } diff --git a/mlir/test/Dialect/Affine/simplify-affine-structures.mlir b/mlir/test/Dialect/Affine/simplify-affine-structures.mlir --- a/mlir/test/Dialect/Affine/simplify-affine-structures.mlir +++ b/mlir/test/Dialect/Affine/simplify-affine-structures.mlir @@ -368,7 +368,7 @@ } -// Testing: Affine.If is not trivially true or false, nothing happens. +// Testing: affine.if is not trivially true or false, nothing happens. // CHECK-LABEL: func @test_dimensional_if_elimination() { func @test_dimensional_if_elimination() { affine.for %arg0 = 1 to 10 { @@ -385,16 +385,96 @@ return } -// Testing: Affine.If don't get removed if it is returning results. +// Testing: affine.if gets removed. // CHECK-LABEL: func @test_num_results_if_elimination -func @test_num_results_if_elimination() -> f32 { - %zero = constant 0.0 : f32 +func @test_num_results_if_elimination() -> index { + // CHECK: %[[old01:.*]] = constant 0 : index + %zero = constant 0 : index + %0 = affine.if affine_set<() : ()> () -> index { + affine.yield %zero : index + } else { + affine.yield %zero : index + } + // CHECK-NEXT: return %[[old01]] : index + return %0 : index +} + + +// Three more test functions involving affine.if operations which are +// returning results: + +// Testing: affine.if gets removed. Else block get promoted. +// CHECK-LABEL: func @test_trivially_false_returning_two_results +func @test_trivially_false_returning_two_results() -> (index, index) { + // CHECK: %[[falsev40:.*]] = constant 7 : index + // CHECK: %[[falsev41:.*]] = constant 13 : index + %var1 = constant 7 : index + %var2 = constant 13 : index + // CHECK: %[[false40:.*]] = constant 2 : index + // CHECK: %[[false41:.*]] = constant 3 : index + %res:2 = affine.if affine_set<(d0, d1) : (5 >= 0, -2 >= 0)> (%var1, %var2) -> (index, index) { + %zero = constant 0 : index + %one = constant 1 : index + affine.yield %zero, %one : index, index + } else { + %two = constant 2 : index + %three = constant 3 : index + affine.yield %two, %three : index, index + } + // CHECK-NEXT: return %[[false40]], %[[false41]] : index, index + return %res#0, %res#1 : index, index +} + +// Testing: affine.if gets removed. Then block get promoted. +// CHECK-LABEL: func @test_trivially_true_returning_five_results +func @test_trivially_true_returning_five_results() -> (index, index, index, index, index) { + // CHECK: %[[truev10:.*]] = constant 7 : index + // CHECK: %[[truev11:.*]] = constant 13 : index + %var1 = constant 7 : index + %var2 = constant 13 : index + // CHECK: %[[true0:.*]] = constant 0 : index + // CHECK: %[[true1:.*]] = constant 1 : index + // CHECK: %[[true2:.*]] = constant 2 : index + // CHECK: %[[true3:.*]] = constant 3 : index + // CHECK: %[[true4:.*]] = constant 4 : index + %res:5 = affine.if affine_set<(d0, d1) : (1 >= 0, 3 >= 0)>(%var1, %var2) -> (index, index, index, index, index) { + %zero = constant 0 : index + %one = constant 1 : index + %two = constant 2 : index + %three = constant 3 : index + %four = constant 4 : index + affine.yield %zero, %one, %two, %three, %four : index, index, index, index, index + } else { + %five = constant 5 : index + %six = constant 6 : index + %seven = constant 7 : index + %eight = constant 8 : index + %nine = constant 9 : index + affine.yield %five, %six, %seven, %eight, %nine : index, index, index, index, index + } + // CHECK-NEXT: return %[[true0]], %[[true1]], %[[true2]], %[[true3]], %[[true4]] : index, index, index, index, index + return %res#0, %res#1, %res#2, %res#3, %res#4 : index, index, index, index, index +} + +// Testing: affine.if doesn't get removed. +// CHECK-LABEL: func @test_not_trivially_true_or_false_returning_three_results +func @test_not_trivially_true_or_false_returning_three_results() -> (index, index, index) { + // CHECK: %[[neitherv10:.*]] = constant 7 : index + // CHECK: %[[neitherv11:.*]] = constant 13 : index + %var1 = constant 7 : index + %var2 = constant 13 : index // CHECK: affine.if - %0 = affine.if affine_set<() : ()> () -> f32 { - affine.yield %zero : f32 - // CHECK: else { + %res:3 = affine.if affine_set<(d0, d1) : (d0 - 1 == 0)>(%var1, %var2) -> (index, index, index) { + %zero = constant 0 : index + %one = constant 1 : index + %two = constant 2 : index + affine.yield %zero, %one, %two : index, index, index + // CHECK: } else { } else { - affine.yield %zero : f32 + %three = constant 3 : index + %four = constant 4 : index + %five = constant 5 : index + affine.yield %three, %four, %five : index, index, index } - return %0 : f32 + return %res#0, %res#1, %res#2 : index, index, index }