diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp @@ -230,6 +230,49 @@ } }; +struct InsertOpInterface + : public BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand) const { + return true; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand) const { + return true; + } + + OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand) const { + assert(&opOperand == &op->getOpOperand(1) /*dest*/ && + "expected dest OpOperand"); + return op->getOpResult(0); + } + + SmallVector getAliasingOpOperand(Operation *op, + OpResult opResult) const { + return {&op->getOpOperand(1) /*dest*/}; + } + + LogicalResult bufferize(Operation *op, OpBuilder &b, + BufferizationState &state) const { + auto insertOp = cast(op); + + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(insertOp); + + Location loc = insertOp.getLoc(); + Value destMemref = getResultBuffer(b, insertOp->getOpResult(0), state); + b.create(loc, insertOp.scalar(), destMemref, + insertOp.indices()); + state.mapBuffer(insertOp, destMemref); + return success(); + } + + BufferRelation bufferRelation(Operation *op, OpOperand &opOperand) const { + return BufferRelation::Equivalent; + } +}; + /// Return true if the (ExtractSliceOp, InsertSliceOp) pair match (i.e. /// equivalent operand / result and same offset/sizes/strides specification). /// @@ -459,6 +502,7 @@ registry.addOpInterface(); registry.addOpInterface(); + registry.addOpInterface(); registry.addOpInterface(); } diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir @@ -945,3 +945,15 @@ func @empty_func() -> () { return } + +// ----- + +// CHECK-LABEL: func @insert_op +// CHECK-SAME: %[[t1:.*]]: memref, %[[s:.*]]: f32, %[[i:.*]]: index +func @insert_op(%t1 : tensor {linalg.inplaceable = true}, + %s : f32, %i : index) -> tensor { + // CHECK: memref.store %[[s]], %[[t1]][%[[i]]] + %0 = tensor.insert %s into %t1[%i] : tensor + // CHECK: return + return %0 : tensor +}