diff --git a/mlir/lib/Analysis/Liveness.cpp b/mlir/lib/Analysis/Liveness.cpp --- a/mlir/lib/Analysis/Liveness.cpp +++ b/mlir/lib/Analysis/Liveness.cpp @@ -21,6 +21,30 @@ using namespace mlir; +/// Tries to find a value (eg. block or operation) that lies in the current +/// container by starting the search at the given value. +template +static ValT *findInCurrent(CurT *current, ValT *value, const GetCurT &getCurT, + const GetValT &getValT) { + while (value) { + // Check whether the current value has an associated parent. + auto *parent = getCurT(value); + if (!parent) + return nullptr; + + // Check whether we have found the current entry. + if (parent == current) + return value; + + // Check for a nested region. + Operation *targetOp = parent->getParentOp(); + if (!targetOp || !targetOp->getBlock()) + return nullptr; + value = getValT(targetOp); + } + return nullptr; +} + namespace { /// Builds and holds block information during the construction phase. struct BlockInfoBuilder { @@ -31,36 +55,57 @@ /// Fills the block builder with initial liveness information. BlockInfoBuilder(Block *block) : block(block) { + auto gatherOutValues = [&](Value value) { + // Check whether this value will be in the outValues + // set (its uses escape this block). Due to the SSA + // properties of the program, the uses must occur after + // the definition. Therefore, we do not have to check + // additional conditions to detect an escaping value. + for (OpOperand &use : value.getUses()) { + Block *ownerBlock = use.getOwner()->getBlock(); + // Find an owner block in the current region. + // Note that a value does not escape this block if it + // is used in a nested region. + ownerBlock = + findInCurrent(block->getParent(), ownerBlock, + [](Block *b) { return b->getParent(); }, + [](Operation *op) { return op->getBlock(); }); + if (ownerBlock && ownerBlock != block) { + outValues.insert(value); + break; + } + } + }; + // Mark all block arguments (phis) as defined. - for (BlockArgument argument : block->getArguments()) + for (BlockArgument argument : block->getArguments()) { + // Insert value into the set of defined values. defValues.insert(argument); - // Check all result values and whether their uses - // are inside this block or not (see outValues). + // Gather all out values of all arguments in the current block. + gatherOutValues(argument); + } + + // Gather out values of all operations in the current block. for (Operation &operation : *block) - for (Value result : operation.getResults()) { - defValues.insert(result); + for (Value result : operation.getResults()) + gatherOutValues(result); - // Check whether this value will be in the outValues - // set (its uses escape this block). Due to the SSA - // properties of the program, the uses must occur after - // the definition. Therefore, we do not have to check - // additional conditions to detect an escaping value. - for (OpOperand &use : result.getUses()) - if (use.getOwner()->getBlock() != block) { - outValues.insert(result); - break; - } - } + // Mark all nested operation results as defined. + block->walk([&](Operation *op) { + for (Value result : op->getResults()) + defValues.insert(result); + }); // Check all operations for used operands. - for (Operation &operation : block->getOperations()) - for (Value operand : operation.getOperands()) { + block->walk([&](Operation *op) { + for (Value operand : op->getOperands()) { // If the operand is already defined in the scope of this // block, we can skip the value in the use set. if (!defValues.count(operand)) useValues.insert(operand); } + }); } /// Updates live-in information of the current block. @@ -110,20 +155,32 @@ }; } // namespace +/// Walks all regions (including nested regions recursively) and invokes the +/// given function for every block. +template +static void walkRegions(MutableArrayRef regions, const FuncT &func) { + for (Region ®ion : regions) + for (Block &block : region) { + func(block); + + // Traverse all nested regions. + for (Operation &operation : block) + walkRegions(operation.getRegions(), func); + } +} + /// Builds the internal liveness block mapping. static void buildBlockMapping(MutableArrayRef regions, DenseMap &builders) { llvm::SetVector toProcess; - // Initialize all block structures - for (Region ®ion : regions) - for (Block &block : region) { - BlockInfoBuilder &builder = - builders.try_emplace(&block, &block).first->second; + walkRegions(regions, [&](Block &block) { + BlockInfoBuilder &builder = + builders.try_emplace(&block, &block).first->second; - if (builder.updateLiveIn()) - toProcess.insert(block.pred_begin(), block.pred_end()); - } + if (builder.updateLiveIn()) + toProcess.insert(block.pred_begin(), block.pred_end()); + }); // Propagate the in and out-value sets (fixpoint iteration) while (!toProcess.empty()) { @@ -257,17 +314,16 @@ DenseMap blockIds; DenseMap operationIds; DenseMap valueIds; - for (Region ®ion : operation->getRegions()) - for (Block &block : region) { - blockIds.insert({&block, blockIds.size()}); - for (BlockArgument argument : block.getArguments()) - valueIds.insert({argument, valueIds.size()}); - for (Operation &operation : block) { - operationIds.insert({&operation, operationIds.size()}); - for (Value result : operation.getResults()) - valueIds.insert({result, valueIds.size()}); - } + walkRegions(operation->getRegions(), [&](Block &block) { + blockIds.insert({&block, blockIds.size()}); + for (BlockArgument argument : block.getArguments()) + valueIds.insert({argument, valueIds.size()}); + for (Operation &operation : block) { + operationIds.insert({&operation, operationIds.size()}); + for (Value result : operation.getResults()) + valueIds.insert({result, valueIds.size()}); } + }); // Local printing helpers auto printValueRef = [&](Value value) { @@ -292,39 +348,38 @@ }; // Dump information about in and out values. - for (Region ®ion : operation->getRegions()) - for (Block &block : region) { - os << "// - Block: " << blockIds[&block] << "\n"; - auto liveness = getLiveness(&block); - os << "// --- LiveIn: "; - printValueRefs(liveness->inValues); - os << "\n// --- LiveOut: "; - printValueRefs(liveness->outValues); + walkRegions(operation->getRegions(), [&](Block &block) { + os << "// - Block: " << blockIds[&block] << "\n"; + auto liveness = getLiveness(&block); + os << "// --- LiveIn: "; + printValueRefs(liveness->inValues); + os << "\n// --- LiveOut: "; + printValueRefs(liveness->outValues); + os << "\n"; + + // Print liveness intervals. + os << "// --- BeginLiveness"; + for (Operation &op : block) { + if (op.getNumResults() < 1) + continue; os << "\n"; - - // Print liveness intervals. - os << "// --- BeginLiveness"; - for (Operation &op : block) { - if (op.getNumResults() < 1) - continue; - os << "\n"; - for (Value result : op.getResults()) { - os << "// "; - printValueRef(result); - os << ":"; - auto liveOperations = resolveLiveness(result); - std::sort(liveOperations.begin(), liveOperations.end(), - [&](Operation *left, Operation *right) { - return operationIds[left] < operationIds[right]; - }); - for (Operation *operation : liveOperations) { - os << "\n// "; - operation->print(os); - } + for (Value result : op.getResults()) { + os << "// "; + printValueRef(result); + os << ":"; + auto liveOperations = resolveLiveness(result); + std::sort(liveOperations.begin(), liveOperations.end(), + [&](Operation *left, Operation *right) { + return operationIds[left] < operationIds[right]; + }); + for (Operation *operation : liveOperations) { + os << "\n// "; + operation->print(os); } } - os << "\n// --- EndLiveness\n"; } + os << "\n// --- EndLiveness\n"; + }); os << "// -------------------\n"; } @@ -365,10 +420,13 @@ Operation *endOperation = startOperation; for (OpOperand &use : value.getUses()) { Operation *useOperation = use.getOwner(); - // Check whether the use is in our block and after - // the current end operation. - if (useOperation->getBlock() == block && - endOperation->isBeforeInBlock(useOperation)) + // Find the associated operation in the current block (if any). + useOperation = findInCurrent(block, useOperation, + [](Operation *op) { return op->getBlock(); }, + [](Operation *op) { return op; }); + // Check whether the use is in our block and after the current + // end operation. + if (useOperation && endOperation->isBeforeInBlock(useOperation)) endOperation = useOperation; } return endOperation; diff --git a/mlir/test/Analysis/test-liveness.mlir b/mlir/test/Analysis/test-liveness.mlir --- a/mlir/test/Analysis/test-liveness.mlir +++ b/mlir/test/Analysis/test-liveness.mlir @@ -188,4 +188,131 @@ // CHECK-NEXT: LiveOut:{{ *$}} %result = addi %sum, %arg2 : i32 return %result : i32 -} \ No newline at end of file +} + +// ----- + +// CHECK-LABEL: Testing : nested_region + +func @nested_region( + %arg0 : index, %arg1 : index, %arg2 : index, + %val0 : i32, %val1 : i32, %val2 : i32, + %buffer : memref) -> i32 { + // CHECK: Block: 0 + // CHECK-NEXT: LiveIn:{{ *$}} + // CHECK-NEXT: LiveOut:{{ *$}} + // CHECK-NEXT: BeginLiveness + // CHECK-NEXT: val_std.addi + // CHECK-NEXT: %0 = addi + // CHECK-NEXT: %1 = addi + // CHECK-NEXT: loop.for + // CHECK: // %2 = addi + // CHECK-NEXT: %3 = addi + // CHECK-NEXT: val_std.addi + // CHECK-NEXT: %1 = addi + // CHECK-NEXT: loop.for + // CHECK: // return %1 + // CHECK: EndLiveness + %0 = addi %val0, %val1 : i32 + %1 = addi %val1, %val2 : i32 + loop.for %arg3 = %arg0 to %arg1 step %arg2 { + // CHECK: Block: 1 + // CHECK-NEXT: LiveIn: arg5@0 arg6@0 val_std.addi + // CHECK-NEXT: LiveOut:{{ *$}} + %2 = addi %0, %val2 : i32 + %3 = addi %2, %0 : i32 + store %3, %buffer[] : memref + } + return %1 : i32 +} + +// ----- + +// CHECK-LABEL: Testing : nested_region2 + +func @nested_region2( + // CHECK: Block: 0 + // CHECK-NEXT: LiveIn:{{ *$}} + // CHECK-NEXT: LiveOut:{{ *$}} + // CHECK-NEXT: BeginLiveness + // CHECK-NEXT: val_std.addi + // CHECK-NEXT: %0 = addi + // CHECK-NEXT: %1 = addi + // CHECK-NEXT: loop.for + // CHECK: // %2 = addi + // CHECK-NEXT: loop.for + // CHECK: // %3 = addi + // CHECK-NEXT: val_std.addi + // CHECK-NEXT: %1 = addi + // CHECK-NEXT: loop.for + // CHECK: // return %1 + // CHECK: EndLiveness + %arg0 : index, %arg1 : index, %arg2 : index, + %val0 : i32, %val1 : i32, %val2 : i32, + %buffer : memref) -> i32 { + %0 = addi %val0, %val1 : i32 + %1 = addi %val1, %val2 : i32 + loop.for %arg3 = %arg0 to %arg1 step %arg2 { + // CHECK: Block: 1 + // CHECK-NEXT: LiveIn: arg0@0 arg1@0 arg2@0 arg5@0 arg6@0 val_std.addi + // CHECK-NEXT: LiveOut:{{ *$}} + // CHECK-NEXT: BeginLiveness + // CHECK-NEXT: val_std.addi + // CHECK-NEXT: %2 = addi + // CHECK-NEXT: loop.for + // CHECK: // %3 = addi + // CHECK: EndLiveness + %2 = addi %0, %val2 : i32 + loop.for %arg4 = %arg0 to %arg1 step %arg2 { + %3 = addi %2, %0 : i32 + store %3, %buffer[] : memref + } + } + return %1 : i32 +} + +// ----- + +// CHECK-LABEL: Testing : nested_region3 + +func @nested_region3( + // CHECK: Block: 0 + // CHECK-NEXT: LiveIn:{{ *$}} + // CHECK-NEXT: LiveOut: arg0@0 arg1@0 arg2@0 arg6@0 val_std.addi val_std.addi + // CHECK-NEXT: BeginLiveness + // CHECK-NEXT: val_std.addi + // CHECK-NEXT: %0 = addi + // CHECK-NEXT: %1 = addi + // CHECK-NEXT: loop.for + // CHECK: // br ^bb1 + // CHECK-NEXT: %2 = addi + // CHECK-NEXT: loop.for + // CHECK: // %2 = addi + // CHECK: EndLiveness + %arg0 : index, %arg1 : index, %arg2 : index, + %val0 : i32, %val1 : i32, %val2 : i32, + %buffer : memref) -> i32 { + %0 = addi %val0, %val1 : i32 + %1 = addi %val1, %val2 : i32 + loop.for %arg3 = %arg0 to %arg1 step %arg2 { + // CHECK: Block: 1 + // CHECK-NEXT: LiveIn: arg5@0 arg6@0 val_std.addi + // CHECK-NEXT: LiveOut:{{ *$}} + %2 = addi %0, %val2 : i32 + store %2, %buffer[] : memref + } + br ^exit + +^exit: + // CHECK: Block: 2 + // CHECK-NEXT: LiveIn: arg0@0 arg1@0 arg2@0 arg6@0 val_std.addi val_std.addi + // CHECK-NEXT: LiveOut:{{ *$}} + loop.for %arg3 = %arg0 to %arg1 step %arg2 { + // CHECK: Block: 3 + // CHECK-NEXT: LiveIn: arg6@0 val_std.addi val_std.addi + // CHECK-NEXT: LiveOut:{{ *$}} + %2 = addi %0, %1 : i32 + store %2, %buffer[] : memref + } + return %1 : i32 +}