diff --git a/mlir/docs/BufferDeallocationInternals.md b/mlir/docs/BufferDeallocationInternals.md --- a/mlir/docs/BufferDeallocationInternals.md +++ b/mlir/docs/BufferDeallocationInternals.md @@ -48,7 +48,7 @@ partial_write(%0, %0) br ^bb3() ^bb3(): - "linalg.copy"(%0, %arg1) : (memref<2xf32>, memref<2xf32>) -> () + test.copy(%0, %arg1) : (memref<2xf32>, memref<2xf32>) -> () return } ``` @@ -133,11 +133,11 @@ ^bb1: br ^bb3(%arg1 : memref<2xf32>) ^bb2: - %0 = alloc() : memref<2xf32> // aliases: %1 + %0 = memref.alloc() : memref<2xf32> // aliases: %1 use(%0) br ^bb3(%0 : memref<2xf32>) ^bb3(%1: memref<2xf32>): // %1 could be %0 or %arg1 - "linalg.copy"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> () + test.copy(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> () return } ``` @@ -149,7 +149,7 @@ ```mlir func @condBranch(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { - %0 = alloc() : memref<2xf32> // moved to bb0 + %0 = memref.alloc() : memref<2xf32> // moved to bb0 cond_br %arg0, ^bb1, ^bb2 ^bb1: br ^bb3(%arg1 : memref<2xf32>) @@ -157,7 +157,7 @@ use(%0) br ^bb3(%0 : memref<2xf32>) ^bb3(%1: memref<2xf32>): - "linalg.copy"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> () + test.copy(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> () return } ``` @@ -179,17 +179,17 @@ ^bb1: br ^bb3(%arg1 : memref) ^bb2(%0: index): - %1 = alloc(%0) : memref // cannot be moved upwards to the data + %1 = memref.alloc(%0) : memref // cannot be moved upwards to the data // dependency to %0 use(%1) br ^bb3(%1 : memref) ^bb3(%2: memref): - "linalg.copy"(%2, %arg2) : (memref, memref) -> () + test.copy(%2, %arg2) : (memref, memref) -> () return } ``` -## Introduction of Copies +## Introduction of Clones In order to guarantee that all allocated buffers are freed properly, we have to pay attention to the control flow and all potential aliases a buffer allocation @@ -200,10 +200,10 @@ ```mlir func @branch(%arg0: i1) { - %0 = alloc() : memref<2xf32> // aliases: %2 + %0 = memref.alloc() : memref<2xf32> // aliases: %2 cond_br %arg0, ^bb1, ^bb2 ^bb1: - %1 = alloc() : memref<2xf32> // resides here for demonstration purposes + %1 = memref.alloc() : memref<2xf32> // resides here for demonstration purposes // aliases: %2 br ^bb3(%1 : memref<2xf32>) ^bb2: @@ -232,88 +232,31 @@ ```mlir func @branch(%arg0: i1) { - %0 = alloc() : memref<2xf32> + %0 = memref.alloc() : memref<2xf32> cond_br %arg0, ^bb1, ^bb2 ^bb1: - %1 = alloc() : memref<2xf32> - %3 = alloc() : memref<2xf32> // temp copy for %1 - "linalg.copy"(%1, %3) : (memref<2xf32>, memref<2xf32>) -> () - dealloc %1 : memref<2xf32> // %1 can be safely freed here + %1 = memref.alloc() : memref<2xf32> + %3 = memref.clone %1 : (memref<2xf32>) -> (memref<2xf32>) + memref.dealloc %1 : memref<2xf32> // %1 can be safely freed here br ^bb3(%3 : memref<2xf32>) ^bb2: use(%0) - %4 = alloc() : memref<2xf32> // temp copy for %0 - "linalg.copy"(%0, %4) : (memref<2xf32>, memref<2xf32>) -> () + %4 = memref.clone %0 : (memref<2xf32>) -> (memref<2xf32>) br ^bb3(%4 : memref<2xf32>) ^bb3(%2: memref<2xf32>): … - dealloc %2 : memref<2xf32> // free temp buffer %2 - dealloc %0 : memref<2xf32> // %0 can be safely freed here + memref.dealloc %2 : memref<2xf32> // free temp buffer %2 + memref.dealloc %0 : memref<2xf32> // %0 can be safely freed here return } ``` Note that a temporary buffer for %2 was introduced to free all allocations properly. Note further that the unnecessary allocation of %3 can be easily -removed using one of the post-pass transformations. - -Reconsider the previously introduced sample demonstrating dynamically shaped -types: - -```mlir -func @condBranchDynamicType( - %arg0: i1, - %arg1: memref, - %arg2: memref, - %arg3: index) { - cond_br %arg0, ^bb1, ^bb2(%arg3: index) -^bb1: - br ^bb3(%arg1 : memref) -^bb2(%0: index): - %1 = alloc(%0) : memref // aliases: %2 - use(%1) - br ^bb3(%1 : memref) -^bb3(%2: memref): - "linalg.copy"(%2, %arg2) : (memref, memref) -> () - return -} -``` +removed using one of the post-pass transformations or the canonicalization +pass. -In the presence of DSTs, we have to parameterize the allocations with -additional dimension information of the source buffers, we want to copy from. -BufferDeallocation automatically introduces all required operations to extract -dimension specifications and wires them with the associated allocations: - -```mlir -func @condBranchDynamicType( - %arg0: i1, - %arg1: memref, - %arg2: memref, - %arg3: index) { - cond_br %arg0, ^bb1, ^bb2(%arg3 : index) -^bb1: - %c0 = constant 0 : index - %0 = dim %arg1, %c0 : memref // dimension operation to parameterize - // the following temp allocation - %1 = alloc(%0) : memref - "linalg.copy"(%arg1, %1) : (memref, memref) -> () - br ^bb3(%1 : memref) -^bb2(%2: index): - %3 = alloc(%2) : memref - use(%3) - %c0_0 = constant 0 : index - %4 = dim %3, %c0_0 : memref // dimension operation to parameterize - // the following temp allocation - %5 = alloc(%4) : memref - "linalg.copy"(%3, %5) : (memref, memref) -> () - dealloc %3 : memref // %3 can be safely freed here - br ^bb3(%5 : memref) -^bb3(%6: memref): - "linalg.copy"(%6, %arg2) : (memref, memref) -> () - dealloc %6 : memref // %6 can be safely freed here - return -} -``` +The presented example also works with dynamically shaped types. BufferDeallocation performs a fix-point iteration taking all aliases of all tracked allocations into account. We initialize the general iteration process @@ -335,7 +278,7 @@ ^bb1: br ^bb6(%arg1 : memref) ^bb2(%0: index): - %1 = alloc(%0) : memref // cannot be moved upwards due to the data + %1 = memref.alloc(%0) : memref // cannot be moved upwards due to the data // dependency to %0 // aliases: %2, %3, %4 use(%1) @@ -349,7 +292,7 @@ ^bb6(%3: memref): // crit. alias of %arg1 and %2 (in other words %1) br ^bb7(%3 : memref) ^bb7(%4: memref): // non-crit. alias of %3, since %3 dominates %4 - "linalg.copy"(%4, %arg2) : (memref, memref) -> () + test.copy(%4, %arg2) : (memref, memref) -> () return } ``` @@ -366,13 +309,10 @@ %arg3: index) { cond_br %arg0, ^bb1, ^bb2(%arg3 : index) ^bb1: - %c0 = constant 0 : index - %d0 = dim %arg1, %c0 : memref - %5 = alloc(%d0) : memref // temp buffer required due to alias %3 - "linalg.copy"(%arg1, %5) : (memref, memref) -> () + %5 = memref.clone %arg1 : (memref) -> (memref) br ^bb6(%5 : memref) ^bb2(%0: index): - %1 = alloc(%0) : memref + %1 = memref.alloc(%0) : memref use(%1) cond_br %arg0, ^bb3, ^bb4 ^bb3: @@ -380,17 +320,14 @@ ^bb4: br ^bb5(%1 : memref) ^bb5(%2: memref): - %c0_0 = constant 0 : index - %d1 = dim %2, %c0_0 : memref - %6 = alloc(%d1) : memref // temp buffer required due to alias %3 - "linalg.copy"(%1, %6) : (memref, memref) -> () - dealloc %1 : memref + %6 = memref.clone %1 : (memref) -> (memref) + memref.dealloc %1 : memref br ^bb6(%6 : memref) ^bb6(%3: memref): br ^bb7(%3 : memref) ^bb7(%4: memref): - "linalg.copy"(%4, %arg2) : (memref, memref) -> () - dealloc %3 : memref // free %3, since %4 is a non-crit. alias of %3 + test.copy(%4, %arg2) : (memref, memref) -> () + memref.dealloc %3 : memref // free %3, since %4 is a non-crit. alias of %3 return } ``` @@ -399,7 +336,7 @@ temporary copy in all predecessor blocks. %3 has an additional (non-critical) alias %4 that extends the live range until the end of bb7. Therefore, we can free %3 after its last use, while taking all aliases into account. Note that %4 - does not need to be freed, since we did not introduce a copy for it. +does not need to be freed, since we did not introduce a copy for it. The actual introduction of buffer copies is done after the fix-point iteration has been terminated and all critical aliases have been detected. A critical @@ -445,7 +382,7 @@ func @inner_region_control_flow( %arg0 : index, %arg1 : index) -> memref { - %0 = alloc(%arg0, %arg0) : memref + %0 = memref.alloc(%arg0, %arg0) : memref %1 = custom.region_if %0 : memref -> (memref) then(%arg2 : memref) { // aliases: %arg4, %1 custom.region_if_yield %arg2 : memref @@ -468,11 +405,11 @@ ```mlir func @nested_region_control_flow(%arg0 : index, %arg1 : index) -> memref { %0 = cmpi "eq", %arg0, %arg1 : index - %1 = alloc(%arg0, %arg0) : memref + %1 = memref.alloc(%arg0, %arg0) : memref %2 = scf.if %0 -> (memref) { scf.yield %1 : memref // %2 will be an alias of %1 } else { - %3 = alloc(%arg0, %arg1) : memref // nested allocation in a div. + %3 = memref.alloc(%arg0, %arg1) : memref // nested allocation in a div. // branch use(%3) scf.yield %1 : memref // %2 will be an alias of %1 @@ -489,13 +426,13 @@ ```mlir func @nested_region_control_flow(%arg0: index, %arg1: index) -> memref { %0 = cmpi "eq", %arg0, %arg1 : index - %1 = alloc(%arg0, %arg0) : memref + %1 = memref.alloc(%arg0, %arg0) : memref %2 = scf.if %0 -> (memref) { scf.yield %1 : memref } else { - %3 = alloc(%arg0, %arg1) : memref + %3 = memref.alloc(%arg0, %arg1) : memref use(%3) - dealloc %3 : memref // %3 can be safely freed here + memref.dealloc %3 : memref // %3 can be safely freed here scf.yield %1 : memref } return %2 : memref @@ -514,12 +451,12 @@ func @inner_region_control_flow_div( %arg0 : index, %arg1 : index) -> memref { - %0 = alloc(%arg0, %arg0) : memref + %0 = memref.alloc(%arg0, %arg0) : memref %1 = custom.region_if %0 : memref -> (memref) then(%arg2 : memref) { // aliases: %arg4, %1 custom.region_if_yield %arg2 : memref } else(%arg3 : memref) { - %2 = alloc(%arg0, %arg1) : memref // aliases: %arg4, %1 + %2 = memref.alloc(%arg0, %arg1) : memref // aliases: %arg4, %1 custom.region_if_yield %2 : memref } join(%arg4 : memref) { // aliases: %1 custom.region_if_yield %arg4 : memref @@ -537,40 +474,22 @@ func @inner_region_control_flow_div( %arg0 : index, %arg1 : index) -> memref { - %0 = alloc(%arg0, %arg0) : memref + %0 = memref.alloc(%arg0, %arg0) : memref %1 = custom.region_if %0 : memref -> (memref) then(%arg2 : memref) { - %c0 = constant 0 : index // determine dimension extents for temp allocation - %2 = dim %arg2, %c0 : memref - %c1 = constant 1 : index - %3 = dim %arg2, %c1 : memref - %4 = alloc(%2, %3) : memref // temp buffer required due to critic. - // alias %arg4 - linalg.copy(%arg2, %4) : memref, memref + %4 = memref.clone %arg2 : (memref) -> (memref) custom.region_if_yield %4 : memref } else(%arg3 : memref) { - %2 = alloc(%arg0, %arg1) : memref - %c0 = constant 0 : index // determine dimension extents for temp allocation - %3 = dim %2, %c0 : memref - %c1 = constant 1 : index - %4 = dim %2, %c1 : memref - %5 = alloc(%3, %4) : memref // temp buffer required due to critic. - // alias %arg4 - linalg.copy(%2, %5) : memref, memref - dealloc %2 : memref + %2 = memref.alloc(%arg0, %arg1) : memref + %5 = memref.clone %2 : (memref) -> (memref) + memref.dealloc %2 : memref custom.region_if_yield %5 : memref } join(%arg4: memref) { - %c0 = constant 0 : index // determine dimension extents for temp allocation - %2 = dim %arg4, %c0 : memref - %c1 = constant 1 : index - %3 = dim %arg4, %c1 : memref - %4 = alloc(%2, %3) : memref // this allocation will be removed by - // applying the copy removal pass - linalg.copy(%arg4, %4) : memref, memref - dealloc %arg4 : memref + %4 = memref.clone %arg4 : (memref) -> (memref) + memref.dealloc %arg4 : memref custom.region_if_yield %4 : memref } - dealloc %0 : memref // %0 can be safely freed here + memref.dealloc %0 : memref // %0 can be safely freed here return %1 : memref } ``` @@ -600,7 +519,7 @@ iter_args(%iterBuf = %buf) -> memref<2xf32> { %1 = cmpi "eq", %i, %ub : index %2 = scf.if %1 -> (memref<2xf32>) { - %3 = alloc() : memref<2xf32> // makes %2 a critical alias due to a + %3 = memref.alloc() : memref<2xf32> // makes %2 a critical alias due to a // divergent allocation use(%3) scf.yield %3 : memref<2xf32> @@ -609,7 +528,7 @@ } scf.yield %2 : memref<2xf32> } - "linalg.copy"(%0, %res) : (memref<2xf32>, memref<2xf32>) -> () + test.copy(%0, %res) : (memref<2xf32>, memref<2xf32>) -> () return } ``` @@ -634,31 +553,27 @@ %step: index, %buf: memref<2xf32>, %res: memref<2xf32>) { - %4 = alloc() : memref<2xf32> - "linalg.copy"(%buf, %4) : (memref<2xf32>, memref<2xf32>) -> () + %4 = memref.clone %buf : (memref<2xf32>) -> (memref<2xf32>) %0 = scf.for %i = %lb to %ub step %step iter_args(%iterBuf = %4) -> memref<2xf32> { %1 = cmpi "eq", %i, %ub : index %2 = scf.if %1 -> (memref<2xf32>) { - %3 = alloc() : memref<2xf32> // makes %2 a critical alias + %3 = memref.alloc() : memref<2xf32> // makes %2 a critical alias use(%3) - %5 = alloc() : memref<2xf32> // temp copy due to crit. alias %2 - "linalg.copy"(%3, %5) : memref<2xf32>, memref<2xf32> - dealloc %3 : memref<2xf32> + %5 = memref.clone %3 : (memref<2xf32>) -> (memref<2xf32>) + memref.dealloc %3 : memref<2xf32> scf.yield %5 : memref<2xf32> } else { - %6 = alloc() : memref<2xf32> // temp copy due to crit. alias %2 - "linalg.copy"(%iterBuf, %6) : memref<2xf32>, memref<2xf32> + %6 = memref.clone %iterBuf : (memref<2xf32>) -> (memref<2xf32>) scf.yield %6 : memref<2xf32> } - %7 = alloc() : memref<2xf32> // temp copy due to crit. alias %iterBuf - "linalg.copy"(%2, %7) : memref<2xf32>, memref<2xf32> - dealloc %2 : memref<2xf32> - dealloc %iterBuf : memref<2xf32> // free backedge iteration variable + %7 = memref.clone %2 : (memref<2xf32>) -> (memref<2xf32>) + memref.dealloc %2 : memref<2xf32> + memref.dealloc %iterBuf : memref<2xf32> // free backedge iteration variable scf.yield %7 : memref<2xf32> } - "linalg.copy"(%0, %res) : (memref<2xf32>, memref<2xf32>) -> () - dealloc %0 : memref<2xf32> // free temp copy %0 + test.copy(%0, %res) : (memref<2xf32>, memref<2xf32>) -> () + memref.dealloc %0 : memref<2xf32> // free temp copy %0 return } ``` @@ -684,12 +599,100 @@ In order to limit the complexity of the BufferDeallocation transformation, some tiny code-polishing/optimization transformations are not applied on-the-fly -during placement. Currently, there is only the CopyRemoval transformation to -remove unnecessary copy and allocation operations. +during placement. Currently, there is the CopyRemoval transformation to +remove unnecessary copy and allocation operations. Furthermore, a +canonicalization pattern is added to the clone operation to reduce the +appearance of unnecessary clones. Note: further transformations might be added to the post-pass phase in the future. +## Clone Canonicalization + +During placement of clones it may happen, that unnecessary clones are inserted. +If these clones appear with their corresponding dealloc operation within the +same block, we can use the canonicalizer to remove these unnecessary operations. +Note, that this step needs to take place after the insertion of clones and +deallocs in the buffer deallocation step. The canonicalization inludes both, +the newly created target value from the clone operation and the source +operation. + +## Canonicalization of the Source Buffer of the Clone Operation + +In this case, the source of the clone operation can be used instead of its +target. The unused allocation and deallocation operations that are defined for +this clone operation are also removed. Here is a working example generated by +the BufferDeallocation pass that allocates a buffer with dynamic size. A deeper +analysis of this sample reveals that the highlighted operations are redundant +and can be removed. + +```mlir +func @dynamic_allocation(%arg0: index, %arg1: index) -> memref { + %1 = memref.alloc(%arg0, %arg1) : memref + %2 = memref.clone %1 : (memref) -> (memref) + memref.dealloc %1 : memref + return %2 : memref +} +``` + +Will be transformed to: + +```mlir +func @dynamic_allocation(%arg0: index, %arg1: index) -> memref { + %1 = memref.alloc(%arg0, %arg1) : memref + return %1 : memref +} +``` + +In this case, the additional copy %2 can be replaced with its original source +buffer %1. This also applies to the associated dealloc operation of %1. + +## Canonicalization of the Target Buffer of the Clone Operation + +In this case, the target buffer of the clone operation can be used instead of +its source. The unused deallocation operation that is defined for this clone +operation is also removed. + +Consider the following example where a generic test operation writes the result +to %temp and then copies %temp to %result. However, these two operations +can be merged into a single step. Canonicalization removes the clone operation +and %temp, and replaces the uses of %temp with %result: + +```mlir +func @reuseTarget(%arg0: memref<2xf32>, %result: memref<2xf32>){ + %temp = memref.alloc() : memref<2xf32> + test.generic { + args_in = 1 : i64, + args_out = 1 : i64, + indexing_maps = [#map0, #map0], + iterator_types = ["parallel"]} %arg0, %temp { + ^bb0(%gen2_arg0: f32, %gen2_arg1: f32): + %tmp2 = exp %gen2_arg0 : f32 + test.yield %tmp2 : f32 + }: memref<2xf32>, memref<2xf32> + %result = memref.clone %temp : (memref<2xf32>) -> (memref<2xf32>) + memref.dealloc %temp : memref<2xf32> + return +} +``` + +Will be transformed to: + +```mlir +func @reuseTarget(%arg0: memref<2xf32>, %result: memref<2xf32>){ + test.generic { + args_in = 1 : i64, + args_out = 1 : i64, + indexing_maps = [#map0, #map0], + iterator_types = ["parallel"]} %arg0, %result { + ^bb0(%gen2_arg0: f32, %gen2_arg1: f32): + %tmp2 = exp %gen2_arg0 : f32 + test.yield %tmp2 : f32 + }: memref<2xf32>, memref<2xf32> + return +} +``` + ## CopyRemoval Pass A common pattern that arises during placement is the introduction of @@ -767,16 +770,16 @@ ```mlir func @reuseTarget(%arg0: memref<2xf32>, %result: memref<2xf32>){ %temp = alloc() : memref<2xf32> - linalg.generic { + test.generic { args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg0, %temp { ^bb0(%gen2_arg0: f32, %gen2_arg1: f32): %tmp2 = exp %gen2_arg0 : f32 - linalg.yield %tmp2 : f32 + test.yield %tmp2 : f32 }: memref<2xf32>, memref<2xf32> - "linalg.copy"(%temp, %result) : (memref<2xf32>, memref<2xf32>) -> () + test.copy(%temp, %result) : (memref<2xf32>, memref<2xf32>) -> () dealloc %temp : memref<2xf32> return } @@ -786,14 +789,14 @@ ```mlir func @reuseTarget(%arg0: memref<2xf32>, %result: memref<2xf32>){ - linalg.generic { + test.generic { args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg0, %result { ^bb0(%gen2_arg0: f32, %gen2_arg1: f32): %tmp2 = exp %gen2_arg0 : f32 - linalg.yield %tmp2 : f32 + test.yield %tmp2 : f32 }: memref<2xf32>, memref<2xf32> return } @@ -813,6 +816,6 @@ BufferDeallocation introduces additional copies using allocations from the “memref” dialect (“memref.alloc”). Analogous, all deallocations use the “memref” dialect-free operation “memref.dealloc”. The actual copy process is -realized using “linalg.copy”. Furthermore, buffers are essentially immutable +realized using “test.copy”. Furthermore, buffers are essentially immutable after their creation in a block. Another limitations are known in the case using unstructered control flow. diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h @@ -12,6 +12,7 @@ #include "mlir/IR/Dialect.h" #include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Interfaces/CastInterfaces.h" +#include "mlir/Interfaces/CopyOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/ViewLikeInterface.h" diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -11,6 +11,7 @@ include "mlir/Dialect/MemRef/IR/MemRefBase.td" include "mlir/Interfaces/CastInterfaces.td" +include "mlir/Interfaces/CopyOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/ViewLikeInterface.td" include "mlir/IR/SymbolInterfaces.td" @@ -333,6 +334,43 @@ let hasFolder = 1; } +//===----------------------------------------------------------------------===// +// CloneOp +//===----------------------------------------------------------------------===// + +def CloneOp : MemRef_Op<"clone", [ + CopyOpInterface, + DeclareOpInterfaceMethods + ]> { + let builders = [ + OpBuilder<(ins "Value":$value), [{ + return build($_builder, $_state, value.getType(), value); + }]>]; + + let description = [{ + Clones the data in the input view into an implicitly defined output view. + + Usage: + + ```mlir + %arg 1 = memref.clone %arg0 : memref to memref + ``` + }]; + + let arguments = (ins Arg:$input); + let results = (outs Arg:$output); + + let extraClassDeclaration = [{ + Value getSource() { return input();} + Value getTarget() { return output(); } + }]; + + let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)"; + + let hasFolder = 1; + let hasCanonicalizer = 1; +} + //===----------------------------------------------------------------------===// // DeallocOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h @@ -0,0 +1,37 @@ +//===- MemRefUtils.h - MemRef transformation utilities ----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file defines prototypes for various transformation utilities for +// the MemRefOps dialect. These are not passes by themselves but are used +// either by passes, optimization sequences, or in turn by other transformation +// utilities. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_MEMREFOPS_UTILS_MEMREFUTILS_H +#define MLIR_DIALECT_MEMREFOPS_UTILS_MEMREFUTILS_H + +#include "mlir/Dialect/MemRef/IR/MemRef.h" + +namespace mlir { + +/// Returns true if `val` value has at least a user between `start` and +/// `end` operations. +bool hasUsersBetween(Value val, Operation *start, Operation *end); + +/// Finds associated deallocs that can be linked to our allocation nodes (if +/// any). +Operation *findDealloc(Value allocValue); + +/// Returns the allocation operation for `value` in `block` if it exists. +/// nullptr otherwise. +Operation *findAlloc(Value value, Block *block); + +} // end namespace mlir + +#endif // MLIR_DIALECT_MEMREFOPS_UTILS_MEMREFUTILS_H diff --git a/mlir/include/mlir/Transforms/BufferUtils.h b/mlir/include/mlir/Transforms/BufferUtils.h --- a/mlir/include/mlir/Transforms/BufferUtils.h +++ b/mlir/include/mlir/Transforms/BufferUtils.h @@ -39,10 +39,6 @@ static Operation *getStartOperation(Value allocValue, Block *placementBlock, const Liveness &liveness); - /// Find an associated dealloc operation that is linked to the given - /// allocation node (if any). - static Operation *findDealloc(Value allocValue); - public: /// Initializes the internal list by discovering all supported allocation /// nodes. diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -282,8 +282,6 @@ }]; let constructor = "mlir::createBufferDeallocationPass()"; - // TODO: this pass likely shouldn't depend on Linalg? - let dependentDialects = ["linalg::LinalgDialect"]; } def BufferHoisting : FunctionPass<"buffer-hoisting"> { diff --git a/mlir/lib/Dialect/MemRef/CMakeLists.txt b/mlir/lib/Dialect/MemRef/CMakeLists.txt --- a/mlir/lib/Dialect/MemRef/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/CMakeLists.txt @@ -1 +1,22 @@ -add_subdirectory(IR) +add_mlir_dialect_library(MLIRMemRef + IR/MemRefDialect.cpp + IR/MemRefOps.cpp + Utils/MemRefUtils.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/inlude/mlir/Dialect/MemRefDialect + + DEPENDS + MLIRStandardOpsIncGen + MLIRMemRefOpsIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRDialect + MLIRIR + MLIRStandard + MLIRTensor + MLIRViewLikeInterface +) diff --git a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt deleted file mode 100644 --- a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt +++ /dev/null @@ -1,21 +0,0 @@ -add_mlir_dialect_library(MLIRMemRef - MemRefDialect.cpp - MemRefOps.cpp - - ADDITIONAL_HEADER_DIRS - ${PROJECT_SOURCE_DIR}/inlude/mlir/Dialect/MemRefDialect - - DEPENDS - MLIRStandardOpsIncGen - MLIRMemRefOpsIncGen - - LINK_COMPONENTS - Core - - LINK_LIBS PUBLIC - MLIRDialect - MLIRIR - MLIRStandard - MLIRTensor - MLIRViewLikeInterface -) diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -464,6 +465,72 @@ return succeeded(foldMemRefCast(*this)) ? getResult() : Value(); } +//===----------------------------------------------------------------------===// +// CloneOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(CloneOp op) { return success(); } + +void CloneOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), input(), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), output(), + SideEffects::DefaultResource::get()); +} + +namespace { + +/// Fold Dealloc operations that are deallocating an AllocOp that is only used +/// by other Dealloc operations. +struct SimplifyClones : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(CloneOp cloneOp, + PatternRewriter &rewriter) const override { + if (cloneOp.use_empty()) { + rewriter.eraseOp(cloneOp); + return success(); + } + + Value source = cloneOp.input(); + Operation *clone = cloneOp.getOperation(); + Operation *deallocOp = mlir::findDealloc(cloneOp.output()); + Operation *sourceOp = source.getDefiningOp(); + if (sourceOp && deallocOp && + sourceOp->getBlock() == deallocOp->getBlock() && + !mlir::hasUsersBetween(source, clone, deallocOp)) { + rewriter.replaceOp(cloneOp, source); + rewriter.eraseOp(deallocOp); + return success(); + } + + deallocOp = mlir::findDealloc(source); + Operation *allocOp = mlir::findAlloc(source, clone->getBlock()); + if (sourceOp && deallocOp && allocOp && + allocOp->getBlock() == deallocOp->getBlock() && + !mlir::hasUsersBetween(source, allocOp, cloneOp) && + !mlir::hasUsersBetween(source, cloneOp, deallocOp)) { + rewriter.replaceOp(cloneOp, source); + rewriter.eraseOp(deallocOp); + return success(); + } + return failure(); + } +}; + +} // end anonymous namespace. + +void CloneOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + +OpFoldResult CloneOp::fold(ArrayRef operands) { + return succeeded(foldMemRefCast(*this)) ? getResult() : Value(); +} + //===----------------------------------------------------------------------===// // DeallocOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp @@ -0,0 +1,60 @@ +//===- Utils.cpp - Utilities to support the MemRef dialect ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements utilities for the MemRef dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" + +using namespace mlir; + +/// Returns true if `val` value has at least a user between `start` and +/// `end` operations. +bool mlir::hasUsersBetween(Value val, Operation *start, Operation *end) { + assert((start || end) && "Start and end operations cannot be null"); + Block *block = start->getBlock(); + assert(block == end->getBlock() && + "Start and end operations should be in the same block."); + return llvm::any_of(val.getUsers(), [&](Operation *op) { + return op->getBlock() == block && start->isBeforeInBlock(op) && + op->isBeforeInBlock(end); + }); +}; + +/// Finds associated deallocs that can be linked to our allocation nodes (if +/// any). +Operation *mlir::findDealloc(Value allocValue) { + auto userIt = llvm::find_if(allocValue.getUsers(), [&](Operation *user) { + auto effectInterface = dyn_cast(user); + if (!effectInterface) + return false; + // Try to find a free effect that is applied to one of our values + // that will be automatically freed by our pass. + SmallVector effects; + effectInterface.getEffectsOnValue(allocValue, effects); + return llvm::any_of(effects, [&](MemoryEffects::EffectInstance &it) { + return isa(it.getEffect()); + }); + }); + // Assign the associated dealloc operation (if any). + return userIt != allocValue.user_end() ? *userIt : nullptr; +} + +/// Returns the allocation operation for `value` in `block` if it exists. +/// nullptr otherwise. +Operation *mlir::findAlloc(Value value, Block *block) { + Operation *op = value.getDefiningOp(); + if (op && op->getBlock() == block) { + auto effects = dyn_cast(op); + if (effects && effects.hasEffect()) + return op; + } + return nullptr; +} diff --git a/mlir/lib/Transforms/BufferDeallocation.cpp b/mlir/lib/Transforms/BufferDeallocation.cpp --- a/mlir/lib/Transforms/BufferDeallocation.cpp +++ b/mlir/lib/Transforms/BufferDeallocation.cpp @@ -7,16 +7,15 @@ //===----------------------------------------------------------------------===// // // This file implements logic for computing correct alloc and dealloc positions. -// Furthermore, buffer placement also adds required new alloc and copy -// operations to ensure that all buffers are deallocated. The main class is the +// Furthermore, buffer deallocation also adds required new clone operations to +// ensure that all buffers are deallocated. The main class is the // BufferDeallocationPass class that implements the underlying algorithm. In // order to put allocations and deallocations at safe positions, it is // significantly important to put them into the correct blocks. However, the // liveness analysis does not pay attention to aliases, which can occur due to // branches (and their associated block arguments) in general. For this purpose, // BufferDeallocation firstly finds all possible aliases for a single value -// (using the BufferAliasAnalysis class). Consider the following -// example: +// (using the BufferAliasAnalysis class). Consider the following example: // // ^bb0(%arg0): // cond_br %cond, ^bb1, ^bb2 @@ -30,16 +29,16 @@ // // We should place the dealloc for %new_value in exit. However, we have to free // the buffer in the same block, because it cannot be freed in the post -// dominator. However, this requires a new copy buffer for %arg1 that will +// dominator. However, this requires a new clone buffer for %arg1 that will // contain the actual contents. Using the class BufferAliasAnalysis, we // will find out that %new_value has a potential alias %arg1. In order to find // the dealloc position we have to find all potential aliases, iterate over // their uses and find the common post-dominator block (note that additional -// copies and buffers remove potential aliases and will influence the placement +// clones and buffers remove potential aliases and will influence the placement // of the deallocs). In all cases, the computed block can be safely used to free // the %new_value buffer (may be exit or bb2) as it will die and we can use // liveness information to determine the exact operation after which we have to -// insert the dealloc. However, the algorithm supports introducing copy buffers +// insert the dealloc. However, the algorithm supports introducing clone buffers // and placing deallocs in safe locations to ensure that all buffers will be // freed in the end. // @@ -52,10 +51,8 @@ //===----------------------------------------------------------------------===// #include "PassDetail.h" -#include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "mlir/Dialect/StandardOps/Utils/Utils.h" #include "mlir/IR/Operation.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/LoopLikeInterface.h" @@ -187,25 +184,25 @@ /// The buffer deallocation transformation which ensures that all allocs in the /// program have a corresponding de-allocation. As a side-effect, it might also -/// introduce copies that in turn leads to additional allocs and de-allocations. +/// introduce clones that in turn leads to additional deallocations. class BufferDeallocation : BufferPlacementTransformationBase { public: BufferDeallocation(Operation *op) : BufferPlacementTransformationBase(op), dominators(op), postDominators(op) {} - /// Performs the actual placement/creation of all temporary alloc, copy and - /// dealloc nodes. + /// Performs the actual placement/creation of all temporary clone and dealloc + /// nodes. void deallocate() { - // Add additional allocations and copies that are required. - introduceCopies(); + // Add additional clones that are required. + introduceClones(); // Place deallocations for all allocation entries. placeDeallocs(); } private: - /// Introduces required allocs and copy operations to avoid memory leaks. - void introduceCopies() { + /// Introduces required clone operations to avoid memory leaks. + void introduceClones() { // Initialize the set of values that require a dedicated memory free // operation since their operands cannot be safely deallocated in a post // dominator. @@ -214,7 +211,7 @@ SmallVector, 8> toProcess; // Check dominance relation for proper dominance properties. If the given - // value node does not dominate an alias, we will have to create a copy in + // value node does not dominate an alias, we will have to create a clone in // order to free all buffers that can potentially leak into a post // dominator. auto findUnsafeValues = [&](Value source, Block *definingBlock) { @@ -255,7 +252,7 @@ // arguments at the correct locations. aliases.remove(valuesToFree); - // Add new allocs and additional copy operations. + // Add new allocs and additional clone operations. for (Value value : valuesToFree) { if (auto blockArg = value.dyn_cast()) introduceBlockArgCopy(blockArg); @@ -269,7 +266,7 @@ } } - /// Introduces temporary allocs in all predecessors and copies the source + /// Introduces temporary clones in all predecessors and copies the source /// values into the newly allocated buffers. void introduceBlockArgCopy(BlockArgument blockArg) { // Allocate a buffer for the current block argument in the block of @@ -285,9 +282,9 @@ Value sourceValue = branchInterface.getSuccessorOperands(it.getSuccessorIndex()) .getValue()[blockArg.getArgNumber()]; - // Create a new alloc and copy at the current location of the terminator. - Value alloc = introduceBufferCopy(sourceValue, terminator); - // Wire new alloc and successor operand. + // Create a new clone at the current location of the terminator. + Value clone = introduceCloneBuffers(sourceValue, terminator); + // Wire new clone and successor operand. auto mutableOperands = branchInterface.getMutableSuccessorOperands(it.getSuccessorIndex()); if (!mutableOperands.hasValue()) @@ -296,7 +293,7 @@ else mutableOperands.getValue() .slice(blockArg.getArgNumber(), 1) - .assign(alloc); + .assign(clone); } // Check whether the block argument has implicitly defined predecessors via @@ -310,7 +307,7 @@ !(regionInterface = dyn_cast(parentOp))) return; - introduceCopiesForRegionSuccessors( + introduceClonesForRegionSuccessors( regionInterface, argRegion->getParentOp()->getRegions(), blockArg, [&](RegionSuccessor &successorRegion) { // Find a predecessor of our argRegion. @@ -318,7 +315,7 @@ }); // Check whether the block argument belongs to an entry region of the - // parent operation. In this case, we have to introduce an additional copy + // parent operation. In this case, we have to introduce an additional clone // for buffer that is passed to the argument. SmallVector successorRegions; regionInterface.getSuccessorRegions(/*index=*/llvm::None, successorRegions); @@ -329,20 +326,20 @@ if (it == successorRegions.end()) return; - // Determine the actual operand to introduce a copy for and rewire the - // operand to point to the copy instead. + // Determine the actual operand to introduce a clone for and rewire the + // operand to point to the clone instead. Value operand = regionInterface.getSuccessorEntryOperands(argRegion->getRegionNumber()) [llvm::find(it->getSuccessorInputs(), blockArg).getIndex()]; - Value copy = introduceBufferCopy(operand, parentOp); + Value clone = introduceCloneBuffers(operand, parentOp); auto op = llvm::find(parentOp->getOperands(), operand); assert(op != parentOp->getOperands().end() && "parentOp does not contain operand"); - parentOp->setOperand(op.getIndex(), copy); + parentOp->setOperand(op.getIndex(), clone); } - /// Introduces temporary allocs in front of all associated nested-region + /// Introduces temporary clones in front of all associated nested-region /// terminators and copies the source values into the newly allocated buffers. void introduceValueCopyForRegionResult(Value value) { // Get the actual result index in the scope of the parent terminator. @@ -354,20 +351,20 @@ // its parent operation. return !successorRegion.getSuccessor(); }; - // Introduce a copy for all region "results" that are returned to the parent - // operation. This is required since the parent's result value has been - // considered critical. Therefore, the algorithm assumes that a copy of a - // previously allocated buffer is returned by the operation (like in the - // case of a block argument). - introduceCopiesForRegionSuccessors(regionInterface, operation->getRegions(), + // Introduce a clone for all region "results" that are returned to the + // parent operation. This is required since the parent's result value has + // been considered critical. Therefore, the algorithm assumes that a clone + // of a previously allocated buffer is returned by the operation (like in + // the case of a block argument). + introduceClonesForRegionSuccessors(regionInterface, operation->getRegions(), value, regionPredicate); } - /// Introduces buffer copies for all terminators in the given regions. The + /// Introduces buffer clones for all terminators in the given regions. The /// regionPredicate is applied to every successor region in order to restrict - /// the copies to specific regions. + /// the clones to specific regions. template - void introduceCopiesForRegionSuccessors( + void introduceClonesForRegionSuccessors( RegionBranchOpInterface regionInterface, MutableArrayRef regions, Value argValue, const TPredicate ®ionPredicate) { for (Region ®ion : regions) { @@ -393,49 +390,37 @@ walkReturnOperations(®ion, [&](Operation *terminator) { // Extract the source value from the current terminator. Value sourceValue = terminator->getOperand(operandIndex); - // Create a new alloc at the current location of the terminator. - Value alloc = introduceBufferCopy(sourceValue, terminator); - // Wire alloc and terminator operand. - terminator->setOperand(operandIndex, alloc); + // Create a new clone at the current location of the terminator. + Value clone = introduceCloneBuffers(sourceValue, terminator); + // Wire clone and terminator operand. + terminator->setOperand(operandIndex, clone); }); } } - /// Creates a new memory allocation for the given source value and copies + /// Creates a new memory allocation for the given source value and clones /// its content into the newly allocated buffer. The terminator operation is - /// used to insert the alloc and copy operations at the right places. - Value introduceBufferCopy(Value sourceValue, Operation *terminator) { - // Avoid multiple copies of the same source value. This can happen in the + /// used to insert the clone operation at the right place. + Value introduceCloneBuffers(Value sourceValue, Operation *terminator) { + // Avoid multiple clones of the same source value. This can happen in the // presence of loops when a branch acts as a backedge while also having // another successor that returns to its parent operation. Note: that // copying copied buffers can introduce memory leaks since the invariant of - // BufferPlacement assumes that a buffer will be only copied once into a - // temporary buffer. Hence, the construction of copy chains introduces + // BufferDeallocation assumes that a buffer will be only cloned once into a + // temporary buffer. Hence, the construction of clone chains introduces // additional allocations that are not tracked automatically by the // algorithm. - if (copiedValues.contains(sourceValue)) + if (clonedValues.contains(sourceValue)) return sourceValue; - // Create a new alloc at the current location of the terminator. - auto memRefType = sourceValue.getType().cast(); + // Create a new clone operation that copies the contents of the old + // buffer to the new one. OpBuilder builder(terminator); + auto cloneOp = + builder.create(terminator->getLoc(), sourceValue); - // Extract information about dynamically shaped types by - // extracting their dynamic dimensions. - auto dynamicOperands = - getDynOperands(terminator->getLoc(), sourceValue, builder); - - // TODO: provide a generic interface to create dialect-specific - // Alloc and CopyOp nodes. - auto alloc = builder.create(terminator->getLoc(), - memRefType, dynamicOperands); - - // Create a new copy operation that copies to contents of the old - // allocation to the new one. - builder.create(terminator->getLoc(), sourceValue, alloc); - - // Remember the copy of original source value. - copiedValues.insert(alloc); - return alloc; + // Remember the clone of original source value. + clonedValues.insert(cloneOp); + return cloneOp; } /// Finds correct dealloc positions according to the algorithm described at @@ -513,8 +498,8 @@ /// position. PostDominanceInfo postDominators; - /// Stores already copied allocations to avoid additional copies of copies. - ValueSetT copiedValues; + /// Stores already cloned buffers to avoid additional clones of clones. + ValueSetT clonedValues; }; //===----------------------------------------------------------------------===// @@ -522,8 +507,8 @@ //===----------------------------------------------------------------------===// /// The actual buffer deallocation pass that inserts and moves dealloc nodes -/// into the right positions. Furthermore, it inserts additional allocs and -/// copies if necessary. It uses the algorithm described at the top of the file. +/// into the right positions. Furthermore, it inserts additional clones if +/// necessary. It uses the algorithm described at the top of the file. struct BufferDeallocationPass : BufferDeallocationBase { void runOnFunction() override { @@ -540,7 +525,7 @@ return signalPassFailure(); } - // Place all required temporary alloc, copy and dealloc nodes. + // Place all required temporary clone and dealloc nodes. BufferDeallocation deallocation(getFunction()); deallocation.deallocate(); } diff --git a/mlir/lib/Transforms/BufferUtils.cpp b/mlir/lib/Transforms/BufferUtils.cpp --- a/mlir/lib/Transforms/BufferUtils.cpp +++ b/mlir/lib/Transforms/BufferUtils.cpp @@ -12,7 +12,7 @@ #include "mlir/Transforms/BufferUtils.h" #include "PassDetail.h" -#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Operation.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" @@ -49,25 +49,6 @@ return startOperation; } -/// Finds associated deallocs that can be linked to our allocation nodes (if -/// any). -Operation *BufferPlacementAllocs::findDealloc(Value allocValue) { - auto userIt = llvm::find_if(allocValue.getUsers(), [&](Operation *user) { - auto effectInterface = dyn_cast(user); - if (!effectInterface) - return false; - // Try to find a free effect that is applied to one of our values - // that will be automatically freed by our pass. - SmallVector effects; - effectInterface.getEffectsOnValue(allocValue, effects); - return llvm::any_of(effects, [&](MemoryEffects::EffectInstance &it) { - return isa(it.getEffect()); - }); - }); - // Assign the associated dealloc operation (if any). - return userIt != allocValue.user_end() ? *userIt : nullptr; -} - /// Initializes the internal list by discovering all supported allocation /// nodes. BufferPlacementAllocs::BufferPlacementAllocs(Operation *op) { build(op); } diff --git a/mlir/lib/Transforms/CopyRemoval.cpp b/mlir/lib/Transforms/CopyRemoval.cpp --- a/mlir/lib/Transforms/CopyRemoval.cpp +++ b/mlir/lib/Transforms/CopyRemoval.cpp @@ -6,6 +6,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" #include "mlir/Interfaces/CopyOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Pass/Pass.h" @@ -43,19 +44,6 @@ /// List of values that need to be replaced with their counterparts. llvm::SmallDenseSet, 4> replaceList; - /// Returns the allocation operation for `value` in `block` if it exists. - /// nullptr otherwise. - Operation *getAllocationOpInBlock(Value value, Block *block) { - assert(block && "Block cannot be null"); - Operation *op = value.getDefiningOp(); - if (op && op->getBlock() == block) { - auto effects = dyn_cast(op); - if (effects && effects.hasEffect()) - return op; - } - return nullptr; - } - /// Returns the deallocation operation for `value` in `block` if it exists. /// nullptr otherwise. Operation *getDeallocationOpInBlock(Value value, Block *block) { @@ -83,19 +71,6 @@ return false; }; - /// Returns true if `val` value has at least a user between `start` and - /// `end` operations. - bool hasUsersBetween(Value val, Operation *start, Operation *end) { - assert((start || end) && "Start and end operations cannot be null"); - Block *block = start->getBlock(); - assert(block == end->getBlock() && - "Start and end operations should be in the same block."); - return llvm::any_of(val.getUsers(), [&](Operation *op) { - return op->getBlock() == block && start->isBeforeInBlock(op) && - op->isBeforeInBlock(end); - }); - }; - bool areOpsInTheSameBlock(ArrayRef operations) { assert(!operations.empty() && "The operations list should contain at least a single operation"); @@ -141,7 +116,7 @@ Block *copyBlock = copy->getBlock(); Operation *fromDefiningOp = from.getDefiningOp(); Operation *fromFreeingOp = getDeallocationOpInBlock(from, copyBlock); - Operation *toDefiningOp = getAllocationOpInBlock(to, copyBlock); + Operation *toDefiningOp = findAlloc(to, copyBlock); if (!fromDefiningOp || !fromFreeingOp || !toDefiningOp || !areOpsInTheSameBlock({fromFreeingOp, toDefiningOp, copy}) || hasUsersBetween(to, toDefiningOp, copy) || @@ -190,7 +165,7 @@ Operation *copy = copyOp.getOperation(); Block *copyBlock = copy->getBlock(); - Operation *fromDefiningOp = getAllocationOpInBlock(from, copyBlock); + Operation *fromDefiningOp = findAlloc(from, copyBlock); Operation *fromFreeingOp = getDeallocationOpInBlock(from, copyBlock); if (!fromDefiningOp || !fromFreeingOp || !areOpsInTheSameBlock({fromFreeingOp, fromDefiningOp, copy}) || diff --git a/mlir/test/Transforms/buffer-deallocation.mlir b/mlir/test/Transforms/buffer-deallocation.mlir --- a/mlir/test/Transforms/buffer-deallocation.mlir +++ b/mlir/test/Transforms/buffer-deallocation.mlir @@ -30,13 +30,11 @@ } // CHECK-NEXT: cond_br -// CHECK: %[[ALLOC0:.*]] = memref.alloc() -// CHECK-NEXT: linalg.copy +// CHECK: %[[ALLOC0:.*]] = memref.clone // CHECK-NEXT: br ^bb3(%[[ALLOC0]] -// CHECK: %[[ALLOC1:.*]] = memref.alloc() +// CHECK: %[[ALLOC1:.*]] = memref.alloc // CHECK-NEXT: test.buffer_based -// CHECK: %[[ALLOC2:.*]] = memref.alloc() -// CHECK-NEXT: linalg.copy +// CHECK-NEXT: %[[ALLOC2:.*]] = memref.clone %[[ALLOC1]] // CHECK-NEXT: memref.dealloc %[[ALLOC1]] // CHECK-NEXT: br ^bb3(%[[ALLOC2]] // CHECK: test.copy @@ -77,16 +75,12 @@ } // CHECK-NEXT: cond_br -// CHECK: %[[DIM0:.*]] = memref.dim -// CHECK-NEXT: %[[ALLOC0:.*]] = memref.alloc(%[[DIM0]]) -// CHECK-NEXT: linalg.copy(%{{.*}}, %[[ALLOC0]]) +// CHECK: %[[ALLOC0:.*]] = memref.clone // CHECK-NEXT: br ^bb3(%[[ALLOC0]] // CHECK: ^bb2(%[[IDX:.*]]:{{.*}}) // CHECK-NEXT: %[[ALLOC1:.*]] = memref.alloc(%[[IDX]]) // CHECK-NEXT: test.buffer_based -// CHECK: %[[DIM1:.*]] = memref.dim %[[ALLOC1]] -// CHECK-NEXT: %[[ALLOC2:.*]] = memref.alloc(%[[DIM1]]) -// CHECK-NEXT: linalg.copy(%[[ALLOC1]], %[[ALLOC2]]) +// CHECK-NEXT: %[[ALLOC2:.*]] = memref.clone // CHECK-NEXT: memref.dealloc %[[ALLOC1]] // CHECK-NEXT: br ^bb3 // CHECK-NEXT: ^bb3(%[[ALLOC3:.*]]:{{.*}}) @@ -142,12 +136,10 @@ return } -// CHECK-NEXT: cond_br -// CHECK: ^bb1 -// CHECK: %[[DIM0:.*]] = memref.dim -// CHECK-NEXT: %[[ALLOC0:.*]] = memref.alloc(%[[DIM0]]) -// CHECK-NEXT: linalg.copy(%{{.*}}, %[[ALLOC0]]) -// CHECK-NEXT: br ^bb6 +// CHECK-NEXT: cond_br{{.*}} +// CHECK-NEXT: ^bb1 +// CHECK-NEXT: %[[ALLOC0:.*]] = memref.clone +// CHECK-NEXT: br ^bb6(%[[ALLOC0]] // CHECK: ^bb2(%[[IDX:.*]]:{{.*}}) // CHECK-NEXT: %[[ALLOC1:.*]] = memref.alloc(%[[IDX]]) // CHECK-NEXT: test.buffer_based @@ -157,9 +149,7 @@ // CHECK: ^bb4: // CHECK-NEXT: br ^bb5(%[[ALLOC1]]{{.*}}) // CHECK-NEXT: ^bb5(%[[ALLOC2:.*]]:{{.*}}) -// CHECK: %[[DIM2:.*]] = memref.dim %[[ALLOC2]] -// CHECK-NEXT: %[[ALLOC3:.*]] = memref.alloc(%[[DIM2]]) -// CHECK-NEXT: linalg.copy(%[[ALLOC2]], %[[ALLOC3]]) +// CHECK-NEXT: %[[ALLOC3:.*]] = memref.clone %[[ALLOC2]] // CHECK-NEXT: memref.dealloc %[[ALLOC1]] // CHECK-NEXT: br ^bb6(%[[ALLOC3]]{{.*}}) // CHECK-NEXT: ^bb6(%[[ALLOC4:.*]]:{{.*}}) @@ -208,13 +198,11 @@ return } -// CHECK-NEXT: %[[ALLOC0:.*]] = memref.alloc() -// CHECK-NEXT: linalg.copy +// CHECK-NEXT: %[[ALLOC0:.*]] = memref.clone // CHECK-NEXT: cond_br // CHECK: %[[ALLOC1:.*]] = memref.alloc() // CHECK-NEXT: test.buffer_based -// CHECK: %[[ALLOC2:.*]] = memref.alloc() -// CHECK-NEXT: linalg.copy +// CHECK-NEXT: %[[ALLOC2:.*]] = memref.clone %[[ALLOC1]] // CHECK-NEXT: memref.dealloc %[[ALLOC1]] // CHECK: test.copy // CHECK-NEXT: memref.dealloc @@ -419,20 +407,17 @@ return } -// CHECK-NEXT: cond_br -// CHECK: ^bb1 -// CHECK: ^bb1 +// CHECK-NEXT: cond_br{{.*}} +// CHECK-NEXT: ^bb1 // CHECK: %[[ALLOC0:.*]] = memref.alloc() // CHECK-NEXT: test.buffer_based -// CHECK: %[[ALLOC1:.*]] = memref.alloc() -// CHECK-NEXT: linalg.copy +// CHECK-NEXT: %[[ALLOC1:.*]] = memref.clone %[[ALLOC0]] // CHECK-NEXT: memref.dealloc %[[ALLOC0]] // CHECK-NEXT: br ^bb3(%[[ALLOC1]] // CHECK-NEXT: ^bb2 // CHECK-NEXT: %[[ALLOC2:.*]] = memref.alloc() // CHECK-NEXT: test.buffer_based -// CHECK: %[[ALLOC3:.*]] = memref.alloc() -// CHECK-NEXT: linalg.copy +// CHECK-NEXT: %[[ALLOC3:.*]] = memref.clone %[[ALLOC2]] // CHECK-NEXT: memref.dealloc %[[ALLOC2]] // CHECK-NEXT: br ^bb3(%[[ALLOC3]] // CHECK-NEXT: ^bb3(%[[ALLOC4:.*]]:{{.*}}) @@ -545,8 +530,7 @@ } // CHECK: (%[[cond:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %{{.*}}: {{.*}}) // CHECK-NEXT: cond_br %[[cond]], ^[[BB1:.*]], ^[[BB2:.*]] -// CHECK: %[[ALLOC0:.*]] = memref.alloc() -// CHECK-NEXT: linalg.copy(%[[ARG1]], %[[ALLOC0]]) +// CHECK: %[[ALLOC0:.*]] = memref.clone %[[ARG1]] // CHECK: ^[[BB2]]: // CHECK: %[[ALLOC1:.*]] = memref.alloc() // CHECK-NEXT: test.region_buffer_based in(%[[ARG1]]{{.*}}out(%[[ALLOC1]] @@ -554,12 +538,11 @@ // CHECK-NEXT: test.buffer_based in(%[[ARG1]]{{.*}}out(%[[ALLOC2]] // CHECK: memref.dealloc %[[ALLOC2]] // CHECK-NEXT: %{{.*}} = math.exp -// CHECK: %[[ALLOC3:.*]] = memref.alloc() -// CHECK-NEXT: linalg.copy(%[[ALLOC1]], %[[ALLOC3]]) +// CHECK: %[[ALLOC3:.*]] = memref.clone %[[ALLOC1]] // CHECK-NEXT: memref.dealloc %[[ALLOC1]] // CHECK: ^[[BB3:.*]]({{.*}}): // CHECK: test.copy -// CHECK-NEXT: dealloc +// CHECK-NEXT: memref.dealloc // ----- @@ -641,12 +624,10 @@ // CHECK: %[[ALLOC0:.*]] = memref.alloc(%arg0, %arg0) // CHECK-NEXT: %[[ALLOC1:.*]] = scf.if -// CHECK: %[[ALLOC2:.*]] = memref.alloc -// CHECK-NEXT: linalg.copy(%[[ALLOC0]], %[[ALLOC2]]) +// CHECK-NEXT: %[[ALLOC2:.*]] = memref.clone %[[ALLOC0]] // CHECK: scf.yield %[[ALLOC2]] // CHECK: %[[ALLOC3:.*]] = memref.alloc(%arg0, %arg1) -// CHECK: %[[ALLOC4:.*]] = memref.alloc -// CHECK-NEXT: linalg.copy(%[[ALLOC3]], %[[ALLOC4]]) +// CHECK-NEXT: %[[ALLOC4:.*]] = memref.clone %[[ALLOC3]] // CHECK: memref.dealloc %[[ALLOC3]] // CHECK: scf.yield %[[ALLOC4]] // CHECK: memref.dealloc %[[ALLOC0]] @@ -823,20 +804,18 @@ // CHECK: (%[[cond:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %{{.*}}: {{.*}}) // CHECK-NEXT: cond_br %[[cond]], ^[[BB1:.*]], ^[[BB2:.*]] // CHECK: ^[[BB1]]: -// CHECK: %[[ALLOC0:.*]] = memref.alloc() -// CHECK-NEXT: linalg.copy +// CHECK: %[[ALLOC0:.*]] = memref.clone // CHECK: ^[[BB2]]: // CHECK: %[[ALLOC1:.*]] = memref.alloc() // CHECK-NEXT: test.region_buffer_based in(%[[ARG1]]{{.*}}out(%[[ALLOC1]] // CHECK: %[[ALLOCA:.*]] = memref.alloca() // CHECK-NEXT: test.buffer_based in(%[[ARG1]]{{.*}}out(%[[ALLOCA]] // CHECK: %{{.*}} = math.exp -// CHECK: %[[ALLOC2:.*]] = memref.alloc() -// CHECK-NEXT: linalg.copy +// CHECK: %[[ALLOC2:.*]] = memref.clone %[[ALLOC1]] // CHECK-NEXT: memref.dealloc %[[ALLOC1]] // CHECK: ^[[BB3:.*]]({{.*}}): // CHECK: test.copy -// CHECK-NEXT: dealloc +// CHECK-NEXT: memref.dealloc // ----- @@ -888,15 +867,13 @@ // CHECK: %[[ALLOC0:.*]] = memref.alloc() // CHECK-NEXT: memref.dealloc %[[ALLOC0]] -// CHECK-NEXT: %[[ALLOC1:.*]] = memref.alloc() -// CHECK: linalg.copy(%arg3, %[[ALLOC1]]) +// CHECK-NEXT: %[[ALLOC1:.*]] = memref.clone %arg3 // CHECK: %[[ALLOC2:.*]] = scf.for {{.*}} iter_args // CHECK-SAME: (%[[IALLOC:.*]] = %[[ALLOC1]] // CHECK: cmpi // CHECK: memref.dealloc %[[IALLOC]] // CHECK: %[[ALLOC3:.*]] = memref.alloc() -// CHECK: %[[ALLOC4:.*]] = memref.alloc() -// CHECK: linalg.copy(%[[ALLOC3]], %[[ALLOC4]]) +// CHECK: %[[ALLOC4:.*]] = memref.clone %[[ALLOC3]] // CHECK: memref.dealloc %[[ALLOC3]] // CHECK: scf.yield %[[ALLOC4]] // CHECK: } @@ -974,25 +951,21 @@ } // CHECK: %[[ALLOC0:.*]] = memref.alloc() -// CHECK: %[[ALLOC1:.*]] = memref.alloc() -// CHECK-NEXT: linalg.copy(%arg3, %[[ALLOC1]]) +// CHECK-NEXT: %[[ALLOC1:.*]] = memref.clone %arg3 // CHECK-NEXT: %[[ALLOC2:.*]] = scf.for {{.*}} iter_args // CHECK-SAME: (%[[IALLOC:.*]] = %[[ALLOC1]] // CHECK: memref.dealloc %[[IALLOC]] // CHECK: %[[ALLOC3:.*]] = scf.if // CHECK: %[[ALLOC4:.*]] = memref.alloc() -// CHECK-NEXT: %[[ALLOC5:.*]] = memref.alloc() -// CHECK-NEXT: linalg.copy(%[[ALLOC4]], %[[ALLOC5]]) +// CHECK-NEXT: %[[ALLOC5:.*]] = memref.clone %[[ALLOC4]] // CHECK-NEXT: memref.dealloc %[[ALLOC4]] // CHECK-NEXT: scf.yield %[[ALLOC5]] -// CHECK: %[[ALLOC6:.*]] = memref.alloc() -// CHECK-NEXT: linalg.copy(%[[ALLOC0]], %[[ALLOC6]]) +// CHECK: %[[ALLOC6:.*]] = memref.clone %[[ALLOC0]] // CHECK-NEXT: scf.yield %[[ALLOC6]] -// CHECK: %[[ALLOC7:.*]] = memref.alloc() -// CHECK-NEXT: linalg.copy(%[[ALLOC3:.*]], %[[ALLOC7]]) +// CHECK: %[[ALLOC7:.*]] = memref.clone %[[ALLOC3]] // CHECK-NEXT: memref.dealloc %[[ALLOC3]] // CHECK-NEXT: scf.yield %[[ALLOC7]] @@ -1040,17 +1013,14 @@ // CHECK: %[[ALLOC0:.*]] = memref.alloc() // CHECK-NEXT: memref.dealloc %[[ALLOC0]] -// CHECK-NEXT: %[[ALLOC1:.*]] = memref.alloc() -// CHECK-NEXT: linalg.copy(%arg3, %[[ALLOC1]]) +// CHECK-NEXT: %[[ALLOC1:.*]] = memref.clone %arg3 // CHECK-NEXT: %[[VAL_7:.*]] = scf.for {{.*}} iter_args // CHECK-SAME: (%[[IALLOC0:.*]] = %[[ALLOC1]]) -// CHECK: %[[ALLOC2:.*]] = memref.alloc() -// CHECK-NEXT: linalg.copy(%[[IALLOC0]], %[[ALLOC2]]) +// CHECK-NEXT: %[[ALLOC2:.*]] = memref.clone %[[IALLOC0]] // CHECK-NEXT: memref.dealloc %[[IALLOC0]] // CHECK-NEXT: %[[ALLOC3:.*]] = scf.for {{.*}} iter_args // CHECK-SAME: (%[[IALLOC1:.*]] = %[[ALLOC2]]) -// CHECK: %[[ALLOC5:.*]] = memref.alloc() -// CHECK-NEXT: linalg.copy(%[[IALLOC1]], %[[ALLOC5]]) +// CHECK-NEXT: %[[ALLOC5:.*]] = memref.clone %[[IALLOC1]] // CHECK-NEXT: memref.dealloc %[[IALLOC1]] // CHECK: %[[ALLOC6:.*]] = scf.for {{.*}} iter_args @@ -1060,28 +1030,23 @@ // CHECK: %[[ALLOC9:.*]] = scf.if // CHECK: %[[ALLOC11:.*]] = memref.alloc() -// CHECK-NEXT: %[[ALLOC12:.*]] = memref.alloc() -// CHECK-NEXT: linalg.copy(%[[ALLOC11]], %[[ALLOC12]]) +// CHECK-NEXT: %[[ALLOC12:.*]] = memref.clone %[[ALLOC11]] // CHECK-NEXT: memref.dealloc %[[ALLOC11]] // CHECK-NEXT: scf.yield %[[ALLOC12]] -// CHECK: %[[ALLOC13:.*]] = memref.alloc() -// CHECK-NEXT: linalg.copy(%[[IALLOC2]], %[[ALLOC13]]) +// CHECK: %[[ALLOC13:.*]] = memref.clone %[[IALLOC2]] // CHECK-NEXT: scf.yield %[[ALLOC13]] // CHECK: memref.dealloc %[[IALLOC2]] -// CHECK-NEXT: %[[ALLOC10:.*]] = memref.alloc() -// CHECK-NEXT: linalg.copy(%[[ALLOC9]], %[[ALLOC10]]) +// CHECK-NEXT: %[[ALLOC10:.*]] = memref.clone %[[ALLOC9]] // CHECK-NEXT: memref.dealloc %[[ALLOC9]] // CHECK-NEXT: scf.yield %[[ALLOC10]] -// CHECK: %[[ALLOC7:.*]] = memref.alloc() -// CHECK-NEXT: linalg.copy(%[[ALLOC6]], %[[ALLOC7]]) +// CHECK: %[[ALLOC7:.*]] = memref.clone %[[ALLOC6]] // CHECK-NEXT: memref.dealloc %[[ALLOC6]] // CHECK-NEXT: scf.yield %[[ALLOC7]] -// CHECK: %[[ALLOC4:.*]] = memref.alloc() -// CHECK-NEXT: linalg.copy(%[[ALLOC3]], %[[ALLOC4]]) +// CHECK: %[[ALLOC4:.*]] = memref.clone %[[ALLOC3]] // CHECK-NEXT: memref.dealloc %[[ALLOC3]] // CHECK-NEXT: scf.yield %[[ALLOC4]] @@ -1183,8 +1148,7 @@ // CHECK-NEXT: shape.assuming_yield %[[ARG1]] // CHECK: %[[ASSUMING_RESULT:.*]] = shape.assuming %[[ARG0]] // CHECK-NEXT: %[[TMP_ALLOC:.*]] = memref.alloc() -// CHECK-NEXT: %[[RETURNING_ALLOC:.*]] = memref.alloc() -// CHECK-NEXT: linalg.copy(%[[TMP_ALLOC]], %[[RETURNING_ALLOC]]) +// CHECK-NEXT: %[[RETURNING_ALLOC:.*]] = memref.clone %[[TMP_ALLOC]] // CHECK-NEXT: memref.dealloc %[[TMP_ALLOC]] // CHECK-NEXT: shape.assuming_yield %[[RETURNING_ALLOC]] // CHECK: test.copy(%[[ASSUMING_RESULT:.*]], %[[ARG2]]) diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -1059,3 +1059,139 @@ return %2 : tensor } +// ----- + +// CHECK-LABEL: func @simple_clone_elimination +func @simple_clone_elimination() -> memref<5xf32> { + %ret = memref.alloc() : memref<5xf32> + %temp = memref.clone %ret : memref<5xf32> to memref<5xf32> + memref.dealloc %temp : memref<5xf32> + return %ret : memref<5xf32> +} +// CHECK-NEXT: %[[ret:.*]] = memref.alloc() +// CHECK-NOT: %[[temp:.*]] = memref.clone +// CHECK-NOT: memref.dealloc %[[temp]] +// CHECK: return %[[ret]] + +// ----- + +// CHECK-LABEL: func @clone_loop_alloc +func @clone_loop_alloc(%arg0: index, %arg1: index, %arg2: index, %arg3: memref<2xf32>, %arg4: memref<2xf32>) { + %0 = memref.alloc() : memref<2xf32> + memref.dealloc %0 : memref<2xf32> + %1 = memref.clone %arg3 : memref<2xf32> to memref<2xf32> + %2 = scf.for %arg5 = %arg0 to %arg1 step %arg2 iter_args(%arg6 = %1) -> (memref<2xf32>) { + %3 = cmpi eq, %arg5, %arg1 : index + memref.dealloc %arg6 : memref<2xf32> + %4 = memref.alloc() : memref<2xf32> + %5 = memref.clone %4 : memref<2xf32> to memref<2xf32> + memref.dealloc %4 : memref<2xf32> + %6 = memref.clone %5 : memref<2xf32> to memref<2xf32> + memref.dealloc %5 : memref<2xf32> + scf.yield %6 : memref<2xf32> + } + linalg.copy(%2, %arg4) : memref<2xf32>, memref<2xf32> + memref.dealloc %2 : memref<2xf32> + return +} + +// CHECK-NEXT: %[[ALLOC0:.*]] = memref.clone +// CHECK-NEXT: %[[ALLOC1:.*]] = scf.for +// CHECK-NEXT: memref.dealloc +// CHECK-NEXT: %[[ALLOC2:.*]] = memref.alloc +// CHECK-NEXT: scf.yield %[[ALLOC2]] +// CHECK: linalg.copy(%[[ALLOC1]] +// CHECK-NEXT: memref.dealloc %[[ALLOC1]] + +// ----- + +// CHECK-LABEL: func @clone_nested_region +func @clone_nested_region(%arg0: index, %arg1: index) -> memref { + %0 = cmpi eq, %arg0, %arg1 : index + %1 = memref.alloc(%arg0, %arg0) : memref + %2 = scf.if %0 -> (memref) { + %3 = scf.if %0 -> (memref) { + %9 = memref.clone %1 : memref to memref + scf.yield %9 : memref + } else { + %7 = memref.alloc(%arg0, %arg1) : memref + %10 = memref.clone %7 : memref to memref + memref.dealloc %7 : memref + scf.yield %10 : memref + } + %6 = memref.clone %3 : memref to memref + memref.dealloc %3 : memref + scf.yield %6 : memref + } else { + %3 = memref.alloc(%arg1, %arg1) : memref + %6 = memref.clone %3 : memref to memref + memref.dealloc %3 : memref + scf.yield %6 : memref + } + memref.dealloc %1 : memref + return %2 : memref +} + +// CHECK: %[[ALLOC1:.*]] = memref.alloc +// CHECK-NEXT: %[[ALLOC2:.*]] = scf.if +// CHECK-NEXT: %[[ALLOC3:.*]] = scf.if +// CHECK-NEXT: %[[ALLOC5:.*]] = memref.clone %[[ALLOC1]] +// CHECK-NEXT: scf.yield %[[ALLOC5]] +// CHECK: %[[ALLOC0:.*]] = memref.alloc +// CHECK-NEXT: scf.yield %[[ALLOC0]] +// CHECK: %[[ALLOC4:.*]] = memref.clone %[[ALLOC3]] +// CHECK-NEXT: memref.dealloc %[[ALLOC3]] +// CHECK-NEXT: scf.yield %[[ALLOC4]] +// CHECK: %[[ALLOC6:.*]] = memref.alloc +// CHECK-NEXT: scf.yield %[[ALLOC6]] +// CHECK: memref.dealloc %[[ALLOC1]] +// CHECK-NEXT: return %[[ALLOC2]] + +// ----- + +// CHECK-LABEL: func @clone_ret_usage_before_copy +func @clone_ret_usage_before_copy() -> memref<5xf32> { + %ret = memref.alloc() : memref<5xf32> + %c0 = constant 0 : index + %dimension = memref.dim %ret, %c0 : memref<5xf32> + %temp = memref.clone %ret : memref<5xf32> to memref<5xf32> + memref.dealloc %ret : memref<5xf32> + return %temp : memref<5xf32> +} + +// CHECK: %[[ALLOC0:.*]] = memref.alloc +// CHECK-NEXT: return %[[ALLOC0]] + +// ----- + +// CHECK-LABEL: func @clone_ret_usage_after_copy +func @clone_ret_usage_after_copy() -> memref { + %ret = memref.alloc() : memref<5xf32> + %temp = memref.clone %ret : memref<5xf32> to memref<5xf32> + %c0 = constant 0 : index + %dimension = memref.dim %ret, %c0 : memref<5xf32> + memref.dealloc %ret : memref<5xf32> + %ret2 = memref.alloc(%dimension) : memref + return %ret2 : memref +} + +// CHECK: %[[ALLOC0:.*]] = memref.alloc +// CHECK-NEXT: %[[ALLOC1:.*]] = memref.cast +// CHECK-NEXT: return %[[ALLOC1]] + +// ----- + +// CHECK-LABEL: func @clone_temp_usage_after_copy +func @clone_temp_usage_after_copy() -> memref { + %ret = memref.alloc() : memref<5xf32> + %temp = memref.clone %ret : memref<5xf32> to memref<5xf32> + %c0 = constant 0 : index + %dimension = memref.dim %temp, %c0 : memref<5xf32> + memref.dealloc %ret : memref<5xf32> + %ret2 = memref.alloc(%dimension) : memref + return %ret2 : memref +} + +// CHECK: %[[ALLOC0:.*]] = memref.alloc +// CHECK-NEXT: %[[ALLOC1:.*]] = memref.cast +// CHECK-NEXT: return %[[ALLOC1]]