diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -435,7 +435,7 @@ let results = (outs Arg:$output); let extraClassDeclaration = [{ - Value getSource() { return input();} + Value getSource() { return input(); } Value getTarget() { return output(); } }]; diff --git a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h --- a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h +++ b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h @@ -20,10 +20,13 @@ namespace mlir { -/// Finds the associated dealloc that can be linked to our allocation nodes (if -/// any). -Operation *findDealloc(Value allocValue); +/// Finds the dealloc operations that deallocate the given allocated value. +void findDeallocs(Value allocValue, SmallVectorImpl &deallocs); +/// Finds a single dealloc operation for the given allocated value. If there +/// are > 1 deallocates for `allocValue`, returns None, else returns the single +/// deallocate if it exists or nullptr. +llvm::Optional findDealloc(Value allocValue); } // end namespace mlir #endif // MLIR_DIALECT_MEMREF_UTILS_MEMREFUTILS_H diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -175,9 +175,9 @@ LogicalResult matchAndRewrite(T alloc, PatternRewriter &rewriter) const override { if (llvm::any_of(alloc->getUsers(), [&](Operation *op) { - if (auto storeOp = dyn_cast(op)) - return storeOp.value() == alloc; - return !isa(op); + if (auto storeOp = dyn_cast(op)) + return storeOp.value() == alloc; + return !isa(op); })) return failure(); @@ -519,8 +519,8 @@ } namespace { -/// Fold Dealloc operations that are deallocating an AllocOp that is only used -/// by other Dealloc operations. +/// Merge the clone and its source (by converting the clone to a cast) when +/// possible. struct SimplifyClones : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -536,8 +536,14 @@ // This only finds dealloc operations for the immediate value. It should // also consider aliases. That would also make the safety check below // redundant. - Operation *cloneDeallocOp = findDealloc(cloneOp.output()); - Operation *sourceDeallocOp = findDealloc(source); + llvm::Optional maybeCloneDeallocOp = + findDealloc(cloneOp.output()); + llvm::Optional maybeSourceDeallocOp = findDealloc(source); + // Skip if either of them has > 1 deallocate operations. + if (!maybeCloneDeallocOp.hasValue() || !maybeSourceDeallocOp.hasValue()) + return failure(); + Operation *cloneDeallocOp = *maybeCloneDeallocOp; + Operation *sourceDeallocOp = *maybeSourceDeallocOp; // If both are deallocated in the same block, their in-block lifetimes // might not fully overlap, so we cannot decide which one to drop. diff --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp --- a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp +++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp @@ -1,4 +1,4 @@ -//===- Utils.cpp - Utilities to support the MemRef dialect ----------------===// +//===- MemRefUtils.cpp - Utilities to support the MemRef dialect ----------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -15,21 +15,34 @@ using namespace mlir; -/// Finds associated deallocs that can be linked to our allocation nodes (if -/// any). -Operation *mlir::findDealloc(Value allocValue) { - auto userIt = llvm::find_if(allocValue.getUsers(), [&](Operation *user) { +/// Finds the dealloc operations that deallocate the given allocated value. +void mlir::findDeallocs(Value allocValue, + SmallVectorImpl &deallocs) { + llvm::for_each(allocValue.getUsers(), [&](Operation *user) { auto effectInterface = dyn_cast(user); if (!effectInterface) - return false; + return; // Try to find a free effect that is applied to one of our values // that will be automatically freed by our pass. SmallVector effects; effectInterface.getEffectsOnValue(allocValue, effects); - return llvm::any_of(effects, [&](MemoryEffects::EffectInstance &it) { - return isa(it.getEffect()); - }); + const bool isFree = + llvm::any_of(effects, [&](MemoryEffects::EffectInstance &it) { + return isa(it.getEffect()); + }); + if (isFree) { + deallocs.push_back(user); + } }); - // Assign the associated dealloc operation (if any). - return userIt != allocValue.user_end() ? *userIt : nullptr; +} + +/// Finds a single dealloc operation for the given allocated value. +llvm::Optional mlir::findDealloc(Value allocValue) { + SmallVector deallocs; + findDeallocs(allocValue, deallocs); + if (deallocs.size() > 1) + return llvm::None; + if (deallocs.size() == 1) + return deallocs.front(); + return nullptr; } diff --git a/mlir/lib/Transforms/BufferUtils.cpp b/mlir/lib/Transforms/BufferUtils.cpp --- a/mlir/lib/Transforms/BufferUtils.cpp +++ b/mlir/lib/Transforms/BufferUtils.cpp @@ -77,7 +77,11 @@ // Get allocation result. Value allocValue = allocateResultEffects[0].getValue(); // Find the associated dealloc value and register the allocation entry. - allocs.push_back(std::make_tuple(allocValue, findDealloc(allocValue))); + llvm::Optional dealloc = findDealloc(allocValue); + // If the allocation has > 1 dealloc associated with it, skip handling it. + if (!dealloc.hasValue()) + return; + allocs.push_back(std::make_tuple(allocValue, *dealloc)); }); } diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -195,6 +195,44 @@ // ----- +// Verify SimplifyClones skips clones with multiple deallocations. +// CHECK-LABEL: @clone_multiple_dealloc_of_source +// CHECK-SAME: %[[ARG:.*]]: memref +func @clone_multiple_dealloc_of_source(%arg0: memref) -> memref { + // CHECK-NEXT: %[[RES:.*]] = memref.clone %[[ARG]] + // CHECK: memref.dealloc %[[ARG]] + // CHECK: memref.dealloc %[[ARG]] + // CHECK: return %[[RES]] + %0 = memref.clone %arg0 : memref to memref + "if_else"() ({ + memref.dealloc %arg0 : memref + }, { + memref.dealloc %arg0 : memref + }) : () -> () + return %0 : memref +} + +// ----- + +// CHECK-LABEL: @clone_multiple_dealloc_of_clone +// CHECK-SAME: %[[ARG:.*]]: memref +func @clone_multiple_dealloc_of_clone(%arg0: memref) -> memref { + // CHECK-NEXT: %[[CLONE:.*]] = memref.clone %[[ARG]] + // CHECK: memref.dealloc %[[CLONE]] + // CHECK: memref.dealloc %[[CLONE]] + // CHECK: return %[[ARG]] + %0 = memref.clone %arg0 : memref to memref + "use"(%0) : (memref) -> () + "if_else"() ({ + memref.dealloc %0 : memref + }, { + memref.dealloc %0 : memref + }) : () -> () + return %arg0 : memref +} + +// ----- + // CHECK-LABEL: func @dim_of_sized_view // CHECK-SAME: %{{[a-z0-9A-Z_]+}}: memref // CHECK-SAME: %[[SIZE:.[a-z0-9A-Z_]+]]: index @@ -393,7 +431,7 @@ func @allocator(%arg0 : memref>, %arg1 : index) { %0 = memref.alloc(%arg1) : memref memref.store %0, %arg0[] : memref> - return + return } // -----