diff --git a/mlir/docs/AttributesAndTypes.md b/mlir/docs/AttributesAndTypes.md --- a/mlir/docs/AttributesAndTypes.md +++ b/mlir/docs/AttributesAndTypes.md @@ -215,6 +215,8 @@ may suggest, `AttrParameter` is intended for parameters on Attributes, `TypeParameter` for Type parameters, and `AttrOrTypeParameters` for either. +#### Parameter Pitfalls + Below is an easy parameter pitfall, and highlights when to use these parameter classes. @@ -244,6 +246,8 @@ let parameters = (ins ArrayRefIntParam:$dims); ``` +#### Other Parameter Fields + Below contains descriptions for other various available fields: The `allocator` code block has the following substitutions: @@ -256,6 +260,35 @@ - `$_lhs` is an instance of the parameter type. - `$_rhs` is an instance of the parameter type. +##### Parameter Verification + +Parameters can optionally provide custom verification code that will be used +for validation when creating attributes or types. This verification code is +invoked when the "checked" builders are used. (See the section on +[Builders](#Builders) below.) If verification fails, the attribute or type +will not be created. + +Parameter verification code is specified via the `verifier` field of the +`tblgen` record for the parameter class. An example of using parameter +verification using the `verifier` field is shown below. In this case, the +`NonNegativeFloatAttr` will be successfully created only if the float value is +non-negative. + +```tablegen +def NonNegativeAPFloatParam : APFloatParameter<"positive float"> { + let verifier = [{ return ::mlir::success(!value.isNegative()); }]; +} + +def NonNegativeFloatAttr : Test_Attr<"NonNegativeFloat"> { + let parameters = (ins NonNegativeAPFloatParam:$nnfloat); +} +``` + +For more information on attribute/type verification, see the +[Verification](#Verification) section below. + +#### Specialized Parameter Classes + MLIR includes several specialized classes for common situations: - `APFloatParameter` for APFloats. @@ -825,11 +858,33 @@ ### Verification -If the `genVerifyDecl` field is set, additional verification methods are -generated on the class. +Attribute and type classes provide verification of inputs using the +`verifyInvariants()` method. For cases when no verification is necessary, +storage base classes provide a default implementation. + +There are two phases of verification of inputs when creating instances of +attributes and types: + +- parameter verification +- custom verification + +Parameter verification operates on a single parameter (or argument) to the +attribute/type creation function, and can be specified via the `verifier` +record. See the [Parameter Verification](#Parameter Verification) section above +for an example. + +Custom (user-provided) verification occurs in the static `verify()` method of +the class, and can perform arbitrary checks. The signature for the `verify()` +method is generated in the header for the attribute/type C++ class when the +`genVerifyDecl` field (of type `bit`) is set. - `static LogicalResult verify(function_ref emitError, parameters...)` +Since parameter verification is applied to a single parameter only, custom +verification is necessary when success or failure depends on more than one +attribute/type parameter. Users may also wish to use the custom `verify()` +method to avoid defining complex code in `tablegen` input files. + These methods are used to verify the parameters provided to the attribute or type class on construction, and emit any necessary diagnostics. This method is automatically invoked from the builders of the attribute or type class. diff --git a/mlir/include/mlir/IR/AttrTypeBase.td b/mlir/include/mlir/IR/AttrTypeBase.td --- a/mlir/include/mlir/IR/AttrTypeBase.td +++ b/mlir/include/mlir/IR/AttrTypeBase.td @@ -210,7 +210,9 @@ // be provided. bit skipDefaultBuilders = 0; - // Generate the verify and getChecked methods. + // Generate the verify and getChecked methods. (Note that these methods will + // also be generated automatically if any AttrOrTypeParameters have custom + // verifier code.) bit genVerifyDecl = 0; // Extra code to include in the class declaration. @@ -322,6 +324,15 @@ // made available through `$_ctxt`, e.g., for constructing default values for // attributes and types. string defaultValue = ?; + // Custom code to verify parameters used to create an attribute or type. The + // code string should represent the body of a function that takes two + // arguments: + // - a functionref 'emitError' that returns an InFlightDiagnostic + // - a parameter value argument with a name and type that match the name and + // C++ type of the parameter. + // The return type is ::mlir::LogicalResult, where failure() indicates that + // verification failed. + code verifier = ?; } class AttrParameter : AttrOrTypeParameter; diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td --- a/mlir/include/mlir/IR/BuiltinAttributes.td +++ b/mlir/include/mlir/IR/BuiltinAttributes.td @@ -555,6 +555,7 @@ }]; let parameters = (ins AttributeSelfTypeParameter<"">:$type, APFloatParameter<"">:$value); + let builders = [ AttrBuilderWithInferredContext<(ins "Type":$type, "const APFloat &":$value), [{ @@ -845,15 +846,14 @@ "DenseElementsAttr":$values); let builders = [ AttrBuilderWithInferredContext<(ins "ShapedType":$type, - "DenseElementsAttr":$indices, + "DenseIntElementsAttr":$indices, "DenseElementsAttr":$values), [{ assert(indices.getType().getElementType().isInteger(64) && "expected sparse indices to be 64-bit integer values"); assert((type.isa()) && "type must be ranked tensor or vector"); assert(type.hasStaticShape() && "type must have static shape"); - return $_get(type.getContext(), type, - indices.cast(), values); + return $_get(type.getContext(), type, indices, values); }]>, ]; let extraClassDeclaration = [{ diff --git a/mlir/include/mlir/IR/EnumAttr.td b/mlir/include/mlir/IR/EnumAttr.td --- a/mlir/include/mlir/IR/EnumAttr.td +++ b/mlir/include/mlir/IR/EnumAttr.td @@ -337,6 +337,10 @@ "an enum of type " # enumInfo.className> { let parser = enumInfo.parameterParser; let printer = enumInfo.parameterPrinter; + let verifier = [{return }] # enumInfo.underlyingToSymbolFnName # + [{(static_cast<}] # enumInfo.underlyingType # [{>(value)).hasValue() + ? ::mlir::success() + : (emitError() << "invalid enum value");}]; } // An attribute backed by a C++ enum. The attribute contains a single diff --git a/mlir/include/mlir/IR/StorageUniquerSupport.h b/mlir/include/mlir/IR/StorageUniquerSupport.h --- a/mlir/include/mlir/IR/StorageUniquerSupport.h +++ b/mlir/include/mlir/IR/StorageUniquerSupport.h @@ -148,8 +148,8 @@ template static ConcreteT get(MLIRContext *ctx, Args... args) { // Ensure that the invariants are correct for construction. - assert( - succeeded(ConcreteT::verify(getDefaultDiagnosticEmitFn(ctx), args...))); + assert(succeeded( + ConcreteT::verifyInvariants(getDefaultDiagnosticEmitFn(ctx), args...))); return UniquerT::template get(ctx, args...); } @@ -169,7 +169,7 @@ static ConcreteT getChecked(function_ref emitErrorFn, MLIRContext *ctx, Args... args) { // If the construction invariants fail then we return a null attribute. - if (failed(ConcreteT::verify(emitErrorFn, args...))) + if (failed(ConcreteT::verifyInvariants(emitErrorFn, args...))) return ConcreteT(); return UniquerT::template get(ctx, args...); } @@ -196,6 +196,14 @@ return success(); } + /// Default implementation that just calls the verify() function + template + static LogicalResult + verifyInvariants(function_ref emitErrorFn, + Args... args) { + return ConcreteT::verify(emitErrorFn, args...); + } + /// Utility for easy access to the storage instance. ImplType *getImpl() const { return static_cast(this->impl); } diff --git a/mlir/include/mlir/TableGen/AttrOrTypeDef.h b/mlir/include/mlir/TableGen/AttrOrTypeDef.h --- a/mlir/include/mlir/TableGen/AttrOrTypeDef.h +++ b/mlir/include/mlir/TableGen/AttrOrTypeDef.h @@ -97,6 +97,9 @@ /// Get the default value of the parameter if it has one. Optional getDefaultValue() const; + /// If specified, get the custom verifier code for this parameter. + Optional getVerifier() const; + /// Return the underlying def of this parameter. llvm::Init *getDef() const; @@ -195,6 +198,14 @@ /// method. bool genVerifyDecl() const; + /// Returns true if automatic verify() declaration was requested + /// (genVerifyDecl=1), or if any AttrOrTypeParameters have custom verifier + /// code. + bool usesCheckedBuilder() const; + + /// Returns true if any parameter has a non-empty verifier code string + bool anyParamHasVerifier() const; + /// Returns the def's extra class declaration code. Optional getExtraDecls() const; diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -612,7 +612,7 @@ MlirAttribute denseValues) { return wrap( SparseElementsAttr::get(unwrap(shapedType).cast(), - unwrap(denseIndices).cast(), + unwrap(denseIndices).cast(), unwrap(denseValues).cast())); } diff --git a/mlir/lib/IR/Location.cpp b/mlir/lib/IR/Location.cpp --- a/mlir/lib/IR/Location.cpp +++ b/mlir/lib/IR/Location.cpp @@ -8,6 +8,7 @@ #include "mlir/IR/Location.h" #include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/Diagnostics.h" #include "mlir/IR/Visitors.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/TypeSwitch.h" diff --git a/mlir/lib/Parser/AttributeParser.cpp b/mlir/lib/Parser/AttributeParser.cpp --- a/mlir/lib/Parser/AttributeParser.cpp +++ b/mlir/lib/Parser/AttributeParser.cpp @@ -989,7 +989,8 @@ RankedTensorType::get({0, type.getRank()}, indiceEltType); ShapedType valuesType = RankedTensorType::get({0}, type.getElementType()); return getChecked( - loc, type, DenseElementsAttr::get(indicesType, ArrayRef()), + loc, type, + DenseIntElementsAttr::get(indicesType, ArrayRef()), DenseElementsAttr::get(valuesType, ArrayRef())); } @@ -1028,7 +1029,8 @@ // Otherwise, set the shape to the one parsed by the literal parser. indicesType = RankedTensorType::get(indiceParser.getShape(), indiceEltType); } - auto indices = indiceParser.getAttr(indicesLoc, indicesType); + auto indices = indiceParser.getAttr(indicesLoc, indicesType) + .cast(); // If the values are a splat, set the shape explicitly based on the number of // indices. The number of indices is encoded in the first dimension of the diff --git a/mlir/lib/TableGen/AttrOrTypeDef.cpp b/mlir/lib/TableGen/AttrOrTypeDef.cpp --- a/mlir/lib/TableGen/AttrOrTypeDef.cpp +++ b/mlir/lib/TableGen/AttrOrTypeDef.cpp @@ -170,6 +170,16 @@ return def->getValueAsBit("genVerifyDecl"); } +bool AttrOrTypeDef::usesCheckedBuilder() const { + return genVerifyDecl() || anyParamHasVerifier(); +} + +bool AttrOrTypeDef::anyParamHasVerifier() const { + return llvm::any_of(getParameters(), [](const AttrOrTypeParameter &p) { + return p.getVerifier(); + }); +} + Optional AttrOrTypeDef::getExtraDecls() const { auto value = def->getValueAsString("extraClassDeclaration"); return value.empty() ? Optional() : value; @@ -283,6 +293,10 @@ return getDefValue("defaultValue"); } +Optional AttrOrTypeParameter::getVerifier() const { + return getDefValue("verifier"); +} + llvm::Init *AttrOrTypeParameter::getDef() const { return def->getArg(index); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td --- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td +++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td @@ -223,6 +223,7 @@ let assemblyFormat = "`<` $a `>`"; } + // Test simple extern 1D vector using ElementsAttrInterface. def TestExtern1DI64ElementsAttr : Test_Attr<"TestExtern1DI64Elements", [ ElementsAttrInterface @@ -248,4 +249,12 @@ let assemblyFormat = "`<` $handle `>`"; } +def NonNegativeAPFloatParam : APFloatParameter<"positive float"> { + let verifier = [{ return ::mlir::success(!value.isNegative()); }]; +} + +def NonNegativeFloatAttr : Test_Attr<"NonNegativeFloat"> { + let parameters = (ins NonNegativeAPFloatParam:$nnfloat); +} + #endif // TEST_ATTRDEFS diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp --- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp @@ -89,8 +89,8 @@ void emitTopLevelDeclarations(); /// Emit attribute or type builders. void emitBuilders(); - /// Emit a verifier for the def. - void emitVerifier(); + /// Emit verifiers for the def. + void emitVerifiers(); /// Emit parsers and printers. void emitParserPrinter(); /// Emit parameter accessors, if required. @@ -181,8 +181,8 @@ if (storageCls) emitBuilders(); // Emit the verifier. - if (storageCls && def.genVerifyDecl()) - emitVerifier(); + if (storageCls) + emitVerifiers(); // Emit the mnemonic, if there is one, and any associated parser and printer. if (def.getMnemonic()) emitParserPrinter(); @@ -227,22 +227,72 @@ void DefGen::emitBuilders() { if (!def.skipDefaultBuilders()) { emitDefaultBuilder(); - if (def.genVerifyDecl()) + if (def.usesCheckedBuilder()) emitCheckedBuilder(); } for (auto &builder : def.getBuilders()) { emitCustomBuilder(builder); - if (def.genVerifyDecl()) + if (def.usesCheckedBuilder()) emitCheckedCustomBuilder(builder); } } -void DefGen::emitVerifier() { +void DefGen::emitVerifiers() { defCls.declare("Base::getChecked"); - defCls.declareStaticMethod( - "::mlir::LogicalResult", "verify", + + // When no custom verifier code is present, defer to the default + // implementations of verifyInvariants() and verify() in the base class. + if (!def.genVerifyDecl() && !def.anyParamHasVerifier()) + return; + + // Define the verifyInvariants() method, which verifies parameters and then + // invokes the user-provided verify() method (if genVerifyDecl=1) or returns + // success (if genVerifyDecl=0). + Method *viMethod = defCls.addMethod( + "::mlir::LogicalResult", "verifyInvariants", Method::Static, getBuilderParams({{"::llvm::function_ref<::mlir::InFlightDiagnostic()>", "emitError"}})); + + // Generate a check for each attribute/type parameter with verifier code. + const char *verifyFormat = R"( + // Verify parameter '{0}': + if (failed( + [&emitError]({1} value) -> ::mlir::LogicalResult {{ + {2} + }({0}))) + return ::mlir::failure(); + )"; + viMethod->body().indent(); + for (auto ¶m : def.getParameters()) { + if (param.getVerifier()) { + std::string bodyStr(llvm::formatv(verifyFormat, param.getName(), + param.getCppType(), + param.getVerifier())); + viMethod->body().getStream().printReindented(bodyStr); + } + } + + if (def.genVerifyDecl()) { + // Declare the verify() method, which is implemented by the user. + defCls.declareStaticMethod( + "::mlir::LogicalResult", "verify", + getBuilderParams({{"::llvm::function_ref<::mlir::InFlightDiagnostic()>", + "emitError"}})); + + // Call the user verify() function from verifyInvariants() + MethodBody &body = viMethod->body().indent(); + SmallVector argNames; + argNames.push_back("emitError"); + for (auto ¶m : params) + argNames.push_back(param.getName()); + + body << "return verify("; + llvm::interleaveComma(argNames, body); + body << ");\n"; + } else { + // No user verify() function + viMethod->body() << "return ::mlir::success();\n"; + } } void DefGen::emitParserPrinter() { diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp --- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp @@ -288,7 +288,7 @@ // Generate call to the attribute or type builder. Use the checked getter // if one was generated. - if (def.genVerifyDecl()) { + if (def.usesCheckedBuilder()) { os << tgfmt("return $_parser.getChecked<$0>($_loc, $_parser.getContext()", &ctx, def.getCppClassName()); } else { diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp --- a/mlir/unittests/IR/AttributeTest.cpp +++ b/mlir/unittests/IR/AttributeTest.cpp @@ -10,8 +10,12 @@ #include "mlir/IR/BuiltinTypes.h" #include "gtest/gtest.h" +#include "../../test/lib/Dialect/Test/TestAttributes.h" +#include "../../test/lib/Dialect/Test/TestDialect.h" + using namespace mlir; using namespace mlir::detail; +using namespace test; template static void testSplat(Type eltType, const EltTy &splatElt) { @@ -250,4 +254,50 @@ EXPECT_TRUE(zeroStringValue.getType() == stringTy); } +TEST(EnumAttrTest, GetChecked) { + DialectRegistry registry; + registry.insert(); + MLIRContext context(registry); + context.loadDialect(); + + Location loc(UnknownLoc::get(&context)); + auto emitErrorFn = [&]() { return emitError(loc); }; + + // Check that getChecked() returns a non-nullptr attribute when a valid enum + // value is provided (int enum). + auto enumAttr = + TestEnumAttr::getChecked(emitErrorFn, &context, TestEnum::Third); + EXPECT_NE(enumAttr, nullptr); + + // Check that getChecked() returns a nullptr attribute when an invalid enum + // value is provided (int enum). + auto enumAttrInvalid = TestEnumAttr::getChecked(emitErrorFn, &context, + static_cast(-1)); + EXPECT_EQ(enumAttrInvalid, nullptr); + + // Check that getChecked() returns a non-nullptr attribute when a valid enum + // value is provided (bit enum). + auto bitEnumAttr = TestBitEnumAttr::getChecked( + emitErrorFn, &context, TestBitEnum::Read | TestBitEnum::Write); + EXPECT_NE(bitEnumAttr, nullptr); + + // Check that getChecked() returns a nullptr attribute when an invalid enum + // value is provided (bit enum). + auto bitEnumAttrInvalid = TestBitEnumAttr::getChecked( + emitErrorFn, &context, static_cast(0xFFFFFFFF)); + EXPECT_EQ(bitEnumAttrInvalid, nullptr); + + // Check that getChecked() returns a non-nullptr attribute when a non- + // negative value is provided. + auto nonNegAttr = + NonNegativeFloatAttr::getChecked(emitErrorFn, &context, APFloat(1.0)); + EXPECT_NE(nonNegAttr, nullptr); + + // Check that getChecked() returns a non-nullptr attribute when a non- + // negative value is provided. + auto nonNegAttrInvalid = + NonNegativeFloatAttr::getChecked(emitErrorFn, &context, APFloat(-1.0f)); + EXPECT_EQ(nonNegAttrInvalid, nullptr); +} + } // namespace