diff --git a/mlir/lib/TableGen/Constraint.cpp b/mlir/lib/TableGen/Constraint.cpp --- a/mlir/lib/TableGen/Constraint.cpp +++ b/mlir/lib/TableGen/Constraint.cpp @@ -17,21 +17,28 @@ Constraint::Constraint(const llvm::Record *record) : def(record), kind(CK_Uncategorized) { - if (record->isSubClassOf("TypeConstraint")) { + // Look through OpVariable's to their constraint. + if (def->isSubClassOf("OpVariable")) + def = def->getValueAsDef("constraint"); + if (def->isSubClassOf("TypeConstraint")) { kind = CK_Type; - } else if (record->isSubClassOf("AttrConstraint")) { + } else if (def->isSubClassOf("AttrConstraint")) { kind = CK_Attr; - } else if (record->isSubClassOf("RegionConstraint")) { + } else if (def->isSubClassOf("RegionConstraint")) { kind = CK_Region; - } else if (record->isSubClassOf("SuccessorConstraint")) { + } else if (def->isSubClassOf("SuccessorConstraint")) { kind = CK_Successor; } else { - assert(record->isSubClassOf("Constraint")); + assert(def->isSubClassOf("Constraint")); } } Constraint::Constraint(Kind kind, const llvm::Record *record) - : def(record), kind(kind) {} + : def(record), kind(kind) { + // Look through OpVariable's to their constraint. + if (def->isSubClassOf("OpVariable")) + def = def->getValueAsDef("constraint"); +} Pred Constraint::getPredicate() const { auto *val = def->getValue("predicate"); diff --git a/mlir/test/mlir-tblgen/op-decl.td b/mlir/test/mlir-tblgen/op-decl.td --- a/mlir/test/mlir-tblgen/op-decl.td +++ b/mlir/test/mlir-tblgen/op-decl.td @@ -155,7 +155,7 @@ def NS_FOp : NS_Op<"op_with_all_types_constraint", [AllTypesMatch<["a", "b"]>]> { let arguments = (ins AnyType:$a); - let results = (outs AnyType:$b); + let results = (outs Res:$b); } // CHECK-LABEL: class FOp :