diff --git a/mlir/include/mlir/Conversion/TosaToArith/TosaToArith.h b/mlir/include/mlir/Conversion/TosaToArith/TosaToArith.h --- a/mlir/include/mlir/Conversion/TosaToArith/TosaToArith.h +++ b/mlir/include/mlir/Conversion/TosaToArith/TosaToArith.h @@ -22,7 +22,8 @@ namespace tosa { -std::unique_ptr createTosaToArith(); +std::unique_ptr createTosaToArith(bool includeApplyRescale = false, + bool use32BitApplyRescale = false); void populateTosaToArithConversionPatterns(RewritePatternSet *patterns); diff --git a/mlir/lib/Conversion/TosaToArith/TosaToArithPass.cpp b/mlir/lib/Conversion/TosaToArith/TosaToArithPass.cpp --- a/mlir/lib/Conversion/TosaToArith/TosaToArithPass.cpp +++ b/mlir/lib/Conversion/TosaToArith/TosaToArithPass.cpp @@ -31,6 +31,8 @@ namespace { struct TosaToArith : public impl::TosaToArithBase { public: + TosaToArith(TosaToArithOptions &options) : TosaToArithBase(options) {} + void runOnOperation() override { RewritePatternSet patterns(&getContext()); ConversionTarget target(getContext()); @@ -52,6 +54,8 @@ }; } // namespace -std::unique_ptr mlir::tosa::createTosaToArith() { - return std::make_unique(); +std::unique_ptr mlir::tosa::createTosaToArith(bool includeApplyRescale, + bool use32BitApplyRescale) { + TosaToArithOptions options = {includeApplyRescale, use32BitApplyRescale}; + return std::make_unique(options); }