diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -466,7 +466,6 @@ // Compute allocation memref type. assert(shapedValue.getType().isa()); - MemRefType memRefType = shapedValue.getType().dyn_cast(); SmallVector dynShape; MemRefType allocMemRefType = getAllocationTypeAndShape(b, loc, shapedValue, dynShape); @@ -485,17 +484,7 @@ } // Create the buffer allocation. - Value alloc = - createBufferAllocation(b, loc, allocMemRefType, dynShape, skipDealloc); - - // Insert a cast if a different type was requested. - if (memRefType && memRefType != allocMemRefType) { - assert(memref::CastOp::areCastCompatible(allocMemRefType, memRefType) && - "createAlloc: cast incompatible"); - alloc = b.create(loc, memRefType, alloc); - } - - return alloc; + return createBufferAllocation(b, loc, allocMemRefType, dynShape, skipDealloc); } /// Create a memory copy between two memref buffers. diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -411,7 +411,22 @@ *yieldedAlloc, state.getOptions()); (void)copyStatus; assert(succeeded(copyStatus) && "could not create memcpy"); - return *yieldedAlloc; + + if (yieldedVal.getType() == yieldedAlloc->getType()) + return *yieldedAlloc; + + // The iter_arg memref type has a layout map. Cast the new buffer to + // the same type. + // TODO: In case the iter_arg has a layout map that is not the fully + // dynamic one, we cannot cast the new buffer. In that case, the + // iter_arg must be changed to the fully dynamic layout map. (And then + // the new buffer can be casted.) + assert(memref::CastOp::areCastCompatible(yieldedAlloc->getType(), + yieldedVal.getType()) && + "scf.for op bufferization: cast incompatible"); + Value casted = rewriter.create( + val.getLoc(), yieldedVal.getType(), *yieldedAlloc); + return casted; }); yieldOp.getResultsMutable().assign(yieldValues); diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-partial.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-partial.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-partial.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-partial.mlir @@ -193,12 +193,11 @@ // CHECK-TENSOR: %[[t1_memref:.*]] = bufferization.to_memref %[[t1]] %c0 = arith.constant 0 : index // CHECK-TENSOR: %[[alloc:.*]] = memref.alloc - // CHECK-TENSOR: %[[casted:.*]] = memref.cast %[[alloc]] - // CHECK-TENSOR: %[[casted_tensor:.*]] = bufferization.to_tensor %[[casted]] + // CHECK-TENSOR: %[[casted_alloc:.*]] = bufferization.to_tensor %[[alloc]] // CHECK-TENSOR: memref.copy %[[t1_memref]], %[[alloc]] // CHECK-TENSOR: memref.store %{{.*}}, %[[alloc]] %0 = tensor.insert %f into %t1[%c0] : tensor - // CHECK-TENSOR: return %[[casted_tensor]] + // CHECK-TENSOR: return %[[casted_alloc]] return %0 : tensor } diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir @@ -29,10 +29,9 @@ // CHECK: %[[A_memref:.*]] = bufferization.to_memref %[[A]] // CHECK: %[[dim:.*]] = tensor.dim %[[A]] // CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]]) - // CHECK: %[[casted:.*]] = memref.cast %[[alloc]] // CHECK: memref.copy %[[A_memref]], %[[alloc]] // CHECK: vector.transfer_write %{{.*}}, %[[alloc]] - // CHECK: %[[res_tensor:.*]] = bufferization.to_tensor %[[casted]] + // CHECK: %[[res_tensor:.*]] = bufferization.to_tensor %[[alloc]] %0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor // CHECK: memref.dealloc %[[alloc]] diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir @@ -308,16 +308,16 @@ // CHECK-DAG: #[[$map_1d_dyn:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> -// CHECK-LABEL: func @scf_for_yield_only -// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref +// CHECK-LABEL: func @scf_for_yield_only( +// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref, // CHECK-SAME: %[[t:[a-zA-Z0-9]*]]: memref +// CHECK-SAME: ) -> memref { func @scf_for_yield_only(%A : tensor {linalg.inplaceable = false}, %B : tensor {linalg.inplaceable = true}, %lb : index, %ub : index, %step : index) -> (tensor, tensor) { // CHECK: %[[ALLOC_FOR_A:.*]] = memref.alloc - // CHECK: %[[CASTED:.*]] = memref.cast %[[ALLOC_FOR_A]] // CHECK: memref.copy %[[A]], %[[ALLOC_FOR_A]] // The first scf.for remains but just turns into dead code. @@ -330,7 +330,7 @@ scf.yield %t : tensor } - // CHECK: return %[[CASTED]] : memref + // CHECK: return %[[ALLOC_FOR_A]] : memref // CHECK-NOT: dealloc return %r0, %r1: tensor, tensor } @@ -373,7 +373,6 @@ -> (tensor, tensor) { // CHECK: %[[ALLOC_FOR_A:.*]] = memref.alloc - // CHECK: %[[CASTED:.*]] = memref.cast %[[ALLOC_FOR_A]] // CHECK: memref.copy %[[A]], %[[ALLOC_FOR_A]] // CHECK: %[[svA:.*]] = memref.subview %[[ALLOC_FOR_A]][0] [4] [1] @@ -396,7 +395,7 @@ scf.yield %ttA, %ttB : tensor, tensor } - // CHECK: return %[[CASTED]] : memref + // CHECK: return %[[ALLOC_FOR_A]] : memref return %r0#0, %r0#1: tensor, tensor } @@ -418,7 +417,6 @@ // memref.store is left over. // CHECK: %[[alloc:.*]] = memref.alloc - // CHECK: %[[casted:.*]] = memref.cast %[[alloc]] // CHECK: memref.copy %[[m1]], %[[alloc]] // CHECK: memref.store %{{.*}}, %[[alloc]][%{{.*}}] %0, %1, %2 = scf.execute_region -> (f32, tensor, f32) { @@ -426,6 +424,7 @@ scf.yield %f1, %t2, %f1 : f32, tensor, f32 } + // CHECK: %[[casted:.*]] = memref.cast %[[alloc]] // CHECK: %[[load:.*]] = memref.load %[[m1]] %3 = tensor.extract %t1[%idx] : tensor @@ -783,8 +782,8 @@ %idx = arith.constant 0 : index // CHECK: %[[alloc:.*]] = memref.alloc - // CHECK: %[[casted:.*]] = memref.cast %[[alloc]] - // CHECK: memref.copy %[[t1]], %[[alloc]] + // CHECK-DAG: %[[casted:.*]] = memref.cast %[[alloc]] + // CHECK-DAG: memref.copy %[[t1]], %[[alloc]] // CHECK: %[[select:.*]] = arith.select %{{.*}}, %[[casted]], %[[t2]] %s = arith.select %c, %t1, %t2 : tensor @@ -859,13 +858,11 @@ // CHECK-LABEL: func @scf_for_yield_non_equivalent( // CHECK-SAME: %[[t:.*]]: memref, %lb : index, %ub : index, %step : index) -> tensor { diff --git a/mlir/test/Dialect/Linalg/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Linalg/one-shot-module-bufferize.mlir --- a/mlir/test/Dialect/Linalg/one-shot-module-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/one-shot-module-bufferize.mlir @@ -358,7 +358,7 @@ // %r0#0 requires a copy because we have no idea what the function is doing. // CHECK-DAG: %[[alloc:.*]] = memref.alloc // CHECK-DAG: %[[casted:.*]] = memref.cast %[[alloc]] -// CHECK: memref.copy %[[B]], %[[alloc]] +// CHECK-DAG: memref.copy %[[B]], %[[alloc]] // CHECK-NEXT: call @some_external_func(%[[casted]]) : (memref) -> () call @some_external_func(%r0#0) : (tensor) -> () @@ -475,15 +475,15 @@ // conflict. However, inside `entry`, the writes do cause a conflict because // %A, %B and %C are not inplaceable. This test case shows that this kind of // conflict detection has a "transitive" nature. -// CHECK: %[[ALLOC_C:.*]] = memref.alloc -// CHECK: %[[CASTED_C:.*]] = memref.cast %[[ALLOC_C]] -// CHECK: %[[ALLOC_B:.*]] = memref.alloc -// CHECK: %[[CASTED_B:.*]] = memref.cast %[[ALLOC_B]] -// CHECK: %[[ALLOC_A:.*]] = memref.alloc -// CHECK: memref.copy %[[A]], %[[ALLOC_A]] -// CHECK: memref.copy %[[B]], %[[ALLOC_B]] -// CHECK: memref.copy %[[C]], %[[ALLOC_C]] -// CHECK: %[[CASTED_A:.*]] = memref.cast %[[ALLOC_A]] +// CHECK-DAG: %[[ALLOC_C:.*]] = memref.alloc +// CHECK-DAG: %[[CASTED_C:.*]] = memref.cast %[[ALLOC_C]] +// CHECK-DAG: %[[ALLOC_B:.*]] = memref.alloc +// CHECK-DAG: %[[CASTED_B:.*]] = memref.cast %[[ALLOC_B]] +// CHECK-DAG: %[[ALLOC_A:.*]] = memref.alloc +// CHECK-DAG: %[[CASTED_A:.*]] = memref.cast %[[ALLOC_A]] +// CHECK-DAG: memref.copy %[[A]], %[[ALLOC_A]] +// CHECK-DAG: memref.copy %[[B]], %[[ALLOC_B]] +// CHECK-DAG: memref.copy %[[C]], %[[ALLOC_C]] // CHECK-NEXT: call @callee(%[[CASTED_A]], %[[CASTED_B]], %[[CASTED_C]]) call @callee(%A, %B, %C) : (tensor, tensor, tensor) -> () return @@ -539,8 +539,8 @@ // CHECK: scf.for {{.*}} { %1 = scf.for %iv = %c0 to %c10 step %c1 iter_args(%t1 = %t0) -> (tensor) { // CHECK: %[[alloc:.*]] = memref.alloc - // CHECK: %[[casted:.*]] = memref.cast %[[alloc]] - // CHECK: memref.copy %[[arg0]], %[[alloc]] + // CHECK-DAG: %[[casted:.*]] = memref.cast %[[alloc]] + // CHECK-DAG: memref.copy %[[arg0]], %[[alloc]] // CHECK: call @inner_func_2(%[[casted]]) // CHECK: memref.dealloc %[[alloc]] // CHECK-NOT: scf.yield