diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -752,6 +752,13 @@ operations in the Standard dialect. }]; + let options = [ + Option<"includeApplyRescale", "include-apply-rescale", + "bool", /*default=*/"false", + "Whether to include the lowering for tosa.apply_rescale to standard" + " ops"> + ]; + let constructor = "tosa::createTosaToStandard()"; } diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp @@ -12,6 +12,7 @@ #include "../PassDetail.h" #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h" +#include "mlir/Conversion/TosaToStandard/TosaToStandard.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Math/IR/Math.h" diff --git a/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp b/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp --- a/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp +++ b/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp @@ -173,8 +173,7 @@ void mlir::tosa::populateTosaToStandardConversionPatterns( RewritePatternSet *patterns) { - patterns->add( - patterns->getContext()); + patterns->add(patterns->getContext()); } void mlir::tosa::populateTosaRescaleToStandardConversionPatterns( diff --git a/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp b/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp --- a/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp +++ b/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp @@ -33,11 +33,16 @@ ConversionTarget target(getContext()); target.addIllegalOp(); target.addIllegalOp(); - target.addIllegalOp(); target.addLegalDialect(); target.addLegalDialect(); mlir::tosa::populateTosaToStandardConversionPatterns(&patterns); + + if (this->includeApplyRescale) { + mlir::tosa::populateTosaRescaleToStandardConversionPatterns(&patterns); + target.addIllegalOp(); + } + if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) signalPassFailure(); diff --git a/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir b/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir --- a/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir +++ b/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt --split-input-file --tosa-to-standard %s -verify-diagnostics -o -| FileCheck %s +// RUN: mlir-opt --split-input-file --tosa-to-standard="include-apply-rescale=true" %s -verify-diagnostics -o -| FileCheck %s // CHECK-LABEL: func @const_test func @const_test() -> (tensor) {