diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -2115,6 +2115,22 @@ return success(); llvm_unreachable("unexpected yieldOp"); } + +/// Bufferization for tensor::ExtractOp just translate to memref.load, it only +/// reads the tensor. +static LogicalResult bufferize(OpBuilder &b, tensor::ExtractOp extractOp, + BlockAndValueMapping &bvm, + BufferizationAliasInfo &aliasInfo) { + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(extractOp); + + Location loc = extractOp.getLoc(); + Value srcMemref = lookup(bvm, extractOp.tensor()); + Value l = b.create(loc, srcMemref, extractOp.indices()); + extractOp.replaceAllUsesWith(l); + return success(); +} //===----------------------------------------------------------------------===// // Bufferization analyses. //===----------------------------------------------------------------------===// @@ -2310,6 +2326,7 @@ scf::ForOp, InitTensorOp, InsertSliceOp, + tensor::ExtractOp, LinalgOp, ReturnOp, TiledLoopOp, diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir @@ -34,6 +34,19 @@ // ----- +// CHECK-LABEL: func @tensor_extract(%{{.*}}: memref) -> f32 { +func @tensor_extract(%A : tensor) -> (f32) { + %c0 = constant 0 : index + +// CHECK: %[[RES:.*]] = memref.load {{.*}} : memref + %0 = tensor.extract %A[%c0] : tensor + +// CHECK: return %[[RES]] : f32 + return %0 : f32 +} + +// ----- + // CHECK-DAG: #[[$map_1d_dyn:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> /// No linalg.inplaceable flag, must allocate.