diff --git a/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp b/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp --- a/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Dialect/Affine/Analysis/AffineAnalysis.cpp @@ -299,69 +299,55 @@ return numCommonLoops; } -/// Returns Block common to 'srcAccess.opInst' and 'dstAccess.opInst'. -static Block *getCommonBlock(const MemRefAccess &srcAccess, - const MemRefAccess &dstAccess, - const FlatAffineValueConstraints &srcDomain, - unsigned numCommonLoops) { - // Get the chain of ancestor blocks to the given `MemRefAccess` instance. The - // search terminates when either an op with the `AffineScope` trait or - // `endBlock` is reached. - auto getChainOfAncestorBlocks = [&](const MemRefAccess &access, - SmallVector &ancestorBlocks, - Block *endBlock = nullptr) { - Block *currBlock = access.opInst->getBlock(); - // Loop terminates when the currBlock is nullptr or equals to the endBlock, - // or its parent operation holds an affine scope. - while (currBlock && currBlock != endBlock && - !currBlock->getParentOp()->hasTrait()) { - ancestorBlocks.push_back(currBlock); - currBlock = currBlock->getParentOp()->getBlock(); - } - }; - - if (numCommonLoops == 0) { - Block *block = srcAccess.opInst->getBlock(); - while (!block->getParentOp()->hasTrait()) - block = block->getParentOp()->getBlock(); - return block; - } - Value commonForIV = srcDomain.getValue(numCommonLoops - 1); - AffineForOp forOp = getForInductionVarOwner(commonForIV); - assert(forOp && "commonForValue was not an induction variable"); +/// Returns the closest surrounding block common to 'opA' and 'opB'. By +/// construction, this should always exist as the block of an op that starts an +/// affine scope is always a valid candidate. +static Block *getCommonBlock(Operation *opA, Operation *opB) { + // Get the chain of ancestor blocks for the given `MemRefAccess` instance. The + // chain extends up to and includnig an op that starts an affine scope. + auto getChainOfAncestorBlocks = + [&](Operation *op, SmallVectorImpl &ancestorBlocks) { + Block *currBlock = op->getBlock(); + // Loop terminates when the currBlock is nullptr or equals to the + // endBlock, or its parent operation holds an affine scope. + while (currBlock && + !currBlock->getParentOp()->hasTrait()) { + ancestorBlocks.push_back(currBlock); + currBlock = currBlock->getParentOp()->getBlock(); + } + assert(currBlock && + "parent op starting an affine scope is always expected"); + ancestorBlocks.push_back(currBlock); + }; // Find the closest common block including those in AffineIf. SmallVector srcAncestorBlocks, dstAncestorBlocks; - getChainOfAncestorBlocks(srcAccess, srcAncestorBlocks, forOp.getBody()); - getChainOfAncestorBlocks(dstAccess, dstAncestorBlocks, forOp.getBody()); + getChainOfAncestorBlocks(opA, srcAncestorBlocks); + getChainOfAncestorBlocks(opB, dstAncestorBlocks); - Block *commonBlock = forOp.getBody(); + Block *commonBlock = nullptr; for (int i = srcAncestorBlocks.size() - 1, j = dstAncestorBlocks.size() - 1; i >= 0 && j >= 0 && srcAncestorBlocks[i] == dstAncestorBlocks[j]; i--, j--) commonBlock = srcAncestorBlocks[i]; - + assert(commonBlock && "ops expected to have a common surrounding block"); return commonBlock; } -// Returns true if the ancestor operation of 'srcAccess' appears before the -// ancestor operation of 'dstAccess' in the common ancestral block. Returns -// false otherwise. -// Note that because 'srcAccess' or 'dstAccess' may be nested in conditionals, -// the function is named 'srcAppearsBeforeDstInCommonBlock'. Note that -// 'numCommonLoops' is the number of contiguous surrounding outer loops. -static bool srcAppearsBeforeDstInAncestralBlock( - const MemRefAccess &srcAccess, const MemRefAccess &dstAccess, - const FlatAffineValueConstraints &srcDomain, unsigned numCommonLoops) { +/// Returns true if the ancestor operation of 'srcAccess' appears before the +/// ancestor operation of 'dstAccess' in their common ancestral block. The +/// operations for `srcAccess` and `dstAccess` are expected to be in the same +/// affine scope. +static bool srcAppearsBeforeDstInAncestralBlock(const MemRefAccess &srcAccess, + const MemRefAccess &dstAccess) { // Get Block common to 'srcAccess.opInst' and 'dstAccess.opInst'. - auto *commonBlock = - getCommonBlock(srcAccess, dstAccess, srcDomain, numCommonLoops); + auto *commonBlock = getCommonBlock(srcAccess.opInst, dstAccess.opInst); // Check the dominance relationship between the respective ancestors of the // src and dst in the Block of the innermost among the common loops. auto *srcInst = commonBlock->findAncestorOpInBlock(*srcAccess.opInst); - assert(srcInst != nullptr); + assert(srcInst && "src access op must lie in common block"); auto *dstInst = commonBlock->findAncestorOpInBlock(*dstAccess.opInst); - assert(dstInst != nullptr); + assert(dstInst && "dest access op must lie in common block"); // Determine whether dstInst comes after srcInst. return srcInst->isBeforeInBlock(dstInst); @@ -631,8 +617,7 @@ unsigned numCommonLoops = getNumCommonLoops(srcDomain, dstDomain); assert(loopDepth <= numCommonLoops + 1); if (!allowRAR && loopDepth > numCommonLoops && - !srcAppearsBeforeDstInAncestralBlock(srcAccess, dstAccess, srcDomain, - numCommonLoops)) { + !srcAppearsBeforeDstInAncestralBlock(srcAccess, dstAccess)) { return DependenceResult::NoDependence; }