diff --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td @@ -411,7 +411,7 @@ void getNumRegionInvocations(ArrayRef operands, SmallVectorImpl &countPerRegion); }]; - + let hasFolder = 1; let hasCanonicalizer = 1; } diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -13,10 +13,10 @@ #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Support/MathExtras.h" #include "mlir/Transforms/InliningUtils.h" - using namespace mlir; using namespace mlir::scf; @@ -1199,6 +1199,27 @@ } } +LogicalResult IfOp::fold(ArrayRef operands, + SmallVectorImpl &results) { + // if (!c) then A() else B() -> if c then B() else A() + if (!getElseRegion().empty()) + if (arith::XOrIOp xorStmt = getCondition().getDefiningOp()) { + if (matchPattern(xorStmt.getRhs(), m_One())) { + getConditionMutable().assign(xorStmt.getLhs()); + Block *thenBlock = &getThenRegion().front(); + // It would be nicer to use iplist::swap, but that has no implemented + // callbacks See: https://llvm.org/doxygen/ilist_8h_source.html#l00224 + getThenRegion().getBlocks().splice(getThenRegion().getBlocks().begin(), + getElseRegion().getBlocks()); + getElseRegion().getBlocks().splice(getElseRegion().getBlocks().begin(), + getThenRegion().getBlocks(), + thenBlock); + return success(); + } + } + return failure(); +} + namespace { // Pattern to remove unused IfOp results. struct RemoveUnusedResults : public OpRewritePattern { diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -447,6 +447,29 @@ // ----- +// CHECK-LABEL: func @if_condition_swap +// CHECK-NEXT: %{{.*}} = scf.if %arg0 -> (index) { +// CHECK-NEXT: %[[i1:.+]] = "test.origFalse"() : () -> index +// CHECK-NEXT: scf.yield %[[i1]] : index +// CHECK-NEXT: } else { +// CHECK-NEXT: %[[i2:.+]] = "test.origTrue"() : () -> index +// CHECK-NEXT: scf.yield %[[i2]] : index +// CHECK-NEXT: } +func @if_condition_swap(%cond: i1) -> index { + %true = arith.constant true + %not = arith.xori %cond, %true : i1 + %0 = scf.if %not -> (index) { + %1 = "test.origTrue"() : () -> index + scf.yield %1 : index + } else { + %1 = "test.origFalse"() : () -> index + scf.yield %1 : index + } + return %0 : index +} + +// ----- + // CHECK-LABEL: @remove_zero_iteration_loop func @remove_zero_iteration_loop() { %c42 = arith.constant 42 : index