diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h --- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h +++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h @@ -397,6 +397,14 @@ void visitRegionSuccessors(RegionBranchOpInterface branch, ArrayRef operands); + /// Visit a terminator (an op implementing `RegionBranchTerminatorOpInterface` + /// or a return-like op) to compute the lattice values of its operands, given + /// its parent op `branch`. The lattice value of an operand is determined + /// based on the corresponding arguments in `terminator`'s region + /// successor(s). + void visitRegionSuccessorsFromTerminator(Operation *terminator, + RegionBranchOpInterface branch); + /// Get the lattice element for a value, and also set up /// dependencies so that the analysis on the given ProgramPoint is re-invoked /// if the value changes. diff --git a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp --- a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp @@ -89,13 +89,15 @@ void LivenessAnalysis::visitBranchOperand(OpOperand &operand) { // We know (at the moment) and assume (for the future) that `operand` is a - // non-forwarded branch operand of an op of type `RegionBranchOpInterface`, - // `BranchOpInterface`, or `RegionBranchTerminatorOpInterface`. + // non-forwarded branch operand of a `RegionBranchOpInterface`, + // `BranchOpInterface`, `RegionBranchTerminatorOpInterface` or return-like op. Operation *op = operand.getOwner(); assert((isa(op) || isa(op) || - isa(op)) && + isa(op) || + op->hasTrait()) && "expected the op to be `RegionBranchOpInterface`, " - "`BranchOpInterface`, or `RegionBranchTerminatorOpInterface`"); + "`BranchOpInterface`, `RegionBranchTerminatorOpInterface`, or " + "return-like"); // The lattices of the non-forwarded branch operands don't get updated like // the forwarded branch operands or the non-branch operands. Thus they need @@ -120,10 +122,14 @@ // successors. blocks = op->getSuccessors(); } else { - // When the op is a `RegionBranchTerminatorOpInterface`, like a - // `scf.condition` op, its branch operand controls the flow into this op's - // parent's (which is a `RegionBranchOpInterface`'s) regions. - for (Region ®ion : op->getParentOp()->getRegions()) { + // When the op is a `RegionBranchTerminatorOpInterface`, like an + // `scf.condition` op or return-like, like an `scf.yield` op, its branch + // operand controls the flow into this op's parent's (which is a + // `RegionBranchOpInterface`'s) regions. + Operation *parentOp = op->getParentOp(); + assert(isa(parentOp) && + "expected parent op to implement `RegionBranchOpInterface`"); + for (Region ®ion : parentOp->getRegions()) { for (Block &block : region) blocks.push_back(&block); } @@ -155,10 +161,11 @@ visitOperation(op, operandLiveness, resultsLiveness); // We also visit the parent op with the parent's results and this operand if - // `op` is a `RegionBranchTerminatorOpInterface` because its non-forwarded - // operand depends on not only its memory effects/results but also on those of - // its parent's. - if (!isa(op)) + // `op` is a `RegionBranchTerminatorOpInterface` or return-like because its + // non-forwarded operand depends on not only its memory effects/results but + // also on those of its parent's. + if (!isa(op) && + !op->hasTrait()) return; Operation *parentOp = op->getParentOp(); SmallVector parentResultsLiveness; diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp --- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp @@ -429,55 +429,19 @@ } } - // The block arguments of the branched to region flow back into the - // operands of the yield operation. - if (auto terminator = dyn_cast(op)) { + // When the region of an op implementing `RegionBranchOpInterface` has a + // terminator implementing `RegionBranchTerminatorOpInterface` or a + // return-like terminator, the region's successors' arguments flow back into + // the "successor operands" of this terminator. + if (isa(op) || + op->hasTrait()) { if (auto branch = dyn_cast(op->getParentOp())) { - SmallVector successors; - SmallVector operands(op->getNumOperands(), nullptr); - branch.getSuccessorRegions(op->getParentRegion()->getRegionNumber(), - operands, successors); - // All operands not forwarded to any successor. This set can be - // non-contiguous in the presence of multiple successors. - BitVector unaccounted(op->getNumOperands(), true); - - for (const RegionSuccessor &successor : successors) { - ValueRange inputs = successor.getSuccessorInputs(); - Region *region = successor.getSuccessor(); - OperandRange operands = - region ? terminator.getSuccessorOperands(region->getRegionNumber()) - : terminator.getSuccessorOperands({}); - MutableArrayRef opoperands = operandsToOpOperands(operands); - for (auto [opoperand, input] : llvm::zip(opoperands, inputs)) { - meet(getLatticeElement(opoperand.get()), - *getLatticeElementFor(op, input)); - unaccounted.reset( - const_cast(opoperand).getOperandNumber()); - } - } - // Visit operands of the branch op not forwarded to the next region. - // (Like e.g. the boolean of `scf.conditional`) - for (int index : unaccounted.set_bits()) { - visitBranchOperand(op->getOpOperand(index)); - } + visitRegionSuccessorsFromTerminator(op, branch); return; } } - // yield-like ops usually don't implement `RegionBranchTerminatorOpInterface`, - // since they behave like a return in the sense that they forward to the - // results of some other (here: the parent) op. if (op->hasTrait()) { - if (auto branch = dyn_cast(op->getParentOp())) { - OperandRange operands = op->getOperands(); - ResultRange results = op->getParentOp()->getResults(); - assert(results.size() == operands.size() && - "Can't derive arg mapping for yield-like op."); - for (auto [operand, result] : llvm::zip(operands, results)) - meet(getLatticeElement(operand), *getLatticeElementFor(op, result)); - return; - } - // Going backwards, the operands of the return are derived from the // results of all CallOps calling this CallableOp. if (auto callable = dyn_cast(op->getParentOp())) { @@ -535,6 +499,46 @@ } } +void AbstractSparseBackwardDataFlowAnalysis:: + visitRegionSuccessorsFromTerminator(Operation *terminator, + RegionBranchOpInterface branch) { + assert(isa(terminator) || + terminator->hasTrait() && + "expected a `RegionBranchTerminatorOpInterface` op or a " + "return-like op"); + assert(terminator->getParentOp() == branch.getOperation() && + "expected `branch` to be the parent op of `terminator`"); + + SmallVector operandAttributes(terminator->getNumOperands(), + nullptr); + SmallVector successors; + branch.getSuccessorRegions(terminator->getParentRegion()->getRegionNumber(), + operandAttributes, successors); + // All operands not forwarded to any successor. This set can be + // non-contiguous in the presence of multiple successors. + BitVector unaccounted(terminator->getNumOperands(), true); + + for (const RegionSuccessor &successor : successors) { + ValueRange inputs = successor.getSuccessorInputs(); + Region *region = successor.getSuccessor(); + OperandRange operands = + region ? *getRegionBranchSuccessorOperands(terminator, + region->getRegionNumber()) + : *getRegionBranchSuccessorOperands(terminator, {}); + MutableArrayRef opOperands = operandsToOpOperands(operands); + for (auto [opOperand, input] : llvm::zip(opOperands, inputs)) { + meet(getLatticeElement(opOperand.get()), + *getLatticeElementFor(terminator, input)); + unaccounted.reset(const_cast(opOperand).getOperandNumber()); + } + } + // Visit operands of the branch op not forwarded to the next region. + // (Like e.g. the boolean of `scf.conditional`) + for (int index : unaccounted.set_bits()) { + visitBranchOperand(terminator->getOpOperand(index)); + } +} + const AbstractSparseLattice * AbstractSparseBackwardDataFlowAnalysis::getLatticeElementFor(ProgramPoint point, Value value) { diff --git a/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir b/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir --- a/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir +++ b/mlir/test/Analysis/DataFlow/test-liveness-analysis.mlir @@ -16,8 +16,6 @@ // where its op could take the control has an op with memory effects" // %arg2 is live because it can make the control go into a block with a memory // effecting op. -// Note that if `visitBranchOperand()` was left empty, it would have been -// incorrectly marked as "not live". // CHECK-LABEL: test_tag: br: // CHECK-NEXT: operand #0: live // CHECK-NEXT: operand #1: live @@ -41,8 +39,6 @@ // where its op could take the control has an op with memory effects" // %arg0 is live because it can make the control go into a block with a memory // effecting op. -// Note that if `visitBranchOperand()` was left empty, it would have been -// incorrectly marked as "not live". // CHECK-LABEL: test_tag: flag: // CHECK-NEXT: operand #0: live func.func @test_3_BranchOpInterface_type_1.b(%arg0: i32, %arg1: memref, %arg2: memref) { @@ -77,26 +73,35 @@ // Positive test: Type (3) "is used to compute a value of type (1) or (2)" // %arg1 is live because the scf.while has a live result and %arg1 is a // non-forwarded branch operand. -// Note that if `visitBranchOperand()` was left empty, it would have been -// incorrectly marked as "not live". // %arg2 is live because it is forwarded to the live result of the scf.while // op. -// Negative test: %arg3 is not live even though %arg1 and %arg2 are live -// because it is neither a non-forwarded branch operand nor a forwarded -// operand that forwards to a live value. It actually is a forwarded operand -// that forwards to a non-live value. +// %arg5 is live because it is forwarded to %arg8 which is live. +// %arg8 is live because it is forwarded to %arg4 which is live as it writes +// to memory. +// Negative test: +// %arg3 is not live even though %arg1, %arg2, and %arg5 are live because it +// is neither a non-forwarded branch operand nor a forwarded operand that +// forwards to a live value. It actually is a forwarded operand that forwards +// to non-live values %0#1 and %arg7. // CHECK-LABEL: test_tag: condition: // CHECK-NEXT: operand #0: live // CHECK-NEXT: operand #1: live // CHECK-NEXT: operand #2: not live +// CHECK-NEXT: operand #3: live +// CHECK-LABEL: test_tag: add: +// CHECK-NEXT: operand #0: live func.func @test_5_RegionBranchTerminatorOpInterface_type_3(%arg0: memref, %arg1: i1) -> (i32) { %c0_i32 = arith.constant 0 : i32 %c1_i32 = arith.constant 1 : i32 - %0:2 = scf.while (%arg2 = %c0_i32, %arg3 = %c1_i32) : (i32, i32) -> (i32, i32) { - scf.condition(%arg1) {tag = "condition"} %arg2, %arg3 : i32, i32 + %c2_i32 = arith.constant 2 : i32 + %0:3 = scf.while (%arg2 = %c0_i32, %arg3 = %c1_i32, %arg4 = %c2_i32, %arg5 = %c2_i32) : (i32, i32, i32, i32) -> (i32, i32, i32) { + memref.store %arg4, %arg0[] : memref + scf.condition(%arg1) {tag = "condition"} %arg2, %arg3, %arg5 : i32, i32, i32 } do { - ^bb0(%arg2: i32, %arg3: i32): - scf.yield %arg2, %arg3 : i32, i32 + ^bb0(%arg6: i32, %arg7: i32, %arg8: i32): + %1 = arith.addi %arg8, %arg8 {tag = "add"} : i32 + %c3_i32 = arith.constant 3 : i32 + scf.yield %arg6, %arg7, %arg8, %c3_i32 : i32, i32, i32, i32 } return %0#0 : i32 } @@ -112,12 +117,10 @@ // zero, ten, and one are live because they are used to decide the number of // times the `for` loop executes, which in turn decides the value stored in // memory. -// Note that if `visitBranchOperand()` was left empty, they would have been -// incorrectly marked as "not live". // in_private0 and x are also live because they decide the value stored in // memory. -// Negative test: y is not live even though the non-forwarded branch operand -// and x are live. +// Negative test: +// y is not live even though the non-forwarded branch operand and x are live. // CHECK-LABEL: test_tag: in_private0: // CHECK-NEXT: operand #0: live // CHECK-NEXT: operand #1: live diff --git a/mlir/test/Analysis/DataFlow/test-written-to.mlir b/mlir/test/Analysis/DataFlow/test-written-to.mlir --- a/mlir/test/Analysis/DataFlow/test-written-to.mlir +++ b/mlir/test/Analysis/DataFlow/test-written-to.mlir @@ -168,9 +168,10 @@ // CHECK-LABEL: test_tag: zero // CHECK: result #0: [c] // CHECK-LABEL: test_tag: init -// CHECK: result #0: [a b] +// CHECK: result #0: [a b c] // CHECK-LABEL: test_tag: condition // CHECK: operand #0: [brancharg0] +// CHECK: operand #2: [a b c] func.func @test_while(%m0: memref, %init : i32, %cond: i1) { %zero = arith.constant {tag = "zero"} 0 : i32 %init2 = arith.addi %init, %init {tag = "init"} : i32 @@ -181,7 +182,7 @@ ^bb0(%arg1: i32, %arg2: i32): memref.store %arg1, %m0[] {tag_name = "c"} : memref %res = arith.addi %arg2, %arg2 : i32 - scf.yield %arg1, %res: i32, i32 + scf.yield %res, %res: i32, i32 } memref.store %1, %m0[] {tag_name = "b"} : memref return @@ -189,6 +190,32 @@ // ----- +// CHECK-LABEL: test_tag: zero +// CHECK: result #0: [] +// CHECK-LABEL: test_tag: one +// CHECK: result #0: [a] +// CHECK-LABEL: test_tag: condition +// CHECK: operand #0: [brancharg0] +// +// The important thing to note in this test is that the sparse backward dataflow +// analysis framework also works on complex region branch ops like this one +// where the number of operands in the `scf.yield` op don't match the number of +// results in the parent op. +func.func @test_complex_while(%m0: memref, %cond: i1) { + %zero = arith.constant {tag = "zero"} 0 : i32 + %one = arith.constant {tag = "one"} 1 : i32 + %0 = scf.while (%arg1 = %zero, %arg2 = %one) : (i32, i32) -> (i32) { + scf.condition(%cond) {tag = "condition"} %arg2 : i32 + } do { + ^bb0(%arg1: i32): + scf.yield %arg1, %arg1: i32, i32 + } + memref.store %0, %m0[] {tag_name = "a"} : memref + return +} + +// ----- + // CHECK-LABEL: test_tag: zero // CHECK: result #0: [brancharg0] // CHECK-LABEL: test_tag: ten