diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td --- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td @@ -124,9 +124,7 @@ def Arith_ConstantOp : Op, - TypesMatchWith< - "result and attribute have the same type", - "value", "result", "$_self">]> { + AllTypesMatch<["value", "result"]>]> { let summary = "integer or floating point constant"; let description = [{ The `constant` operation produces an SSA value equal to some integer or @@ -154,8 +152,6 @@ let results = (outs /*SignlessIntegerOrFloatLike*/AnyType:$result); let builders = [ - OpBuilder<(ins "Attribute":$value), - [{ build($_builder, $_state, value.getType(), value); }]>, OpBuilder<(ins "Attribute":$value, "Type":$type), [{ build($_builder, $_state, type, value); }]>, ]; diff --git a/mlir/include/mlir/TableGen/CodeGenHelpers.h b/mlir/include/mlir/TableGen/CodeGenHelpers.h --- a/mlir/include/mlir/TableGen/CodeGenHelpers.h +++ b/mlir/include/mlir/TableGen/CodeGenHelpers.h @@ -187,19 +187,9 @@ /// ensure that the static functions have a unique name. std::string uniqueOutputLabel; - /// Unique constraints by their predicate and summary. Constraints that share - /// the same predicate may have different descriptions; ensure that the - /// correct error message is reported when verification fails. - struct ConstraintUniquer { - static Constraint getEmptyKey(); - static Constraint getTombstoneKey(); - static unsigned getHashValue(Constraint constraint); - static bool isEqual(Constraint lhs, Constraint rhs); - }; /// Use a MapVector to ensure that functions are generated deterministically. - using ConstraintMap = - llvm::MapVector>; + using ConstraintMap = llvm::MapVector>; /// A generic function to emit constraints void emitConstraints(const ConstraintMap &constraints, StringRef selfName, diff --git a/mlir/include/mlir/TableGen/Constraint.h b/mlir/include/mlir/TableGen/Constraint.h --- a/mlir/include/mlir/TableGen/Constraint.h +++ b/mlir/include/mlir/TableGen/Constraint.h @@ -94,4 +94,20 @@ } // namespace tblgen } // namespace mlir +namespace llvm { +/// Unique constraints by their predicate and summary. Constraints that share +/// the same predicate may have different descriptions; ensure that the +/// correct error message is reported when verification fails. +template <> +struct DenseMapInfo { + using RecordDenseMapInfo = llvm::DenseMapInfo; + + static mlir::tblgen::Constraint getEmptyKey(); + static mlir::tblgen::Constraint getTombstoneKey(); + static unsigned getHashValue(mlir::tblgen::Constraint constraint); + static bool isEqual(mlir::tblgen::Constraint lhs, + mlir::tblgen::Constraint rhs); +}; +} // namespace llvm + #endif // MLIR_TABLEGEN_CONSTRAINT_H_ diff --git a/mlir/lib/TableGen/Constraint.cpp b/mlir/lib/TableGen/Constraint.cpp --- a/mlir/lib/TableGen/Constraint.cpp +++ b/mlir/lib/TableGen/Constraint.cpp @@ -108,3 +108,34 @@ std::vector &&entities) : constraint(constraint), self(std::string(self)), entities(std::move(entities)) {} + +Constraint DenseMapInfo::getEmptyKey() { + return Constraint(RecordDenseMapInfo::getEmptyKey(), + Constraint::CK_Uncategorized); +} + +Constraint DenseMapInfo::getTombstoneKey() { + return Constraint(RecordDenseMapInfo::getTombstoneKey(), + Constraint::CK_Uncategorized); +} + +unsigned DenseMapInfo::getHashValue(Constraint constraint) { + if (constraint == getEmptyKey()) + return RecordDenseMapInfo::getHashValue(RecordDenseMapInfo::getEmptyKey()); + if (constraint == getTombstoneKey()) { + return RecordDenseMapInfo::getHashValue( + RecordDenseMapInfo::getTombstoneKey()); + } + return llvm::hash_combine(constraint.getPredicate(), constraint.getSummary()); +} + +bool DenseMapInfo::isEqual(Constraint lhs, Constraint rhs) { + if (lhs == rhs) + return true; + if (lhs == getEmptyKey() || lhs == getTombstoneKey()) + return false; + if (rhs == getEmptyKey() || rhs == getTombstoneKey()) + return false; + return lhs.getPredicate() == rhs.getPredicate() && + lhs.getSummary() == rhs.getSummary(); +} diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp --- a/mlir/lib/TableGen/Operator.cpp +++ b/mlir/lib/TableGen/Operator.cpp @@ -357,10 +357,6 @@ continue; } - if (getArg(*mi).is()) { - // TODO: Handle attributes. - continue; - } resultTypeMapping[i].emplace_back(*mi); found = true; } diff --git a/mlir/python/mlir/dialects/_arith_ops_ext.py b/mlir/python/mlir/dialects/_arith_ops_ext.py --- a/mlir/python/mlir/dialects/_arith_ops_ext.py +++ b/mlir/python/mlir/dialects/_arith_ops_ext.py @@ -41,11 +41,11 @@ loc=None, ip=None): if isinstance(value, int): - super().__init__(result, IntegerAttr.get(result, value), loc=loc, ip=ip) + super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip) elif isinstance(value, float): - super().__init__(result, FloatAttr.get(result, value), loc=loc, ip=ip) + super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip) else: - super().__init__(result, value, loc=loc, ip=ip) + super().__init__(value, loc=loc, ip=ip) @classmethod def create_index(cls, value: int, *, loc=None, ip=None): diff --git a/mlir/test/Dialect/Arithmetic/invalid.mlir b/mlir/test/Dialect/Arithmetic/invalid.mlir --- a/mlir/test/Dialect/Arithmetic/invalid.mlir +++ b/mlir/test/Dialect/Arithmetic/invalid.mlir @@ -25,7 +25,7 @@ // ----- func.func @complex_constant_wrong_attribute_type() { - // expected-error @+1 {{'arith.constant' op failed to verify that result and attribute have the same type}} + // expected-error @+1 {{'arith.constant' op failed to verify that all of {value, result} have same type}} %0 = "arith.constant" () {value = 1.0 : f32} : () -> complex return } @@ -50,7 +50,7 @@ func.func @constant() { ^bb: - %x = "arith.constant"(){value = "xyz"} : () -> i32 // expected-error {{'arith.constant' op failed to verify that result and attribute have the same type}} + %x = "arith.constant"(){value = "xyz"} : () -> i32 // expected-error {{'arith.constant' op failed to verify that all of {value, result} have same type}} return } @@ -58,7 +58,7 @@ func.func @constant_out_of_range() { ^bb: - %x = "arith.constant"(){value = 100} : () -> i1 // expected-error {{'arith.constant' op failed to verify that result and attribute have the same type}} + %x = "arith.constant"(){value = 100} : () -> i1 // expected-error {{'arith.constant' op failed to verify that all of {value, result} have same type}} return } @@ -66,7 +66,7 @@ func.func @constant_wrong_type() { ^bb: - %x = "arith.constant"(){value = 10.} : () -> f32 // expected-error {{'arith.constant' op failed to verify that result and attribute have the same type}} + %x = "arith.constant"(){value = 10.} : () -> f32 // expected-error {{'arith.constant' op failed to verify that all of {value, result} have same type}} return } diff --git a/mlir/test/IR/diagnostic-handler.mlir b/mlir/test/IR/diagnostic-handler.mlir --- a/mlir/test/IR/diagnostic-handler.mlir +++ b/mlir/test/IR/diagnostic-handler.mlir @@ -5,7 +5,7 @@ // Emit the first available call stack in the fused location. func.func @constant_out_of_range() { - // CHECK: mysource1:0:0: error: 'arith.constant' op failed to verify that result and attribute have the same type + // CHECK: mysource1:0:0: error: 'arith.constant' op failed to verify that all of {value, result} have same type // CHECK-NEXT: mysource2:1:0: note: called from // CHECK-NEXT: mysource3:2:0: note: called from %x = "arith.constant"() {value = 100} : () -> i1 loc(fused["bar", callsite("foo"("mysource1":0:0) at callsite("mysource2":1:0 at "mysource3":2:0))]) diff --git a/mlir/test/mlir-tblgen/op-result.td b/mlir/test/mlir-tblgen/op-result.td --- a/mlir/test/mlir-tblgen/op-result.td +++ b/mlir/test/mlir-tblgen/op-result.td @@ -123,7 +123,8 @@ // CHECK-LABEL: LogicalResult OpL1::inferReturnTypes // CHECK-NOT: } -// CHECK: inferredReturnTypes[0] = operands[0].getType(); +// CHECK: ::mlir::Type odsInferredType0 = operands[0].getType(); +// CHECK: inferredReturnTypes[0] = odsInferredType0; def OpL2 : NS_Op<"op_with_all_types_constraint", [AllTypesMatch<["c", "b"]>, AllTypesMatch<["a", "d"]>]> { @@ -133,5 +134,18 @@ // CHECK-LABEL: LogicalResult OpL2::inferReturnTypes // CHECK-NOT: } -// CHECK: inferredReturnTypes[0] = operands[2].getType(); -// CHECK: inferredReturnTypes[1] = operands[0].getType(); +// CHECK: ::mlir::Type odsInferredType0 = operands[2].getType(); +// CHECK: ::mlir::Type odsInferredType1 = operands[0].getType(); +// CHECK: inferredReturnTypes[0] = odsInferredType0; +// CHECK: inferredReturnTypes[1] = odsInferredType1; + +def OpL3 : NS_Op<"op_with_all_types_constraint", + [AllTypesMatch<["a", "b"]>]> { + let arguments = (ins I32Attr:$a); + let results = (outs AnyType:$b); +} + +// CHECK-LABEL: LogicalResult OpL3::inferReturnTypes +// CHECK-NOT: } +// CHECK: ::mlir::Type odsInferredType0 = attributes.get("a").getType(); +// CHECK: inferredReturnTypes[0] = odsInferredType0; diff --git a/mlir/tools/mlir-tblgen/CodeGenHelpers.cpp b/mlir/tools/mlir-tblgen/CodeGenHelpers.cpp --- a/mlir/tools/mlir-tblgen/CodeGenHelpers.cpp +++ b/mlir/tools/mlir-tblgen/CodeGenHelpers.cpp @@ -234,41 +234,6 @@ //===----------------------------------------------------------------------===// // Constraint Uniquing -using RecordDenseMapInfo = llvm::DenseMapInfo; - -Constraint StaticVerifierFunctionEmitter::ConstraintUniquer::getEmptyKey() { - return Constraint(RecordDenseMapInfo::getEmptyKey(), - Constraint::CK_Uncategorized); -} - -Constraint StaticVerifierFunctionEmitter::ConstraintUniquer::getTombstoneKey() { - return Constraint(RecordDenseMapInfo::getTombstoneKey(), - Constraint::CK_Uncategorized); -} - -unsigned StaticVerifierFunctionEmitter::ConstraintUniquer::getHashValue( - Constraint constraint) { - if (constraint == getEmptyKey()) - return RecordDenseMapInfo::getHashValue(RecordDenseMapInfo::getEmptyKey()); - if (constraint == getTombstoneKey()) { - return RecordDenseMapInfo::getHashValue( - RecordDenseMapInfo::getTombstoneKey()); - } - return llvm::hash_combine(constraint.getPredicate(), constraint.getSummary()); -} - -bool StaticVerifierFunctionEmitter::ConstraintUniquer::isEqual(Constraint lhs, - Constraint rhs) { - if (lhs == rhs) - return true; - if (lhs == getEmptyKey() || lhs == getTombstoneKey()) - return false; - if (rhs == getEmptyKey() || rhs == getTombstoneKey()) - return false; - return lhs.getPredicate() == rhs.getPredicate() && - lhs.getSummary() == rhs.getSummary(); -} - /// An attribute constraint that references anything other than itself and the /// current op cannot be generically extracted into a function. Most /// prohibitive are operands and results, which require calls to diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -2336,23 +2336,60 @@ fctx.withBuilder("odsBuilder"); body << " ::mlir::Builder odsBuilder(context);\n"; - auto emitType = [&](const tblgen::Operator::ArgOrType &type) -> MethodBody & { - if (!type.isArg()) - return body << tgfmt(*type.getType().getBuilderCall(), &fctx); - auto argIndex = type.getArg(); - assert(!op.getArg(argIndex).is()); + // Preprocess the result types and build all of the types used during + // inferrence. This limits the amount of duplicated work when a type is used + // to infer multiple others. + llvm::DenseMap constraintsTypes; + llvm::DenseMap argumentsTypes; + int inferredTypeIdx = 0; + for (int i = 0, e = op.getNumResults(); i != e; ++i) { + auto type = op.getSameTypeAsResult(i).front(); + + // If the type isn't an argument, it refers to a buildable type. + if (!type.isArg()) { + auto it = constraintsTypes.try_emplace(type.getType(), inferredTypeIdx); + if (!it.second) + continue; + + // If we haven't seen this constraint, generate a variable for it. + body << " ::mlir::Type odsInferredType" << inferredTypeIdx++ << " = " + << tgfmt(*type.getType().getBuilderCall(), &fctx) << ";\n"; + continue; + } + + // Otherwise, this is an argument. + int argIndex = type.getArg(); + auto it = argumentsTypes.try_emplace(argIndex, inferredTypeIdx); + if (!it.second) + continue; + body << " ::mlir::Type odsInferredType" << inferredTypeIdx++ << " = "; + + // If this is an operand, just index into operand list to access the type. auto arg = op.getArgToOperandOrAttribute(argIndex); - if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) - return body << "operands[" << arg.operandOrAttributeIndex() - << "].getType()"; - return body << "attributes[" << arg.operandOrAttributeIndex() - << "].getType()"; - }; + if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) { + body << "operands[" << arg.operandOrAttributeIndex() << "].getType()"; + + // If this is an attribute, index into the attribute dictionary. + } else { + auto *attr = + op.getArg(arg.operandOrAttributeIndex()).get(); + body << "attributes.get(\"" << attr->name << "\").getType()"; + } + body << ";\n"; + } + // Perform a second pass that handles assigning the inferred types to the + // results. for (int i = 0, e = op.getNumResults(); i != e; ++i) { - body << " inferredReturnTypes[" << i << "] = "; auto types = op.getSameTypeAsResult(i); - emitType(types[0]) << ";\n"; + + // Append the inferred type. + auto type = types.front(); + body << " inferredReturnTypes[" << i << "] = odsInferredType" + << (type.isArg() ? argumentsTypes[type.getArg()] + : constraintsTypes[type.getType()]) + << ";\n"; + if (types.size() == 1) continue; // TODO: We could verify equality here, but skipping that for verification.