diff --git a/mlir/docs/BufferDeallocationInternals.md b/mlir/docs/BufferDeallocationInternals.md --- a/mlir/docs/BufferDeallocationInternals.md +++ b/mlir/docs/BufferDeallocationInternals.md @@ -48,7 +48,7 @@ partial_write(%0, %0) br ^bb3() ^bb3(): - "linalg.copy"(%0, %arg1) : (memref<2xf32>, memref<2xf32>) -> () + test.copy(%0, %arg1) : (memref<2xf32>, memref<2xf32>) -> () return } ``` @@ -133,11 +133,11 @@ ^bb1: br ^bb3(%arg1 : memref<2xf32>) ^bb2: - %0 = alloc() : memref<2xf32> // aliases: %1 + %0 = memref.alloc() : memref<2xf32> // aliases: %1 use(%0) br ^bb3(%0 : memref<2xf32>) ^bb3(%1: memref<2xf32>): // %1 could be %0 or %arg1 - "linalg.copy"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> () + test.copy(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> () return } ``` @@ -149,7 +149,7 @@ ```mlir func @condBranch(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { - %0 = alloc() : memref<2xf32> // moved to bb0 + %0 = memref.alloc() : memref<2xf32> // moved to bb0 cond_br %arg0, ^bb1, ^bb2 ^bb1: br ^bb3(%arg1 : memref<2xf32>) @@ -157,7 +157,7 @@ use(%0) br ^bb3(%0 : memref<2xf32>) ^bb3(%1: memref<2xf32>): - "linalg.copy"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> () + test.copy(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> () return } ``` @@ -179,17 +179,17 @@ ^bb1: br ^bb3(%arg1 : memref) ^bb2(%0: index): - %1 = alloc(%0) : memref // cannot be moved upwards to the data + %1 = memref.alloc(%0) : memref // cannot be moved upwards to the data // dependency to %0 use(%1) br ^bb3(%1 : memref) ^bb3(%2: memref): - "linalg.copy"(%2, %arg2) : (memref, memref) -> () + test.copy(%2, %arg2) : (memref, memref) -> () return } ``` -## Introduction of Copies +## Introduction of Clones In order to guarantee that all allocated buffers are freed properly, we have to pay attention to the control flow and all potential aliases a buffer allocation @@ -200,10 +200,10 @@ ```mlir func @branch(%arg0: i1) { - %0 = alloc() : memref<2xf32> // aliases: %2 + %0 = memref.alloc() : memref<2xf32> // aliases: %2 cond_br %arg0, ^bb1, ^bb2 ^bb1: - %1 = alloc() : memref<2xf32> // resides here for demonstration purposes + %1 = memref.alloc() : memref<2xf32> // resides here for demonstration purposes // aliases: %2 br ^bb3(%1 : memref<2xf32>) ^bb2: @@ -232,88 +232,31 @@ ```mlir func @branch(%arg0: i1) { - %0 = alloc() : memref<2xf32> + %0 = memref.alloc() : memref<2xf32> cond_br %arg0, ^bb1, ^bb2 ^bb1: - %1 = alloc() : memref<2xf32> - %3 = alloc() : memref<2xf32> // temp copy for %1 - "linalg.copy"(%1, %3) : (memref<2xf32>, memref<2xf32>) -> () - dealloc %1 : memref<2xf32> // %1 can be safely freed here + %1 = memref.alloc() : memref<2xf32> + %3 = memref.clone %1 : (memref<2xf32>) -> (memref<2xf32>) + memref.dealloc %1 : memref<2xf32> // %1 can be safely freed here br ^bb3(%3 : memref<2xf32>) ^bb2: use(%0) - %4 = alloc() : memref<2xf32> // temp copy for %0 - "linalg.copy"(%0, %4) : (memref<2xf32>, memref<2xf32>) -> () + %4 = memref.clone %0 : (memref<2xf32>) -> (memref<2xf32>) br ^bb3(%4 : memref<2xf32>) ^bb3(%2: memref<2xf32>): … - dealloc %2 : memref<2xf32> // free temp buffer %2 - dealloc %0 : memref<2xf32> // %0 can be safely freed here + memref.dealloc %2 : memref<2xf32> // free temp buffer %2 + memref.dealloc %0 : memref<2xf32> // %0 can be safely freed here return } ``` Note that a temporary buffer for %2 was introduced to free all allocations properly. Note further that the unnecessary allocation of %3 can be easily -removed using one of the post-pass transformations. - -Reconsider the previously introduced sample demonstrating dynamically shaped -types: - -```mlir -func @condBranchDynamicType( - %arg0: i1, - %arg1: memref, - %arg2: memref, - %arg3: index) { - cond_br %arg0, ^bb1, ^bb2(%arg3: index) -^bb1: - br ^bb3(%arg1 : memref) -^bb2(%0: index): - %1 = alloc(%0) : memref // aliases: %2 - use(%1) - br ^bb3(%1 : memref) -^bb3(%2: memref): - "linalg.copy"(%2, %arg2) : (memref, memref) -> () - return -} -``` +removed using one of the post-pass transformations or the canonicalization +pass. -In the presence of DSTs, we have to parameterize the allocations with -additional dimension information of the source buffers, we want to copy from. -BufferDeallocation automatically introduces all required operations to extract -dimension specifications and wires them with the associated allocations: - -```mlir -func @condBranchDynamicType( - %arg0: i1, - %arg1: memref, - %arg2: memref, - %arg3: index) { - cond_br %arg0, ^bb1, ^bb2(%arg3 : index) -^bb1: - %c0 = constant 0 : index - %0 = dim %arg1, %c0 : memref // dimension operation to parameterize - // the following temp allocation - %1 = alloc(%0) : memref - "linalg.copy"(%arg1, %1) : (memref, memref) -> () - br ^bb3(%1 : memref) -^bb2(%2: index): - %3 = alloc(%2) : memref - use(%3) - %c0_0 = constant 0 : index - %4 = dim %3, %c0_0 : memref // dimension operation to parameterize - // the following temp allocation - %5 = alloc(%4) : memref - "linalg.copy"(%3, %5) : (memref, memref) -> () - dealloc %3 : memref // %3 can be safely freed here - br ^bb3(%5 : memref) -^bb3(%6: memref): - "linalg.copy"(%6, %arg2) : (memref, memref) -> () - dealloc %6 : memref // %6 can be safely freed here - return -} -``` +The presented example also works with dynamically shaped types. BufferDeallocation performs a fix-point iteration taking all aliases of all tracked allocations into account. We initialize the general iteration process @@ -335,7 +278,7 @@ ^bb1: br ^bb6(%arg1 : memref) ^bb2(%0: index): - %1 = alloc(%0) : memref // cannot be moved upwards due to the data + %1 = memref.alloc(%0) : memref // cannot be moved upwards due to the data // dependency to %0 // aliases: %2, %3, %4 use(%1) @@ -349,7 +292,7 @@ ^bb6(%3: memref): // crit. alias of %arg1 and %2 (in other words %1) br ^bb7(%3 : memref) ^bb7(%4: memref): // non-crit. alias of %3, since %3 dominates %4 - "linalg.copy"(%4, %arg2) : (memref, memref) -> () + test.copy(%4, %arg2) : (memref, memref) -> () return } ``` @@ -366,13 +309,11 @@ %arg3: index) { cond_br %arg0, ^bb1, ^bb2(%arg3 : index) ^bb1: - %c0 = constant 0 : index - %d0 = dim %arg1, %c0 : memref - %5 = alloc(%d0) : memref // temp buffer required due to alias %3 - "linalg.copy"(%arg1, %5) : (memref, memref) -> () + // temp buffer required due to alias %3 + %5 = memref.clone %arg1 : (memref) -> (memref) br ^bb6(%5 : memref) ^bb2(%0: index): - %1 = alloc(%0) : memref + %1 = memref.alloc(%0) : memref use(%1) cond_br %arg0, ^bb3, ^bb4 ^bb3: @@ -380,17 +321,14 @@ ^bb4: br ^bb5(%1 : memref) ^bb5(%2: memref): - %c0_0 = constant 0 : index - %d1 = dim %2, %c0_0 : memref - %6 = alloc(%d1) : memref // temp buffer required due to alias %3 - "linalg.copy"(%1, %6) : (memref, memref) -> () - dealloc %1 : memref + %6 = memref.clone %1 : (memref) -> (memref) + memref.dealloc %1 : memref br ^bb6(%6 : memref) ^bb6(%3: memref): br ^bb7(%3 : memref) ^bb7(%4: memref): - "linalg.copy"(%4, %arg2) : (memref, memref) -> () - dealloc %3 : memref // free %3, since %4 is a non-crit. alias of %3 + test.copy(%4, %arg2) : (memref, memref) -> () + memref.dealloc %3 : memref // free %3, since %4 is a non-crit. alias of %3 return } ``` @@ -399,7 +337,7 @@ temporary copy in all predecessor blocks. %3 has an additional (non-critical) alias %4 that extends the live range until the end of bb7. Therefore, we can free %3 after its last use, while taking all aliases into account. Note that %4 - does not need to be freed, since we did not introduce a copy for it. +does not need to be freed, since we did not introduce a copy for it. The actual introduction of buffer copies is done after the fix-point iteration has been terminated and all critical aliases have been detected. A critical @@ -445,7 +383,7 @@ func @inner_region_control_flow( %arg0 : index, %arg1 : index) -> memref { - %0 = alloc(%arg0, %arg0) : memref + %0 = memref.alloc(%arg0, %arg0) : memref %1 = custom.region_if %0 : memref -> (memref) then(%arg2 : memref) { // aliases: %arg4, %1 custom.region_if_yield %arg2 : memref @@ -468,11 +406,11 @@ ```mlir func @nested_region_control_flow(%arg0 : index, %arg1 : index) -> memref { %0 = cmpi "eq", %arg0, %arg1 : index - %1 = alloc(%arg0, %arg0) : memref + %1 = memref.alloc(%arg0, %arg0) : memref %2 = scf.if %0 -> (memref) { scf.yield %1 : memref // %2 will be an alias of %1 } else { - %3 = alloc(%arg0, %arg1) : memref // nested allocation in a div. + %3 = memref.alloc(%arg0, %arg1) : memref // nested allocation in a div. // branch use(%3) scf.yield %1 : memref // %2 will be an alias of %1 @@ -489,13 +427,13 @@ ```mlir func @nested_region_control_flow(%arg0: index, %arg1: index) -> memref { %0 = cmpi "eq", %arg0, %arg1 : index - %1 = alloc(%arg0, %arg0) : memref + %1 = memref.alloc(%arg0, %arg0) : memref %2 = scf.if %0 -> (memref) { scf.yield %1 : memref } else { - %3 = alloc(%arg0, %arg1) : memref + %3 = memref.alloc(%arg0, %arg1) : memref use(%3) - dealloc %3 : memref // %3 can be safely freed here + memref.dealloc %3 : memref // %3 can be safely freed here scf.yield %1 : memref } return %2 : memref @@ -514,12 +452,12 @@ func @inner_region_control_flow_div( %arg0 : index, %arg1 : index) -> memref { - %0 = alloc(%arg0, %arg0) : memref + %0 = memref.alloc(%arg0, %arg0) : memref %1 = custom.region_if %0 : memref -> (memref) then(%arg2 : memref) { // aliases: %arg4, %1 custom.region_if_yield %arg2 : memref } else(%arg3 : memref) { - %2 = alloc(%arg0, %arg1) : memref // aliases: %arg4, %1 + %2 = memref.alloc(%arg0, %arg1) : memref // aliases: %arg4, %1 custom.region_if_yield %2 : memref } join(%arg4 : memref) { // aliases: %1 custom.region_if_yield %arg4 : memref @@ -537,40 +475,22 @@ func @inner_region_control_flow_div( %arg0 : index, %arg1 : index) -> memref { - %0 = alloc(%arg0, %arg0) : memref + %0 = memref.alloc(%arg0, %arg0) : memref %1 = custom.region_if %0 : memref -> (memref) then(%arg2 : memref) { - %c0 = constant 0 : index // determine dimension extents for temp allocation - %2 = dim %arg2, %c0 : memref - %c1 = constant 1 : index - %3 = dim %arg2, %c1 : memref - %4 = alloc(%2, %3) : memref // temp buffer required due to critic. - // alias %arg4 - linalg.copy(%arg2, %4) : memref, memref + %4 = memref.clone %arg2 : (memref) -> (memref) custom.region_if_yield %4 : memref } else(%arg3 : memref) { - %2 = alloc(%arg0, %arg1) : memref - %c0 = constant 0 : index // determine dimension extents for temp allocation - %3 = dim %2, %c0 : memref - %c1 = constant 1 : index - %4 = dim %2, %c1 : memref - %5 = alloc(%3, %4) : memref // temp buffer required due to critic. - // alias %arg4 - linalg.copy(%2, %5) : memref, memref - dealloc %2 : memref + %2 = memref.alloc(%arg0, %arg1) : memref + %5 = memref.clone %2 : (memref) -> (memref) + memref.dealloc %2 : memref custom.region_if_yield %5 : memref } join(%arg4: memref) { - %c0 = constant 0 : index // determine dimension extents for temp allocation - %2 = dim %arg4, %c0 : memref - %c1 = constant 1 : index - %3 = dim %arg4, %c1 : memref - %4 = alloc(%2, %3) : memref // this allocation will be removed by - // applying the copy removal pass - linalg.copy(%arg4, %4) : memref, memref - dealloc %arg4 : memref + %4 = memref.clone %arg4 : (memref) -> (memref) + memref.dealloc %arg4 : memref custom.region_if_yield %4 : memref } - dealloc %0 : memref // %0 can be safely freed here + memref.dealloc %0 : memref // %0 can be safely freed here return %1 : memref } ``` @@ -600,7 +520,7 @@ iter_args(%iterBuf = %buf) -> memref<2xf32> { %1 = cmpi "eq", %i, %ub : index %2 = scf.if %1 -> (memref<2xf32>) { - %3 = alloc() : memref<2xf32> // makes %2 a critical alias due to a + %3 = memref.alloc() : memref<2xf32> // makes %2 a critical alias due to a // divergent allocation use(%3) scf.yield %3 : memref<2xf32> @@ -609,7 +529,7 @@ } scf.yield %2 : memref<2xf32> } - "linalg.copy"(%0, %res) : (memref<2xf32>, memref<2xf32>) -> () + test.copy(%0, %res) : (memref<2xf32>, memref<2xf32>) -> () return } ``` @@ -634,31 +554,27 @@ %step: index, %buf: memref<2xf32>, %res: memref<2xf32>) { - %4 = alloc() : memref<2xf32> - "linalg.copy"(%buf, %4) : (memref<2xf32>, memref<2xf32>) -> () + %4 = memref.clone %buf : (memref<2xf32>) -> (memref<2xf32>) %0 = scf.for %i = %lb to %ub step %step iter_args(%iterBuf = %4) -> memref<2xf32> { %1 = cmpi "eq", %i, %ub : index %2 = scf.if %1 -> (memref<2xf32>) { - %3 = alloc() : memref<2xf32> // makes %2 a critical alias + %3 = memref.alloc() : memref<2xf32> // makes %2 a critical alias use(%3) - %5 = alloc() : memref<2xf32> // temp copy due to crit. alias %2 - "linalg.copy"(%3, %5) : memref<2xf32>, memref<2xf32> - dealloc %3 : memref<2xf32> + %5 = memref.clone %3 : (memref<2xf32>) -> (memref<2xf32>) + memref.dealloc %3 : memref<2xf32> scf.yield %5 : memref<2xf32> } else { - %6 = alloc() : memref<2xf32> // temp copy due to crit. alias %2 - "linalg.copy"(%iterBuf, %6) : memref<2xf32>, memref<2xf32> + %6 = memref.clone %iterBuf : (memref<2xf32>) -> (memref<2xf32>) scf.yield %6 : memref<2xf32> } - %7 = alloc() : memref<2xf32> // temp copy due to crit. alias %iterBuf - "linalg.copy"(%2, %7) : memref<2xf32>, memref<2xf32> - dealloc %2 : memref<2xf32> - dealloc %iterBuf : memref<2xf32> // free backedge iteration variable + %7 = memref.clone %2 : (memref<2xf32>) -> (memref<2xf32>) + memref.dealloc %2 : memref<2xf32> + memref.dealloc %iterBuf : memref<2xf32> // free backedge iteration variable scf.yield %7 : memref<2xf32> } - "linalg.copy"(%0, %res) : (memref<2xf32>, memref<2xf32>) -> () - dealloc %0 : memref<2xf32> // free temp copy %0 + test.copy(%0, %res) : (memref<2xf32>, memref<2xf32>) -> () + memref.dealloc %0 : memref<2xf32> // free temp copy %0 return } ``` @@ -684,46 +600,37 @@ In order to limit the complexity of the BufferDeallocation transformation, some tiny code-polishing/optimization transformations are not applied on-the-fly -during placement. Currently, there is only the CopyRemoval transformation to -remove unnecessary copy and allocation operations. +during placement. Currently, a canonicalization pattern is added to the clone +operation to reduce the appearance of unnecessary clones. Note: further transformations might be added to the post-pass phase in the future. -## CopyRemoval Pass - -A common pattern that arises during placement is the introduction of -unnecessary temporary copies that are used instead of the original source -buffer. For this reason, there is a post-pass transformation that removes these -allocations and copies via `-copy-removal`. This pass, besides removing -unnecessary copy operations, will also remove the dead allocations and their -corresponding deallocation operations. The CopyRemoval pass can currently be -applied to operations that implement the `CopyOpInterface` in any of these two -situations which are +## Clone Canonicalization -* reusing the source buffer of the copy operation. -* reusing the target buffer of the copy operation. +During placement of clones it may happen, that unnecessary clones are inserted. +If these clones appear with their corresponding dealloc operation within the +same block, we can use the canonicalizer to remove these unnecessary operations. +Note, that this step needs to take place after the insertion of clones and +deallocs in the buffer deallocation step. The canonicalization inludes both, +the newly created target value from the clone operation and the source +operation. -## Reusing the Source Buffer of the Copy Operation +## Canonicalization of the Source Buffer of the Clone Operation -In this case, the source of the copy operation can be used instead of target. -The unused allocation and deallocation operations that are defined for this -copy operation are also removed. Here is a working example generated by the -BufferDeallocation pass that allocates a buffer with dynamic size. A deeper +In this case, the source of the clone operation can be used instead of its +target. The unused allocation and deallocation operations that are defined for +this clone operation are also removed. Here is a working example generated by +the BufferDeallocation pass that allocates a buffer with dynamic size. A deeper analysis of this sample reveals that the highlighted operations are redundant and can be removed. ```mlir func @dynamic_allocation(%arg0: index, %arg1: index) -> memref { - %7 = alloc(%arg0, %arg1) : memref - %c0_0 = constant 0 : index - %8 = dim %7, %c0_0 : memref - %c1_1 = constant 1 : index - %9 = dim %7, %c1_1 : memref - %10 = alloc(%8, %9) : memref - linalg.copy(%7, %10) : memref, memref - dealloc %7 : memref - return %10 : memref + %1 = memref.alloc(%arg0, %arg1) : memref + %2 = memref.clone %1 : (memref) -> (memref) + memref.dealloc %1 : memref + return %2 : memref } ``` @@ -731,53 +638,39 @@ ```mlir func @dynamic_allocation(%arg0: index, %arg1: index) -> memref { - %7 = alloc(%arg0, %arg1) : memref - %c0_0 = constant 0 : index - %8 = dim %7, %c0_0 : memref - %c1_1 = constant 1 : index - %9 = dim %7, %c1_1 : memref - return %7 : memref + %1 = memref.alloc(%arg0, %arg1) : memref + return %1 : memref } ``` -In this case, the additional copy %10 can be replaced with its original source -buffer %7. This also applies to the associated dealloc operation of %7. +In this case, the additional copy %2 can be replaced with its original source +buffer %1. This also applies to the associated dealloc operation of %1. -To limit the complexity of this transformation, it only removes copy operations -when the following constraints are met: +## Canonicalization of the Target Buffer of the Clone Operation -* The copy operation, the defining operation for the target value, and the -deallocation of the source value lie in the same block. -* There are no users/aliases of the target value between the defining operation -of the target value and its copy operation. -* There are no users/aliases of the source value between its associated copy -operation and the deallocation of the source value. +In this case, the target buffer of the clone operation can be used instead of +its source. The unused deallocation operation that is defined for this clone +operation is also removed. -## Reusing the Target Buffer of the Copy Operation - -In this case, the target buffer of the copy operation can be used instead of -its source. The unused allocation and deallocation operations that are defined -for this copy operation are also removed. - -Consider the following example where a generic linalg operation writes the -result to %temp and then copies %temp to %result. However, these two operations -can be merged into a single step. Copy removal removes the copy operation and -%temp, and replaces the uses of %temp with %result: +Consider the following example where a generic test operation writes the result +to %temp and then copies %temp to %result. However, these two operations +can be merged into a single step. Canonicalization removes the clone operation +and %temp, and replaces the uses of %temp with %result: ```mlir func @reuseTarget(%arg0: memref<2xf32>, %result: memref<2xf32>){ - %temp = alloc() : memref<2xf32> - linalg.generic { + %temp = memref.alloc() : memref<2xf32> + test.generic { args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg0, %temp { ^bb0(%gen2_arg0: f32, %gen2_arg1: f32): %tmp2 = exp %gen2_arg0 : f32 - linalg.yield %tmp2 : f32 + test.yield %tmp2 : f32 }: memref<2xf32>, memref<2xf32> - "linalg.copy"(%temp, %result) : (memref<2xf32>, memref<2xf32>) -> () - dealloc %temp : memref<2xf32> + %result = memref.clone %temp : (memref<2xf32>) -> (memref<2xf32>) + memref.dealloc %temp : memref<2xf32> return } ``` @@ -786,33 +679,24 @@ ```mlir func @reuseTarget(%arg0: memref<2xf32>, %result: memref<2xf32>){ - linalg.generic { + test.generic { args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg0, %result { ^bb0(%gen2_arg0: f32, %gen2_arg1: f32): %tmp2 = exp %gen2_arg0 : f32 - linalg.yield %tmp2 : f32 + test.yield %tmp2 : f32 }: memref<2xf32>, memref<2xf32> return } ``` -Like before, several constraints to use the transformation apply: - -* The copy operation, the defining operation of the source value, and the -deallocation of the source value lie in the same block. -* There are no users/aliases of the target value between the defining operation -of the source value and the copy operation. -* There are no users/aliases of the source value between the copy operation and -the deallocation of the source value. - ## Known Limitations -BufferDeallocation introduces additional copies using allocations from the -“memref” dialect (“memref.alloc”). Analogous, all deallocations use the -“memref” dialect-free operation “memref.dealloc”. The actual copy process is -realized using “linalg.copy”. Furthermore, buffers are essentially immutable -after their creation in a block. Another limitations are known in the case -using unstructered control flow. +BufferDeallocation introduces additional clones from “memref” dialect +(“memref.clone”). Analogous, all deallocations use the “memref” dialect-free +operation “memref.dealloc”. The actual copy process is realized using +“test.copy”. Furthermore, buffers are essentially immutable after their +creation in a block. Another limitations are known in the case using +unstructered control flow. diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h @@ -12,6 +12,7 @@ #include "mlir/IR/Dialect.h" #include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Interfaces/CastInterfaces.h" +#include "mlir/Interfaces/CopyOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/ViewLikeInterface.h" diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -11,6 +11,7 @@ include "mlir/Dialect/MemRef/IR/MemRefBase.td" include "mlir/Interfaces/CastInterfaces.td" +include "mlir/Interfaces/CopyOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/ViewLikeInterface.td" include "mlir/IR/SymbolInterfaces.td" @@ -235,6 +236,9 @@ // Result type is tensor<4x?xf32> %12 = memref.buffer_cast %10 : memref<4x?xf32, #map0, 42> ``` + + Note, that mutating the result of the buffer cast operation leads to + undefined behavior. }]; let arguments = (ins AnyTensor:$tensor); @@ -333,6 +337,46 @@ let hasFolder = 1; } +//===----------------------------------------------------------------------===// +// CloneOp +//===----------------------------------------------------------------------===// + +def CloneOp : MemRef_Op<"clone", [ + CopyOpInterface, + DeclareOpInterfaceMethods + ]> { + let builders = [ + OpBuilder<(ins "Value":$value), [{ + return build($_builder, $_state, value.getType(), value); + }]>]; + + let description = [{ + Clones the data in the input view into an implicitly defined output view. + + Usage: + + ```mlir + %arg 1 = memref.clone %arg0 : memref to memref + ``` + + Note, that mutating the source or result of the clone operation leads to + undefined behavior. + }]; + + let arguments = (ins Arg:$input); + let results = (outs Arg:$output); + + let extraClassDeclaration = [{ + Value getSource() { return input();} + Value getTarget() { return output(); } + }]; + + let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)"; + + let hasFolder = 1; + let hasCanonicalizer = 1; +} + //===----------------------------------------------------------------------===// // DeallocOp //===----------------------------------------------------------------------===// @@ -1109,6 +1153,9 @@ // Produces a value of tensor<4x?xf32> type. %12 = memref.tensor_load %10 : memref<4x?xf32, #layout, memspace0> ``` + + If tensor load is used in the bufferization steps, mutating the source + buffer after loading leads to undefined behavior. }]; let arguments = (ins Arg createCanonicalizerPass(); -/// Create a pass that removes unnecessary Copy operations. -std::unique_ptr createCopyRemovalPass(); - /// Creates a pass to perform common sub expression elimination. std::unique_ptr createCSEPass(); diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -282,8 +282,6 @@ }]; let constructor = "mlir::createBufferDeallocationPass()"; - // TODO: this pass likely shouldn't depend on Linalg? - let dependentDialects = ["linalg::LinalgDialect"]; } def BufferHoisting : FunctionPass<"buffer-hoisting"> { @@ -366,11 +364,6 @@ let dependentDialects = ["memref::MemRefDialect"]; } -def CopyRemoval : FunctionPass<"copy-removal"> { - let summary = "Remove the redundant copies from input IR"; - let constructor = "mlir::createCopyRemovalPass()"; -} - def CSE : Pass<"cse"> { let summary = "Eliminate common sub-expressions"; let description = [{ diff --git a/mlir/lib/Dialect/MemRef/CMakeLists.txt b/mlir/lib/Dialect/MemRef/CMakeLists.txt --- a/mlir/lib/Dialect/MemRef/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/CMakeLists.txt @@ -1 +1,22 @@ -add_subdirectory(IR) +add_mlir_dialect_library(MLIRMemRef + IR/MemRefDialect.cpp + IR/MemRefOps.cpp + Utils/MemRefUtils.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/inlude/mlir/Dialect/MemRefDialect + + DEPENDS + MLIRStandardOpsIncGen + MLIRMemRefOpsIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRDialect + MLIRIR + MLIRStandard + MLIRTensor + MLIRViewLikeInterface +) diff --git a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt deleted file mode 100644 --- a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt +++ /dev/null @@ -1,21 +0,0 @@ -add_mlir_dialect_library(MLIRMemRef - MemRefDialect.cpp - MemRefOps.cpp - - ADDITIONAL_HEADER_DIRS - ${PROJECT_SOURCE_DIR}/inlude/mlir/Dialect/MemRefDialect - - DEPENDS - MLIRStandardOpsIncGen - MLIRMemRefOpsIncGen - - LINK_COMPONENTS - Core - - LINK_LIBS PUBLIC - MLIRDialect - MLIRIR - MLIRStandard - MLIRTensor - MLIRViewLikeInterface -) diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -464,6 +465,76 @@ return succeeded(foldMemRefCast(*this)) ? getResult() : Value(); } +//===----------------------------------------------------------------------===// +// CloneOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(CloneOp op) { return success(); } + +void CloneOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), input(), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), output(), + SideEffects::DefaultResource::get()); +} + +namespace { +/// Fold Dealloc operations that are deallocating an AllocOp that is only used +/// by other Dealloc operations. +struct SimplifyClones : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(CloneOp cloneOp, + PatternRewriter &rewriter) const override { + if (cloneOp.use_empty()) { + rewriter.eraseOp(cloneOp); + return success(); + } + + Value source = cloneOp.input(); + + // Removes the clone operation and the corresponding dealloc and alloc + // operation (if any). + auto tryRemoveClone = [&](Operation *sourceOp, Operation *dealloc, + Operation *alloc) { + if (!sourceOp || !dealloc || !alloc || + alloc->getBlock() != dealloc->getBlock()) + return false; + rewriter.replaceOp(cloneOp, source); + rewriter.eraseOp(dealloc); + return true; + }; + + // Removes unnecessary clones that are derived from the result of the clone + // op. + Operation *deallocOp = findDealloc(cloneOp.output()); + Operation *sourceOp = source.getDefiningOp(); + if (tryRemoveClone(sourceOp, deallocOp, sourceOp)) + return success(); + + // Removes unnecessary clones that are derived from the source of the clone + // op. + deallocOp = findDealloc(source); + if (tryRemoveClone(sourceOp, deallocOp, cloneOp)) + return success(); + + return failure(); + } +}; + +} // end anonymous namespace. + +void CloneOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + +OpFoldResult CloneOp::fold(ArrayRef operands) { + return succeeded(foldMemRefCast(*this)) ? getResult() : Value(); +} + //===----------------------------------------------------------------------===// // DeallocOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp @@ -0,0 +1,35 @@ +//===- Utils.cpp - Utilities to support the MemRef dialect ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements utilities for the MemRef dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" + +using namespace mlir; + +/// Finds associated deallocs that can be linked to our allocation nodes (if +/// any). +Operation *mlir::findDealloc(Value allocValue) { + auto userIt = llvm::find_if(allocValue.getUsers(), [&](Operation *user) { + auto effectInterface = dyn_cast(user); + if (!effectInterface) + return false; + // Try to find a free effect that is applied to one of our values + // that will be automatically freed by our pass. + SmallVector effects; + effectInterface.getEffectsOnValue(allocValue, effects); + return llvm::any_of(effects, [&](MemoryEffects::EffectInstance &it) { + return isa(it.getEffect()); + }); + }); + // Assign the associated dealloc operation (if any). + return userIt != allocValue.user_end() ? *userIt : nullptr; +} diff --git a/mlir/lib/Transforms/BufferDeallocation.cpp b/mlir/lib/Transforms/BufferDeallocation.cpp --- a/mlir/lib/Transforms/BufferDeallocation.cpp +++ b/mlir/lib/Transforms/BufferDeallocation.cpp @@ -7,16 +7,15 @@ //===----------------------------------------------------------------------===// // // This file implements logic for computing correct alloc and dealloc positions. -// Furthermore, buffer placement also adds required new alloc and copy -// operations to ensure that all buffers are deallocated. The main class is the +// Furthermore, buffer deallocation also adds required new clone operations to +// ensure that all buffers are deallocated. The main class is the // BufferDeallocationPass class that implements the underlying algorithm. In // order to put allocations and deallocations at safe positions, it is // significantly important to put them into the correct blocks. However, the // liveness analysis does not pay attention to aliases, which can occur due to // branches (and their associated block arguments) in general. For this purpose, // BufferDeallocation firstly finds all possible aliases for a single value -// (using the BufferAliasAnalysis class). Consider the following -// example: +// (using the BufferAliasAnalysis class). Consider the following example: // // ^bb0(%arg0): // cond_br %cond, ^bb1, ^bb2 @@ -30,16 +29,16 @@ // // We should place the dealloc for %new_value in exit. However, we have to free // the buffer in the same block, because it cannot be freed in the post -// dominator. However, this requires a new copy buffer for %arg1 that will +// dominator. However, this requires a new clone buffer for %arg1 that will // contain the actual contents. Using the class BufferAliasAnalysis, we // will find out that %new_value has a potential alias %arg1. In order to find // the dealloc position we have to find all potential aliases, iterate over // their uses and find the common post-dominator block (note that additional -// copies and buffers remove potential aliases and will influence the placement +// clones and buffers remove potential aliases and will influence the placement // of the deallocs). In all cases, the computed block can be safely used to free // the %new_value buffer (may be exit or bb2) as it will die and we can use // liveness information to determine the exact operation after which we have to -// insert the dealloc. However, the algorithm supports introducing copy buffers +// insert the dealloc. However, the algorithm supports introducing clone buffers // and placing deallocs in safe locations to ensure that all buffers will be // freed in the end. // @@ -52,10 +51,8 @@ //===----------------------------------------------------------------------===// #include "PassDetail.h" -#include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "mlir/Dialect/StandardOps/Utils/Utils.h" #include "mlir/IR/Operation.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/LoopLikeInterface.h" @@ -187,25 +184,25 @@ /// The buffer deallocation transformation which ensures that all allocs in the /// program have a corresponding de-allocation. As a side-effect, it might also -/// introduce copies that in turn leads to additional allocs and de-allocations. +/// introduce clones that in turn leads to additional deallocations. class BufferDeallocation : BufferPlacementTransformationBase { public: BufferDeallocation(Operation *op) : BufferPlacementTransformationBase(op), dominators(op), postDominators(op) {} - /// Performs the actual placement/creation of all temporary alloc, copy and - /// dealloc nodes. + /// Performs the actual placement/creation of all temporary clone and dealloc + /// nodes. void deallocate() { - // Add additional allocations and copies that are required. - introduceCopies(); + // Add additional clones that are required. + introduceClones(); // Place deallocations for all allocation entries. placeDeallocs(); } private: - /// Introduces required allocs and copy operations to avoid memory leaks. - void introduceCopies() { + /// Introduces required clone operations to avoid memory leaks. + void introduceClones() { // Initialize the set of values that require a dedicated memory free // operation since their operands cannot be safely deallocated in a post // dominator. @@ -214,7 +211,7 @@ SmallVector, 8> toProcess; // Check dominance relation for proper dominance properties. If the given - // value node does not dominate an alias, we will have to create a copy in + // value node does not dominate an alias, we will have to create a clone in // order to free all buffers that can potentially leak into a post // dominator. auto findUnsafeValues = [&](Value source, Block *definingBlock) { @@ -255,7 +252,7 @@ // arguments at the correct locations. aliases.remove(valuesToFree); - // Add new allocs and additional copy operations. + // Add new allocs and additional clone operations. for (Value value : valuesToFree) { if (auto blockArg = value.dyn_cast()) introduceBlockArgCopy(blockArg); @@ -269,7 +266,7 @@ } } - /// Introduces temporary allocs in all predecessors and copies the source + /// Introduces temporary clones in all predecessors and copies the source /// values into the newly allocated buffers. void introduceBlockArgCopy(BlockArgument blockArg) { // Allocate a buffer for the current block argument in the block of @@ -285,9 +282,9 @@ Value sourceValue = branchInterface.getSuccessorOperands(it.getSuccessorIndex()) .getValue()[blockArg.getArgNumber()]; - // Create a new alloc and copy at the current location of the terminator. - Value alloc = introduceBufferCopy(sourceValue, terminator); - // Wire new alloc and successor operand. + // Create a new clone at the current location of the terminator. + Value clone = introduceCloneBuffers(sourceValue, terminator); + // Wire new clone and successor operand. auto mutableOperands = branchInterface.getMutableSuccessorOperands(it.getSuccessorIndex()); if (!mutableOperands.hasValue()) @@ -296,7 +293,7 @@ else mutableOperands.getValue() .slice(blockArg.getArgNumber(), 1) - .assign(alloc); + .assign(clone); } // Check whether the block argument has implicitly defined predecessors via @@ -310,7 +307,7 @@ !(regionInterface = dyn_cast(parentOp))) return; - introduceCopiesForRegionSuccessors( + introduceClonesForRegionSuccessors( regionInterface, argRegion->getParentOp()->getRegions(), blockArg, [&](RegionSuccessor &successorRegion) { // Find a predecessor of our argRegion. @@ -318,7 +315,7 @@ }); // Check whether the block argument belongs to an entry region of the - // parent operation. In this case, we have to introduce an additional copy + // parent operation. In this case, we have to introduce an additional clone // for buffer that is passed to the argument. SmallVector successorRegions; regionInterface.getSuccessorRegions(/*index=*/llvm::None, successorRegions); @@ -329,20 +326,20 @@ if (it == successorRegions.end()) return; - // Determine the actual operand to introduce a copy for and rewire the - // operand to point to the copy instead. + // Determine the actual operand to introduce a clone for and rewire the + // operand to point to the clone instead. Value operand = regionInterface.getSuccessorEntryOperands(argRegion->getRegionNumber()) [llvm::find(it->getSuccessorInputs(), blockArg).getIndex()]; - Value copy = introduceBufferCopy(operand, parentOp); + Value clone = introduceCloneBuffers(operand, parentOp); auto op = llvm::find(parentOp->getOperands(), operand); assert(op != parentOp->getOperands().end() && "parentOp does not contain operand"); - parentOp->setOperand(op.getIndex(), copy); + parentOp->setOperand(op.getIndex(), clone); } - /// Introduces temporary allocs in front of all associated nested-region + /// Introduces temporary clones in front of all associated nested-region /// terminators and copies the source values into the newly allocated buffers. void introduceValueCopyForRegionResult(Value value) { // Get the actual result index in the scope of the parent terminator. @@ -354,20 +351,20 @@ // its parent operation. return !successorRegion.getSuccessor(); }; - // Introduce a copy for all region "results" that are returned to the parent - // operation. This is required since the parent's result value has been - // considered critical. Therefore, the algorithm assumes that a copy of a - // previously allocated buffer is returned by the operation (like in the - // case of a block argument). - introduceCopiesForRegionSuccessors(regionInterface, operation->getRegions(), + // Introduce a clone for all region "results" that are returned to the + // parent operation. This is required since the parent's result value has + // been considered critical. Therefore, the algorithm assumes that a clone + // of a previously allocated buffer is returned by the operation (like in + // the case of a block argument). + introduceClonesForRegionSuccessors(regionInterface, operation->getRegions(), value, regionPredicate); } - /// Introduces buffer copies for all terminators in the given regions. The + /// Introduces buffer clones for all terminators in the given regions. The /// regionPredicate is applied to every successor region in order to restrict - /// the copies to specific regions. + /// the clones to specific regions. template - void introduceCopiesForRegionSuccessors( + void introduceClonesForRegionSuccessors( RegionBranchOpInterface regionInterface, MutableArrayRef regions, Value argValue, const TPredicate ®ionPredicate) { for (Region ®ion : regions) { @@ -393,49 +390,37 @@ walkReturnOperations(®ion, [&](Operation *terminator) { // Extract the source value from the current terminator. Value sourceValue = terminator->getOperand(operandIndex); - // Create a new alloc at the current location of the terminator. - Value alloc = introduceBufferCopy(sourceValue, terminator); - // Wire alloc and terminator operand. - terminator->setOperand(operandIndex, alloc); + // Create a new clone at the current location of the terminator. + Value clone = introduceCloneBuffers(sourceValue, terminator); + // Wire clone and terminator operand. + terminator->setOperand(operandIndex, clone); }); } } - /// Creates a new memory allocation for the given source value and copies + /// Creates a new memory allocation for the given source value and clones /// its content into the newly allocated buffer. The terminator operation is - /// used to insert the alloc and copy operations at the right places. - Value introduceBufferCopy(Value sourceValue, Operation *terminator) { - // Avoid multiple copies of the same source value. This can happen in the + /// used to insert the clone operation at the right place. + Value introduceCloneBuffers(Value sourceValue, Operation *terminator) { + // Avoid multiple clones of the same source value. This can happen in the // presence of loops when a branch acts as a backedge while also having // another successor that returns to its parent operation. Note: that // copying copied buffers can introduce memory leaks since the invariant of - // BufferPlacement assumes that a buffer will be only copied once into a - // temporary buffer. Hence, the construction of copy chains introduces + // BufferDeallocation assumes that a buffer will be only cloned once into a + // temporary buffer. Hence, the construction of clone chains introduces // additional allocations that are not tracked automatically by the // algorithm. - if (copiedValues.contains(sourceValue)) + if (clonedValues.contains(sourceValue)) return sourceValue; - // Create a new alloc at the current location of the terminator. - auto memRefType = sourceValue.getType().cast(); + // Create a new clone operation that copies the contents of the old + // buffer to the new one. OpBuilder builder(terminator); + auto cloneOp = + builder.create(terminator->getLoc(), sourceValue); - // Extract information about dynamically shaped types by - // extracting their dynamic dimensions. - auto dynamicOperands = - getDynOperands(terminator->getLoc(), sourceValue, builder); - - // TODO: provide a generic interface to create dialect-specific - // Alloc and CopyOp nodes. - auto alloc = builder.create(terminator->getLoc(), - memRefType, dynamicOperands); - - // Create a new copy operation that copies to contents of the old - // allocation to the new one. - builder.create(terminator->getLoc(), sourceValue, alloc); - - // Remember the copy of original source value. - copiedValues.insert(alloc); - return alloc; + // Remember the clone of original source value. + clonedValues.insert(cloneOp); + return cloneOp; } /// Finds correct dealloc positions according to the algorithm described at @@ -513,8 +498,8 @@ /// position. PostDominanceInfo postDominators; - /// Stores already copied allocations to avoid additional copies of copies. - ValueSetT copiedValues; + /// Stores already cloned buffers to avoid additional clones of clones. + ValueSetT clonedValues; }; //===----------------------------------------------------------------------===// @@ -522,8 +507,8 @@ //===----------------------------------------------------------------------===// /// The actual buffer deallocation pass that inserts and moves dealloc nodes -/// into the right positions. Furthermore, it inserts additional allocs and -/// copies if necessary. It uses the algorithm described at the top of the file. +/// into the right positions. Furthermore, it inserts additional clones if +/// necessary. It uses the algorithm described at the top of the file. struct BufferDeallocationPass : BufferDeallocationBase { void runOnFunction() override { @@ -540,7 +525,7 @@ return signalPassFailure(); } - // Place all required temporary alloc, copy and dealloc nodes. + // Place all required temporary clone and dealloc nodes. BufferDeallocation deallocation(getFunction()); deallocation.deallocate(); } diff --git a/mlir/lib/Transforms/BufferUtils.cpp b/mlir/lib/Transforms/BufferUtils.cpp --- a/mlir/lib/Transforms/BufferUtils.cpp +++ b/mlir/lib/Transforms/BufferUtils.cpp @@ -12,7 +12,7 @@ #include "mlir/Transforms/BufferUtils.h" #include "PassDetail.h" -#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Operation.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" @@ -49,25 +49,6 @@ return startOperation; } -/// Finds associated deallocs that can be linked to our allocation nodes (if -/// any). -Operation *BufferPlacementAllocs::findDealloc(Value allocValue) { - auto userIt = llvm::find_if(allocValue.getUsers(), [&](Operation *user) { - auto effectInterface = dyn_cast(user); - if (!effectInterface) - return false; - // Try to find a free effect that is applied to one of our values - // that will be automatically freed by our pass. - SmallVector effects; - effectInterface.getEffectsOnValue(allocValue, effects); - return llvm::any_of(effects, [&](MemoryEffects::EffectInstance &it) { - return isa(it.getEffect()); - }); - }); - // Assign the associated dealloc operation (if any). - return userIt != allocValue.user_end() ? *userIt : nullptr; -} - /// Initializes the internal list by discovering all supported allocation /// nodes. BufferPlacementAllocs::BufferPlacementAllocs(Operation *op) { build(op); } diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt --- a/mlir/lib/Transforms/CMakeLists.txt +++ b/mlir/lib/Transforms/CMakeLists.txt @@ -7,7 +7,6 @@ BufferUtils.cpp Bufferize.cpp Canonicalizer.cpp - CopyRemoval.cpp CSE.cpp Inliner.cpp LocationSnapshot.cpp diff --git a/mlir/lib/Transforms/CopyRemoval.cpp b/mlir/lib/Transforms/CopyRemoval.cpp deleted file mode 100644 --- a/mlir/lib/Transforms/CopyRemoval.cpp +++ /dev/null @@ -1,217 +0,0 @@ -//===- CopyRemoval.cpp - Removing the redundant copies --------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir/Interfaces/CopyOpInterface.h" -#include "mlir/Interfaces/SideEffectInterfaces.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/Passes.h" - -using namespace mlir; -using namespace MemoryEffects; - -namespace { - -//===----------------------------------------------------------------------===// -// CopyRemovalPass -//===----------------------------------------------------------------------===// - -/// This pass removes the redundant Copy operations. Additionally, it -/// removes the leftover definition and deallocation operations by erasing the -/// copy operation. -class CopyRemovalPass : public PassWrapper> { -public: - void runOnOperation() override { - getOperation()->walk([&](CopyOpInterface copyOp) { - reuseCopySourceAsTarget(copyOp); - reuseCopyTargetAsSource(copyOp); - }); - for (std::pair &pair : replaceList) - pair.first.replaceAllUsesWith(pair.second); - for (Operation *op : eraseList) - op->erase(); - } - -private: - /// List of operations that need to be removed. - llvm::SmallPtrSet eraseList; - - /// List of values that need to be replaced with their counterparts. - llvm::SmallDenseSet, 4> replaceList; - - /// Returns the allocation operation for `value` in `block` if it exists. - /// nullptr otherwise. - Operation *getAllocationOpInBlock(Value value, Block *block) { - assert(block && "Block cannot be null"); - Operation *op = value.getDefiningOp(); - if (op && op->getBlock() == block) { - auto effects = dyn_cast(op); - if (effects && effects.hasEffect()) - return op; - } - return nullptr; - } - - /// Returns the deallocation operation for `value` in `block` if it exists. - /// nullptr otherwise. - Operation *getDeallocationOpInBlock(Value value, Block *block) { - assert(block && "Block cannot be null"); - auto valueUsers = value.getUsers(); - auto it = llvm::find_if(valueUsers, [&](Operation *op) { - auto effects = dyn_cast(op); - return effects && op->getBlock() == block && effects.hasEffect(); - }); - return (it == valueUsers.end() ? nullptr : *it); - } - - /// Returns true if an operation between start and end operations has memory - /// effect. - bool hasMemoryEffectOpBetween(Operation *start, Operation *end) { - assert((start || end) && "Start and end operations cannot be null"); - assert(start->getBlock() == end->getBlock() && - "Start and end operations should be in the same block."); - Operation *op = start->getNextNode(); - while (op->isBeforeInBlock(end)) { - if (isa(op)) - return true; - op = op->getNextNode(); - } - return false; - }; - - /// Returns true if `val` value has at least a user between `start` and - /// `end` operations. - bool hasUsersBetween(Value val, Operation *start, Operation *end) { - assert((start || end) && "Start and end operations cannot be null"); - Block *block = start->getBlock(); - assert(block == end->getBlock() && - "Start and end operations should be in the same block."); - return llvm::any_of(val.getUsers(), [&](Operation *op) { - return op->getBlock() == block && start->isBeforeInBlock(op) && - op->isBeforeInBlock(end); - }); - }; - - bool areOpsInTheSameBlock(ArrayRef operations) { - assert(!operations.empty() && - "The operations list should contain at least a single operation"); - Block *block = operations.front()->getBlock(); - return llvm::none_of( - operations, [&](Operation *op) { return block != op->getBlock(); }); - } - - /// Input: - /// func(){ - /// %from = alloc() - /// write_to(%from) - /// %to = alloc() - /// copy(%from,%to) - /// dealloc(%from) - /// return %to - /// } - /// - /// Output: - /// func(){ - /// %from = alloc() - /// write_to(%from) - /// return %from - /// } - /// Constraints: - /// 1) %to, copy and dealloc must all be defined and lie in the same block. - /// 2) This transformation cannot be applied if there is a single user/alias - /// of `to` value between the defining operation of `to` and the copy - /// operation. - /// 3) This transformation cannot be applied if there is a single user/alias - /// of `from` value between the copy operation and the deallocation of `from`. - /// TODO: Alias analysis is not available at the moment. Currently, we check - /// if there are any operations with memory effects between copy and - /// deallocation operations. - void reuseCopySourceAsTarget(CopyOpInterface copyOp) { - if (eraseList.count(copyOp)) - return; - - Value from = copyOp.getSource(); - Value to = copyOp.getTarget(); - - Operation *copy = copyOp.getOperation(); - Block *copyBlock = copy->getBlock(); - Operation *fromDefiningOp = from.getDefiningOp(); - Operation *fromFreeingOp = getDeallocationOpInBlock(from, copyBlock); - Operation *toDefiningOp = getAllocationOpInBlock(to, copyBlock); - if (!fromDefiningOp || !fromFreeingOp || !toDefiningOp || - !areOpsInTheSameBlock({fromFreeingOp, toDefiningOp, copy}) || - hasUsersBetween(to, toDefiningOp, copy) || - hasUsersBetween(from, copy, fromFreeingOp) || - hasMemoryEffectOpBetween(copy, fromFreeingOp)) - return; - - replaceList.insert({to, from}); - eraseList.insert(copy); - eraseList.insert(toDefiningOp); - eraseList.insert(fromFreeingOp); - } - - /// Input: - /// func(){ - /// %to = alloc() - /// %from = alloc() - /// write_to(%from) - /// copy(%from,%to) - /// dealloc(%from) - /// return %to - /// } - /// - /// Output: - /// func(){ - /// %to = alloc() - /// write_to(%to) - /// return %to - /// } - /// Constraints: - /// 1) %from, copy and dealloc must all be defined and lie in the same block. - /// 2) This transformation cannot be applied if there is a single user/alias - /// of `to` value between the defining operation of `from` and the copy - /// operation. - /// 3) This transformation cannot be applied if there is a single user/alias - /// of `from` value between the copy operation and the deallocation of `from`. - /// TODO: Alias analysis is not available at the moment. Currently, we check - /// if there are any operations with memory effects between copy and - /// deallocation operations. - void reuseCopyTargetAsSource(CopyOpInterface copyOp) { - if (eraseList.count(copyOp)) - return; - - Value from = copyOp.getSource(); - Value to = copyOp.getTarget(); - - Operation *copy = copyOp.getOperation(); - Block *copyBlock = copy->getBlock(); - Operation *fromDefiningOp = getAllocationOpInBlock(from, copyBlock); - Operation *fromFreeingOp = getDeallocationOpInBlock(from, copyBlock); - if (!fromDefiningOp || !fromFreeingOp || - !areOpsInTheSameBlock({fromFreeingOp, fromDefiningOp, copy}) || - hasUsersBetween(to, fromDefiningOp, copy) || - hasUsersBetween(from, copy, fromFreeingOp) || - hasMemoryEffectOpBetween(copy, fromFreeingOp)) - return; - - replaceList.insert({from, to}); - eraseList.insert(copy); - eraseList.insert(fromDefiningOp); - eraseList.insert(fromFreeingOp); - } -}; - -} // end anonymous namespace - -//===----------------------------------------------------------------------===// -// CopyRemovalPass construction -//===----------------------------------------------------------------------===// - -std::unique_ptr mlir::createCopyRemovalPass() { - return std::make_unique(); -} diff --git a/mlir/test/Transforms/buffer-deallocation.mlir b/mlir/test/Transforms/buffer-deallocation.mlir --- a/mlir/test/Transforms/buffer-deallocation.mlir +++ b/mlir/test/Transforms/buffer-deallocation.mlir @@ -30,13 +30,11 @@ } // CHECK-NEXT: cond_br -// CHECK: %[[ALLOC0:.*]] = memref.alloc() -// CHECK-NEXT: linalg.copy +// CHECK: %[[ALLOC0:.*]] = memref.clone // CHECK-NEXT: br ^bb3(%[[ALLOC0]] -// CHECK: %[[ALLOC1:.*]] = memref.alloc() +// CHECK: %[[ALLOC1:.*]] = memref.alloc // CHECK-NEXT: test.buffer_based -// CHECK: %[[ALLOC2:.*]] = memref.alloc() -// CHECK-NEXT: linalg.copy +// CHECK-NEXT: %[[ALLOC2:.*]] = memref.clone %[[ALLOC1]] // CHECK-NEXT: memref.dealloc %[[ALLOC1]] // CHECK-NEXT: br ^bb3(%[[ALLOC2]] // CHECK: test.copy @@ -77,16 +75,12 @@ } // CHECK-NEXT: cond_br -// CHECK: %[[DIM0:.*]] = memref.dim -// CHECK-NEXT: %[[ALLOC0:.*]] = memref.alloc(%[[DIM0]]) -// CHECK-NEXT: linalg.copy(%{{.*}}, %[[ALLOC0]]) +// CHECK: %[[ALLOC0:.*]] = memref.clone // CHECK-NEXT: br ^bb3(%[[ALLOC0]] // CHECK: ^bb2(%[[IDX:.*]]:{{.*}}) // CHECK-NEXT: %[[ALLOC1:.*]] = memref.alloc(%[[IDX]]) // CHECK-NEXT: test.buffer_based -// CHECK: %[[DIM1:.*]] = memref.dim %[[ALLOC1]] -// CHECK-NEXT: %[[ALLOC2:.*]] = memref.alloc(%[[DIM1]]) -// CHECK-NEXT: linalg.copy(%[[ALLOC1]], %[[ALLOC2]]) +// CHECK-NEXT: %[[ALLOC2:.*]] = memref.clone // CHECK-NEXT: memref.dealloc %[[ALLOC1]] // CHECK-NEXT: br ^bb3 // CHECK-NEXT: ^bb3(%[[ALLOC3:.*]]:{{.*}}) @@ -142,12 +136,10 @@ return } -// CHECK-NEXT: cond_br -// CHECK: ^bb1 -// CHECK: %[[DIM0:.*]] = memref.dim -// CHECK-NEXT: %[[ALLOC0:.*]] = memref.alloc(%[[DIM0]]) -// CHECK-NEXT: linalg.copy(%{{.*}}, %[[ALLOC0]]) -// CHECK-NEXT: br ^bb6 +// CHECK-NEXT: cond_br{{.*}} +// CHECK-NEXT: ^bb1 +// CHECK-NEXT: %[[ALLOC0:.*]] = memref.clone +// CHECK-NEXT: br ^bb6(%[[ALLOC0]] // CHECK: ^bb2(%[[IDX:.*]]:{{.*}}) // CHECK-NEXT: %[[ALLOC1:.*]] = memref.alloc(%[[IDX]]) // CHECK-NEXT: test.buffer_based @@ -157,9 +149,7 @@ // CHECK: ^bb4: // CHECK-NEXT: br ^bb5(%[[ALLOC1]]{{.*}}) // CHECK-NEXT: ^bb5(%[[ALLOC2:.*]]:{{.*}}) -// CHECK: %[[DIM2:.*]] = memref.dim %[[ALLOC2]] -// CHECK-NEXT: %[[ALLOC3:.*]] = memref.alloc(%[[DIM2]]) -// CHECK-NEXT: linalg.copy(%[[ALLOC2]], %[[ALLOC3]]) +// CHECK-NEXT: %[[ALLOC3:.*]] = memref.clone %[[ALLOC2]] // CHECK-NEXT: memref.dealloc %[[ALLOC1]] // CHECK-NEXT: br ^bb6(%[[ALLOC3]]{{.*}}) // CHECK-NEXT: ^bb6(%[[ALLOC4:.*]]:{{.*}}) @@ -208,13 +198,11 @@ return } -// CHECK-NEXT: %[[ALLOC0:.*]] = memref.alloc() -// CHECK-NEXT: linalg.copy +// CHECK-NEXT: %[[ALLOC0:.*]] = memref.clone // CHECK-NEXT: cond_br // CHECK: %[[ALLOC1:.*]] = memref.alloc() // CHECK-NEXT: test.buffer_based -// CHECK: %[[ALLOC2:.*]] = memref.alloc() -// CHECK-NEXT: linalg.copy +// CHECK-NEXT: %[[ALLOC2:.*]] = memref.clone %[[ALLOC1]] // CHECK-NEXT: memref.dealloc %[[ALLOC1]] // CHECK: test.copy // CHECK-NEXT: memref.dealloc @@ -419,20 +407,17 @@ return } -// CHECK-NEXT: cond_br -// CHECK: ^bb1 -// CHECK: ^bb1 +// CHECK-NEXT: cond_br{{.*}} +// CHECK-NEXT: ^bb1 // CHECK: %[[ALLOC0:.*]] = memref.alloc() // CHECK-NEXT: test.buffer_based -// CHECK: %[[ALLOC1:.*]] = memref.alloc() -// CHECK-NEXT: linalg.copy +// CHECK-NEXT: %[[ALLOC1:.*]] = memref.clone %[[ALLOC0]] // CHECK-NEXT: memref.dealloc %[[ALLOC0]] // CHECK-NEXT: br ^bb3(%[[ALLOC1]] // CHECK-NEXT: ^bb2 // CHECK-NEXT: %[[ALLOC2:.*]] = memref.alloc() // CHECK-NEXT: test.buffer_based -// CHECK: %[[ALLOC3:.*]] = memref.alloc() -// CHECK-NEXT: linalg.copy +// CHECK-NEXT: %[[ALLOC3:.*]] = memref.clone %[[ALLOC2]] // CHECK-NEXT: memref.dealloc %[[ALLOC2]] // CHECK-NEXT: br ^bb3(%[[ALLOC3]] // CHECK-NEXT: ^bb3(%[[ALLOC4:.*]]:{{.*}}) @@ -545,8 +530,7 @@ } // CHECK: (%[[cond:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %{{.*}}: {{.*}}) // CHECK-NEXT: cond_br %[[cond]], ^[[BB1:.*]], ^[[BB2:.*]] -// CHECK: %[[ALLOC0:.*]] = memref.alloc() -// CHECK-NEXT: linalg.copy(%[[ARG1]], %[[ALLOC0]]) +// CHECK: %[[ALLOC0:.*]] = memref.clone %[[ARG1]] // CHECK: ^[[BB2]]: // CHECK: %[[ALLOC1:.*]] = memref.alloc() // CHECK-NEXT: test.region_buffer_based in(%[[ARG1]]{{.*}}out(%[[ALLOC1]] @@ -554,12 +538,11 @@ // CHECK-NEXT: test.buffer_based in(%[[ARG1]]{{.*}}out(%[[ALLOC2]] // CHECK: memref.dealloc %[[ALLOC2]] // CHECK-NEXT: %{{.*}} = math.exp -// CHECK: %[[ALLOC3:.*]] = memref.alloc() -// CHECK-NEXT: linalg.copy(%[[ALLOC1]], %[[ALLOC3]]) +// CHECK: %[[ALLOC3:.*]] = memref.clone %[[ALLOC1]] // CHECK-NEXT: memref.dealloc %[[ALLOC1]] // CHECK: ^[[BB3:.*]]({{.*}}): // CHECK: test.copy -// CHECK-NEXT: dealloc +// CHECK-NEXT: memref.dealloc // ----- @@ -641,12 +624,10 @@ // CHECK: %[[ALLOC0:.*]] = memref.alloc(%arg0, %arg0) // CHECK-NEXT: %[[ALLOC1:.*]] = scf.if -// CHECK: %[[ALLOC2:.*]] = memref.alloc -// CHECK-NEXT: linalg.copy(%[[ALLOC0]], %[[ALLOC2]]) +// CHECK-NEXT: %[[ALLOC2:.*]] = memref.clone %[[ALLOC0]] // CHECK: scf.yield %[[ALLOC2]] // CHECK: %[[ALLOC3:.*]] = memref.alloc(%arg0, %arg1) -// CHECK: %[[ALLOC4:.*]] = memref.alloc -// CHECK-NEXT: linalg.copy(%[[ALLOC3]], %[[ALLOC4]]) +// CHECK-NEXT: %[[ALLOC4:.*]] = memref.clone %[[ALLOC3]] // CHECK: memref.dealloc %[[ALLOC3]] // CHECK: scf.yield %[[ALLOC4]] // CHECK: memref.dealloc %[[ALLOC0]] @@ -823,20 +804,18 @@ // CHECK: (%[[cond:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %{{.*}}: {{.*}}) // CHECK-NEXT: cond_br %[[cond]], ^[[BB1:.*]], ^[[BB2:.*]] // CHECK: ^[[BB1]]: -// CHECK: %[[ALLOC0:.*]] = memref.alloc() -// CHECK-NEXT: linalg.copy +// CHECK: %[[ALLOC0:.*]] = memref.clone // CHECK: ^[[BB2]]: // CHECK: %[[ALLOC1:.*]] = memref.alloc() // CHECK-NEXT: test.region_buffer_based in(%[[ARG1]]{{.*}}out(%[[ALLOC1]] // CHECK: %[[ALLOCA:.*]] = memref.alloca() // CHECK-NEXT: test.buffer_based in(%[[ARG1]]{{.*}}out(%[[ALLOCA]] // CHECK: %{{.*}} = math.exp -// CHECK: %[[ALLOC2:.*]] = memref.alloc() -// CHECK-NEXT: linalg.copy +// CHECK: %[[ALLOC2:.*]] = memref.clone %[[ALLOC1]] // CHECK-NEXT: memref.dealloc %[[ALLOC1]] // CHECK: ^[[BB3:.*]]({{.*}}): // CHECK: test.copy -// CHECK-NEXT: dealloc +// CHECK-NEXT: memref.dealloc // ----- @@ -888,15 +867,13 @@ // CHECK: %[[ALLOC0:.*]] = memref.alloc() // CHECK-NEXT: memref.dealloc %[[ALLOC0]] -// CHECK-NEXT: %[[ALLOC1:.*]] = memref.alloc() -// CHECK: linalg.copy(%arg3, %[[ALLOC1]]) +// CHECK-NEXT: %[[ALLOC1:.*]] = memref.clone %arg3 // CHECK: %[[ALLOC2:.*]] = scf.for {{.*}} iter_args // CHECK-SAME: (%[[IALLOC:.*]] = %[[ALLOC1]] // CHECK: cmpi // CHECK: memref.dealloc %[[IALLOC]] // CHECK: %[[ALLOC3:.*]] = memref.alloc() -// CHECK: %[[ALLOC4:.*]] = memref.alloc() -// CHECK: linalg.copy(%[[ALLOC3]], %[[ALLOC4]]) +// CHECK: %[[ALLOC4:.*]] = memref.clone %[[ALLOC3]] // CHECK: memref.dealloc %[[ALLOC3]] // CHECK: scf.yield %[[ALLOC4]] // CHECK: } @@ -974,25 +951,21 @@ } // CHECK: %[[ALLOC0:.*]] = memref.alloc() -// CHECK: %[[ALLOC1:.*]] = memref.alloc() -// CHECK-NEXT: linalg.copy(%arg3, %[[ALLOC1]]) +// CHECK-NEXT: %[[ALLOC1:.*]] = memref.clone %arg3 // CHECK-NEXT: %[[ALLOC2:.*]] = scf.for {{.*}} iter_args // CHECK-SAME: (%[[IALLOC:.*]] = %[[ALLOC1]] // CHECK: memref.dealloc %[[IALLOC]] // CHECK: %[[ALLOC3:.*]] = scf.if // CHECK: %[[ALLOC4:.*]] = memref.alloc() -// CHECK-NEXT: %[[ALLOC5:.*]] = memref.alloc() -// CHECK-NEXT: linalg.copy(%[[ALLOC4]], %[[ALLOC5]]) +// CHECK-NEXT: %[[ALLOC5:.*]] = memref.clone %[[ALLOC4]] // CHECK-NEXT: memref.dealloc %[[ALLOC4]] // CHECK-NEXT: scf.yield %[[ALLOC5]] -// CHECK: %[[ALLOC6:.*]] = memref.alloc() -// CHECK-NEXT: linalg.copy(%[[ALLOC0]], %[[ALLOC6]]) +// CHECK: %[[ALLOC6:.*]] = memref.clone %[[ALLOC0]] // CHECK-NEXT: scf.yield %[[ALLOC6]] -// CHECK: %[[ALLOC7:.*]] = memref.alloc() -// CHECK-NEXT: linalg.copy(%[[ALLOC3:.*]], %[[ALLOC7]]) +// CHECK: %[[ALLOC7:.*]] = memref.clone %[[ALLOC3]] // CHECK-NEXT: memref.dealloc %[[ALLOC3]] // CHECK-NEXT: scf.yield %[[ALLOC7]] @@ -1040,17 +1013,14 @@ // CHECK: %[[ALLOC0:.*]] = memref.alloc() // CHECK-NEXT: memref.dealloc %[[ALLOC0]] -// CHECK-NEXT: %[[ALLOC1:.*]] = memref.alloc() -// CHECK-NEXT: linalg.copy(%arg3, %[[ALLOC1]]) +// CHECK-NEXT: %[[ALLOC1:.*]] = memref.clone %arg3 // CHECK-NEXT: %[[VAL_7:.*]] = scf.for {{.*}} iter_args // CHECK-SAME: (%[[IALLOC0:.*]] = %[[ALLOC1]]) -// CHECK: %[[ALLOC2:.*]] = memref.alloc() -// CHECK-NEXT: linalg.copy(%[[IALLOC0]], %[[ALLOC2]]) +// CHECK-NEXT: %[[ALLOC2:.*]] = memref.clone %[[IALLOC0]] // CHECK-NEXT: memref.dealloc %[[IALLOC0]] // CHECK-NEXT: %[[ALLOC3:.*]] = scf.for {{.*}} iter_args // CHECK-SAME: (%[[IALLOC1:.*]] = %[[ALLOC2]]) -// CHECK: %[[ALLOC5:.*]] = memref.alloc() -// CHECK-NEXT: linalg.copy(%[[IALLOC1]], %[[ALLOC5]]) +// CHECK-NEXT: %[[ALLOC5:.*]] = memref.clone %[[IALLOC1]] // CHECK-NEXT: memref.dealloc %[[IALLOC1]] // CHECK: %[[ALLOC6:.*]] = scf.for {{.*}} iter_args @@ -1060,28 +1030,23 @@ // CHECK: %[[ALLOC9:.*]] = scf.if // CHECK: %[[ALLOC11:.*]] = memref.alloc() -// CHECK-NEXT: %[[ALLOC12:.*]] = memref.alloc() -// CHECK-NEXT: linalg.copy(%[[ALLOC11]], %[[ALLOC12]]) +// CHECK-NEXT: %[[ALLOC12:.*]] = memref.clone %[[ALLOC11]] // CHECK-NEXT: memref.dealloc %[[ALLOC11]] // CHECK-NEXT: scf.yield %[[ALLOC12]] -// CHECK: %[[ALLOC13:.*]] = memref.alloc() -// CHECK-NEXT: linalg.copy(%[[IALLOC2]], %[[ALLOC13]]) +// CHECK: %[[ALLOC13:.*]] = memref.clone %[[IALLOC2]] // CHECK-NEXT: scf.yield %[[ALLOC13]] // CHECK: memref.dealloc %[[IALLOC2]] -// CHECK-NEXT: %[[ALLOC10:.*]] = memref.alloc() -// CHECK-NEXT: linalg.copy(%[[ALLOC9]], %[[ALLOC10]]) +// CHECK-NEXT: %[[ALLOC10:.*]] = memref.clone %[[ALLOC9]] // CHECK-NEXT: memref.dealloc %[[ALLOC9]] // CHECK-NEXT: scf.yield %[[ALLOC10]] -// CHECK: %[[ALLOC7:.*]] = memref.alloc() -// CHECK-NEXT: linalg.copy(%[[ALLOC6]], %[[ALLOC7]]) +// CHECK: %[[ALLOC7:.*]] = memref.clone %[[ALLOC6]] // CHECK-NEXT: memref.dealloc %[[ALLOC6]] // CHECK-NEXT: scf.yield %[[ALLOC7]] -// CHECK: %[[ALLOC4:.*]] = memref.alloc() -// CHECK-NEXT: linalg.copy(%[[ALLOC3]], %[[ALLOC4]]) +// CHECK: %[[ALLOC4:.*]] = memref.clone %[[ALLOC3]] // CHECK-NEXT: memref.dealloc %[[ALLOC3]] // CHECK-NEXT: scf.yield %[[ALLOC4]] @@ -1183,8 +1148,7 @@ // CHECK-NEXT: shape.assuming_yield %[[ARG1]] // CHECK: %[[ASSUMING_RESULT:.*]] = shape.assuming %[[ARG0]] // CHECK-NEXT: %[[TMP_ALLOC:.*]] = memref.alloc() -// CHECK-NEXT: %[[RETURNING_ALLOC:.*]] = memref.alloc() -// CHECK-NEXT: linalg.copy(%[[TMP_ALLOC]], %[[RETURNING_ALLOC]]) +// CHECK-NEXT: %[[RETURNING_ALLOC:.*]] = memref.clone %[[TMP_ALLOC]] // CHECK-NEXT: memref.dealloc %[[TMP_ALLOC]] // CHECK-NEXT: shape.assuming_yield %[[RETURNING_ALLOC]] // CHECK: test.copy(%[[ASSUMING_RESULT:.*]], %[[ARG2]]) diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -1059,3 +1059,88 @@ return %2 : tensor } +// ----- + +// CHECK-LABEL: func @simple_clone_elimination +func @simple_clone_elimination() -> memref<5xf32> { + %ret = memref.alloc() : memref<5xf32> + %temp = memref.clone %ret : memref<5xf32> to memref<5xf32> + memref.dealloc %temp : memref<5xf32> + return %ret : memref<5xf32> +} +// CHECK-NEXT: %[[ret:.*]] = memref.alloc() +// CHECK-NOT: %[[temp:.*]] = memref.clone +// CHECK-NOT: memref.dealloc %[[temp]] +// CHECK: return %[[ret]] + +// ----- + +// CHECK-LABEL: func @clone_loop_alloc +func @clone_loop_alloc(%arg0: index, %arg1: index, %arg2: index, %arg3: memref<2xf32>, %arg4: memref<2xf32>) { + %0 = memref.alloc() : memref<2xf32> + memref.dealloc %0 : memref<2xf32> + %1 = memref.clone %arg3 : memref<2xf32> to memref<2xf32> + %2 = scf.for %arg5 = %arg0 to %arg1 step %arg2 iter_args(%arg6 = %1) -> (memref<2xf32>) { + %3 = cmpi eq, %arg5, %arg1 : index + memref.dealloc %arg6 : memref<2xf32> + %4 = memref.alloc() : memref<2xf32> + %5 = memref.clone %4 : memref<2xf32> to memref<2xf32> + memref.dealloc %4 : memref<2xf32> + %6 = memref.clone %5 : memref<2xf32> to memref<2xf32> + memref.dealloc %5 : memref<2xf32> + scf.yield %6 : memref<2xf32> + } + linalg.copy(%2, %arg4) : memref<2xf32>, memref<2xf32> + memref.dealloc %2 : memref<2xf32> + return +} + +// CHECK-NEXT: %[[ALLOC0:.*]] = memref.clone +// CHECK-NEXT: %[[ALLOC1:.*]] = scf.for +// CHECK-NEXT: memref.dealloc +// CHECK-NEXT: %[[ALLOC2:.*]] = memref.alloc +// CHECK-NEXT: scf.yield %[[ALLOC2]] +// CHECK: linalg.copy(%[[ALLOC1]] +// CHECK-NEXT: memref.dealloc %[[ALLOC1]] + +// ----- + +// CHECK-LABEL: func @clone_nested_region +func @clone_nested_region(%arg0: index, %arg1: index) -> memref { + %0 = cmpi eq, %arg0, %arg1 : index + %1 = memref.alloc(%arg0, %arg0) : memref + %2 = scf.if %0 -> (memref) { + %3 = scf.if %0 -> (memref) { + %9 = memref.clone %1 : memref to memref + scf.yield %9 : memref + } else { + %7 = memref.alloc(%arg0, %arg1) : memref + %10 = memref.clone %7 : memref to memref + memref.dealloc %7 : memref + scf.yield %10 : memref + } + %6 = memref.clone %3 : memref to memref + memref.dealloc %3 : memref + scf.yield %6 : memref + } else { + %3 = memref.alloc(%arg1, %arg1) : memref + %6 = memref.clone %3 : memref to memref + memref.dealloc %3 : memref + scf.yield %6 : memref + } + memref.dealloc %1 : memref + return %2 : memref +} + +// CHECK: %[[ALLOC1:.*]] = memref.alloc +// CHECK-NEXT: %[[ALLOC2:.*]] = scf.if +// CHECK-NEXT: %[[ALLOC3_1:.*]] = scf.if +// CHECK-NEXT: %[[ALLOC4_1:.*]] = memref.clone %[[ALLOC1]] +// CHECK-NEXT: scf.yield %[[ALLOC4_1]] +// CHECK: %[[ALLOC4_2:.*]] = memref.alloc +// CHECK-NEXT: scf.yield %[[ALLOC4_2]] +// CHECK: scf.yield %[[ALLOC3_1]] +// CHECK: %[[ALLOC3_2:.*]] = memref.alloc +// CHECK-NEXT: scf.yield %[[ALLOC3_2]] +// CHECK: memref.dealloc %[[ALLOC1]] +// CHECK-NEXT: return %[[ALLOC2]] diff --git a/mlir/test/Transforms/copy-removal.mlir b/mlir/test/Transforms/copy-removal.mlir deleted file mode 100644 --- a/mlir/test/Transforms/copy-removal.mlir +++ /dev/null @@ -1,361 +0,0 @@ -// RUN: mlir-opt -copy-removal -split-input-file %s | FileCheck %s - -// All linalg copies except the linalg.copy(%1, %9) must be removed since the -// defining operation of %1 and its DeallocOp have been defined in another block. - -// CHECK-LABEL: func @nested_region_control_flow_div_nested -func @nested_region_control_flow_div_nested(%arg0: index, %arg1: index) -> memref { - %0 = cmpi eq, %arg0, %arg1 : index - %1 = memref.alloc(%arg0, %arg0) : memref - // CHECK: %{{.*}} = scf.if - %2 = scf.if %0 -> (memref) { - // CHECK: %[[PERCENT3:.*]] = scf.if - %3 = scf.if %0 -> (memref) { - %c0_0 = constant 0 : index - %7 = memref.dim %1, %c0_0 : memref - %c1_1 = constant 1 : index - %8 = memref.dim %1, %c1_1 : memref - %9 = memref.alloc(%7, %8) : memref - // CHECK: linalg.copy({{.*}}, %[[PERCENT9:.*]]) - linalg.copy(%1, %9) : memref, memref - // CHECK: scf.yield %[[PERCENT9]] - scf.yield %9 : memref - } else { - // CHECK: %[[PERCENT7:.*]] = memref.alloc - %7 = memref.alloc(%arg0, %arg1) : memref - %c0_0 = constant 0 : index - %8 = memref.dim %7, %c0_0 : memref - %c1_1 = constant 1 : index - %9 = memref.dim %7, %c1_1 : memref - // CHECK-NOT: %{{.*}} = memref.alloc - // CHECK-NOT: linalg.copy(%[[PERCENT7]], %{{.*}}) - // CHECK-NOT: memref.dealloc %[[PERCENT7]] - %10 = memref.alloc(%8, %9) : memref - linalg.copy(%7, %10) : memref, memref - memref.dealloc %7 : memref - // CHECK: scf.yield %[[PERCENT7]] - scf.yield %10 : memref - } - %c0 = constant 0 : index - %4 = memref.dim %3, %c0 : memref - %c1 = constant 1 : index - %5 = memref.dim %3, %c1 : memref - // CHECK-NOT: %{{.*}} = memref.alloc - // CHECK-NOT: linalg.copy(%[[PERCENT3]], %{{.*}}) - // CHECK-NOT: memref.dealloc %[[PERCENT3]] - %6 = memref.alloc(%4, %5) : memref - linalg.copy(%3, %6) : memref, memref - memref.dealloc %3 : memref - // CHECK: scf.yield %[[PERCENT3]] - scf.yield %6 : memref - } else { - // CHECK: %[[PERCENT3:.*]] = memref.alloc - %3 = memref.alloc(%arg1, %arg1) : memref - %c0 = constant 0 : index - %4 = memref.dim %3, %c0 : memref - %c1 = constant 1 : index - %5 = memref.dim %3, %c1 : memref - // CHECK-NOT: %{{.*}} = memref.alloc - // CHECK-NOT: linalg.copy(%[[PERCENT3]], %{{.*}}) - // CHECK-NOT: memref.dealloc %[[PERCENT3]] - %6 = memref.alloc(%4, %5) : memref - linalg.copy(%3, %6) : memref, memref - memref.dealloc %3 : memref - // CHECK: scf.yield %[[PERCENT3]] - scf.yield %6 : memref - } - memref.dealloc %1 : memref - return %2 : memref -} - -// ----- - -// CHECK-LABEL: func @simple_test -func @simple_test() -> memref<5xf32> { - %temp = memref.alloc() : memref<5xf32> - %ret = memref.alloc() : memref<5xf32> - linalg.copy(%ret, %temp) : memref<5xf32>, memref<5xf32> - memref.dealloc %ret : memref<5xf32> - return %temp : memref<5xf32> -} -// CHECK-SAME: () -> memref<5xf32> -// CHECK-NEXT: %[[ret:.*]] = memref.alloc() -// CHECK-NOT: linalg.copy(%[[ret]], %{{.*}}) -// CHECK-NOT: memref.dealloc %[[ret]] -// CHECK: return %[[ret]] - -// ----- - -// It is legal to remove the copy operation that %ret has a usage before the copy -// operation. The allocation of %temp and the deallocation of %ret should be also -// removed. - -// CHECK-LABEL: func @test_with_ret_usage_before_copy -func @test_with_ret_usage_before_copy() -> memref<5xf32> { - %ret = memref.alloc() : memref<5xf32> - %temp = memref.alloc() : memref<5xf32> - %c0 = constant 0 : index - %dimension = memref.dim %ret, %c0 : memref<5xf32> - linalg.copy(%ret, %temp) : memref<5xf32>, memref<5xf32> - memref.dealloc %ret : memref<5xf32> - return %temp : memref<5xf32> -} -// CHECK-NEXT: %[[ret:.*]] = memref.alloc() -// CHECK-NOT: %{{.*}} = memref.alloc -// CHECK-NEXT: %{{.*}} = constant -// CHECK-NEXT: %[[DIM:.*]] = memref.dim %[[ret]] -// CHECK-NOT: linalg.copy(%[[ret]], %{{.*}}) -// CHECK-NOT: memref.dealloc %[[ret]] -// CHECK: return %[[ret]] - -// ----- - -// It is illegal to remove a copy operation that %ret has a usage after copy -// operation. - -// CHECK-LABEL: func @test_with_ret_usage_after_copy -func @test_with_ret_usage_after_copy() -> memref<5xf32> { - %ret = memref.alloc() : memref<5xf32> - %temp = memref.alloc() : memref<5xf32> - // CHECK: linalg.copy - linalg.copy(%ret, %temp) : memref<5xf32>, memref<5xf32> - %c0 = constant 0 : index - %dimension = memref.dim %ret, %c0 : memref<5xf32> - memref.dealloc %ret : memref<5xf32> - return %temp : memref<5xf32> -} - -// ----- - -// It is illegal to remove a copy operation that %temp has a usage before copy -// operation. - -// CHECK-LABEL: func @test_with_temp_usage_before_copy -func @test_with_temp_usage_before_copy() -> memref<5xf32> { - %ret = memref.alloc() : memref<5xf32> - %temp = memref.alloc() : memref<5xf32> - %c0 = constant 0 : index - %dimension = memref.dim %temp, %c0 : memref<5xf32> - // CHECK: linalg.copy - linalg.copy(%ret, %temp) : memref<5xf32>, memref<5xf32> - memref.dealloc %ret : memref<5xf32> - return %temp : memref<5xf32> -} - -// ----- - -// It is legal to remove the copy operation that %temp has a usage after the copy -// operation. The allocation of %temp and the deallocation of %ret could be also -// removed. - -// However the following pattern is not handled by copy removal. -// %from = memref.alloc() -// %to = memref.alloc() -// copy(%from, %to) -// read_from(%from) + write_to(%something_else) -// memref.dealloc(%from) -// return %to -// In particular, linalg.generic is a memoryEffectOp between copy and dealloc. -// Since no alias analysis is performed and no distinction is made between reads -// and writes, the linalg.generic with effects blocks copy removal. - -#map0 = affine_map<(d0) -> (d0)> - -// CHECK-LABEL: func @test_with_temp_usage_after_copy -func @test_with_temp_usage_after_copy() -> memref<5xf32> { - %ret = memref.alloc() : memref<5xf32> - %res = memref.alloc() : memref<5xf32> - %temp = memref.alloc() : memref<5xf32> - linalg.copy(%ret, %temp) : memref<5xf32>, memref<5xf32> - linalg.generic { - indexing_maps = [#map0, #map0], - iterator_types = ["parallel"]} - ins(%temp : memref<5xf32>) - outs(%res : memref<5xf32>) { - ^bb0(%gen1_arg0: f32, %gen1_arg1: f32): - %tmp1 = math.exp %gen1_arg0 : f32 - linalg.yield %tmp1 : f32 - } - memref.dealloc %ret : memref<5xf32> - return %temp : memref<5xf32> -} -// CHECK-NEXT: %[[ret:.*]] = memref.alloc() -// CHECK-NEXT: %[[res:.*]] = memref.alloc() -// CHECK-NEXT: %[[temp:.*]] = memref.alloc() -// CHECK-NEXT: linalg.copy(%[[ret]], %[[temp]]) -// CHECK-NEXT: linalg.generic -// CHECK: memref.dealloc %[[ret]] -// CHECK: return %[[temp]] - -// ----- - -// CHECK-LABEL: func @make_allocation -func @make_allocation() -> memref<5xf32> { - %mem = memref.alloc() : memref<5xf32> - return %mem : memref<5xf32> -} - -// CHECK-LABEL: func @test_with_function_call -func @test_with_function_call() -> memref<5xf32> { - // CHECK-NEXT: %[[ret:.*]] = call @make_allocation() : () -> memref<5xf32> - %ret = call @make_allocation() : () -> (memref<5xf32>) - // CHECK-NOT: %{{.*}} = memref.alloc - // CHECK-NOT: linalg.copy(%[[ret]], %{{.*}}) - // CHECK-NOT: memref.dealloc %[[ret]] - %temp = memref.alloc() : memref<5xf32> - linalg.copy(%ret, %temp) : memref<5xf32>, memref<5xf32> - memref.dealloc %ret : memref<5xf32> - // CHECK: return %[[ret]] - return %temp : memref<5xf32> -} - -// ----- - -// CHECK-LABEL: func @multiple_deallocs_in_different_blocks -func @multiple_deallocs_in_different_blocks(%cond : i1) -> memref<5xf32> { - // CHECK-NEXT: %[[PERCENT0:.*]] = memref.alloc() - %0 = memref.alloc() : memref<5xf32> - cond_br %cond, ^bb1, ^bb2 -^bb1: - memref.dealloc %0 : memref<5xf32> - // CHECK: br ^[[BB3:.*]](%[[PERCENT0]] - br ^bb3(%0 : memref<5xf32>) -^bb2: - // CHECK-NOT: %{{.*}} = memref.alloc - // CHECK-NOT: linalg.copy(%[[PERCENT0]], %{{.*}}) - // CHECK-NOT: memref.dealloc %[[PERCENT0]] - %temp = memref.alloc() : memref<5xf32> - linalg.copy(%0, %temp) : memref<5xf32>, memref<5xf32> - memref.dealloc %0 : memref<5xf32> - // CHECK: br ^[[BB3]](%[[PERCENT0]] - br ^bb3(%temp : memref<5xf32>) -^bb3(%res : memref<5xf32>): - return %res : memref<5xf32> -} - -// ----- - -#map0 = affine_map<(d0) -> (d0)> - -// CHECK-LABEL: func @test_ReuseCopyTargetAsSource -func @test_ReuseCopyTargetAsSource(%arg0: memref<2xf32>, %result: memref<2xf32>){ - // CHECK-SAME: (%[[ARG0:.*]]: memref<2xf32>, %[[RES:.*]]: memref<2xf32>) - // CHECK-NOT: %{{.*}} = memref.alloc - %temp = memref.alloc() : memref<2xf32> - // CHECK-NEXT: linalg.generic - // CHECK-SAME: ins(%[[ARG0]]{{.*}}outs(%[[RES]] - // CHECK-NOT: linalg.copy(%{{.*}}, %[[RES]]) - // CHECK-NOT: memref.dealloc %{{.*}} - linalg.generic { - indexing_maps = [#map0, #map0], - iterator_types = ["parallel"]} - ins(%arg0 : memref<2xf32>) - outs(%temp : memref<2xf32>) { - ^bb0(%gen2_arg0: f32, %gen2_arg1: f32): - %tmp2 = math.exp %gen2_arg0 : f32 - linalg.yield %tmp2 : f32 - } - linalg.copy(%temp, %result) : memref<2xf32>, memref<2xf32> - memref.dealloc %temp : memref<2xf32> - // CHECK: return - return -} - -// ----- - -// Copy operation must not be removed since an operation writes to %to value -// before copy. - -#map0 = affine_map<(d0) -> (d0)> - -// CHECK-LABEL: func @test_ReuseCopyTargetAsSource -func @test_ReuseCopyTargetAsSource(%arg0: memref<2xf32>){ - %to = memref.alloc() : memref<2xf32> - %temp = memref.alloc() : memref<2xf32> - linalg.generic { - indexing_maps = [#map0, #map0], - iterator_types = ["parallel"]} - ins(%arg0 : memref<2xf32>) - outs(%temp : memref<2xf32>) { - ^bb0(%gen1_arg0: f32, %gen1_arg1: f32): - %tmp1 = math.exp %gen1_arg0 : f32 - linalg.yield %tmp1 : f32 - } - linalg.generic { - indexing_maps = [#map0, #map0], - iterator_types = ["parallel"]} - ins(%arg0 : memref<2xf32>) - outs(%to : memref<2xf32>) { - ^bb0(%gen2_arg0: f32, %gen2_arg1: f32): - %tmp2 = math.exp %gen2_arg0 : f32 - linalg.yield %tmp2 : f32 - } - // CHECK: linalg.copy - linalg.copy(%temp, %to) : memref<2xf32>, memref<2xf32> - memref.dealloc %temp : memref<2xf32> - return -} - -// ----- - -// The only redundant copy is linalg.copy(%4, %5) - -// CHECK-LABEL: func @loop_alloc -func @loop_alloc(%arg0: index, %arg1: index, %arg2: index, %arg3: memref<2xf32>, %arg4: memref<2xf32>) { - // CHECK: %{{.*}} = memref.alloc() - %0 = memref.alloc() : memref<2xf32> - memref.dealloc %0 : memref<2xf32> - // CHECK: %{{.*}} = memref.alloc() - %1 = memref.alloc() : memref<2xf32> - // CHECK: linalg.copy - linalg.copy(%arg3, %1) : memref<2xf32>, memref<2xf32> - %2 = scf.for %arg5 = %arg0 to %arg1 step %arg2 iter_args(%arg6 = %1) -> (memref<2xf32>) { - %3 = cmpi eq, %arg5, %arg1 : index - // CHECK: memref.dealloc - memref.dealloc %arg6 : memref<2xf32> - // CHECK: %[[PERCENT4:.*]] = memref.alloc() - %4 = memref.alloc() : memref<2xf32> - // CHECK-NOT: memref.alloc - // CHECK-NOT: linalg.copy - // CHECK-NOT: memref.dealloc - %5 = memref.alloc() : memref<2xf32> - linalg.copy(%4, %5) : memref<2xf32>, memref<2xf32> - memref.dealloc %4 : memref<2xf32> - // CHECK: %[[PERCENT6:.*]] = memref.alloc() - %6 = memref.alloc() : memref<2xf32> - // CHECK: linalg.copy(%[[PERCENT4]], %[[PERCENT6]]) - linalg.copy(%5, %6) : memref<2xf32>, memref<2xf32> - scf.yield %6 : memref<2xf32> - } - // CHECK: linalg.copy - linalg.copy(%2, %arg4) : memref<2xf32>, memref<2xf32> - memref.dealloc %2 : memref<2xf32> - return -} - -// ----- - -// The linalg.copy operation can be removed in addition to alloc and dealloc -// operations. All uses of %0 is then replaced with %arg2. - -// CHECK-LABEL: func @check_with_affine_dialect -func @check_with_affine_dialect(%arg0: memref<4xf32>, %arg1: memref<4xf32>, %arg2: memref<4xf32>) { - // CHECK-SAME: (%[[ARG0:.*]]: memref<4xf32>, %[[ARG1:.*]]: memref<4xf32>, %[[RES:.*]]: memref<4xf32>) - // CHECK-NOT: memref.alloc - %0 = memref.alloc() : memref<4xf32> - affine.for %arg3 = 0 to 4 { - %5 = affine.load %arg0[%arg3] : memref<4xf32> - %6 = affine.load %arg1[%arg3] : memref<4xf32> - %7 = cmpf ogt, %5, %6 : f32 - // CHECK: %[[SELECT_RES:.*]] = select - %8 = select %7, %5, %6 : f32 - // CHECK-NEXT: affine.store %[[SELECT_RES]], %[[RES]] - affine.store %8, %0[%arg3] : memref<4xf32> - } - // CHECK-NOT: linalg.copy - // CHECK-NOT: dealloc - linalg.copy(%0, %arg2) : memref<4xf32>, memref<4xf32> - memref.dealloc %0 : memref<4xf32> - //CHECK: return - return -}