diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -1106,12 +1106,179 @@ return success(); } }; + +// Allow the true region of an if to assume the condition is true +// and vice versa. For example: +// +// scf.if %cmp { +// print(%cmp) +// } +// +// becomes +// +// scf.if %cmp { +// print(true) +// } +// +struct ConditionPropagation : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(IfOp op, + PatternRewriter &rewriter) const override { + // Early exit if the condition is constant since replacing a constant + // in the body with another constant isn't a simplification. + if (op.condition().getDefiningOp()) + return failure(); + + bool changed = false; + mlir::Type i1Ty = rewriter.getI1Type(); + + // true / false values for use in body + // These variables serve to prevent creating duplicate constants + Value constantTrue = nullptr; + Value constantFalse = nullptr; + + for (OpOperand &use : + llvm::make_early_inc_range(op.condition().getUses())) { + if (op.thenRegion().isAncestor(use.getOwner()->getParentRegion())) { + changed = true; + + if (!constantTrue) + constantTrue = rewriter.create( + op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1)); + + rewriter.updateRootInPlace(use.getOwner(), + [&]() { use.set(constantTrue); }); + } else if (op.elseRegion().isAncestor( + use.getOwner()->getParentRegion())) { + changed = true; + + if (!constantFalse) + constantFalse = rewriter.create( + op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0)); + + rewriter.updateRootInPlace(use.getOwner(), + [&]() { use.set(constantFalse); }); + } + } + + return success(changed); + } +}; + +// Remove any statements from an if that are equivalent to the condition +// or its negation. For example: +// +// %res:2 = scf.if %cmp { +// yield something(), true +// } else { +// yield something2(), false +// } +// print(%res#1) +// +// becomes + +// %res = scf.if %cmp { +// yield something() +// } else { +// yield something2() +// } +// print(%cmp) +// +// Additionally if both branches yield the same value, replace all uses +// of the result with the yielded value +// +// %res:2 = scf.if %cmp { +// yield something(), %arg1 +// } else { +// yield something2(), %arg1 +// } +// print(%res#1) +// +// becomes + +// %res = scf.if %cmp { +// yield something() +// } else { +// yield something2() +// } +// print(%arg1) +struct ReplaceIfYieldWithConditionOrValue : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(IfOp op, + PatternRewriter &rewriter) const override { + // Early exit if there are no results that could be replaced. + if (op.getNumResults() == 0) + return failure(); + + auto trueYield = cast(op.thenRegion().back().getTerminator()); + auto falseYield = + cast(op.elseRegion().back().getTerminator()); + + rewriter.setInsertionPoint(op->getBlock(), + op.getOperation()->getIterator()); + bool changed = false; + mlir::Type i1Ty = rewriter.getI1Type(); + for (auto tup : + llvm::zip(trueYield.results(), falseYield.results(), op.results())) { + Value trueResult, falseResult, opResult; + std::tie(trueResult, falseResult, opResult) = tup; + + if (trueResult == falseResult) { + if (!opResult.use_empty()) { + opResult.replaceAllUsesWith(trueResult); + changed = true; + } + continue; + } + + auto trueYield = trueResult.getDefiningOp(); + if (!trueYield) + continue; + + if (!trueYield.getType().isInteger(1)) + continue; + + auto falseYield = falseResult.getDefiningOp(); + if (!falseYield) + continue; + + bool trueVal = trueYield.getValue().cast().getValue(); + bool falseVal = falseYield.getValue().cast().getValue(); + if (!trueVal && falseVal) { + Value notCond = nullptr; + for (OpOperand &use : llvm::make_early_inc_range(opResult.getUses())) { + changed = true; + rewriter.updateRootInPlace(use.getOwner(), [&]() { + if (!notCond) + notCond = rewriter.create( + op.getLoc(), op.condition(), + rewriter.create( + op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1))); + use.set(notCond); + }); + } + } + if (trueVal && !falseVal) { + for (OpOperand &use : llvm::make_early_inc_range(opResult.getUses())) { + changed = true; + rewriter.updateRootInPlace(use.getOwner(), + [&]() { use.set(op.condition()); }); + } + } + } + return success(changed); + } +}; + } // namespace void IfOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results + .add(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -103,22 +103,25 @@ func @one_unused(%cond: i1) -> (index) { %c0 = constant 0 : index %c1 = constant 1 : index + %c2 = constant 2 : index + %c3 = constant 3 : index %0, %1 = scf.if %cond -> (index, index) { call @side_effect() : () -> () scf.yield %c0, %c1 : index, index } else { - scf.yield %c0, %c1 : index, index + scf.yield %c2, %c3 : index, index } return %1 : index } // CHECK-LABEL: func @one_unused // CHECK: [[C0:%.*]] = constant 1 : index +// CHECK: [[C3:%.*]] = constant 3 : index // CHECK: [[V0:%.*]] = scf.if %{{.*}} -> (index) { // CHECK: call @side_effect() : () -> () // CHECK: scf.yield [[C0]] : index // CHECK: } else -// CHECK: scf.yield [[C0]] : index +// CHECK: scf.yield [[C3]] : index // CHECK: } // CHECK: return [[V0]] : index @@ -128,12 +131,14 @@ func @nested_unused(%cond1: i1, %cond2: i1) -> (index) { %c0 = constant 0 : index %c1 = constant 1 : index + %c2 = constant 2 : index + %c3 = constant 3 : index %0, %1 = scf.if %cond1 -> (index, index) { %2, %3 = scf.if %cond2 -> (index, index) { call @side_effect() : () -> () scf.yield %c0, %c1 : index, index } else { - scf.yield %c0, %c1 : index, index + scf.yield %c2, %c3 : index, index } scf.yield %2, %3 : index, index } else { @@ -144,12 +149,13 @@ // CHECK-LABEL: func @nested_unused // CHECK: [[C0:%.*]] = constant 1 : index +// CHECK: [[C3:%.*]] = constant 3 : index // CHECK: [[V0:%.*]] = scf.if {{.*}} -> (index) { // CHECK: [[V1:%.*]] = scf.if {{.*}} -> (index) { // CHECK: call @side_effect() : () -> () // CHECK: scf.yield [[C0]] : index // CHECK: } else -// CHECK: scf.yield [[C0]] : index +// CHECK: scf.yield [[C3]] : index // CHECK: } // CHECK: scf.yield [[V1]] : index // CHECK: } else @@ -610,3 +616,111 @@ %res = subtensor_insert %2 into %t1[0, 0] [32, 1024] [1, 1] : tensor<32x1024xf32> into tensor<1024x1024xf32> return %res : tensor<1024x1024xf32> } + + + +// CHECK-LABEL: @cond_prop +func @cond_prop(%arg0 : i1) -> index { + %c1 = constant 1 : index + %c2 = constant 2 : index + %c3 = constant 3 : index + %c4 = constant 4 : index + %res = scf.if %arg0 -> index { + %res1 = scf.if %arg0 -> index { + %v1 = "test.get_some_value"() : () -> i32 + scf.yield %c1 : index + } else { + %v2 = "test.get_some_value"() : () -> i32 + scf.yield %c2 : index + } + scf.yield %res1 : index + } else { + %res2 = scf.if %arg0 -> index { + %v3 = "test.get_some_value"() : () -> i32 + scf.yield %c3 : index + } else { + %v4 = "test.get_some_value"() : () -> i32 + scf.yield %c4 : index + } + scf.yield %res2 : index + } + return %res : index +} +// CHECK-NEXT: %[[c1:.+]] = constant 1 : index +// CHECK-NEXT: %[[c4:.+]] = constant 4 : index +// CHECK-NEXT: %[[if:.+]] = scf.if %arg0 -> (index) { +// CHECK-NEXT: %{{.+}} = "test.get_some_value"() : () -> i32 +// CHECK-NEXT: scf.yield %[[c1]] : index +// CHECK-NEXT: } else { +// CHECK-NEXT: %{{.+}} = "test.get_some_value"() : () -> i32 +// CHECK-NEXT: scf.yield %[[c4]] : index +// CHECK-NEXT: } +// CHECK-NEXT: return %[[if]] : index +// CHECK-NEXT:} + +// CHECK-LABEL: @replace_if_with_cond1 +func @replace_if_with_cond1(%arg0 : i1) -> (i32, i1) { + %true = constant true + %false = constant false + %res:2 = scf.if %arg0 -> (i32, i1) { + %v = "test.get_some_value"() : () -> i32 + scf.yield %v, %true : i32, i1 + } else { + %v2 = "test.get_some_value"() : () -> i32 + scf.yield %v2, %false : i32, i1 + } + return %res#0, %res#1 : i32, i1 +} +// CHECK-NEXT: %[[if:.+]] = scf.if %arg0 -> (i32) { +// CHECK-NEXT: %[[sv1:.+]] = "test.get_some_value"() : () -> i32 +// CHECK-NEXT: scf.yield %[[sv1]] : i32 +// CHECK-NEXT: } else { +// CHECK-NEXT: %[[sv2:.+]] = "test.get_some_value"() : () -> i32 +// CHECK-NEXT: scf.yield %[[sv2]] : i32 +// CHECK-NEXT: } +// CHECK-NEXT: return %[[if]], %arg0 : i32, i1 + +// CHECK-LABEL: @replace_if_with_cond2 +func @replace_if_with_cond2(%arg0 : i1) -> (i32, i1) { + %true = constant true + %false = constant false + %res:2 = scf.if %arg0 -> (i32, i1) { + %v = "test.get_some_value"() : () -> i32 + scf.yield %v, %false : i32, i1 + } else { + %v2 = "test.get_some_value"() : () -> i32 + scf.yield %v2, %true : i32, i1 + } + return %res#0, %res#1 : i32, i1 +} +// CHECK-NEXT: %true = constant true +// CHECK-NEXT: %[[toret:.+]] = xor %arg0, %true : i1 +// CHECK-NEXT: %[[if:.+]] = scf.if %arg0 -> (i32) { +// CHECK-NEXT: %[[sv1:.+]] = "test.get_some_value"() : () -> i32 +// CHECK-NEXT: scf.yield %[[sv1]] : i32 +// CHECK-NEXT: } else { +// CHECK-NEXT: %[[sv2:.+]] = "test.get_some_value"() : () -> i32 +// CHECK-NEXT: scf.yield %[[sv2]] : i32 +// CHECK-NEXT: } +// CHECK-NEXT: return %[[if]], %[[toret]] : i32, i1 + + +// CHECK-LABEL: @replace_if_with_cond3 +func @replace_if_with_cond3(%arg0 : i1, %arg2: i64) -> (i32, i64) { + %res:2 = scf.if %arg0 -> (i32, i64) { + %v = "test.get_some_value"() : () -> i32 + scf.yield %v, %arg2 : i32, i64 + } else { + %v2 = "test.get_some_value"() : () -> i32 + scf.yield %v2, %arg2 : i32, i64 + } + return %res#0, %res#1 : i32, i64 +} +// CHECK-NEXT: %[[if:.+]] = scf.if %arg0 -> (i32) { +// CHECK-NEXT: %[[sv1:.+]] = "test.get_some_value"() : () -> i32 +// CHECK-NEXT: scf.yield %[[sv1]] : i32 +// CHECK-NEXT: } else { +// CHECK-NEXT: %[[sv2:.+]] = "test.get_some_value"() : () -> i32 +// CHECK-NEXT: scf.yield %[[sv2]] : i32 +// CHECK-NEXT: } +// CHECK-NEXT: return %[[if]], %arg1 : i32, i64 \ No newline at end of file diff --git a/mlir/test/Dialect/SCF/ops.mlir b/mlir/test/Dialect/SCF/ops.mlir --- a/mlir/test/Dialect/SCF/ops.mlir +++ b/mlir/test/Dialect/SCF/ops.mlir @@ -278,4 +278,4 @@ scf.yield } return -} +} \ No newline at end of file diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -1177,11 +1177,12 @@ // ----- // CHECK-LABEL: func @clone_nested_region -func @clone_nested_region(%arg0: index, %arg1: index) -> memref { +func @clone_nested_region(%arg0: index, %arg1: index, %arg2: index) -> memref { + %cmp = cmpi eq, %arg0, %arg1 : index %0 = cmpi eq, %arg0, %arg1 : index %1 = memref.alloc(%arg0, %arg0) : memref %2 = scf.if %0 -> (memref) { - %3 = scf.if %0 -> (memref) { + %3 = scf.if %cmp -> (memref) { %9 = memref.clone %1 : memref to memref scf.yield %9 : memref } else {