diff --git a/mlir/include/mlir/Analysis/Utils.h b/mlir/include/mlir/Analysis/Utils.h --- a/mlir/include/mlir/Analysis/Utils.h +++ b/mlir/include/mlir/Analysis/Utils.h @@ -41,7 +41,7 @@ /// Returns the nesting depth of this operation, i.e., the number of loops /// surrounding this operation. -unsigned getNestingDepth(Operation &op); +unsigned getNestingDepth(Operation *op); /// Returns in 'sequentialLoops' all sequential loops in loop nest rooted /// at 'forOp'. diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -569,8 +569,8 @@ if (srcAccess.memref != dstAccess.memref) continue; // Check if 'loopDepth' exceeds nesting depth of src/dst ops. - if ((!isBackwardSlice && loopDepth > getNestingDepth(*opsA[i])) || - (isBackwardSlice && loopDepth > getNestingDepth(*opsB[j]))) { + if ((!isBackwardSlice && loopDepth > getNestingDepth(opsA[i])) || + (isBackwardSlice && loopDepth > getNestingDepth(opsB[j]))) { LLVM_DEBUG(llvm::dbgs() << "Invalid loop depth\n."); return failure(); } @@ -895,8 +895,8 @@ /// Returns the nesting depth of this statement, i.e., the number of loops /// surrounding this statement. -unsigned mlir::getNestingDepth(Operation &op) { - Operation *currOp = &op; +unsigned mlir::getNestingDepth(Operation *op) { + Operation *currOp = op; unsigned depth = 0; while ((currOp = currOp->getParentOp())) { if (isa(currOp)) @@ -957,7 +957,7 @@ auto region = std::make_unique(opInst->getLoc()); if (failed( region->compute(opInst, - /*loopDepth=*/getNestingDepth(*block.begin())))) { + /*loopDepth=*/getNestingDepth(&*block.begin())))) { return opInst->emitError("error obtaining memory region\n"); } @@ -1023,7 +1023,7 @@ return false; // Dep check depth would be number of enclosing loops + 1. - unsigned depth = getNestingDepth(*forOp.getOperation()) + 1; + unsigned depth = getNestingDepth(forOp) + 1; // Check dependences between all pairs of ops in 'loadAndStoreOpInsts'. for (auto *srcOpInst : loadAndStoreOpInsts) { diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -1492,7 +1492,7 @@ srcStoreOp = nullptr; break; } - unsigned loopDepth = getNestingDepth(*storeOp); + unsigned loopDepth = getNestingDepth(storeOp); if (loopDepth > maxLoopDepth) { maxLoopDepth = loopDepth; srcStoreOp = storeOp; diff --git a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp --- a/mlir/lib/Transforms/MemRefDataFlowOpt.cpp +++ b/mlir/lib/Transforms/MemRefDataFlowOpt.cpp @@ -85,21 +85,18 @@ // This is a straightforward implementation not optimized for speed. Optimize // if needed. void MemRefDataFlowOpt::forwardStoreToLoad(AffineLoadOp loadOp) { - Operation *loadOpInst = loadOp.getOperation(); - - // First pass over the use list to get minimum number of surrounding + // First pass over the use list to get the minimum number of surrounding // loops common between the load op and the store op, with min taken across // all store ops. SmallVector storeOps; - unsigned minSurroundingLoops = getNestingDepth(*loadOpInst); + unsigned minSurroundingLoops = getNestingDepth(loadOp); for (auto *user : loadOp.getMemRef().getUsers()) { auto storeOp = dyn_cast(user); if (!storeOp) continue; - auto *storeOpInst = storeOp.getOperation(); - unsigned nsLoops = getNumCommonSurroundingLoops(*loadOpInst, *storeOpInst); + unsigned nsLoops = getNumCommonSurroundingLoops(*loadOp, *storeOp); minSurroundingLoops = std::min(nsLoops, minSurroundingLoops); - storeOps.push_back(storeOpInst); + storeOps.push_back(storeOp); } // The list of store op candidates for forwarding that satisfy conditions @@ -111,12 +108,12 @@ // post-dominance on these. 'fwdingCandidates' are a subset of depSrcStores. SmallVector depSrcStores; - for (auto *storeOpInst : storeOps) { - MemRefAccess srcAccess(storeOpInst); - MemRefAccess destAccess(loadOpInst); + for (auto *storeOp : storeOps) { + MemRefAccess srcAccess(storeOp); + MemRefAccess destAccess(loadOp); // Find stores that may be reaching the load. FlatAffineConstraints dependenceConstraints; - unsigned nsLoops = getNumCommonSurroundingLoops(*loadOpInst, *storeOpInst); + unsigned nsLoops = getNumCommonSurroundingLoops(*loadOp, *storeOp); unsigned d; // Dependences at loop depth <= minSurroundingLoops do NOT matter. for (d = nsLoops + 1; d > minSurroundingLoops; d--) { @@ -130,7 +127,7 @@ continue; // Stores that *may* be reaching the load. - depSrcStores.push_back(storeOpInst); + depSrcStores.push_back(storeOp); // 1. Check if the store and the load have mathematically equivalent // affine access functions; this implies that they statically refer to the @@ -144,11 +141,11 @@ continue; // 2. The store has to dominate the load op to be candidate. - if (!domInfo->dominates(storeOpInst, loadOpInst)) + if (!domInfo->dominates(storeOp, loadOp)) continue; // We now have a candidate for forwarding. - fwdingCandidates.push_back(storeOpInst); + fwdingCandidates.push_back(storeOp); } // 3. Of all the store op's that meet the above criteria, the store that @@ -158,11 +155,11 @@ // Note: this can be implemented in a cleaner way with postdominator tree // traversals. Consider this for the future if needed. Operation *lastWriteStoreOp = nullptr; - for (auto *storeOpInst : fwdingCandidates) { + for (auto *storeOp : fwdingCandidates) { if (llvm::all_of(depSrcStores, [&](Operation *depStore) { - return postDomInfo->postDominates(storeOpInst, depStore); + return postDomInfo->postDominates(storeOp, depStore); })) { - lastWriteStoreOp = storeOpInst; + lastWriteStoreOp = storeOp; break; } } @@ -175,7 +172,7 @@ // Record the memref for a later sweep to optimize away. memrefsToErase.insert(loadOp.getMemRef()); // Record this to erase later. - loadOpsToErase.push_back(loadOpInst); + loadOpsToErase.push_back(loadOp); } void MemRefDataFlowOpt::runOnFunction() { @@ -192,32 +189,31 @@ loadOpsToErase.clear(); memrefsToErase.clear(); - // Walk all load's and perform load/store forwarding. + // Walk all load's and perform store to load forwarding. f.walk([&](AffineLoadOp loadOp) { forwardStoreToLoad(loadOp); }); // Erase all load op's whose results were replaced with store fwd'ed ones. - for (auto *loadOp : loadOpsToErase) { + for (auto *loadOp : loadOpsToErase) loadOp->erase(); - } // Check if the store fwd'ed memrefs are now left with only stores and can // thus be completely deleted. Note: the canonicalize pass should be able // to do this as well, but we'll do it here since we collected these anyway. for (auto memref : memrefsToErase) { // If the memref hasn't been alloc'ed in this function, skip. - Operation *defInst = memref.getDefiningOp(); - if (!defInst || !isa(defInst)) + Operation *defOp = memref.getDefiningOp(); + if (!defOp || !isa(defOp)) // TODO(mlir-team): if the memref was returned by a 'call' operation, we // could still erase it if the call had no side-effects. continue; - if (llvm::any_of(memref.getUsers(), [&](Operation *ownerInst) { - return (!isa(ownerInst) && !isa(ownerInst)); + if (llvm::any_of(memref.getUsers(), [&](Operation *ownerOp) { + return (!isa(ownerOp) && !isa(ownerOp)); })) continue; // Erase all stores, the dealloc, and the alloc on the memref. for (auto *user : llvm::make_early_inc_range(memref.getUsers())) user->erase(); - defInst->erase(); + defOp->erase(); } } diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -1911,7 +1911,7 @@ // Copies will be generated for this depth, i.e., symbolic in all loops // surrounding the this block range. - unsigned copyDepth = getNestingDepth(*begin); + unsigned copyDepth = getNestingDepth(&*begin); LLVM_DEBUG(llvm::dbgs() << "Generating copies at depth " << copyDepth << "\n"); diff --git a/mlir/test/lib/Transforms/TestLoopFusion.cpp b/mlir/test/lib/Transforms/TestLoopFusion.cpp --- a/mlir/test/lib/Transforms/TestLoopFusion.cpp +++ b/mlir/test/lib/Transforms/TestLoopFusion.cpp @@ -10,21 +10,14 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Analysis/AffineAnalysis.h" -#include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/Utils.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "mlir/IR/Builders.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/LoopFusionUtils.h" #include "mlir/Transforms/LoopUtils.h" #include "mlir/Transforms/Passes.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/Support/CommandLine.h" -#include "llvm/Support/Debug.h" - #define DEBUG_TYPE "test-loop-fusion" using namespace mlir; @@ -90,7 +83,7 @@ std::string result; llvm::raw_string_ostream os(result); // Slice insertion point format [loop-depth, operation-block-index] - unsigned ipd = getNestingDepth(*sliceUnion.insertPoint); + unsigned ipd = getNestingDepth(&*sliceUnion.insertPoint); unsigned ipb = getBlockIndex(*sliceUnion.insertPoint); os << "insert point: (" << std::to_string(ipd) << ", " << std::to_string(ipb) << ")";