diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -2011,8 +2011,11 @@ //===----------------------------------------------------------------------===// // OpTrait represents a trait regarding an operation. -// TODO: Remove this class in favor of using Trait. -class OpTrait; +class OpTrait traits = []> { + // Specify the list of traits that need to be verified before the verification + // of this OpTrait. + list dependentTraits = traits; +} // Define a OpTrait corresponding to a list of OpTraits, this allows for // specifying a list of traits as trait. Avoids needing to do @@ -2023,11 +2026,15 @@ } // These classes are used to define operation specific traits. -class NativeOpTrait : NativeTrait, OpTrait; -class ParamNativeOpTrait - : ParamNativeTrait, OpTrait; -class GenInternalOpTrait : GenInternalTrait, OpTrait; -class PredOpTrait : PredTrait, OpTrait; +class NativeOpTrait traits = []> + : NativeTrait, OpTrait { +} +class ParamNativeOpTrait traits = []> + : ParamNativeTrait, OpTrait; +class GenInternalOpTrait traits = []> + : GenInternalTrait, OpTrait; +class PredOpTrait traits = []> + : PredTrait, OpTrait; // Op defines an affine scope. def AffineScope : NativeOpTrait<"AffineScope">; @@ -2167,8 +2174,9 @@ // OpInterfaceTrait corresponds to a specific 'OpInterface' class defined in // C++. The purpose to wrap around C++ symbol string with this class is to make // interfaces specified for ops in TableGen less alien and more integrated. -class OpInterfaceTrait - : InterfaceTrait, OpTrait { +class OpInterfaceTrait traits = []> + : InterfaceTrait, OpTrait { // Specify the body of the verification function. `$_op` will be replaced with // the operation being verified. code verify = verifyBody; diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp --- a/mlir/lib/TableGen/Operator.cpp +++ b/mlir/lib/TableGen/Operator.cpp @@ -540,6 +540,22 @@ SmallPtrSet traitSet; traits.reserve(traitSet.size()); + auto verifyTraitValidity = [&](Record *trait) { + if (!trait->isSubClassOf("OpTrait")) + PrintFatalError(def.getLoc(), trait->getValueAsString("trait") + + " is not an OpTrait"); + + auto *dependentTraits = trait->getValueAsListInit("dependentTraits"); + for (auto *traitInit : *dependentTraits) + if (traitSet.find(traitInit) == traitSet.end()) + PrintFatalError( + def.getLoc(), + trait->getValueAsString("trait") + " requires " + + cast(traitInit)->getDef()->getValueAsString( + "trait") + + " to be declared before itself"); + }; + std::function insert; insert = [&](llvm::ListInit *traitList) { for (auto *traitInit : *traitList) { @@ -548,6 +564,11 @@ insert(def->getValueAsListInit("traits")); continue; } + + // Verify if the trait is an op trait and has all the dependent traits + // come before itself. + verifyTraitValidity(def); + // Keep traits in the same order while skipping over duplicates. if (traitSet.insert(traitInit).second) traits.push_back(Trait::create(traitInit)); diff --git a/mlir/test/mlir-tblgen/op-error.td b/mlir/test/mlir-tblgen/op-error.td --- a/mlir/test/mlir-tblgen/op-error.td +++ b/mlir/test/mlir-tblgen/op-error.td @@ -9,6 +9,8 @@ // RUN: not mlir-tblgen -gen-op-decls -I %S/../../include -DERROR9 %s 2>&1 | FileCheck --check-prefix=ERROR9 %s // RUN: not mlir-tblgen -gen-op-decls -I %S/../../include -DERROR10 %s 2>&1 | FileCheck --check-prefix=ERROR10 %s // RUN: not mlir-tblgen -gen-op-decls -I %S/../../include -DERROR11 %s 2>&1 | FileCheck --check-prefix=ERROR11 %s +// RUN: not mlir-tblgen -gen-op-decls -I %S/../../include -DERROR12 %s 2>&1 | FileCheck --check-prefix=ERROR12 %s +// RUN: not mlir-tblgen -gen-op-decls -I %S/../../include -DERROR13 %s 2>&1 | FileCheck --check-prefix=ERROR13 %s include "mlir/IR/OpBase.td" @@ -104,3 +106,21 @@ let regions = (region AnyRegion:$target); } #endif + +#ifdef ERROR12 +def OpTraitA : NativeOpTrait<"OpTraitA"> {} +def OpTraitB : NativeOpTrait<"OpTraitB", [OpTraitA]> {} + +// ERROR12: error: OpTraitB requires OpTraitA to be declared before itself +def OpTraitWithoutDependentTrait : Op {} +#endif + +#ifdef ERROR13 +def OpTraitA : NativeOpTrait<"OpTraitA"> {} +def OpInterfaceB : OpInterface<"OpInterfaceB"> { + let dependentTraits = [OpTraitA]; +} + +// ERROR13: error: OpInterfaceB::Trait requires OpTraitA to be declared before itself +def OpInterfaceWithoutDependentTrait : Op {} +#endif