diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp @@ -86,15 +86,15 @@ // Set insertion point now that potential alloc/dealloc are introduced. b.setInsertionPoint(op); - op.clone(b, loc, /*resultTypes=*/TypeRange{}, newOperands); + auto bufferizedOp = cast( + op.clone(b, loc, /*resultTypes=*/TypeRange{}, newOperands)); // Replace the results of the old op with the new output buffers. if (op->getNumResults()) state.mapBuffer(op->getResults(), newOutputBuffers); // The original op will be DCE'd away later. - - return success(); + return comprehensive_bufferize::bufferize(bufferizedOp.getBlock(), state); } template diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir @@ -1008,3 +1008,31 @@ func @empty_func() -> () { return } + +// ----- + +func @gather_like(%arg0 : tensor, %arg1 : tensor, + %arg2 : tensor {linalg.inplaceable = true}) -> tensor { + %0 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%arg1 : tensor) outs(%arg2 : tensor) { + ^bb0(%arg3: i32, %arg4 : f32): + %iv1 = linalg.index 1 : index + %1 = arith.index_cast %arg3: i32 to index + %2 = tensor.extract %arg0[%1, %iv1] : tensor + linalg.yield %2 : f32 + } -> tensor + return %0 : tensor +} +// CHECK-LABEL: func @gather_like( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref