diff --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h --- a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h +++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h @@ -14,6 +14,7 @@ #ifndef MLIR_CONVERSION_TOSATOLINALG_TOSATOLINALG_H #define MLIR_CONVERSION_TOSATOLINALG_TOSATOLINALG_H +#include "mlir/Dialect/Tosa/Transforms/Passes.h" #include "mlir/Pass/Pass.h" namespace mlir { @@ -31,8 +32,11 @@ /// the pass, the function will only contain linalg ops or standard ops if the /// pipeline succeeds. The option to disable decompositions is available for /// benchmarking performance improvements from the canonicalizations. -void addTosaToLinalgPasses(OpPassManager &pm, - bool disableTosaDecompositions = false); +void addTosaToLinalgPasses( + OpPassManager &pm, bool disableTosaDecompositions = false, + // Note: Default to 'none' level unless otherwise specified. + tosa::ValidationOptions const &validationOptions = + tosa::ValidationOptions().setLevel(tosa::TosaLevelEnum::None)); /// Populates conversion passes from TOSA dialect to Linalg dialect. void populateTosaToLinalgConversionPatterns(RewritePatternSet *patterns); diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h @@ -40,7 +40,31 @@ std::unique_ptr createTosaMakeBroadcastablePass(); std::unique_ptr createTosaTestQuantUtilAPIPass(); std::unique_ptr createTosaOptionalDecompositions(); -std::unique_ptr createTosaValidationPass(); + +struct ValidationOptions { + /// Validate if operations match for the given profile. + TosaProfileEnum profile = TosaProfileEnum::Undefined; + ValidationOptions &setProfile(TosaProfileEnum profile) { + this->profile = profile; + return *this; + } + /// Verify if the properties of certain operations align the spec requirement. + bool strictOperationSpecAlignment = false; + ValidationOptions &enableStrictOperationSpecAlignment(bool enable = true) { + strictOperationSpecAlignment = enable; + return *this; + } + /// Validate if operator parameters are within specfication for the given + /// level. + TosaLevelEnum level = TosaLevelEnum::EightK; + ValidationOptions &setLevel(TosaLevelEnum level) { + this->level = level; + return *this; + } +}; + +std::unique_ptr createTosaValidationPass( + ValidationOptions const &options = ValidationOptions()); #define GEN_PASS_REGISTRATION #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td @@ -91,15 +91,32 @@ let constructor = "createTosaValidationPass()"; let options = [ - Option<"profileName", "profile", "std::string", - /*default=*/"\"undefined\"", - "Validate if operations match for the given profile">, + Option<"profile", "profile", "mlir::tosa::TosaProfileEnum", + /*default=*/"mlir::tosa::TosaProfileEnum::Undefined", + "Validate if operations match for the given profile", + [{::llvm::cl::values( + clEnumValN(mlir::tosa::TosaProfileEnum::BaseInference, "bi", + "Use Base Inference profile."), + clEnumValN(mlir::tosa::TosaProfileEnum::MainInference, "mi", + "Use Main Inference profile."), + clEnumValN(mlir::tosa::TosaProfileEnum::MainTraining, "mt", + "Use Main Training profile."), + clEnumValN(mlir::tosa::TosaProfileEnum::MainTraining, "undefined", + "Do not define a profile.") + )}]>, Option<"StrictOperationSpecAlignment", "strict-op-spec-alignment", "bool", /*default=*/"false", "Verify if the properties of certain operations align the spec requirement">, - Option<"levelName", "level", "std::string", - /*default=*/"\"8k\"", - "Validate if operator parameters are within specfication for the given level">, + Option<"level", "level", "mlir::tosa::TosaLevelEnum", + /*default=*/"mlir::tosa::TosaLevelEnum::EightK", + "Validate if operator parameters are within specfication for the given level", + [{::llvm::cl::values( + clEnumValN(mlir::tosa::TosaLevelEnum::EightK, "8k", + "Ranges are expected to be sufficient for applications with frame sizes up to 8K."), + clEnumValN(mlir::tosa::TosaLevelEnum::None, "none", + "Allows the full range of arguments specified by the operations according " + "to the operation data types.") + )}]> ]; } diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp @@ -74,8 +74,9 @@ return std::make_unique(); } -void mlir::tosa::addTosaToLinalgPasses(OpPassManager &pm, - bool disableTosaDecompositions) { +void mlir::tosa::addTosaToLinalgPasses( + OpPassManager &pm, bool disableTosaDecompositions, + tosa::ValidationOptions const &validationOptions) { // Optional decompositions are designed to benefit linalg. if (!disableTosaDecompositions) pm.addNestedPass(tosa::createTosaOptionalDecompositions()); @@ -88,6 +89,7 @@ // TODO: Remove pass that operates on const tensor and enable optionality pm.addNestedPass(tosa::createTosaLayerwiseConstantFoldPass()); pm.addNestedPass(tosa::createTosaMakeBroadcastablePass()); - pm.addNestedPass(tosa::createTosaValidationPass()); + pm.addNestedPass( + tosa::createTosaValidationPass(validationOptions)); pm.addNestedPass(tosa::createTosaToLinalg()); } diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp --- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -96,6 +96,11 @@ struct TosaValidation : public tosa::impl::TosaValidationBase { public: explicit TosaValidation() { populateConstantOperandChecks(); } + explicit TosaValidation(const ValidationOptions &options) : TosaValidation() { + this->profile = options.profile; + this->StrictOperationSpecAlignment = options.strictOperationSpecAlignment; + this->level = options.level; + } void runOnOperation() override; LogicalResult applyConstantOperandCheck(Operation *op) { @@ -387,18 +392,13 @@ // configure profile and level values from pass options profileName and // levelName void configLevelAndProfile() { - profileType = symbolizeEnum(profileName); - - auto levelType = symbolizeEnum(levelName); - tosa_level = TOSA_LEVEL_NONE; - if (levelType == TosaLevelEnum::EightK) { + if (level == TosaLevelEnum::EightK) { tosa_level = TOSA_LEVEL_EIGHTK; } } SmallVector> const_checkers; - std::optional profileType; tosa_level_t tosa_level; }; @@ -431,7 +431,7 @@ configLevelAndProfile(); getOperation().walk([&](Operation *op) { for (Value operand : op->getOperands()) { - if ((profileType == TosaProfileEnum::BaseInference) && + if ((profile == TosaProfileEnum::BaseInference) && isa(getElementTypeOrSelf(operand))) { return signalPassFailure(); } @@ -451,6 +451,7 @@ } } // namespace -std::unique_ptr mlir::tosa::createTosaValidationPass() { - return std::make_unique(); +std::unique_ptr +mlir::tosa::createTosaValidationPass(ValidationOptions const &options) { + return std::make_unique(options); }