diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt --- a/mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt @@ -1,5 +1,7 @@ set(LLVM_TARGET_DEFINITIONS Passes.td) mlir_tablegen(Passes.h.inc -gen-pass-decls -name TosaOpt) +mlir_tablegen(PassesEnums.h.inc -gen-enum-decls) +mlir_tablegen(PassesEnums.cpp.inc -gen-enum-defs) add_public_tablegen_target(MLIRTosaPassIncGen) add_dependencies(mlir-headers MLIRTosaPassIncGen) diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h @@ -14,6 +14,7 @@ #define MLIR_DIALECT_TOSA_TRANSFORMS_PASSES_H #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tosa/Transforms/PassesEnums.h.inc" #include "mlir/Pass/Pass.h" namespace mlir { @@ -37,6 +38,7 @@ std::unique_ptr createTosaMakeBroadcastablePass(); std::unique_ptr createTosaTestQuantUtilAPIPass(); std::unique_ptr createTosaOptionalDecompositions(); +std::unique_ptr createTosaValidationPass(); #define GEN_PASS_REGISTRATION #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td @@ -13,6 +13,7 @@ #ifndef MLIR_DIALECT_TOSA_TRANSFORMS_PASSES #define MLIR_DIALECT_TOSA_TRANSFORMS_PASSES +include "mlir/IR/EnumAttr.td" include "mlir/Pass/PassBase.td" def TosaLayerwiseConstantFoldPass : Pass<"tosa-layerwise-constant-fold", "func::FuncOp"> { @@ -63,4 +64,28 @@ let constructor = "tosa::createTosaOptionalDecompositions()"; } +def TosaProfileType : I32EnumAttr<"TosaProfileEnum", "Tosa profile", + [ + I32EnumAttrCase<"BaseInference", 0, "bi">, + I32EnumAttrCase<"MainInference", 1, "mi">, + I32EnumAttrCase<"MainTraining", 2, "mt">, + I32EnumAttrCase<"Undefined", 3> + ]>{ + let cppNamespace = "mlir::tosa"; +} + +def TosaValidation : Pass<"tosa-validate", "func::FuncOp"> { + let summary = "Validates TOSA dialect"; + let description = [{ + This pass validates if input TOSA operations match the specification for given + criteria, e.g. TOSA profile. + }]; + let constructor = "createTosaValidationPass()"; + + let options = [ + Option<"profileName", "profile", "std::string", + /*default=*/"\"undefined\"", + "Validation if ops match for given profile">]; +} + #endif // MLIR_DIALECT_TOSA_TRANSFORMS_PASSES diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp @@ -84,5 +84,6 @@ // TODO: Remove pass that operates on const tensor and enable optionality pm.addNestedPass(tosa::createTosaLayerwiseConstantFoldPass()); pm.addNestedPass(tosa::createTosaMakeBroadcastablePass()); + pm.addNestedPass(tosa::createTosaValidationPass()); pm.addNestedPass(tosa::createTosaToLinalg()); } diff --git a/mlir/lib/Conversion/TosaToSCF/TosaToSCFPass.cpp b/mlir/lib/Conversion/TosaToSCF/TosaToSCFPass.cpp --- a/mlir/lib/Conversion/TosaToSCF/TosaToSCFPass.cpp +++ b/mlir/lib/Conversion/TosaToSCF/TosaToSCFPass.cpp @@ -53,5 +53,6 @@ } void mlir::tosa::addTosaToSCFPasses(OpPassManager &pm) { + pm.addNestedPass(createTosaValidationPass()); pm.addNestedPass(createTosaToSCF()); } diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt @@ -7,6 +7,7 @@ TosaLayerwiseConstantFoldPass.cpp TosaMakeBroadcastable.cpp TosaOptionalDecompositions.cpp + TosaValidation.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa/Transforms diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -0,0 +1,68 @@ +//===- TosaValidation.cpp ------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Validate if TOSA dialect input matchs with the specification for given +// requirements. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/Dialect/Tosa/Transforms/PassesEnums.cpp.inc" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +namespace tosa { +#define GEN_PASS_DEF_TOSAVALIDATION +#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" +} // namespace tosa +} // namespace mlir + +using namespace mlir; +using namespace mlir::tosa; + +namespace { + +//===----------------------------------------------------------------------===// +// TOSA Validation Pass. +//===----------------------------------------------------------------------===// + +struct TosaValidation : public tosa::impl::TosaValidationBase { +public: + explicit TosaValidation() {} + +private: + void runOnOperation() override; + + llvm::Optional profile_type; +}; + +void TosaValidation::runOnOperation() { + profile_type = symbolizeEnum(profileName); + + getOperation().walk([&](Operation *op) { + for (Value operand : op->getOperands()) { + if ((profile_type == TosaProfileEnum::BaseInference) && + getElementTypeOrSelf(operand).isa()) { + return signalPassFailure(); + } + } + }); +} +} // namespace + +std::unique_ptr mlir::tosa::createTosaValidationPass() { + return std::make_unique(); +}