diff --git a/mlir/include/mlir/IR/AttrTypeBase.td b/mlir/include/mlir/IR/AttrTypeBase.td --- a/mlir/include/mlir/IR/AttrTypeBase.td +++ b/mlir/include/mlir/IR/AttrTypeBase.td @@ -42,7 +42,12 @@ // Builders //===----------------------------------------------------------------------===// -// Class for defining a custom getter. +// Class for defining a custom getters and verifiers (MethodSpec) +// +// The MethodSpec (method specification) class can be used to provide function +// signatures (and optionally implementations) in a tablegen record. +// +// Custom getters: // // TableGen generates several generic getter methods for each attribute and type // by default, corresponding to the specified dag parameters. If the default @@ -96,10 +101,30 @@ // This is necessary because the `body` is also used to generate `getChecked` // methods, which have a different underlying `Base::get*` call. // -class AttrOrTypeBuilder { +// Custom verifiers: +// +// A `verify()` function is necessary when an attribute or type implements the +// `getChecked()` static method. The `getChecked()` arguments are forwarded to +// the `verify()` method, and if that method returns `failure()`, a null +// attribute will be returned to the `getChecked()` caller. +// +// If `genVerifyDecl=1` and no custom verifiers are provided, the `verify()` +// function will be declared in the class header, and the user must provide an +// implementation of the function in a separate source file. (The `getChecked()` +// function will be generated automatically when `genVerifyDecl=1`.) +// +// If one or more custom verifiers are provided in the `verifiers` field, the +// associated declarations will be generated in the header file. Providing a +// body for the verifier is optional - if no body/implementation is provided, +// the user must provide one in a separate source file. (The `getChecked()` +// implementation will be generated automatically when `verifiers` is nonempty.) +class MethodSpec { dag dagParams = parameters; code body = bodyCode; +} +class AttrOrTypeBuilder : + MethodSpec { // The context parameter can be inferred from one of the other parameters and // is not implicitly added to the parameter list. bit hasInferredContextParam = 0; @@ -121,6 +146,10 @@ class TypeBuilderWithInferredContext : AttrOrTypeBuilderWithInferredContext; +class AttrOrTypeVerifier : + MethodSpec { +} + //===----------------------------------------------------------------------===// // Definitions //===----------------------------------------------------------------------===// @@ -183,6 +212,9 @@ // Note that builders should only be provided when a def has parameters. list builders = ?; + // Note that verifiers should only be provided when a def has parameters. + list verifiers = ?; + // The list of traits attached to this def. list traits = defTraits; diff --git a/mlir/include/mlir/IR/EnumAttr.td b/mlir/include/mlir/IR/EnumAttr.td --- a/mlir/include/mlir/IR/EnumAttr.td +++ b/mlir/include/mlir/IR/EnumAttr.td @@ -396,6 +396,15 @@ // The default assembly format for enum attributes. Selected to best work with // operation assembly formats. let assemblyFormat = "$value"; + + // Provide a verifier for enums so that getChecked() will return a null + // attribute when the argument is not a valid enum value. + let verifiers = [ + AttrOrTypeVerifier<(ins returnType:$value), [{ + auto uval = static_cast<}] # enumInfo.underlyingType # [{>(value); + return }] # enumInfo.underlyingToSymbolFnName # [{(uval) ? ::mlir::success() : ::mlir::failure(); + }]> + ]; } #endif // ENUMATTR_TD diff --git a/mlir/include/mlir/TableGen/AttrOrTypeDef.h b/mlir/include/mlir/TableGen/AttrOrTypeDef.h --- a/mlir/include/mlir/TableGen/AttrOrTypeDef.h +++ b/mlir/include/mlir/TableGen/AttrOrTypeDef.h @@ -15,7 +15,7 @@ #define MLIR_TABLEGEN_ATTRORTYPEDEF_H #include "mlir/Support/LLVM.h" -#include "mlir/TableGen/Builder.h" +#include "mlir/TableGen/MethodSpec.h" #include "mlir/TableGen/Trait.h" namespace llvm { @@ -33,9 +33,9 @@ //===----------------------------------------------------------------------===// /// Wrapper class that represents a Tablegen AttrOrTypeBuilder. -class AttrOrTypeBuilder : public Builder { +class AttrOrTypeBuilder : public MethodSpec { public: - using Builder::Builder; + using MethodSpec::MethodSpec; /// Returns true if this builder is able to infer the MLIRContext parameter. bool hasInferredContextParameter() const; @@ -189,6 +189,10 @@ /// method. bool genVerifyDecl() const; + /// Returns true if verifier methods were provided in the record, or if + /// automatic verify() declaration was requested (genVerifyDecl=1). + bool usesCheckedBuilder() const; + /// Returns the def's extra class declaration code. Optional getExtraDecls() const; @@ -202,6 +206,9 @@ /// Returns the builders of this def. ArrayRef getBuilders() const { return builders; } + /// Returns the verifiers of this def. + ArrayRef getVerifiers() const { return verifiers; } + /// Returns the traits of this def. ArrayRef getTraits() const { return traits; } @@ -224,6 +231,9 @@ /// The builders of this definition. SmallVector builders; + /// The verifiers of this definition. + SmallVector verifiers; + /// The traits of this definition. SmallVector traits; diff --git a/mlir/include/mlir/TableGen/Builder.h b/mlir/include/mlir/TableGen/MethodSpec.h rename from mlir/include/mlir/TableGen/Builder.h rename to mlir/include/mlir/TableGen/MethodSpec.h --- a/mlir/include/mlir/TableGen/Builder.h +++ b/mlir/include/mlir/TableGen/MethodSpec.h @@ -1,4 +1,4 @@ -//===- Builder.h - Builder classes ------------------------------*- C++ -*-===// +//===- MethodSpec.h - MethodSpec classes ------------------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,13 +6,13 @@ // //===----------------------------------------------------------------------===// // -// Builder wrapper to simplify using TableGen Record for building -// operations/types/etc. +// Method specification wrapper to simplify using TableGen records for building +// operations/attributes/types/etc. // //===----------------------------------------------------------------------===// -#ifndef MLIR_TABLEGEN_BUILDER_H_ -#define MLIR_TABLEGEN_BUILDER_H_ +#ifndef MLIR_TABLEGEN_METHODSPEC_H_ +#define MLIR_TABLEGEN_METHODSPEC_H_ #include "mlir/Support/LLVM.h" #include "llvm/ADT/ArrayRef.h" @@ -28,11 +28,13 @@ namespace mlir { namespace tblgen { -/// Wrapper class with helper methods for accessing Builders defined in -/// TableGen. -class Builder { +/// Wrapper class with helper methods for accessing MethodSpecs defined in +/// TableGen. A MethodSpec (method specification) provides a function +/// signature (declaration), and (optionally) a function body +/// (implementation). +class MethodSpec { public: - /// This class represents a single parameter to a builder method. + /// This class represents a single parameter to a method specification. class Parameter { public: /// Return a string containing the C++ type of this parameter. @@ -58,28 +60,28 @@ const llvm::Init *def; // Allow access to the constructor. - friend Builder; + friend MethodSpec; }; - /// Construct a builder from the given Record instance. - Builder(const llvm::Record *record, ArrayRef loc); + /// Construct a MethodSpec from the given Record instance. + MethodSpec(const llvm::Record *record, ArrayRef loc); - /// Return a list of parameters used in this build method. + /// Return a list of parameters used in this method. ArrayRef getParameters() const { return parameters; } - /// Return an optional string containing the body of the builder. + /// Return an optional string containing the body of the method spec. Optional getBody() const; protected: - /// The TableGen definition of this builder. + /// The TableGen definition of this method specification. const llvm::Record *def; private: - /// A collection of parameters to the builder. + /// A collection of parameters to the method specification. SmallVector parameters; }; } // namespace tblgen } // namespace mlir -#endif // MLIR_TABLEGEN_BUILDER_H_ +#endif // MLIR_TABLEGEN_METHODSPEC_H_ diff --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h --- a/mlir/include/mlir/TableGen/Operator.h +++ b/mlir/include/mlir/TableGen/Operator.h @@ -16,8 +16,8 @@ #include "mlir/Support/LLVM.h" #include "mlir/TableGen/Argument.h" #include "mlir/TableGen/Attribute.h" -#include "mlir/TableGen/Builder.h" #include "mlir/TableGen/Dialect.h" +#include "mlir/TableGen/MethodSpec.h" #include "mlir/TableGen/Region.h" #include "mlir/TableGen/Successor.h" #include "mlir/TableGen/Trait.h" @@ -300,7 +300,7 @@ OperandOrAttribute getArgToOperandOrAttribute(int index) const; // Returns the builders of this operation. - ArrayRef getBuilders() const { return builders; } + ArrayRef getBuilders() const { return builders; } // Returns the preferred getter name for the accessor. std::string getGetterName(StringRef name) const { @@ -362,7 +362,7 @@ SmallVector attrOrOperandMapping; // The builders of this operator. - SmallVector builders; + SmallVector builders; // The number of native attributes stored in the leading positions of // `attributes`. diff --git a/mlir/lib/TableGen/AttrOrTypeDef.cpp b/mlir/lib/TableGen/AttrOrTypeDef.cpp --- a/mlir/lib/TableGen/AttrOrTypeDef.cpp +++ b/mlir/lib/TableGen/AttrOrTypeDef.cpp @@ -16,6 +16,24 @@ using namespace mlir; using namespace mlir::tblgen; +template +void populateMethodSpecList(SmallVector &v, const llvm::Record *def, + const char *name) { + auto *methodList = dyn_cast_or_null(def->getValueInit(name)); + if (methodList && !methodList->empty()) { + for (llvm::Init *init : methodList->getValues()) { + T method(cast(init)->getDef(), def->getLoc()); + + // Ensure that all parameters have names. + for (const typename T::Parameter ¶m : method.getParameters()) { + if (!param.getName()) + PrintFatalError(def->getLoc(), "method parameters must have a name"); + } + v.emplace_back(method); + } + } +} + //===----------------------------------------------------------------------===// // AttrOrTypeBuilder //===----------------------------------------------------------------------===// @@ -30,23 +48,9 @@ //===----------------------------------------------------------------------===// AttrOrTypeDef::AttrOrTypeDef(const llvm::Record *def) : def(def) { - // Populate the builders. - auto *builderList = - dyn_cast_or_null(def->getValueInit("builders")); - if (builderList && !builderList->empty()) { - for (llvm::Init *init : builderList->getValues()) { - AttrOrTypeBuilder builder(cast(init)->getDef(), - def->getLoc()); - - // Ensure that all parameters have names. - for (const AttrOrTypeBuilder::Parameter ¶m : - builder.getParameters()) { - if (!param.getName()) - PrintFatalError(def->getLoc(), "builder parameters must have a name"); - } - builders.emplace_back(builder); - } - } + // Populate the builders and verifiers. + populateMethodSpecList(builders, def, "builders"); + populateMethodSpecList(verifiers, def, "verifiers"); // Populate the traits. if (auto *traitList = def->getValueAsListInit("traits")) { @@ -162,6 +166,10 @@ return def->getValueAsBit("genVerifyDecl"); } +bool AttrOrTypeDef::usesCheckedBuilder() const { + return !verifiers.empty() || def->getValueAsBit("genVerifyDecl"); +} + Optional AttrOrTypeDef::getExtraDecls() const { auto value = def->getValueAsString("extraClassDeclaration"); return value.empty() ? Optional() : value; diff --git a/mlir/lib/TableGen/CMakeLists.txt b/mlir/lib/TableGen/CMakeLists.txt --- a/mlir/lib/TableGen/CMakeLists.txt +++ b/mlir/lib/TableGen/CMakeLists.txt @@ -12,12 +12,12 @@ Argument.cpp Attribute.cpp AttrOrTypeDef.cpp - Builder.cpp Class.cpp Constraint.cpp Dialect.cpp Format.cpp Interfaces.cpp + MethodSpec.cpp Operator.cpp Pass.cpp Pattern.cpp diff --git a/mlir/lib/TableGen/Builder.cpp b/mlir/lib/TableGen/MethodSpec.cpp rename from mlir/lib/TableGen/Builder.cpp rename to mlir/lib/TableGen/MethodSpec.cpp --- a/mlir/lib/TableGen/Builder.cpp +++ b/mlir/lib/TableGen/MethodSpec.cpp @@ -1,4 +1,4 @@ -//===- Builder.cpp - Builder definitions ----------------------------------===// +//===- MethodSpec.cpp - MethodSpec definitions ----------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// -#include "mlir/TableGen/Builder.h" +#include "mlir/TableGen/MethodSpec.h" #include "llvm/TableGen/Error.h" #include "llvm/TableGen/Record.h" @@ -14,11 +14,11 @@ using namespace mlir::tblgen; //===----------------------------------------------------------------------===// -// Builder::Parameter +// MethodSpec::Parameter //===----------------------------------------------------------------------===// /// Return a string containing the C++ type of this parameter. -StringRef Builder::Parameter::getCppType() const { +StringRef MethodSpec::Parameter::getCppType() const { if (const auto *stringInit = dyn_cast(def)) return stringInit->getValue(); const llvm::Record *record = cast(def)->getDef(); @@ -27,7 +27,7 @@ /// Return an optional string containing the default value to use for this /// parameter. -Optional Builder::Parameter::getDefaultValue() const { +Optional MethodSpec::Parameter::getDefaultValue() const { if (isa(def)) return llvm::None; const llvm::Record *record = cast(def)->getDef(); @@ -36,16 +36,16 @@ } //===----------------------------------------------------------------------===// -// Builder +// MethodSpec //===----------------------------------------------------------------------===// -Builder::Builder(const llvm::Record *record, ArrayRef loc) +MethodSpec::MethodSpec(const llvm::Record *record, ArrayRef loc) : def(record) { - // Initialize the parameters of the builder. + // Initialize the parameters of the method specification. const llvm::DagInit *dag = def->getValueAsDag("dagParams"); auto *defInit = dyn_cast(dag->getOperator()); if (!defInit || !defInit->getDef()->getName().equals("ins")) - PrintFatalError(def->getLoc(), "expected 'ins' in builders"); + PrintFatalError(def->getLoc(), "expected 'ins' in methodspec"); bool seenDefaultValue = false; for (unsigned i = 0, e = dag->getNumArgs(); i < e; ++i) { @@ -67,8 +67,8 @@ } } -/// Return an optional string containing the body of the builder. -Optional Builder::getBody() const { +/// Return an optional string containing the body of the method spec. +Optional MethodSpec::getBody() const { Optional body = def->getValueAsOptionalString("body"); return body && !body->empty() ? body : llvm::None; } diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp --- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp @@ -67,6 +67,19 @@ } } +static SmallVector +getCustomVerifierParams(std::initializer_list prefix, + const MethodSpec &verifier) { + auto params = verifier.getParameters(); + SmallVector verifierParams; + verifierParams.append(prefix.begin(), prefix.end()); + for (auto ¶m : params) { + verifierParams.emplace_back(param.getCppType(), *param.getName(), + param.getDefaultValue()); + } + return verifierParams; +} + //===----------------------------------------------------------------------===// // DefGen //===----------------------------------------------------------------------===// @@ -110,7 +123,7 @@ void emitInterfaceMethods(); //===--------------------------------------------------------------------===// - // Builder Emission + // Builder/Verifier Emission /// Emit the default builder `Attribute::get` void emitDefaultBuilder(); @@ -120,6 +133,8 @@ void emitCustomBuilder(const AttrOrTypeBuilder &builder); /// Emit a checked custom builder. void emitCheckedCustomBuilder(const AttrOrTypeBuilder &builder); + /// Emit a custom verifier. + void emitCustomVerifier(const MethodSpec &verifier); //===--------------------------------------------------------------------===// // Interface Method Emission @@ -192,7 +207,7 @@ if (storageCls) emitBuilders(); // Emit the verifier. - if (storageCls && def.genVerifyDecl()) + if (storageCls && def.usesCheckedBuilder()) emitVerifier(); // Emit the mnemonic, if there is one, and any associated parser and printer. if (def.getMnemonic()) @@ -238,22 +253,32 @@ void DefGen::emitBuilders() { if (!def.skipDefaultBuilders()) { emitDefaultBuilder(); - if (def.genVerifyDecl()) + if (def.usesCheckedBuilder()) emitCheckedBuilder(); } for (auto &builder : def.getBuilders()) { emitCustomBuilder(builder); - if (def.genVerifyDecl()) + if (def.usesCheckedBuilder()) emitCheckedCustomBuilder(builder); } } void DefGen::emitVerifier() { + assert(def.usesCheckedBuilder() && "verify() emitted without getChecked()"); defCls.declare("Base::getChecked"); - defCls.declareStaticMethod( - "::mlir::LogicalResult", "verify", - getBuilderParams({{"::llvm::function_ref<::mlir::InFlightDiagnostic()>", - "emitError"}})); + if (!def.getVerifiers().empty()) { + // The record has one or more verifiers. + for (auto &verifier : def.getVerifiers()) + emitCustomVerifier(verifier); + } else { + // The record has genVerifyDecl=1, and no tblgen verifiers were provided. + // Generate a verify() declaration corresponding to the default builder, + // and the user must provide a definition. + defCls.declareStaticMethod( + "::mlir::LogicalResult", "verify", + getBuilderParams({{"::llvm::function_ref<::mlir::InFlightDiagnostic()>", + "emitError"}})); + } } void DefGen::emitParserPrinter() { @@ -403,6 +428,19 @@ m->body().indent().getStream().printReindented(bodyStr); } +void DefGen::emitCustomVerifier(const MethodSpec &verifier) { + // Don't emit a body if there isn't one. + auto props = verifier.getBody() ? Method::Static : Method::StaticDeclaration; + Method *m = defCls.addMethod( + "::mlir::LogicalResult", "verify", props, + getCustomVerifierParams( + {{"::llvm::function_ref<::mlir::InFlightDiagnostic()>", "emitError"}}, + verifier)); + if (!verifier.getBody()) + return; + m->body().indent().getStream().printReindented(*verifier.getBody()); +} + //===----------------------------------------------------------------------===// // Interface Method Emission diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp --- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp @@ -286,7 +286,7 @@ // Generate call to the attribute or type builder. Use the checked getter // if one was generated. - if (def.genVerifyDecl()) { + if (def.usesCheckedBuilder()) { os << tgfmt("return $_parser.getChecked<$0>($_loc, $_parser.getContext()", &ctx, def.getCppClassName()); } else { diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -1732,8 +1732,8 @@ /// Returns a signature of the builder. Updates the context `fctx` to enable /// replacement of $_builder and $_state in the body. static SmallVector -getBuilderSignature(const Builder &builder) { - ArrayRef params(builder.getParameters()); +getBuilderSignature(const MethodSpec &builder) { + ArrayRef params(builder.getParameters()); // Inject builder and state arguments. SmallVector arguments; @@ -1760,7 +1760,7 @@ void OpEmitter::genBuilder() { // Handle custom builders if provided. - for (const Builder &builder : op.getBuilders()) { + for (const MethodSpec &builder : op.getBuilders()) { SmallVector arguments = getBuilderSignature(builder); Optional body = builder.getBody(); diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp --- a/mlir/unittests/IR/AttributeTest.cpp +++ b/mlir/unittests/IR/AttributeTest.cpp @@ -10,8 +10,12 @@ #include "mlir/IR/BuiltinTypes.h" #include "gtest/gtest.h" +#include "../../test/lib/Dialect/Test/TestAttributes.h" +#include "../../test/lib/Dialect/Test/TestDialect.h" + using namespace mlir; using namespace mlir::detail; +using namespace test; template static void testSplat(Type eltType, const EltTy &splatElt) { @@ -250,4 +254,38 @@ EXPECT_TRUE(zeroStringValue.getType() == stringTy); } +TEST(EnumAttrTest, GetChecked) { + DialectRegistry registry; + registry.insert(); + MLIRContext context(registry); + context.loadDialect(); + + Location loc(UnknownLoc::get(&context)); + auto emitErrorFn = [&]() { return emitError(loc); }; + + // Check that getChecked() returns a non-nullptr attribute when a valid enum + // value is provided (int enum). + auto enumAttr = + TestEnumAttr::getChecked(emitErrorFn, &context, TestEnum::Third); + EXPECT_NE(enumAttr, nullptr); + + // Check that getChecked() returns a nullptr attribute when an invalid enum + // value is provided (int enum). + auto enumAttrInvalid = TestEnumAttr::getChecked(emitErrorFn, &context, + static_cast(-1)); + EXPECT_EQ(enumAttrInvalid, nullptr); + + // Check that getChecked() returns a non-nullptr attribute when a valid enum + // value is provided (bit enum). + auto bitEnumAttr = TestBitEnumAttr::getChecked( + emitErrorFn, &context, TestBitEnum::Read | TestBitEnum::Write); + EXPECT_NE(bitEnumAttr, nullptr); + + // Check that getChecked() returns a nullptr attribute when an invalid enum + // value is provided (bit enum). + auto bitEnumAttrInvalid = TestBitEnumAttr::getChecked( + emitErrorFn, &context, static_cast(0xFFFFFFFF)); + EXPECT_EQ(bitEnumAttrInvalid, nullptr); +} + } // namespace