diff --git a/mlir/include/mlir/TableGen/Predicate.h b/mlir/include/mlir/TableGen/Predicate.h --- a/mlir/include/mlir/TableGen/Predicate.h +++ b/mlir/include/mlir/TableGen/Predicate.h @@ -14,6 +14,7 @@ #define MLIR_TABLEGEN_PREDICATE_H_ #include "mlir/Support/LLVM.h" +#include "llvm/ADT/Hashing.h" #include #include @@ -59,6 +60,8 @@ ArrayRef getLoc() const; protected: + friend llvm::DenseMapInfo; + // The TableGen definition of this predicate. const llvm::Record *def; }; @@ -116,4 +119,18 @@ } // end namespace tblgen } // end namespace mlir +namespace llvm { +template <> +struct DenseMapInfo { + static mlir::tblgen::Pred getEmptyKey() { return mlir::tblgen::Pred(); } + static mlir::tblgen::Pred getTombstoneKey() { return mlir::tblgen::Pred(); } + static unsigned getHashValue(mlir::tblgen::Pred pred) { + return llvm::hash_value(pred.def); + } + static bool isEqual(mlir::tblgen::Pred lhs, mlir::tblgen::Pred rhs) { + return lhs == rhs; + } +}; +} // end namespace llvm + #endif // MLIR_TABLEGEN_PREDICATE_H_ diff --git a/mlir/test/mlir-tblgen/predicate.td b/mlir/test/mlir-tblgen/predicate.td --- a/mlir/test/mlir-tblgen/predicate.td +++ b/mlir/test/mlir-tblgen/predicate.td @@ -13,19 +13,24 @@ def OpA : NS_Op<"op_for_CPred_containing_multiple_same_placeholder", []> { let arguments = (ins I32OrF32:$x); + let results = (outs Variadic:$y); } // CHECK: static ::mlir::LogicalResult [[$INTEGER_FLOAT_CONSTRAINT:__mlir_ods_local_type_constraint.*]]( -// CHECK: if (!((type.isInteger(32) || type.isF32()))) { -// CHECK: return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be 32-bit integer or floating-point type, but got " << type; +// CHECK-NEXT: if (!((type.isInteger(32) || type.isF32()))) { +// CHECK-NEXT: return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be 32-bit integer or floating-point type, but got " << type; + +// Check there is no verifier with same predicate generated. +// CHECK-NOT: if (!((type.isInteger(32) || type.isF32()))) { +// CHECK-NOT: return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be 32-bit integer or floating-point type, but got " << type; // CHECK: static ::mlir::LogicalResult [[$TENSOR_CONSTRAINT:__mlir_ods_local_type_constraint.*]]( -// CHECK: if (!(((type.isa<::mlir::TensorType>())) && ((true)))) { -// CHECK: return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be tensor of any type values, but got " << type; +// CHECK-NEXT: if (!(((type.isa<::mlir::TensorType>())) && ((true)))) { +// CHECK-NEXT: return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be tensor of any type values, but got " << type; // CHECK: static ::mlir::LogicalResult [[$TENSOR_INTEGER_FLOAT_CONSTRAINT:__mlir_ods_local_type_constraint.*]]( -// CHECK: if (!(((type.isa<::mlir::TensorType>())) && (((type.cast<::mlir::ShapedType>().getElementType().isF32())) || ((type.cast<::mlir::ShapedType>().getElementType().isSignlessInteger(32)))))) { -// CHECK: return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be tensor of 32-bit float or 32-bit signless integer values, but got " << type; +// CHECK-NEXT: if (!(((type.isa<::mlir::TensorType>())) && (((type.cast<::mlir::ShapedType>().getElementType().isF32())) || ((type.cast<::mlir::ShapedType>().getElementType().isSignlessInteger(32)))))) { +// CHECK-NEXT: return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be tensor of 32-bit float or 32-bit signless integer values, but got " << type; // CHECK-LABEL: OpA::verify // CHECK: auto valueGroup0 = getODSOperands(0); diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -216,19 +216,50 @@ typeConstraints.insert(result.constraint.getAsOpaquePointer()); } + // Record the mapping from predicate to constraint. If two constraints has the + // same predicate and constraint summary, they can share the same verification + // function. + llvm::DenseMap predToConstraint; FmtContext fctx; for (auto it : llvm::enumerate(typeConstraints)) { + std::string name; + Constraint constraint = Constraint::getFromOpaquePointer(it.value()); + Pred pred = constraint.getPredicate(); + auto iter = predToConstraint.find(pred); + if (iter != predToConstraint.end()) { + do { + Constraint built = Constraint::getFromOpaquePointer(iter->second); + // We may have the different constraints but have the same predicate, + // for example, ConstraintA and Variadic, note that + // Variadic<> doesn't introduce new predicate. In this case, we can + // share the same predicate function if they also have consistent + // summary, otherwise we may report the wrong message while verification + // fails. + if (constraint.getSummary() == built.getSummary()) { + name = getTypeConstraintFn(built).str(); + break; + } + ++iter; + } while (iter != predToConstraint.end() && iter->first == pred); + } + + if (!name.empty()) { + localTypeConstraints.try_emplace(it.value(), name); + continue; + } + // Generate an obscure and unique name for this type constraint. - std::string name = (Twine("__mlir_ods_local_type_constraint_") + - uniqueOutputLabel + Twine(it.index())) - .str(); + name = (Twine("__mlir_ods_local_type_constraint_") + uniqueOutputLabel + + Twine(it.index())) + .str(); + predToConstraint.insert( + std::make_pair(constraint.getPredicate(), it.value())); localTypeConstraints.try_emplace(it.value(), name); // Only generate the methods if we are generating definitions. if (emitDecl) continue; - Constraint constraint = Constraint::getFromOpaquePointer(it.value()); os << "static ::mlir::LogicalResult " << name << "(::mlir::Operation *op, ::mlir::Type type, ::llvm::StringRef " "valueKind, unsigned valueGroupStartIndex) {\n";