diff --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h --- a/mlir/include/mlir/TableGen/Operator.h +++ b/mlir/include/mlir/TableGen/Operator.h @@ -261,6 +261,15 @@ // Requires: all result types are known. ArrayRef getSameTypeAsResult(int index) const; + // Pair consisting kind of argument and index into operands or attributes. + struct OperandOrAttribute { + enum class Kind { Operand, Attribute } kind; + int operandOrAttributeIndex; + }; + + // Returns the OperandOrAttribute corresponding to the argument. + OperandOrAttribute getArgToOperandOrAttribute(int index) const; + private: // Populates the vectors containing operands, attributes, results and traits. void populateOpStructure(); @@ -303,6 +312,9 @@ // The argument with the same type as the result. SmallVector, 4> resultTypeMapping; + // Map from argument to attribute or operand number. + SmallVector attrOrOperandMapping; + // The number of native attributes stored in the leading positions of // `attributes`. int numNativeAttributes; diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp --- a/mlir/lib/TableGen/Operator.cpp +++ b/mlir/lib/TableGen/Operator.cpp @@ -436,9 +436,13 @@ argDef = argDef->getValueAsDef("constraint"); if (argDef->isSubClassOf(typeConstraintClass)) { + attrOrOperandMapping.push_back( + {OperandOrAttribute::Kind::Operand, operandIndex}); arguments.emplace_back(&operands[operandIndex++]); } else { assert(argDef->isSubClassOf(attrClass)); + attrOrOperandMapping.push_back( + {OperandOrAttribute::Kind::Attribute, attrIndex}); arguments.emplace_back(&attributes[attrIndex++]); } } @@ -581,3 +585,8 @@ -> VariableDecorator { return VariableDecorator(cast(init)->getDef()); } + +auto tblgen::Operator::getArgToOperandOrAttribute(int index) const + -> OperandOrAttribute { + return attrOrOperandMapping[index]; +} diff --git a/mlir/test/mlir-tblgen/op-result.td b/mlir/test/mlir-tblgen/op-result.td --- a/mlir/test/mlir-tblgen/op-result.td +++ b/mlir/test/mlir-tblgen/op-result.td @@ -1,6 +1,7 @@ // RUN: mlir-tblgen -gen-op-defs -I %S/../../include %s | FileCheck %s include "mlir/IR/OpBase.td" +include "mlir/Interfaces/InferTypeOpInterface.td" def Test_Dialect : Dialect { let name = "test"; @@ -111,3 +112,15 @@ // CHECK-LABEL: OpK::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange input) // CHECK: odsState.addTypes({input.front().getType()}); + +// Test with inferred shapes and interleaved with operands/attributes. +// +def OpL : NS_Op<"op_with_all_types_constraint", + [AllTypesMatch<["a", "b"]>]> { + let arguments = (ins I32Attr:$attr1, AnyType:$a); + let results = (outs Res:$b); +} + +// CHECK-LABEL: LogicalResult OpL::inferReturnTypes +// CHECK-NOT: } +// CHECK: inferredReturnTypes[0] = operands[0].getType(); diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -1601,7 +1601,12 @@ if (type.isArg()) { auto argIndex = type.getArg(); assert(!op.getArg(argIndex).is()); - return os << "operands[" << argIndex << "].getType()"; + auto arg = op.getArgToOperandOrAttribute(argIndex); + if (arg.kind == mlir::tblgen::Operator::OperandOrAttribute::Kind::Operand) + return os << "operands[" << arg.operandOrAttributeIndex + << "].getType()"; + return os << "attributes[" << arg.operandOrAttributeIndex + << "].getType()"; } else { return os << tgfmt(*type.getType().getBuilderCall(), &fctx); }