diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td @@ -84,8 +84,12 @@ let options = [ Option<"profileName", "profile", "std::string", - /*default=*/"\"undefined\"", - "Validation if ops match for given profile">]; + /*default=*/"\"undefined\"", + "Validate if operations match for the given profile">, + Option<"StrictOperationSpecAlignment", "strict-op-spec-alignment", "bool", + /*default=*/"false", + "Verify if the properties of certain operations align the spec requirement">, + ]; } #endif // MLIR_DIALECT_TOSA_TRANSFORMS_PASSES diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp --- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -35,22 +35,88 @@ namespace { +// Perform checkings on operations that are strictly required by the +// specification. +class OperationValidator { +public: + explicit OperationValidator() { populateConstantOperandChecks(); } + + LogicalResult applyConstantOperandCheck(Operation *op) { + for (auto &pat : patterns) { + if (failed(pat(op))) + return failure(); + } + return success(); + } + +private: + static LogicalResult checkConstantOperandPad(Operation *op); + static LogicalResult checkConstantOperandTranspose(Operation *op); + static LogicalResult checkConstantOperandFullyConnected(Operation *op); + + void populateConstantOperandChecks() { + patterns.emplace_back(checkConstantOperandPad); + patterns.emplace_back(checkConstantOperandTranspose); + patterns.emplace_back(checkConstantOperandFullyConnected); + } + + SmallVector> patterns; +}; + +LogicalResult OperationValidator::checkConstantOperandPad(Operation *op) { + if (auto pad_op = dyn_cast(op)) { + DenseElementsAttr paddings; + if (!matchPattern(pad_op.getPadding(), m_Constant(&paddings))) + return op->emitOpError("padding of pad is not constant"); + + DenseElementsAttr pad_const; + // Assume this op is zero-padding if pad_const is not presented. + if (pad_op.getPadConst() && + !matchPattern(pad_op.getPadConst(), m_Constant(&pad_const))) + return op->emitOpError("pad_const of pad is not constant"); + } + return success(); +} + +LogicalResult OperationValidator::checkConstantOperandTranspose(Operation *op) { + if (auto transpose_op = dyn_cast(op)) { + DenseElementsAttr perms; + if (!matchPattern(transpose_op.getPerms(), m_Constant(&perms))) + return op->emitOpError("perms of transpose is not constant"); + } + return success(); +} + +LogicalResult +OperationValidator::checkConstantOperandFullyConnected(Operation *op) { + if (auto fc_op = dyn_cast(op)) { + DenseElementsAttr weight; + if (!matchPattern(fc_op.getWeight(), m_Constant(&weight))) + return op->emitOpError("weight of fully_connected is not constant"); + + DenseElementsAttr bias; + if (!matchPattern(fc_op.getBias(), m_Constant(&bias))) + return op->emitOpError("bias of fully_connected is not constant"); + } + return success(); +} + //===----------------------------------------------------------------------===// // TOSA Validation Pass. //===----------------------------------------------------------------------===// -struct TosaValidation : public tosa::impl::TosaValidationBase { +class TosaValidation : public tosa::impl::TosaValidationBase { public: explicit TosaValidation() = default; - -private: void runOnOperation() override; +private: std::optional profileType; }; void TosaValidation::runOnOperation() { profileType = symbolizeEnum(profileName); + auto op_validator(std::make_unique()); getOperation().walk([&](Operation *op) { for (Value operand : op->getOperands()) { @@ -62,6 +128,11 @@ return signalPassFailure(); } } + + // Some uses of TOSA rely on the constant operands of particular operations. + if (StrictOperationSpecAlignment && + failed(op_validator->applyConstantOperandCheck(op))) + signalPassFailure(); }); } } // namespace diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -split-input-file -verify-diagnostics +// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate=strict-op-spec-alignment func.func @test_conv2d(%arg0: tensor<1x29x29x4xf32>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> { @@ -37,3 +37,47 @@ } +// ----- + +func.func @test_pad_non_const(%arg0: tensor<13x21x3xf32>, %arg1: tensor<3x2xi32>) -> tensor<13x21x3xf32> { + // expected-error@+1 {{'tosa.pad' op padding of pad is not constant}} + %0 = "tosa.pad"(%arg0, %arg1) : (tensor<13x21x3xf32>, tensor<3x2xi32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- + +func.func @test_pad_non_const(%arg0: tensor<13x21x3xi8>, %arg1: tensor) -> tensor<13x21x3xi8> { + %0 = "tosa.const"() {value = dense<[[0, 0], [0, 1], [0, 1]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32> + // expected-error@+1 {{'tosa.pad' op pad_const of pad is not constant}} + %1 = "tosa.pad"(%arg0, %0, %arg1) : (tensor<13x21x3xi8>, tensor<3x2xi32>, tensor) -> tensor<13x21x3xi8> + return %1 : tensor<13x21x3xi8> +} + +// ----- + +func.func @test_transpose_non_const(%arg0: tensor<13x21x3xf32>, %arg1: tensor<3xi32>) -> tensor<3x13x21xf32> { + // expected-error@+1 {{'tosa.transpose' op perms of transpose is not constant}} + %0 = "tosa.transpose"(%arg0, %arg1) : (tensor<13x21x3xf32>, tensor<3xi32>) -> tensor<3x13x21xf32> + return %0 : tensor<3x13x21xf32> +} + +// ----- + +func.func @test_fully_connected_non_const(%arg0: tensor<13x21x3xf32>, %arg1: tensor<2x3xf32>) -> tensor<273x2xf32> { + %0 = "tosa.const"() {value = dense<0.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %1 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor<13x21x3xf32>) -> tensor<273x3xf32> + // expected-error@+1 {{'tosa.fully_connected' op weight of fully_connected is not constant}} + %2 = "tosa.fully_connected"(%1, %arg1, %0) : (tensor<273x3xf32>, tensor<2x3xf32>, tensor<2xf32>) -> tensor<273x2xf32> + return %2 : tensor<273x2xf32> +} + +// ----- + +func.func @test_fully_connected_non_const(%arg0: tensor<13x21x3xf32>, %arg1: tensor<2xf32>) -> tensor<273x2xf32> { + %0 = "tosa.const"() {value = dense<[[-0.613216758, -0.63714242, -0.73500061], [0.180762768, 0.773053169, -0.933686495]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %1 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor<13x21x3xf32>) -> tensor<273x3xf32> + // expected-error@+1 {{'tosa.fully_connected' op bias of fully_connected is not constant}} + %2 = "tosa.fully_connected"(%1, %0, %arg1) : (tensor<273x3xf32>, tensor<2x3xf32>, tensor<2xf32>) -> tensor<273x2xf32> + return %2 : tensor<273x2xf32> +}