diff --git a/mlir/include/mlir/IR/Region.h b/mlir/include/mlir/IR/Region.h --- a/mlir/include/mlir/IR/Region.h +++ b/mlir/include/mlir/IR/Region.h @@ -110,6 +110,11 @@ /// the operation with an offending use. bool isIsolatedFromAbove(Optional noteLoc = llvm::None); + /// Returns 'block' if 'block' lies in this region, or otherwise finds the + /// ancestor block of 'block' that lies in this region. Returns nullptr if + /// the latter fails. + Block *findAncestorBlockInRegion(Block &block); + /// Drop all operand uses from operations within this region, which is /// an essential step in breaking cyclic dependences between references when /// they are to be deleted. diff --git a/mlir/lib/Analysis/Liveness.cpp b/mlir/lib/Analysis/Liveness.cpp --- a/mlir/lib/Analysis/Liveness.cpp +++ b/mlir/lib/Analysis/Liveness.cpp @@ -31,36 +31,55 @@ /// Fills the block builder with initial liveness information. BlockInfoBuilder(Block *block) : block(block) { + auto gatherOutValues = [&](Value value) { + // Check whether this value will be in the outValues + // set (its uses escape this block). Due to the SSA + // properties of the program, the uses must occur after + // the definition. Therefore, we do not have to check + // additional conditions to detect an escaping value. + for (Operation *useOp : value.getUsers()) { + Block *ownerBlock = useOp->getBlock(); + // Find an owner block in the current region. + // Note that a value does not escape this block if it + // is used in a nested region. + ownerBlock = block->getParent()->findAncestorBlockInRegion(*ownerBlock); + assert(ownerBlock && "Use leaves the current parent region"); + if (ownerBlock != block) { + outValues.insert(value); + break; + } + } + }; + // Mark all block arguments (phis) as defined. - for (BlockArgument argument : block->getArguments()) + for (BlockArgument argument : block->getArguments()) { + // Insert value into the set of defined values. defValues.insert(argument); - // Check all result values and whether their uses - // are inside this block or not (see outValues). + // Gather all out values of all arguments in the current block. + gatherOutValues(argument); + } + + // Gather out values of all operations in the current block. for (Operation &operation : *block) - for (Value result : operation.getResults()) { - defValues.insert(result); + for (Value result : operation.getResults()) + gatherOutValues(result); - // Check whether this value will be in the outValues - // set (its uses escape this block). Due to the SSA - // properties of the program, the uses must occur after - // the definition. Therefore, we do not have to check - // additional conditions to detect an escaping value. - for (OpOperand &use : result.getUses()) - if (use.getOwner()->getBlock() != block) { - outValues.insert(result); - break; - } - } + // Mark all nested operation results as defined. + block->walk([&](Operation *op) { + for (Value result : op->getResults()) + defValues.insert(result); + }); // Check all operations for used operands. - for (Operation &operation : block->getOperations()) - for (Value operand : operation.getOperands()) { + block->walk([&](Operation *op) { + for (Value operand : op->getOperands()) { // If the operand is already defined in the scope of this // block, we can skip the value in the use set. if (!defValues.count(operand)) useValues.insert(operand); } + }); } /// Updates live-in information of the current block. @@ -86,7 +105,8 @@ /// Updates live-out information of the current block. /// It iterates over all successors and unifies their live-in /// values with the current live-out values. - template void updateLiveOut(SourceT &source) { + template + void updateLiveOut(SourceT &source) { for (Block *succ : block->getSuccessors()) { BlockInfoBuilder &builder = source[succ]; llvm::set_union(outValues, builder.inValues); @@ -110,20 +130,32 @@ }; } // namespace +/// Walks all regions (including nested regions recursively) and invokes the +/// given function for every block. +template +static void walkRegions(MutableArrayRef regions, const FuncT &func) { + for (Region ®ion : regions) + for (Block &block : region) { + func(block); + + // Traverse all nested regions. + for (Operation &operation : block) + walkRegions(operation.getRegions(), func); + } +} + /// Builds the internal liveness block mapping. static void buildBlockMapping(MutableArrayRef regions, DenseMap &builders) { llvm::SetVector toProcess; - // Initialize all block structures - for (Region ®ion : regions) - for (Block &block : region) { - BlockInfoBuilder &builder = - builders.try_emplace(&block, &block).first->second; + walkRegions(regions, [&](Block &block) { + BlockInfoBuilder &builder = + builders.try_emplace(&block, &block).first->second; - if (builder.updateLiveIn()) - toProcess.insert(block.pred_begin(), block.pred_end()); - } + if (builder.updateLiveIn()) + toProcess.insert(block.pred_begin(), block.pred_end()); + }); // Propagate the in and out-value sets (fixpoint iteration) while (!toProcess.empty()) { @@ -257,22 +289,21 @@ DenseMap blockIds; DenseMap operationIds; DenseMap valueIds; - for (Region ®ion : operation->getRegions()) - for (Block &block : region) { - blockIds.insert({&block, blockIds.size()}); - for (BlockArgument argument : block.getArguments()) - valueIds.insert({argument, valueIds.size()}); - for (Operation &operation : block) { - operationIds.insert({&operation, operationIds.size()}); - for (Value result : operation.getResults()) - valueIds.insert({result, valueIds.size()}); - } + walkRegions(operation->getRegions(), [&](Block &block) { + blockIds.insert({&block, blockIds.size()}); + for (BlockArgument argument : block.getArguments()) + valueIds.insert({argument, valueIds.size()}); + for (Operation &operation : block) { + operationIds.insert({&operation, operationIds.size()}); + for (Value result : operation.getResults()) + valueIds.insert({result, valueIds.size()}); } + }); // Local printing helpers auto printValueRef = [&](Value value) { if (Operation *defOp = value.getDefiningOp()) - os << "val_" << defOp->getName(); + os << "val_" << valueIds[value]; else { auto blockArg = value.cast(); os << "arg" << blockArg.getArgNumber() << "@" @@ -292,39 +323,38 @@ }; // Dump information about in and out values. - for (Region ®ion : operation->getRegions()) - for (Block &block : region) { - os << "// - Block: " << blockIds[&block] << "\n"; - auto liveness = getLiveness(&block); - os << "// --- LiveIn: "; - printValueRefs(liveness->inValues); - os << "\n// --- LiveOut: "; - printValueRefs(liveness->outValues); + walkRegions(operation->getRegions(), [&](Block &block) { + os << "// - Block: " << blockIds[&block] << "\n"; + auto liveness = getLiveness(&block); + os << "// --- LiveIn: "; + printValueRefs(liveness->inValues); + os << "\n// --- LiveOut: "; + printValueRefs(liveness->outValues); + os << "\n"; + + // Print liveness intervals. + os << "// --- BeginLiveness"; + for (Operation &op : block) { + if (op.getNumResults() < 1) + continue; os << "\n"; - - // Print liveness intervals. - os << "// --- BeginLiveness"; - for (Operation &op : block) { - if (op.getNumResults() < 1) - continue; - os << "\n"; - for (Value result : op.getResults()) { - os << "// "; - printValueRef(result); - os << ":"; - auto liveOperations = resolveLiveness(result); - std::sort(liveOperations.begin(), liveOperations.end(), - [&](Operation *left, Operation *right) { - return operationIds[left] < operationIds[right]; - }); - for (Operation *operation : liveOperations) { - os << "\n// "; - operation->print(os); - } + for (Value result : op.getResults()) { + os << "// "; + printValueRef(result); + os << ":"; + auto liveOperations = resolveLiveness(result); + std::sort(liveOperations.begin(), liveOperations.end(), + [&](Operation *left, Operation *right) { + return operationIds[left] < operationIds[right]; + }); + for (Operation *operation : liveOperations) { + os << "\n// "; + operation->print(os); } } - os << "\n// --- EndLiveness\n"; } + os << "\n// --- EndLiveness\n"; + }); os << "// -------------------\n"; } @@ -363,13 +393,13 @@ // Resolve the last operation (must exist by definition). Operation *endOperation = startOperation; - for (OpOperand &use : value.getUses()) { - Operation *useOperation = use.getOwner(); - // Check whether the use is in our block and after - // the current end operation. - if (useOperation->getBlock() == block && - endOperation->isBeforeInBlock(useOperation)) - endOperation = useOperation; + for (Operation *useOp : value.getUsers()) { + // Find the associated operation in the current block (if any). + useOp = block->findAncestorOpInBlock(*useOp); + // Check whether the use is in our block and after the current + // end operation. + if (useOp && endOperation->isBeforeInBlock(useOp)) + endOperation = useOp; } return endOperation; } diff --git a/mlir/lib/IR/Region.cpp b/mlir/lib/IR/Region.cpp --- a/mlir/lib/IR/Region.cpp +++ b/mlir/lib/IR/Region.cpp @@ -108,6 +108,20 @@ it->walk(remapOperands); } +/// Returns 'block' if 'block' lies in this region, or otherwise finds the +/// ancestor block of 'block' that lies in this region. Returns nullptr if +/// the latter fails. +Block *Region::findAncestorBlockInRegion(Block &block) { + auto currBlock = █ + while (currBlock->getParent() != this) { + Operation *parentOp = currBlock->getParentOp(); + if (!parentOp || !parentOp->getBlock()) + return nullptr; + currBlock = parentOp->getBlock(); + } + return currBlock; +} + void Region::dropAllReferences() { for (Block &b : *this) b.dropAllReferences(); diff --git a/mlir/test/Analysis/test-liveness.mlir b/mlir/test/Analysis/test-liveness.mlir --- a/mlir/test/Analysis/test-liveness.mlir +++ b/mlir/test/Analysis/test-liveness.mlir @@ -25,7 +25,7 @@ // CHECK-NEXT: LiveIn: arg0@0 arg1@0 // CHECK-NEXT: LiveOut:{{ *$}} // CHECK-NEXT: BeginLiveness - // CHECK: val_std.addi + // CHECK: val_2 // CHECK-NEXT: %0 = addi // CHECK-NEXT: return // CHECK-NEXT: EndLiveness @@ -58,7 +58,7 @@ // CHECK-NEXT: LiveIn: arg1@0 arg2@0 // CHECK-NEXT: LiveOut:{{ *$}} // CHECK-NEXT: BeginLiveness - // CHECK: val_std.addi + // CHECK: val_3 // CHECK-NEXT: %0 = addi // CHECK-NEXT: return // CHECK-NEXT: EndLiveness @@ -80,7 +80,7 @@ // CHECK-NEXT: LiveIn: arg1@0 // CHECK-NEXT: LiveOut: arg1@0 arg0@1 // CHECK-NEXT: BeginLiveness - // CHECK-NEXT: val_std.cmpi + // CHECK-NEXT: val_5 // CHECK-NEXT: %2 = cmpi // CHECK-NEXT: cond_br // CHECK-NEXT: EndLiveness @@ -91,11 +91,11 @@ // CHECK-NEXT: LiveIn: arg1@0 arg0@1 // CHECK-NEXT: LiveOut: arg1@0 // CHECK-NEXT: BeginLiveness - // CHECK-NEXT: val_std.constant + // CHECK-NEXT: val_7 // CHECK-NEXT: %c // CHECK-NEXT: %4 = addi // CHECK-NEXT: %5 = addi - // CHECK-NEXT: val_std.addi + // CHECK-NEXT: val_8 // CHECK-NEXT: %4 = addi // CHECK-NEXT: %5 = addi // CHECK-NEXT: br @@ -118,33 +118,33 @@ func @func_ranges(%cond : i1, %arg1 : i32, %arg2 : i32, %arg3 : i32) -> i32 { // CHECK: Block: 0 // CHECK-NEXT: LiveIn:{{ *$}} - // CHECK-NEXT: LiveOut: arg2@0 val_std.muli val_std.addi + // CHECK-NEXT: LiveOut: arg2@0 val_9 val_10 // CHECK-NEXT: BeginLiveness - // CHECK-NEXT: val_std.addi + // CHECK-NEXT: val_4 // CHECK-NEXT: %0 = addi // CHECK-NEXT: %c // CHECK-NEXT: %1 = addi // CHECK-NEXT: %2 = addi // CHECK-NEXT: %3 = muli - // CHECK-NEXT: val_std.constant + // CHECK-NEXT: val_5 // CHECK-NEXT: %c // CHECK-NEXT: %1 = addi // CHECK-NEXT: %2 = addi // CHECK-NEXT: %3 = muli // CHECK-NEXT: %4 = muli // CHECK-NEXT: %5 = addi - // CHECK-NEXT: val_std.addi + // CHECK-NEXT: val_6 // CHECK-NEXT: %1 = addi // CHECK-NEXT: %2 = addi // CHECK-NEXT: %3 = muli - // CHECK-NEXT: val_std.addi + // CHECK-NEXT: val_7 // CHECK-NEXT %2 = addi // CHECK-NEXT %3 = muli // CHECK-NEXT %4 = muli - // CHECK: val_std.muli + // CHECK: val_8 // CHECK-NEXT: %3 = muli // CHECK-NEXT: %4 = muli - // CHECK-NEXT: val_std.muli + // CHECK-NEXT: val_9 // CHECK-NEXT: %4 = muli // CHECK-NEXT: %5 = addi // CHECK-NEXT: cond_br @@ -152,7 +152,7 @@ // CHECK-NEXT: %6 = muli // CHECK-NEXT: %7 = muli // CHECK-NEXT: %8 = addi - // CHECK-NEXT: val_std.addi + // CHECK-NEXT: val_10 // CHECK-NEXT: %5 = addi // CHECK-NEXT: cond_br // CHECK-NEXT: %7 @@ -168,7 +168,7 @@ ^bb1: // CHECK: Block: 1 - // CHECK-NEXT: LiveIn: arg2@0 val_std.muli + // CHECK-NEXT: LiveIn: arg2@0 val_9 // CHECK-NEXT: LiveOut: arg2@0 %const4 = constant 4 : i32 %6 = muli %4, %const4 : i32 @@ -176,7 +176,7 @@ ^bb2: // CHECK: Block: 2 - // CHECK-NEXT: LiveIn: arg2@0 val_std.muli val_std.addi + // CHECK-NEXT: LiveIn: arg2@0 val_9 val_10 // CHECK-NEXT: LiveOut: arg2@0 %7 = muli %4, %5 : i32 %8 = addi %4, %arg2 : i32 @@ -188,4 +188,131 @@ // CHECK-NEXT: LiveOut:{{ *$}} %result = addi %sum, %arg2 : i32 return %result : i32 -} \ No newline at end of file +} + +// ----- + +// CHECK-LABEL: Testing : nested_region + +func @nested_region( + %arg0 : index, %arg1 : index, %arg2 : index, + %arg3 : i32, %arg4 : i32, %arg5 : i32, + %buffer : memref) -> i32 { + // CHECK: Block: 0 + // CHECK-NEXT: LiveIn:{{ *$}} + // CHECK-NEXT: LiveOut:{{ *$}} + // CHECK-NEXT: BeginLiveness + // CHECK-NEXT: val_7 + // CHECK-NEXT: %0 = addi + // CHECK-NEXT: %1 = addi + // CHECK-NEXT: loop.for + // CHECK: // %2 = addi + // CHECK-NEXT: %3 = addi + // CHECK-NEXT: val_8 + // CHECK-NEXT: %1 = addi + // CHECK-NEXT: loop.for + // CHECK: // return %1 + // CHECK: EndLiveness + %0 = addi %arg3, %arg4 : i32 + %1 = addi %arg4, %arg5 : i32 + loop.for %arg6 = %arg0 to %arg1 step %arg2 { + // CHECK: Block: 1 + // CHECK-NEXT: LiveIn: arg5@0 arg6@0 val_7 + // CHECK-NEXT: LiveOut:{{ *$}} + %2 = addi %0, %arg5 : i32 + %3 = addi %2, %0 : i32 + store %3, %buffer[] : memref + } + return %1 : i32 +} + +// ----- + +// CHECK-LABEL: Testing : nested_region2 + +func @nested_region2( + // CHECK: Block: 0 + // CHECK-NEXT: LiveIn:{{ *$}} + // CHECK-NEXT: LiveOut:{{ *$}} + // CHECK-NEXT: BeginLiveness + // CHECK-NEXT: val_7 + // CHECK-NEXT: %0 = addi + // CHECK-NEXT: %1 = addi + // CHECK-NEXT: loop.for + // CHECK: // %2 = addi + // CHECK-NEXT: loop.for + // CHECK: // %3 = addi + // CHECK-NEXT: val_8 + // CHECK-NEXT: %1 = addi + // CHECK-NEXT: loop.for + // CHECK: // return %1 + // CHECK: EndLiveness + %arg0 : index, %arg1 : index, %arg2 : index, + %arg3 : i32, %arg4 : i32, %arg5 : i32, + %buffer : memref) -> i32 { + %0 = addi %arg3, %arg4 : i32 + %1 = addi %arg4, %arg5 : i32 + loop.for %arg6 = %arg0 to %arg1 step %arg2 { + // CHECK: Block: 1 + // CHECK-NEXT: LiveIn: arg0@0 arg1@0 arg2@0 arg5@0 arg6@0 val_7 + // CHECK-NEXT: LiveOut:{{ *$}} + // CHECK-NEXT: BeginLiveness + // CHECK-NEXT: val_10 + // CHECK-NEXT: %2 = addi + // CHECK-NEXT: loop.for + // CHECK: // %3 = addi + // CHECK: EndLiveness + %2 = addi %0, %arg5 : i32 + loop.for %arg7 = %arg0 to %arg1 step %arg2 { + %3 = addi %2, %0 : i32 + store %3, %buffer[] : memref + } + } + return %1 : i32 +} + +// ----- + +// CHECK-LABEL: Testing : nested_region3 + +func @nested_region3( + // CHECK: Block: 0 + // CHECK-NEXT: LiveIn:{{ *$}} + // CHECK-NEXT: LiveOut: arg0@0 arg1@0 arg2@0 arg6@0 val_7 val_8 + // CHECK-NEXT: BeginLiveness + // CHECK-NEXT: val_7 + // CHECK-NEXT: %0 = addi + // CHECK-NEXT: %1 = addi + // CHECK-NEXT: loop.for + // CHECK: // br ^bb1 + // CHECK-NEXT: %2 = addi + // CHECK-NEXT: loop.for + // CHECK: // %2 = addi + // CHECK: EndLiveness + %arg0 : index, %arg1 : index, %arg2 : index, + %arg3 : i32, %arg4 : i32, %arg5 : i32, + %buffer : memref) -> i32 { + %0 = addi %arg3, %arg4 : i32 + %1 = addi %arg4, %arg5 : i32 + loop.for %arg6 = %arg0 to %arg1 step %arg2 { + // CHECK: Block: 1 + // CHECK-NEXT: LiveIn: arg5@0 arg6@0 val_7 + // CHECK-NEXT: LiveOut:{{ *$}} + %2 = addi %0, %arg5 : i32 + store %2, %buffer[] : memref + } + br ^exit + +^exit: + // CHECK: Block: 2 + // CHECK-NEXT: LiveIn: arg0@0 arg1@0 arg2@0 arg6@0 val_7 val_8 + // CHECK-NEXT: LiveOut:{{ *$}} + loop.for %arg7 = %arg0 to %arg1 step %arg2 { + // CHECK: Block: 3 + // CHECK-NEXT: LiveIn: arg6@0 val_7 val_8 + // CHECK-NEXT: LiveOut:{{ *$}} + %2 = addi %0, %1 : i32 + store %2, %buffer[] : memref + } + return %1 : i32 +}