diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h @@ -37,6 +37,7 @@ std::unique_ptr createTosaMakeBroadcastablePass(); std::unique_ptr createTosaTestQuantUtilAPIPass(); std::unique_ptr createTosaOptionalDecompositions(); +std::unique_ptr createTosaValidationPass(); #define GEN_PASS_REGISTRATION #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td @@ -63,4 +63,17 @@ let constructor = "tosa::createTosaOptionalDecompositions()"; } +def TosaValidation : Pass<"tosa-validate", "func::FuncOp"> { + let summary = "Validates TOSA dialect"; + let description = [{ + This pass validates if input TOSA operations match the spec for given criterias, + e.g. profile. + }]; + let constructor = "createTosaValidationPass()"; + let options = [ + Option<"profileName", "profile", "std::string", + /*default=*/"\"undefined\"", + "Validation if ops match for given profile">]; +} + #endif // MLIR_DIALECT_TOSA_TRANSFORMS_PASSES diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp @@ -84,5 +84,6 @@ // TODO: Remove pass that operates on const tensor and enable optionality pm.addNestedPass(tosa::createTosaLayerwiseConstantFoldPass()); pm.addNestedPass(tosa::createTosaMakeBroadcastablePass()); + pm.addNestedPass(tosa::createTosaValidationPass()); pm.addNestedPass(tosa::createTosaToLinalg()); } diff --git a/mlir/lib/Conversion/TosaToSCF/TosaToSCFPass.cpp b/mlir/lib/Conversion/TosaToSCF/TosaToSCFPass.cpp --- a/mlir/lib/Conversion/TosaToSCF/TosaToSCFPass.cpp +++ b/mlir/lib/Conversion/TosaToSCF/TosaToSCFPass.cpp @@ -53,5 +53,6 @@ } void mlir::tosa::addTosaToSCFPasses(OpPassManager &pm) { + pm.addNestedPass(createTosaValidationPass()); pm.addNestedPass(createTosaToSCF()); } diff --git a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt @@ -7,6 +7,7 @@ TosaLayerwiseConstantFoldPass.cpp TosaMakeBroadcastable.cpp TosaOptionalDecompositions.cpp + TosaValidation.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa/Transforms diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -0,0 +1,91 @@ +//===- TosaValidation.cpp ------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Validate if TOSA dialect input matchs with the specification for given +// requirements. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/Transforms/Passes.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +namespace tosa { +#define GEN_PASS_DEF_TOSAVALIDATION +#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" +} // namespace tosa +} // namespace mlir + +using namespace mlir; +using namespace mlir::tosa; + +namespace { + +//===----------------------------------------------------------------------===// +// TOSA Validation Pass. +//===----------------------------------------------------------------------===// + +/// The type of profiles supported by TOSA. +enum class TosaProfileType { + BaseInference = 0, + MainInference = 1, + MainTraining = 2, + Undefined = 3, +}; + +struct TosaValidation : public tosa::impl::TosaValidationBase { +public: + void runOnOperation() override { + profile_type = llvm::StringSwitch(profileName) + .Case("bi", TosaProfileType::BaseInference) + .Case("mi", TosaProfileType::MainInference) + .Case("mt", TosaProfileType::MainTraining) + .Default(TosaProfileType::Undefined); + + traverseOperation(getOperation()); + } + + void traverseOperation(Operation *op) { + for (Value operand : op->getOperands()) { + // Operand must be integer type in the base inference profile. + if ((profile_type == TosaProfileType::BaseInference) && + getElementTypeOrSelf(operand).isa()) { + return signalPassFailure(); + } + } + + // Recurse into each of the regions attached to the operation. + for (Region ®ion : op->getRegions()) + traverseRegion(region); + } + + void traverseRegion(Region ®ion) { + for (Block &block : region.getBlocks()) + traverseBlock(block); + } + + void traverseBlock(Block &block) { + for (Operation &op : block.getOperations()) + traverseOperation(&op); + } + + TosaProfileType profile_type; +}; +} // namespace + +std::unique_ptr mlir::tosa::createTosaValidationPass() { + return std::make_unique(); +}