diff --git a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp @@ -24,8 +24,7 @@ using namespace mlir; namespace { -class BufferizeCastOp : public OpConversionPattern { -public: +struct BufferizeCastOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor, @@ -36,11 +35,8 @@ return success(); } }; -} // namespace -namespace { -class BufferizeDimOp : public OpConversionPattern { -public: +struct BufferizeDimOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor, @@ -50,11 +46,8 @@ return success(); } }; -} // namespace -namespace { -class BufferizeExtractOp : public OpConversionPattern { -public: +struct BufferizeExtractOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(tensor::ExtractOp op, OpAdaptor adaptor, @@ -64,10 +57,8 @@ return success(); } }; -} // namespace -namespace { -class BufferizeFromElementsOp +struct BufferizeFromElementsOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -88,11 +79,8 @@ return success(); } }; -} // namespace -namespace { -class BufferizeGenerateOp : public OpConversionPattern { -public: +struct BufferizeGenerateOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult @@ -150,44 +138,51 @@ return success(); } }; -} // namespace -void mlir::populateTensorBufferizePatterns( - bufferization::BufferizeTypeConverter &typeConverter, - RewritePatternSet &patterns) { - patterns.add( - typeConverter, patterns.getContext()); -} +struct BufferizeRankOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(tensor::RankOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, op.getType(), + adaptor.tensor()); + return success(); + } +}; -namespace { struct TensorBufferizePass : public TensorBufferizeBase { void runOnFunction() override { auto *context = &getContext(); bufferization::BufferizeTypeConverter typeConverter; - RewritePatternSet patterns(context); - ConversionTarget target(*context); - - bufferization::populateBufferizeMaterializationLegality(target); - populateTensorBufferizePatterns(typeConverter, patterns); - target.addIllegalOp(); - target.addLegalDialect(); + ConversionTarget target(*context); + target.addLegalDialect(); target.addDynamicallyLegalDialect( [&](Operation *op) { return typeConverter.isLegal(op); }); - target.addLegalOp(); - target.addLegalOp(); - target.addLegalDialect(); + target.addLegalOp(); + target.addIllegalOp(); + bufferization::populateBufferizeMaterializationLegality(target); + RewritePatternSet patterns(context); + populateTensorBufferizePatterns(typeConverter, patterns); if (failed( applyPartialConversion(getFunction(), target, std::move(patterns)))) signalPassFailure(); } }; + } // namespace +void mlir::populateTensorBufferizePatterns( + bufferization::BufferizeTypeConverter &typeConverter, + RewritePatternSet &patterns) { + patterns.add( + typeConverter, patterns.getContext()); +} + std::unique_ptr mlir::createTensorBufferizePass() { return std::make_unique(); } diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir --- a/mlir/test/Dialect/Tensor/bufferize.mlir +++ b/mlir/test/Dialect/Tensor/bufferize.mlir @@ -11,6 +11,15 @@ return %0 : index } +// CHECK-LABEL: func @rank( +// CHECK-SAME: %[[TENSOR:.*]]: tensor<*xf32>) -> index { +// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] +// CHECK: %[[EXTENT:.*]] = memref.rank %[[MEMREF]] : memref<*xf32> +func @rank(%arg0: tensor<*xf32>) -> index { + %0 = tensor.rank %arg0 : tensor<*xf32> + return %0 : index +} + // CHECK-LABEL: func @tensor.cast( // CHECK-SAME: %[[TENSOR:.*]]: tensor) -> tensor<2xindex> { // CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]]