diff --git a/mlir/include/mlir/TableGen/Predicate.h b/mlir/include/mlir/TableGen/Predicate.h --- a/mlir/include/mlir/TableGen/Predicate.h +++ b/mlir/include/mlir/TableGen/Predicate.h @@ -58,6 +58,13 @@ // Get the location of the predicate. ArrayRef getLoc() const; + // Get an opaque pointer to this predicate. + const void *getAsOpaquePointer() const { return def; } + // Construct a predicate from the opaque pointer representation. + static Pred getFromOpaquePointer(const void *ptr) { + return Pred(reinterpret_cast(ptr)); + } + protected: // The TableGen definition of this predicate. const llvm::Record *def; diff --git a/mlir/test/mlir-tblgen/predicate.td b/mlir/test/mlir-tblgen/predicate.td --- a/mlir/test/mlir-tblgen/predicate.td +++ b/mlir/test/mlir-tblgen/predicate.td @@ -13,19 +13,24 @@ def OpA : NS_Op<"op_for_CPred_containing_multiple_same_placeholder", []> { let arguments = (ins I32OrF32:$x); + let results = (outs Variadic:$y); } // CHECK: static ::mlir::LogicalResult [[$INTEGER_FLOAT_CONSTRAINT:__mlir_ods_local_type_constraint.*]]( -// CHECK: if (!((type.isInteger(32) || type.isF32()))) { -// CHECK: return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be 32-bit integer or floating-point type, but got " << type; +// CHECK-NEXT: if (!((type.isInteger(32) || type.isF32()))) { +// CHECK-NEXT: return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be 32-bit integer or floating-point type, but got " << type; + +// Check there is no verifier with same predicate generated. +// CHECK-NOT: if (!((type.isInteger(32) || type.isF32()))) { +// CHECK-NOT: return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be 32-bit integer or floating-point type, but got " << type; // CHECK: static ::mlir::LogicalResult [[$TENSOR_CONSTRAINT:__mlir_ods_local_type_constraint.*]]( -// CHECK: if (!(((type.isa<::mlir::TensorType>())) && ((true)))) { -// CHECK: return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be tensor of any type values, but got " << type; +// CHECK-NEXT: if (!(((type.isa<::mlir::TensorType>())) && ((true)))) { +// CHECK-NEXT: return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be tensor of any type values, but got " << type; // CHECK: static ::mlir::LogicalResult [[$TENSOR_INTEGER_FLOAT_CONSTRAINT:__mlir_ods_local_type_constraint.*]]( -// CHECK: if (!(((type.isa<::mlir::TensorType>())) && (((type.cast<::mlir::ShapedType>().getElementType().isF32())) || ((type.cast<::mlir::ShapedType>().getElementType().isSignlessInteger(32)))))) { -// CHECK: return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be tensor of 32-bit float or 32-bit signless integer values, but got " << type; +// CHECK-NEXT: if (!(((type.isa<::mlir::TensorType>())) && (((type.cast<::mlir::ShapedType>().getElementType().isF32())) || ((type.cast<::mlir::ShapedType>().getElementType().isSignlessInteger(32)))))) { +// CHECK-NEXT: return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be tensor of 32-bit float or 32-bit signless integer values, but got " << type; // CHECK-LABEL: OpA::verify // CHECK: auto valueGroup0 = getODSOperands(0); diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -29,6 +29,8 @@ #include "llvm/TableGen/Record.h" #include "llvm/TableGen/TableGenBackend.h" +#include + #define DEBUG_TYPE "mlir-tblgen-opdefgen" using namespace llvm; @@ -215,19 +217,45 @@ typeConstraints.insert(result.constraint.getAsOpaquePointer()); } + // Record the mapping from predicate to constraint. If two constraints has the + // same predicate and constraint summary, they can share the same verification + // function. + std::unordered_multimap predToConstraint; FmtContext fctx; for (auto it : llvm::enumerate(typeConstraints)) { + std::string name; + Constraint constraint = Constraint::getFromOpaquePointer(it.value()); + Pred pred = constraint.getPredicate(); + auto iter = predToConstraint.find(pred.getAsOpaquePointer()); + if (iter != predToConstraint.end()) { + do { + Constraint built = Constraint::getFromOpaquePointer(iter->second); + if (constraint.getSummary() == built.getSummary()) { + name = getTypeConstraintFn(built).str(); + break; + } + ++iter; + } while (iter != predToConstraint.end() && + iter->first == pred.getAsOpaquePointer()); + } + + if (!name.empty()) { + localTypeConstraints.try_emplace(it.value(), name); + continue; + } + // Generate an obscure and unique name for this type constraint. - std::string name = (Twine("__mlir_ods_local_type_constraint_") + - uniqueOutputLabel + Twine(it.index())) - .str(); + name = (Twine("__mlir_ods_local_type_constraint_") + uniqueOutputLabel + + Twine(it.index())) + .str(); + predToConstraint.insert(std::make_pair( + constraint.getPredicate().getAsOpaquePointer(), it.value())); localTypeConstraints.try_emplace(it.value(), name); // Only generate the methods if we are generating definitions. if (emitDecl) continue; - Constraint constraint = Constraint::getFromOpaquePointer(it.value()); os << "static ::mlir::LogicalResult " << name << "(::mlir::Operation *op, ::mlir::Type type, ::llvm::StringRef " "valueKind, unsigned valueGroupStartIndex) {\n";