diff --git a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp @@ -59,6 +59,19 @@ } }; +struct BufferizeInsertOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(tensor::InsertOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + rewriter.create(op.getLoc(), adaptor.scalar(), + adaptor.dest(), adaptor.indices()); + rewriter.replaceOp(op, {adaptor.dest()}); + return success(); + } +}; + struct BufferizeFromElementsOp : public OpConversionPattern { public: @@ -208,7 +221,7 @@ StandardOpsDialect>( [&](Operation *op) { return typeConverter.isLegal(op); }); target.addLegalOp(); - target.addIllegalOp(); bufferization::populateBufferizeMaterializationLegality(target); @@ -226,8 +239,8 @@ bufferization::BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) { patterns.add( - typeConverter, patterns.getContext()); + BufferizeInsertOp, BufferizeFromElementsOp, BufferizeGenerateOp, + BufferizeRankOp>(typeConverter, patterns.getContext()); } std::unique_ptr mlir::createTensorBufferizePass() { diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir --- a/mlir/test/Dialect/Tensor/bufferize.mlir +++ b/mlir/test/Dialect/Tensor/bufferize.mlir @@ -242,3 +242,17 @@ } : tensor return %tensor : tensor } + +// CHECK-LABEL: func @tensor.insert( +// CHECK-SAME: %[[TENSOR:.*]]: tensor, +// CHECK-SAME: %[[ELEM_VAL:.*]]: f32, +// CHECK-SAME: %[[IDX:.*]]: index) -> tensor { +// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : memref +// CHECK: memref.store %[[ELEM_VAL]], %[[MEMREF]][%[[IDX]]] : memref +// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]] : memref +// CHECK: return %[[RET]] : tensor +// CHECK: } +func @tensor.insert(%arg0: tensor, %arg1: f32, %arg2: index) -> tensor { + %0 = tensor.insert %arg1 into %arg0[%arg2] : tensor + return %0 : tensor +}