diff --git a/mlir/include/mlir/IR/AttrTypeBase.td b/mlir/include/mlir/IR/AttrTypeBase.td --- a/mlir/include/mlir/IR/AttrTypeBase.td +++ b/mlir/include/mlir/IR/AttrTypeBase.td @@ -210,7 +210,9 @@ // be provided. bit skipDefaultBuilders = 0; - // Generate the verify and getChecked methods. + // Generate the verify and getChecked methods. (Note that these methods will + // also be generated automatically if any AttrOrTypeParameters have custom + // verifier code.) bit genVerifyDecl = 0; // Extra code to include in the class declaration. @@ -320,6 +322,15 @@ // made available through `$_ctx`, e.g., for constructing default values for // attributes and types. string defaultValue = ?; + // Custom code to verify parameters used to create an attribute or type. The + // code string should represent the body of a function that takes two + // arguments: + // - a functionref 'emitError' that returns an InFlightDiagnostic + // - a parameter value argument with a name and type that match the name and + // C++ type of the parameter. + // The return type is ::mlir::LogicalResult, where failure() indicates that + // verification failed. + code verifier = ?; } class AttrParameter : AttrOrTypeParameter; diff --git a/mlir/include/mlir/IR/EnumAttr.td b/mlir/include/mlir/IR/EnumAttr.td --- a/mlir/include/mlir/IR/EnumAttr.td +++ b/mlir/include/mlir/IR/EnumAttr.td @@ -337,6 +337,10 @@ "an enum of type " # enumInfo.className> { let parser = enumInfo.parameterParser; let printer = enumInfo.parameterPrinter; + let verifier = [{return }] # enumInfo.underlyingToSymbolFnName # + [{(static_cast<}] # enumInfo.underlyingType # [{>(value)).hasValue() + ? ::mlir::success() + : (emitError() << "invalid enum value");}]; } // An attribute backed by a C++ enum. The attribute contains a single diff --git a/mlir/include/mlir/TableGen/AttrOrTypeDef.h b/mlir/include/mlir/TableGen/AttrOrTypeDef.h --- a/mlir/include/mlir/TableGen/AttrOrTypeDef.h +++ b/mlir/include/mlir/TableGen/AttrOrTypeDef.h @@ -91,6 +91,9 @@ /// Get the default value of the parameter if it has one. Optional getDefaultValue() const; + /// If specified, get the custom verifier code for this parameter. + Optional getVerifier() const; + /// Return the underlying def of this parameter. llvm::Init *getDef() const; @@ -189,6 +192,14 @@ /// method. bool genVerifyDecl() const; + /// Returns true if automatic verify() declaration was requested + /// (genVerifyDecl=1), or if any AttrOrTypeParameters have custom verifier + /// code. + bool usesCheckedBuilder() const; + + /// Returns true if any parameter has a non-empty verifier code string + bool anyParamHasVerifier() const; + /// Returns the def's extra class declaration code. Optional getExtraDecls() const; diff --git a/mlir/lib/TableGen/AttrOrTypeDef.cpp b/mlir/lib/TableGen/AttrOrTypeDef.cpp --- a/mlir/lib/TableGen/AttrOrTypeDef.cpp +++ b/mlir/lib/TableGen/AttrOrTypeDef.cpp @@ -162,6 +162,16 @@ return def->getValueAsBit("genVerifyDecl"); } +bool AttrOrTypeDef::usesCheckedBuilder() const { + return genVerifyDecl() || anyParamHasVerifier(); +} + +bool AttrOrTypeDef::anyParamHasVerifier() const { + return llvm::any_of(getParameters(), [](const AttrOrTypeParameter &p) { + return p.getVerifier(); + }); +} + Optional AttrOrTypeDef::getExtraDecls() const { auto value = def->getValueAsString("extraClassDeclaration"); return value.empty() ? Optional() : value; @@ -273,6 +283,10 @@ return getDefValue("defaultValue"); } +Optional AttrOrTypeParameter::getVerifier() const { + return getDefValue("verifier"); +} + llvm::Init *AttrOrTypeParameter::getDef() const { return def->getArg(index); } //===----------------------------------------------------------------------===// diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp --- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp @@ -192,7 +192,7 @@ if (storageCls) emitBuilders(); // Emit the verifier. - if (storageCls && def.genVerifyDecl()) + if (storageCls && def.usesCheckedBuilder()) emitVerifier(); // Emit the mnemonic, if there is one, and any associated parser and printer. if (def.getMnemonic()) @@ -238,22 +238,48 @@ void DefGen::emitBuilders() { if (!def.skipDefaultBuilders()) { emitDefaultBuilder(); - if (def.genVerifyDecl()) + if (def.usesCheckedBuilder()) emitCheckedBuilder(); } for (auto &builder : def.getBuilders()) { emitCustomBuilder(builder); - if (def.genVerifyDecl()) + if (def.usesCheckedBuilder()) emitCheckedCustomBuilder(builder); } } void DefGen::emitVerifier() { + assert(def.usesCheckedBuilder() && "verify emitted without checked builder"); defCls.declare("Base::getChecked"); - defCls.declareStaticMethod( - "::mlir::LogicalResult", "verify", + if (def.genVerifyDecl()) { + defCls.declareStaticMethod( + "::mlir::LogicalResult", "verify", + getBuilderParams({{"::llvm::function_ref<::mlir::InFlightDiagnostic()>", + "emitError"}})); + return; + } + assert(def.anyParamHasVerifier() && "no params with custom verify code"); + Method *m = defCls.addMethod( + "::mlir::LogicalResult", "verify", Method::Static, getBuilderParams({{"::llvm::function_ref<::mlir::InFlightDiagnostic()>", "emitError"}})); + const char *verifyFormat = R"( + // Verify parameter '{0}': + if (failed( + [&emitError]({1} {0}) -> ::mlir::LogicalResult {{ + {2} + }({0}))) + return ::mlir::failure(); + )"; + for (auto ¶m : def.getParameters()) { + if (param.getVerifier()) { + std::string bodyStr(llvm::formatv(verifyFormat, param.getName(), + param.getCppType(), + param.getVerifier())); + m->body().indent().getStream().printReindented(bodyStr); + } + } + m->body() << "return ::mlir::success();\n"; } void DefGen::emitParserPrinter() { diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp --- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp @@ -286,7 +286,7 @@ // Generate call to the attribute or type builder. Use the checked getter // if one was generated. - if (def.genVerifyDecl()) { + if (def.usesCheckedBuilder()) { os << tgfmt("return $_parser.getChecked<$0>($_loc, $_parser.getContext()", &ctx, def.getCppClassName()); } else { diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp --- a/mlir/unittests/IR/AttributeTest.cpp +++ b/mlir/unittests/IR/AttributeTest.cpp @@ -10,8 +10,12 @@ #include "mlir/IR/BuiltinTypes.h" #include "gtest/gtest.h" +#include "../../test/lib/Dialect/Test/TestAttributes.h" +#include "../../test/lib/Dialect/Test/TestDialect.h" + using namespace mlir; using namespace mlir::detail; +using namespace test; template static void testSplat(Type eltType, const EltTy &splatElt) { @@ -250,4 +254,38 @@ EXPECT_TRUE(zeroStringValue.getType() == stringTy); } +TEST(EnumAttrTest, GetChecked) { + DialectRegistry registry; + registry.insert(); + MLIRContext context(registry); + context.loadDialect(); + + Location loc(UnknownLoc::get(&context)); + auto emitErrorFn = [&]() { return emitError(loc); }; + + // Check that getChecked() returns a non-nullptr attribute when a valid enum + // value is provided (int enum). + auto enumAttr = + TestEnumAttr::getChecked(emitErrorFn, &context, TestEnum::Third); + EXPECT_NE(enumAttr, nullptr); + + // Check that getChecked() returns a nullptr attribute when an invalid enum + // value is provided (int enum). + auto enumAttrInvalid = TestEnumAttr::getChecked(emitErrorFn, &context, + static_cast(-1)); + EXPECT_EQ(enumAttrInvalid, nullptr); + + // Check that getChecked() returns a non-nullptr attribute when a valid enum + // value is provided (bit enum). + auto bitEnumAttr = TestBitEnumAttr::getChecked( + emitErrorFn, &context, TestBitEnum::Read | TestBitEnum::Write); + EXPECT_NE(bitEnumAttr, nullptr); + + // Check that getChecked() returns a nullptr attribute when an invalid enum + // value is provided (bit enum). + auto bitEnumAttrInvalid = TestBitEnumAttr::getChecked( + emitErrorFn, &context, static_cast(0xFFFFFFFF)); + EXPECT_EQ(bitEnumAttrInvalid, nullptr); +} + } // namespace