diff --git a/mlir/test/mlir-tblgen/op-attribute.td b/mlir/test/mlir-tblgen/op-attribute.td --- a/mlir/test/mlir-tblgen/op-attribute.td +++ b/mlir/test/mlir-tblgen/op-attribute.td @@ -65,14 +65,16 @@ // DEF: ::mlir::LogicalResult AOpAdaptor::verify // DEF: auto tblgen_aAttr = odsAttrs.get("aAttr"); -// DEF-NEXT: if (!tblgen_aAttr) return emitError(loc, "'test.a_op' op ""requires attribute 'aAttr'"); -// DEF: if (!((some-condition))) return emitError(loc, "'test.a_op' op ""attribute 'aAttr' failed to satisfy constraint: some attribute kind"); +// DEF-NEXT: if (!tblgen_aAttr) +// DEF-NEXT: return emitError(loc, "'test.a_op' op ""requires attribute 'aAttr'"); +// DEF: if (tblgen_aAttr && !((some-condition))) +// DEF-NEXT: return emitError(loc, "'test.a_op' op ""attribute 'aAttr' failed to satisfy constraint: some attribute kind"); // DEF: auto tblgen_bAttr = odsAttrs.get("bAttr"); -// DEF-NEXT: if (tblgen_bAttr) { -// DEF-NEXT: if (!((some-condition))) return emitError(loc, "'test.a_op' op ""attribute 'bAttr' failed to satisfy constraint: some attribute kind"); +// DEF-NEXT: if (tblgen_bAttr && !((some-condition))) +// DEF-NEXT: return emitError(loc, "'test.a_op' op ""attribute 'bAttr' failed to satisfy constraint: some attribute kind"); // DEF: auto tblgen_cAttr = odsAttrs.get("cAttr"); -// DEF-NEXT: if (tblgen_cAttr) { -// DEF-NEXT: if (!((some-condition))) return emitError(loc, "'test.a_op' op ""attribute 'cAttr' failed to satisfy constraint: some attribute kind"); +// DEF-NEXT: if (tblgen_cAttr && !((some-condition))) +// DEF-NEXT: return emitError(loc, "'test.a_op' op ""attribute 'cAttr' failed to satisfy constraint: some attribute kind"); // Test getter methods // --- @@ -177,14 +179,16 @@ // DEF: ::mlir::LogicalResult AgetOpAdaptor::verify // DEF: auto tblgen_aAttr = odsAttrs.get("aAttr"); -// DEF-NEXT: if (!tblgen_aAttr) return emitError(loc, "'test2.a_get_op' op ""requires attribute 'aAttr'"); -// DEF: if (!((some-condition))) return emitError(loc, "'test2.a_get_op' op ""attribute 'aAttr' failed to satisfy constraint: some attribute kind"); +// DEF-NEXT: if (!tblgen_aAttr) +// DEF-NEXT. return emitError(loc, "'test2.a_get_op' op ""requires attribute 'aAttr'"); +// DEF: if (tblgen_aAttr && !((some-condition))) +// DEF-NEXT: return emitError(loc, "'test2.a_get_op' op ""attribute 'aAttr' failed to satisfy constraint: some attribute kind"); // DEF: auto tblgen_bAttr = odsAttrs.get("bAttr"); -// DEF-NEXT: if (tblgen_bAttr) { -// DEF-NEXT: if (!((some-condition))) return emitError(loc, "'test2.a_get_op' op ""attribute 'bAttr' failed to satisfy constraint: some attribute kind"); +// DEF-NEXT: if (tblgen_bAttr && !((some-condition))) +// DEF-NEXT: return emitError(loc, "'test2.a_get_op' op ""attribute 'bAttr' failed to satisfy constraint: some attribute kind"); // DEF: auto tblgen_cAttr = odsAttrs.get("cAttr"); -// DEF-NEXT: if (tblgen_cAttr) { -// DEF-NEXT: if (!((some-condition))) return emitError(loc, "'test2.a_get_op' op ""attribute 'cAttr' failed to satisfy constraint: some attribute kind"); +// DEF-NEXT: if (tblgen_cAttr && !((some-condition))) +// DEF-NEXT: return emitError(loc, "'test2.a_get_op' op ""attribute 'cAttr' failed to satisfy constraint: some attribute kind"); // Test getter methods // --- @@ -267,19 +271,19 @@ // --- // DEF-LABEL: BOpAdaptor::verify -// DEF: if (!((true))) -// DEF: if (!((tblgen_bool_attr.isa<::mlir::BoolAttr>()))) -// DEF: if (!(((tblgen_i32_attr.isa<::mlir::IntegerAttr>())) && ((tblgen_i32_attr.cast<::mlir::IntegerAttr>().getType().isSignlessInteger(32))))) -// DEF: if (!(((tblgen_i64_attr.isa<::mlir::IntegerAttr>())) && ((tblgen_i64_attr.cast<::mlir::IntegerAttr>().getType().isSignlessInteger(64))))) -// DEF: if (!(((tblgen_f32_attr.isa<::mlir::FloatAttr>())) && ((tblgen_f32_attr.cast<::mlir::FloatAttr>().getType().isF32())))) -// DEF: if (!(((tblgen_f64_attr.isa<::mlir::FloatAttr>())) && ((tblgen_f64_attr.cast<::mlir::FloatAttr>().getType().isF64())))) -// DEF: if (!((tblgen_str_attr.isa<::mlir::StringAttr>()))) -// DEF: if (!((tblgen_elements_attr.isa<::mlir::ElementsAttr>()))) -// DEF: if (!((tblgen_function_attr.isa<::mlir::FlatSymbolRefAttr>()))) -// DEF: if (!(((tblgen_some_type_attr.isa<::mlir::TypeAttr>())) && ((tblgen_some_type_attr.cast<::mlir::TypeAttr>().getValue().isa())))) -// DEF: if (!((tblgen_array_attr.isa<::mlir::ArrayAttr>()))) -// DEF: if (!(((tblgen_some_attr_array.isa<::mlir::ArrayAttr>())) && (::llvm::all_of(tblgen_some_attr_array.cast<::mlir::ArrayAttr>(), [&](::mlir::Attribute attr) { return (some-condition); })))) -// DEF: if (!(((tblgen_type_attr.isa<::mlir::TypeAttr>())) && ((tblgen_type_attr.cast<::mlir::TypeAttr>().getValue().isa<::mlir::Type>())))) +// DEF: if (tblgen_any_attr && !((true))) +// DEF: if (tblgen_bool_attr && !((tblgen_bool_attr.isa<::mlir::BoolAttr>()))) +// DEF: if (tblgen_i32_attr && !(((tblgen_i32_attr.isa<::mlir::IntegerAttr>())) && ((tblgen_i32_attr.cast<::mlir::IntegerAttr>().getType().isSignlessInteger(32))))) +// DEF: if (tblgen_i64_attr && !(((tblgen_i64_attr.isa<::mlir::IntegerAttr>())) && ((tblgen_i64_attr.cast<::mlir::IntegerAttr>().getType().isSignlessInteger(64))))) +// DEF: if (tblgen_f32_attr && !(((tblgen_f32_attr.isa<::mlir::FloatAttr>())) && ((tblgen_f32_attr.cast<::mlir::FloatAttr>().getType().isF32())))) +// DEF: if (tblgen_f64_attr && !(((tblgen_f64_attr.isa<::mlir::FloatAttr>())) && ((tblgen_f64_attr.cast<::mlir::FloatAttr>().getType().isF64())))) +// DEF: if (tblgen_str_attr && !((tblgen_str_attr.isa<::mlir::StringAttr>()))) +// DEF: if (tblgen_elements_attr && !((tblgen_elements_attr.isa<::mlir::ElementsAttr>()))) +// DEF: if (tblgen_function_attr && !((tblgen_function_attr.isa<::mlir::FlatSymbolRefAttr>()))) +// DEF: if (tblgen_some_type_attr && !(((tblgen_some_type_attr.isa<::mlir::TypeAttr>())) && ((tblgen_some_type_attr.cast<::mlir::TypeAttr>().getValue().isa())))) +// DEF: if (tblgen_array_attr && !((tblgen_array_attr.isa<::mlir::ArrayAttr>()))) +// DEF: if (tblgen_some_attr_array && !(((tblgen_some_attr_array.isa<::mlir::ArrayAttr>())) && (::llvm::all_of(tblgen_some_attr_array.cast<::mlir::ArrayAttr>(), [&](::mlir::Attribute attr) { return (some-condition); })))) +// DEF: if (tblgen_type_attr && !(((tblgen_type_attr.isa<::mlir::TypeAttr>())) && ((tblgen_type_attr.cast<::mlir::TypeAttr>().getValue().isa<::mlir::Type>())))) // Test common attribute kind getters' return types // --- diff --git a/mlir/test/mlir-tblgen/predicate.td b/mlir/test/mlir-tblgen/predicate.td --- a/mlir/test/mlir-tblgen/predicate.td +++ b/mlir/test/mlir-tblgen/predicate.td @@ -51,7 +51,7 @@ // CHECK-LABEL: OpFAdaptor::verify // CHECK: (tblgen_attr.cast<::mlir::IntegerAttr>().getInt() >= 10) -// CHECK-SAME: "attribute 'attr' failed to satisfy constraint: 32-bit signless integer attribute whose minimum value is 10" +// CHECK-NEXT: "attribute 'attr' failed to satisfy constraint: 32-bit signless integer attribute whose minimum value is 10" def OpFX : NS_Op<"op_for_int_max_val", []> { let arguments = (ins Confined]>:$attr); @@ -59,7 +59,7 @@ // CHECK-LABEL: OpFXAdaptor::verify // CHECK: (tblgen_attr.cast<::mlir::IntegerAttr>().getInt() <= 10) -// CHECK-SAME: "attribute 'attr' failed to satisfy constraint: 32-bit signless integer attribute whose maximum value is 10" +// CHECK-NEXT: "attribute 'attr' failed to satisfy constraint: 32-bit signless integer attribute whose maximum value is 10" def OpG : NS_Op<"op_for_arr_min_count", []> { let arguments = (ins Confined]>:$attr); @@ -67,7 +67,7 @@ // CHECK-LABEL: OpGAdaptor::verify // CHECK: (tblgen_attr.cast<::mlir::ArrayAttr>().size() >= 8) -// CHECK-SAME: "attribute 'attr' failed to satisfy constraint: array attribute with at least 8 elements" +// CHECK-NEXT: "attribute 'attr' failed to satisfy constraint: array attribute with at least 8 elements" def OpH : NS_Op<"op_for_arr_value_at_index", []> { let arguments = (ins Confined]>:$attr); @@ -75,7 +75,7 @@ // CHECK-LABEL: OpHAdaptor::verify // CHECK: (((tblgen_attr.cast<::mlir::ArrayAttr>().size() > 0)) && ((tblgen_attr.cast<::mlir::ArrayAttr>()[0].cast<::mlir::IntegerAttr>().getInt() == 8))))) -// CHECK-SAME: "attribute 'attr' failed to satisfy constraint: array attribute whose 0-th element must be 8" +// CHECK-NEXT: "attribute 'attr' failed to satisfy constraint: array attribute whose 0-th element must be 8" def OpI: NS_Op<"op_for_arr_min_value_at_index", []> { let arguments = (ins Confined]>:$attr); @@ -83,7 +83,7 @@ // CHECK-LABEL: OpIAdaptor::verify // CHECK: (((tblgen_attr.cast<::mlir::ArrayAttr>().size() > 0)) && ((tblgen_attr.cast<::mlir::ArrayAttr>()[0].cast<::mlir::IntegerAttr>().getInt() >= 8))))) -// CHECK-SAME: "attribute 'attr' failed to satisfy constraint: array attribute whose 0-th element must be at least 8" +// CHECK-NEXT: "attribute 'attr' failed to satisfy constraint: array attribute whose 0-th element must be at least 8" def OpJ: NS_Op<"op_for_TCopVTEtAreSameAt", [ PredOpTrait<"operands indexed at 0, 2, 3 should all have " @@ -121,4 +121,4 @@ // CHECK-LABEL: OpLAdaptor::verify // CHECK: getValue() == "foo" -// CHECK-SAME: only value \"foo\" is allowed +// CHECK-NEXT: only value \"foo\" is allowed diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -42,6 +42,17 @@ static const char *const odsBuilder = "odsBuilder"; static const char *const builderOpState = "odsState"; +// Code for OpAdaptors to lookup an attribute using strings on the provided +// DictionaryAttr. +// +// {0}: The attribute name. +static const char *const adaptorGetAttr = "odsAttrs.get(\"{0}\")"; + +// Code for Ops to lookup an attribute using the cached identifier. +// +// {0}: The attribute's getter name. +static const char *const opGetAttr = "(*this)->getAttr({0}AttrName())"; + // The logic to calculate the actual value range for a declared operand/result // of an op with variadic operands/results. Note that this logic is not for // general use; it assumes all variadic operands/results must have the same @@ -163,8 +174,7 @@ const auto &operand = op.getOperand(index); if (!operand.name.empty()) return std::string(operand.name); - else - return std::string(formatv("{0}_{1}", generatedArgName, index)); + return std::string(formatv("{0}_{1}", generatedArgName, index)); } // Returns true if we can use unwrapped value for the given `attr` in builders. @@ -370,13 +380,16 @@ // an operand (the generated function call returns an OperandRange); // - resultGet corresponds to the name of the function to get an result (the // generated function call returns a ValueRange); +// - opRequired whether an op instance is needed static void populateSubstitutions(const Operator &op, const char *attrGet, const char *operandGet, const char *resultGet, - FmtContext &ctx) { + FmtContext &ctx, bool opRequired) { // Populate substitutions for attributes and named operands. - for (const auto &namedAttr : op.getAttributes()) + for (const auto &namedAttr : op.getAttributes()) { ctx.addSubst(namedAttr.name, - formatv("{0}(\"{1}\")", attrGet, namedAttr.name)); + formatv(attrGet, opRequired ? op.getGetterName(namedAttr.name) + : namedAttr.name)); + } for (int i = 0, e = op.getNumOperands(); i < e; ++i) { auto &value = op.getOperand(i); if (value.name.empty()) @@ -413,62 +426,63 @@ const Twine &emitErrorPrefix, bool emitVerificationRequiringOp, FmtContext &ctx, OpMethodBody &body) { + // Check that a required attribute exists. + // + // {0}: Attribute variable name. + // {1}: Emit error prefix. + // {2}: Attribute name. + const char *const checkRequiredAttr = R"( + if (!{0}) + return {1}"requires attribute '{2}'"); + )"; + // Check the condition on an attribute if it is required. This assumes that + // default values are valid. + // TODO: verify the default value is valid (perhaps in debug mode only). + // + // {0}: Attribute variable name. + // {1}: Attribute condition code. + // {2}: Emit error prefix. + // {3}: Attribute/constraint description. + const char *const checkAttrCondition = R"( + if ({0} && !({1})) + return {2}"attribute '{3}' failed to satisfy constraint: {4}"); + )"; + for (const auto &namedAttr : op.getAttributes()) { const auto &attr = namedAttr.attr; + StringRef attrName = namedAttr.name; if (attr.isDerivedAttr()) continue; - auto attrName = namedAttr.name; bool allowMissingAttr = attr.hasDefaultValue() || attr.isOptional(); auto attrPred = attr.getPredicate(); - auto condition = attrPred.isNull() ? "" : attrPred.getCondition(); - // There is a condition to emit only if the use of $_op and whether to - // emit verifications for op matches. - bool hasConditionToEmit = (!(condition.find("$_op") != StringRef::npos) ^ - emitVerificationRequiringOp); + std::string condition = attrPred.isNull() ? "" : attrPred.getCondition(); + // If the attribute's condition needs an op but none is available, then the + // condition cannot be emitted. + bool canEmitCondition = + !StringRef(condition).contains("$_op") || emitVerificationRequiringOp; // Prefix with `tblgen_` to avoid hiding the attribute accessor. - auto varName = tblgenNamePrefix + attrName; - - // If the attribute is - // 1. Required (not allowed missing) and not in op verification, or - // 2. Has a condition that will get verified - // then the variable will be used. - // - // Therefore, for optional attributes whose verification requires that an - // op already exists for verification/emitVerificationRequiringOp is set - // has nothing that can be verified here. - if ((allowMissingAttr || emitVerificationRequiringOp) && - !hasConditionToEmit) - continue; + Twine varName = tblgenNamePrefix + attrName; - body << formatv(" {\n auto {0} = {1}(\"{2}\");\n", varName, attrGet, - attrName); - - if (!emitVerificationRequiringOp && !allowMissingAttr) { - body << " if (!" << varName << ") return " << emitErrorPrefix - << "\"requires attribute '" << attrName << "'\");\n"; - } - - if (!hasConditionToEmit) { - body << " }\n"; + // If the attribute is not required and we cannot emit the condition, then + // there is nothing to be done. + if (allowMissingAttr && !canEmitCondition) continue; - } - if (allowMissingAttr) { - // If the attribute has a default value, then only verify the predicate if - // set. This does effectively assume that the default value is valid. - // TODO: verify the debug value is valid (perhaps in debug mode only). - body << " if (" << varName << ") {\n"; + body << formatv(" {\n auto {0} = {1};", varName, + formatv(attrGet, emitVerificationRequiringOp + ? op.getGetterName(attrName) + : attrName)); + + if (!allowMissingAttr) + body << formatv(checkRequiredAttr, varName, emitErrorPrefix, attrName); + if (canEmitCondition) { + body << formatv(checkAttrCondition, varName, + tgfmt(condition, &ctx.withSelf(varName)), emitErrorPrefix, + attrName, escapeString(attr.getSummary())); } - - body << tgfmt(" if (!($0)) return $1\"attribute '$2' " - "failed to satisfy constraint: $3\");\n", - /*ctx=*/nullptr, tgfmt(condition, &ctx.withSelf(varName)), - emitErrorPrefix, attrName, escapeString(attr.getSummary())); - if (allowMissingAttr) - body << " }\n"; - body << " }\n"; + body << "}\n"; } } @@ -2085,21 +2099,71 @@ method->body() << " " << tgfmt(printer, &fctx); } +/// Generate verification on native traits requiring attributes. +static void genNativeTraitAttrVerifier(OpMethodBody &body, const Operator &op, + const char *const attrGet, + const Twine &emitError, + bool opRequired) { + // Check that the variadic segment sizes attribute exists and contains the + // expected number of elements. + // + // {0}: Attribute name. + // {1}: Expected number of elements. + // {2}: "operand" or "result". + // {3}: Attribute getter call. + // {4}: Emit error prefix. + const char *const checkAttrSizedValueSegmentsCode = R"( + { + auto sizeAttr = {3}.dyn_cast<::mlir::DenseIntElementsAttr>(); + if (!sizeAttr) + return {4}"missing segment sizes attribute '{0}'"); + auto numElements = + sizeAttr.getType().cast<::mlir::ShapedType>().getNumElements(); + if (numElements != {1}) + return {4}"'{0}' attribute for specifying {2} segments must have {1} " + "elements, but got ") << numElements; + } + )"; + + // Verify a few traits first so that we can use getODSOperands() and + // getODSResults() in the rest of the verifier. + for (auto &trait : op.getTraits()) { + auto *t = dyn_cast(&trait); + if (!t) + continue; + std::string traitName = t->getFullyQualifiedTraitName(); + if (traitName == "::mlir::OpTrait::AttrSizedOperandSegments") { + StringRef attrName = "operand_segment_sizes"; + body << formatv( + checkAttrSizedValueSegmentsCode, attrName, op.getNumOperands(), + "operand", + formatv(attrGet, opRequired ? op.getGetterName(attrName) : attrName), + emitError); + } else if (traitName == "::mlir::OpTrait::AttrSizedResultSegments") { + StringRef attrName = "result_segment_sizes"; + body << formatv( + checkAttrSizedValueSegmentsCode, attrName, op.getNumResults(), + "result", + formatv(attrGet, opRequired ? op.getGetterName(attrName) : attrName), + emitError); + } + } +} + void OpEmitter::genVerifier() { auto *method = opClass.addMethodAndPrune("::mlir::LogicalResult", "verify"); ERROR_IF_PRUNED(method, "verify", op); auto &body = method->body(); - body << " if (::mlir::failed(" << op.getAdaptorName() - << "(*this).verify((*this)->getLoc()))) " - << "return ::mlir::failure();\n"; + + genNativeTraitAttrVerifier(body, op, opGetAttr, "emitOpError(", true); auto *valueInit = def.getValueInit("verifier"); StringInit *stringInit = dyn_cast(valueInit); bool hasCustomVerify = stringInit && !stringInit->getValue().empty(); - populateSubstitutions(op, "(*this)->getAttr", "this->getODSOperands", - "this->getODSResults", verifyCtx); + populateSubstitutions(op, opGetAttr, "this->getODSOperands", + "this->getODSResults", verifyCtx, /*opRequired=*/true); - genAttributeVerifier(op, "(*this)->getAttr", "emitOpError(", + genAttributeVerifier(op, opGetAttr, "emitOpError(", /*emitVerificationRequiringOp=*/true, verifyCtx, body); genOperandResultVerifier(body, op.getOperands(), "operand"); genOperandResultVerifier(body, op.getResults(), "result"); @@ -2530,39 +2594,16 @@ ERROR_IF_PRUNED(method, "verify", op); auto &body = method->body(); - const char *checkAttrSizedValueSegmentsCode = R"( - { - auto sizeAttr = odsAttrs.get("{0}").cast<::mlir::DenseIntElementsAttr>(); - auto numElements = sizeAttr.getType().cast<::mlir::ShapedType>().getNumElements(); - if (numElements != {1}) - return emitError(loc, "'{0}' attribute for specifying {2} segments " - "must have {1} elements, but got ") << numElements; - } - )"; - - // Verify a few traits first so that we can use - // getODSOperands()/getODSResults() in the rest of the verifier. - for (auto &trait : op.getTraits()) { - if (auto *t = dyn_cast(&trait)) { - if (t->getFullyQualifiedTraitName() == - "::mlir::OpTrait::AttrSizedOperandSegments") { - body << formatv(checkAttrSizedValueSegmentsCode, - "operand_segment_sizes", op.getNumOperands(), - "operand"); - } else if (t->getFullyQualifiedTraitName() == - "::mlir::OpTrait::AttrSizedResultSegments") { - body << formatv(checkAttrSizedValueSegmentsCode, "result_segment_sizes", - op.getNumResults(), "result"); - } - } - } + std::string emitError = + "emitError(loc, \"'" + op.getOperationName() + "' op \""; + genNativeTraitAttrVerifier(body, op, adaptorGetAttr, emitError, + /*opRequired=*/false); FmtContext verifyCtx; - populateSubstitutions(op, "odsAttrs.get", "getODSOperands", - "", verifyCtx); - genAttributeVerifier(op, "odsAttrs.get", - Twine("emitError(loc, \"'") + op.getOperationName() + - "' op \"", + populateSubstitutions(op, adaptorGetAttr, "getODSOperands", + "", verifyCtx, + /*opRequired=*/false); + genAttributeVerifier(op, adaptorGetAttr, emitError, /*emitVerificationRequiringOp*/ false, verifyCtx, body); body << " return ::mlir::success();";