diff --git a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp @@ -18,6 +18,21 @@ using namespace mlir; +namespace { +class BufferizeExtractElementOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(ExtractElementOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + ExtractElementOp::Adaptor adaptor(operands); + rewriter.replaceOpWithNewOp(op, adaptor.aggregate(), + adaptor.indices()); + return success(); + } +}; +} // namespace + namespace { class BufferizeTensorCastOp : public OpConversionPattern { public: @@ -32,10 +47,34 @@ }; } // namespace +namespace { +class BufferizeTensorFromElementsOp + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(TensorFromElementsOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + int numberOfElements = op.elements().size(); + auto resultType = MemRefType::get( + {numberOfElements}, op.getType().cast().getElementType()); + Value result = rewriter.create(op.getLoc(), resultType); + for (auto element : llvm::enumerate(op.elements())) { + Value index = + rewriter.create(op.getLoc(), element.index()); + rewriter.create(op.getLoc(), element.value(), result, index); + } + rewriter.replaceOp(op, {result}); + return success(); + } +}; +} // namespace + void mlir::populateStdBufferizePatterns(MLIRContext *context, BufferizeTypeConverter &typeConverter, OwningRewritePatternList &patterns) { - patterns.insert(typeConverter, context); + patterns.insert(typeConverter, context); } namespace { @@ -49,9 +88,9 @@ target.addLegalDialect(); populateStdBufferizePatterns(context, typeConverter, patterns); - target.addIllegalOp(); + target.addIllegalOp(); - if (failed(mlir::applyPartialConversion(getFunction(), target, patterns))) + if (failed(applyPartialConversion(getFunction(), target, patterns))) signalPassFailure(); } }; diff --git a/mlir/test/Dialect/Standard/bufferize.mlir b/mlir/test/Dialect/Standard/bufferize.mlir --- a/mlir/test/Dialect/Standard/bufferize.mlir +++ b/mlir/test/Dialect/Standard/bufferize.mlir @@ -1,5 +1,17 @@ // RUN: mlir-opt %s -std-bufferize | FileCheck %s +// CHECK-LABEL: func @extract_element( +// CHECK-SAME: %[[TENSOR:.*]]: tensor, +// CHECK-SAME: %[[IDX:.*]]: index) -> f32 { +// CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] : memref +// CHECK: %[[RET:.*]] = load %[[MEMREF]][%[[IDX]]] : memref +// CHECK: return %[[RET]] : f32 +// CHECK: } +func @extract_element(%arg0: tensor, %arg1: index) -> f32 { + %0 = extract_element %arg0[%arg1] : tensor + return %0 : f32 +} + // CHECK-LABEL: func @tensor_cast( // CHECK-SAME: %[[TENSOR:.*]]: tensor) -> tensor<2xindex> { // CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]] @@ -10,3 +22,18 @@ %0 = tensor_cast %arg0 : tensor to tensor<2xindex> return %0 : tensor<2xindex> } + +// CHECK-LABEL: func @tensor_from_elements( +// CHECK-SAME: %[[ELEM0:.*]]: index, +// CHECK-SAME: %[[ELEM1:.*]]: index) -> tensor<2xindex> { +// CHECK: %[[MEMREF:.*]] = alloc() +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C0]]] +// CHECK: %[[C1:.*]] = constant 1 : index +// CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C1]]] +// CHECK: %[[RET:.*]] = tensor_load %[[MEMREF]] +// CHECK: return %[[RET]] : tensor<2xindex> +func @tensor_from_elements(%arg0: index, %arg1: index) -> tensor<2xindex> { + %0 = tensor_from_elements %arg0, %arg1 : tensor<2xindex> + return %0 : tensor<2xindex> +}