diff --git a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp @@ -88,6 +88,24 @@ }; } // namespace +namespace { +class BufferizeSelectOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(SelectOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + if (!op.condition().getType().isa()) + return rewriter.notifyMatchFailure(op, "requires scalar condition"); + + SelectOp::Adaptor adaptor(operands); + rewriter.replaceOpWithNewOp( + op, adaptor.condition(), adaptor.true_value(), adaptor.false_value()); + return success(); + } +}; +} // namespace + namespace { class BufferizeTensorCastOp : public OpConversionPattern { public: @@ -128,10 +146,15 @@ void mlir::populateStdBufferizePatterns(MLIRContext *context, BufferizeTypeConverter &typeConverter, OwningRewritePatternList &patterns) { - patterns - .insert( - typeConverter, context); + patterns.insert< + // clang-format off + BufferizeDynamicTensorFromElementsOp, + BufferizeExtractElementOp, + BufferizeSelectOp, + BufferizeTensorCastOp, + BufferizeTensorFromElementsOp + // clang-format on + >(typeConverter, context); } namespace { @@ -148,6 +171,13 @@ populateStdBufferizePatterns(context, typeConverter, patterns); target.addIllegalOp(); + // We only bufferize the case of tensor selected type and scalar condition, + // as that boils down to a select over memref descriptors (don't need to + // touch the data). + target.addDynamicallyLegalOp([&](SelectOp op) { + return typeConverter.isLegal(op.getType()) || + !op.condition().getType().isa(); + }); if (failed( applyPartialConversion(getFunction(), target, std::move(patterns)))) signalPassFailure(); diff --git a/mlir/test/Dialect/Standard/bufferize.mlir b/mlir/test/Dialect/Standard/bufferize.mlir --- a/mlir/test/Dialect/Standard/bufferize.mlir +++ b/mlir/test/Dialect/Standard/bufferize.mlir @@ -61,6 +61,20 @@ return %0 : f32 } +// CHECK-LABEL: func @select( +// CHECK-SAME: %[[PRED:.*]]: i1, +// CHECK-SAME: %[[TRUE_VAL:.*]]: tensor, +// CHECK-SAME: %[[FALSE_VAL:.*]]: tensor) -> tensor { +// CHECK: %[[TRUE_VAL_MEMREF:.*]] = tensor_to_memref %[[TRUE_VAL]] : memref +// CHECK: %[[FALSE_VAL_MEMREF:.*]] = tensor_to_memref %[[FALSE_VAL]] : memref +// CHECK: %[[RET_MEMREF:.*]] = select %[[PRED]], %[[TRUE_VAL_MEMREF]], %[[FALSE_VAL_MEMREF]] : memref +// CHECK: %[[RET:.*]] = tensor_load %[[RET_MEMREF]] : memref +// CHECK: return %[[RET]] : tensor +func @select(%arg0: i1, %arg1: tensor, %arg2: tensor) -> tensor { + %0 = select %arg0, %arg1, %arg2 : tensor + return %0 : tensor +} + // CHECK-LABEL: func @tensor_cast( // CHECK-SAME: %[[TENSOR:.*]]: tensor) -> tensor<2xindex> { // CHECK: %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]]