diff --git a/mlir/include/mlir/Transforms/Utils.h b/mlir/include/mlir/Transforms/Utils.h --- a/mlir/include/mlir/Transforms/Utils.h +++ b/mlir/include/mlir/Transforms/Utils.h @@ -45,13 +45,14 @@ /// correspond to memref's indices, and its symbolic inputs if any should be /// provided in `symbolOperands`. /// -/// `domInstFilter`, if non-null, restricts the replacement to only those -/// operations that are dominated by the former; similarly, `postDomInstFilter` +/// `domOpFilter`, if non-null, restricts the replacement to only those +/// operations that are dominated by the former; similarly, `postDomOpFilter` /// restricts replacement to only those operations that are postdominated by it. /// /// 'allowNonDereferencingOps', if set, allows replacement of non-dereferencing -/// uses of a memref without any requirement for access index rewrites. The -/// default value of this flag variable is false. +/// uses of a memref without any requirement for access index rewrites as long +/// as the user operation has the MemRefsNormalizable trait. The default value +/// of this flag is false. /// /// 'replaceInDeallocOp', if set, lets DeallocOp, a non-dereferencing user, to /// also be a candidate for replacement. The default value of this flag is @@ -73,9 +74,9 @@ LogicalResult replaceAllMemRefUsesWith( Value oldMemRef, Value newMemRef, ArrayRef extraIndices = {}, AffineMap indexRemap = AffineMap(), ArrayRef extraOperands = {}, - ArrayRef symbolOperands = {}, Operation *domInstFilter = nullptr, - Operation *postDomInstFilter = nullptr, - bool allowNonDereferencingOps = false, bool replaceInDeallocOp = false); + ArrayRef symbolOperands = {}, Operation *domOpFilter = nullptr, + Operation *postDomOpFilter = nullptr, bool allowNonDereferencingOps = false, + bool replaceInDeallocOp = false); /// Performs the same replacement as the other version above but only for the /// dereferencing uses of `oldMemRef` in `op`, except in cases where diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -23,6 +23,9 @@ #include "mlir/Support/MathExtras.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/TypeSwitch.h" + +#define DEBUG_TYPE "transforms-utils" + using namespace mlir; // Perform the replacement in `op`. @@ -207,8 +210,8 @@ LogicalResult mlir::replaceAllMemRefUsesWith( Value oldMemRef, Value newMemRef, ArrayRef extraIndices, AffineMap indexRemap, ArrayRef extraOperands, - ArrayRef symbolOperands, Operation *domInstFilter, - Operation *postDomInstFilter, bool allowNonDereferencingOps, + ArrayRef symbolOperands, Operation *domOpFilter, + Operation *postDomOpFilter, bool allowNonDereferencingOps, bool replaceInDeallocOp) { unsigned newMemRefRank = newMemRef.getType().cast().getRank(); (void)newMemRefRank; // unused in opt mode @@ -230,25 +233,25 @@ std::unique_ptr domInfo; std::unique_ptr postDomInfo; - if (domInstFilter) - domInfo = std::make_unique( - domInstFilter->getParentOfType()); + if (domOpFilter) + domInfo = + std::make_unique(domOpFilter->getParentOfType()); - if (postDomInstFilter) + if (postDomOpFilter) postDomInfo = std::make_unique( - postDomInstFilter->getParentOfType()); + postDomOpFilter->getParentOfType()); // Walk all uses of old memref; collect ops to perform replacement. We use a // DenseSet since an operation could potentially have multiple uses of a // memref (although rare), and the replacement later is going to erase ops. DenseSet opsToReplace; for (auto *op : oldMemRef.getUsers()) { - // Skip this use if it's not dominated by domInstFilter. - if (domInstFilter && !domInfo->dominates(domInstFilter, op)) + // Skip this use if it's not dominated by domOpFilter. + if (domOpFilter && !domInfo->dominates(domOpFilter, op)) continue; - // Skip this use if it's not post-dominated by postDomInstFilter. - if (postDomInstFilter && !postDomInfo->postDominates(postDomInstFilter, op)) + // Skip this use if it's not post-dominated by postDomOpFilter. + if (postDomOpFilter && !postDomInfo->postDominates(postDomOpFilter, op)) continue; // Skip dealloc's - no replacement is necessary, and a memref replacement @@ -260,13 +263,20 @@ // for the memref to be used in a non-dereferencing way outside of the // region where this replacement is happening. if (!isa(*op)) { - if (!allowNonDereferencingOps) + if (!allowNonDereferencingOps) { + LLVM_DEBUG(llvm::dbgs() + << "Memref replacement failed: non-deferencing memref op: \n" + << *op << '\n'); return failure(); - // Currently we support the following non-dereferencing ops to be a - // candidate for replacement: Dealloc, CallOp and ReturnOp. - // TODO: Add support for other kinds of ops. - if (!op->hasTrait()) + } + // Non-dereferencing ops with the MemRefsNormalizable trait are + // supported for replacement. + if (!op->hasTrait()) { + LLVM_DEBUG(llvm::dbgs() << "Memref replacement failed: use without a " + "memrefs normalizable trait: \n" + << *op << '\n'); return failure(); + } } // We'll first collect and then replace --- since replacement erases the op @@ -661,8 +671,8 @@ /*indexRemap=*/layoutMap, /*extraOperands=*/{}, /*symbolOperands=*/symbolOperands, - /*domInstFilter=*/nullptr, - /*postDomInstFilter=*/nullptr, + /*domOpFilter=*/nullptr, + /*postDomOpFilter=*/nullptr, /*allowDereferencingOps=*/true))) { // If it failed (due to escapes for example), bail out. newAlloc.erase();