diff --git a/mlir/include/mlir/TableGen/Constraint.h b/mlir/include/mlir/TableGen/Constraint.h --- a/mlir/include/mlir/TableGen/Constraint.h +++ b/mlir/include/mlir/TableGen/Constraint.h @@ -52,6 +52,13 @@ Kind getKind() const { return kind; } + /// Get an opaque pointer to the constraint. + const void *getAsOpaquePointer() const { return def; } + /// Construct a constraint from the opaque pointer representation. + static Constraint getFromOpaquePointer(const void *ptr) { + return Constraint(reinterpret_cast(ptr)); + } + protected: Constraint(Kind kind, const llvm::Record *record); diff --git a/mlir/test/mlir-tblgen/predicate.td b/mlir/test/mlir-tblgen/predicate.td --- a/mlir/test/mlir-tblgen/predicate.td +++ b/mlir/test/mlir-tblgen/predicate.td @@ -15,10 +15,22 @@ let arguments = (ins I32OrF32:$x); } +// CHECK: static ::mlir::LogicalResult [[$INTEGER_FLOAT_CONSTRAINT:__mlir_ods_local_type_constraint.*]]( +// CHECK: if (!((type.isInteger(32) || type.isF32()))) { +// CHECK: return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be 32-bit integer or floating-point type, but got " << type; + +// CHECK: static ::mlir::LogicalResult [[$TENSOR_CONSTRAINT:__mlir_ods_local_type_constraint.*]]( +// CHECK: if (!(((type.isa<::mlir::TensorType>())) && ((true)))) { +// CHECK: return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be tensor of any type values, but got " << type; + +// CHECK: static ::mlir::LogicalResult [[$TENSOR_INTEGER_FLOAT_CONSTRAINT:__mlir_ods_local_type_constraint.*]]( +// CHECK: if (!(((type.isa<::mlir::TensorType>())) && (((type.cast<::mlir::ShapedType>().getElementType().isF32())) || ((type.cast<::mlir::ShapedType>().getElementType().isSignlessInteger(32)))))) { +// CHECK: return op->emitOpError(valueKind) << " #" << valueGroupStartIndex << " must be tensor of 32-bit float or 32-bit signless integer values, but got " << type; + // CHECK-LABEL: OpA::verify // CHECK: auto valueGroup0 = getODSOperands(0); // CHECK: for (::mlir::Value v : valueGroup0) { -// CHECK: if (!((v.getType().isInteger(32) || v.getType().isF32()))) +// CHECK: if (::mlir::failed([[$INTEGER_FLOAT_CONSTRAINT]] def OpB : NS_Op<"op_for_And_PredOpTrait", [ PredOpTrait<"both first and second holds", @@ -93,4 +105,4 @@ // CHECK-LABEL: OpK::verify // CHECK: auto valueGroup0 = getODSOperands(0); // CHECK: for (::mlir::Value v : valueGroup0) { -// CHECK: if (!(((v.getType().isa<::mlir::TensorType>())) && (((v.getType().cast<::mlir::ShapedType>().getElementType().isF32())) || ((v.getType().cast<::mlir::ShapedType>().getElementType().isSignlessInteger(32)))))) +// CHECK: if (::mlir::failed([[$TENSOR_INTEGER_FLOAT_CONSTRAINT]] diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -118,6 +118,144 @@ )"; +//===----------------------------------------------------------------------===// +// StaticVerifierFunctionEmitter +//===----------------------------------------------------------------------===// + +namespace { +/// This class deduplicates shared operation verification code by emitting +/// static functions alongside the op definitions. These methods are local to +/// the definition file, and are invoked within the operation verify methods. +/// An example is shown below: +/// +/// static LogicalResult localVerify(...) +/// +/// LogicalResult OpA::verify(...) { +/// if (failed(localVerify(...))) +/// return failure(); +/// ... +/// } +/// +/// LogicalResult OpB::verify(...) { +/// if (failed(localVerify(...))) +/// return failure(); +/// ... +/// } +/// +class StaticVerifierFunctionEmitter { +public: + StaticVerifierFunctionEmitter(const llvm::RecordKeeper &records, + ArrayRef opDefs, + raw_ostream &os, bool emitDecl); + + /// Get the name of the local function used for the given type constraint. + /// These functions are used for operand and result constraints and have the + /// form: + /// LogicalResult(Operation *op, Type type, StringRef valueKind, + /// unsigned valueGroupStartIndex); + StringRef getTypeConstraintFn(const Constraint &constraint) const { + auto it = localTypeConstraints.find(constraint.getAsOpaquePointer()); + assert(it != localTypeConstraints.end() && "expected valid constraint fn"); + return it->second; + } + +private: + /// Returns a unique name to use when generating local methods. + static std::string getUniqueName(const llvm::RecordKeeper &records); + + /// Emit local methods for the type constraints used within the provided op + /// definitions. + void emitTypeConstraintMethods(ArrayRef opDefs, + raw_ostream &os, bool emitDecl); + + /// A unique label for the file currently being generated. This is used to + /// ensure that the local functions have a unique name. + std::string uniqueOutputLabel; + + /// A set of functions implementing type constraints, used for operand and + /// result verification. + llvm::DenseMap localTypeConstraints; +}; +} // namespace + +StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter( + const llvm::RecordKeeper &records, ArrayRef opDefs, + raw_ostream &os, bool emitDecl) + : uniqueOutputLabel(getUniqueName(records)) { + llvm::Optional namespaceEmitter; + if (!emitDecl) { + os << formatv(opCommentHeader, "Local Utility Method", "Definitions"); + namespaceEmitter.emplace(os, Operator(*opDefs[0]).getDialect()); + } + + emitTypeConstraintMethods(opDefs, os, emitDecl); +} + +std::string StaticVerifierFunctionEmitter::getUniqueName( + const llvm::RecordKeeper &records) { + // Use the input file name when generating a unique name. + std::string inputFilename = records.getInputFilename(); + + // Drop all but the base filename. + StringRef nameRef = llvm::sys::path::filename(inputFilename); + nameRef.consume_back(".td"); + + // Sanitize any invalid characters. + std::string uniqueName; + for (char c : nameRef) { + if (llvm::isAlnum(c) || c == '_') + uniqueName.push_back(c); + else + uniqueName.append(llvm::utohexstr((unsigned char)c)); + } + return uniqueName; +} + +void StaticVerifierFunctionEmitter::emitTypeConstraintMethods( + ArrayRef opDefs, raw_ostream &os, bool emitDecl) { + // Collect a set of all of the used type constraints within the operation + // definitions. + llvm::SetVector typeConstraints; + for (Record *def : opDefs) { + Operator op(*def); + for (NamedTypeConstraint &operand : op.getOperands()) + if (operand.hasPredicate()) + typeConstraints.insert(operand.constraint.getAsOpaquePointer()); + for (NamedTypeConstraint &result : op.getResults()) + if (result.hasPredicate()) + typeConstraints.insert(result.constraint.getAsOpaquePointer()); + } + + FmtContext fctx; + for (auto it : llvm::enumerate(typeConstraints)) { + // Generate an obscure and unique name for this type constraint. + std::string name = (Twine("__mlir_ods_local_type_constraint_") + + uniqueOutputLabel + Twine(it.index())) + .str(); + localTypeConstraints.try_emplace(it.value(), name); + + // Only generate the methods if we are generating definitions. + if (emitDecl) + continue; + + Constraint constraint = Constraint::getFromOpaquePointer(it.value()); + os << "static ::mlir::LogicalResult " << name + << "(::mlir::Operation *op, ::mlir::Type type, ::llvm::StringRef " + "valueKind, unsigned valueGroupStartIndex) {\n"; + + os << " if (!(" + << tgfmt(constraint.getConditionTemplate(), &fctx.withSelf("type")) + << ")) {\n" + << formatv( + " return op->emitOpError(valueKind) << \" #\" << " + "valueGroupStartIndex << \" must be {0}, but got \" << type;\n", + constraint.getDescription()) + << " }\n" + << " return ::mlir::success();\n" + << "}\n\n"; + } +} + //===----------------------------------------------------------------------===// // Utility structs and functions //===----------------------------------------------------------------------===// @@ -165,11 +303,16 @@ // Helper class to emit a record into the given output stream. class OpEmitter { public: - static void emitDecl(const Operator &op, raw_ostream &os); - static void emitDef(const Operator &op, raw_ostream &os); + static void + emitDecl(const Operator &op, raw_ostream &os, + const StaticVerifierFunctionEmitter &staticVerifierEmitter); + static void + emitDef(const Operator &op, raw_ostream &os, + const StaticVerifierFunctionEmitter &staticVerifierEmitter); private: - OpEmitter(const Operator &op); + OpEmitter(const Operator &op, + const StaticVerifierFunctionEmitter &staticVerifierEmitter); void emitDecl(raw_ostream &os); void emitDef(raw_ostream &os); @@ -322,6 +465,9 @@ // The format context for verification code generation. FmtContext verifyCtx; + + // The emitter containing all of the locally emitted verification functions. + const StaticVerifierFunctionEmitter &staticVerifierEmitter; }; } // end anonymous namespace @@ -435,9 +581,11 @@ } } -OpEmitter::OpEmitter(const Operator &op) +OpEmitter::OpEmitter(const Operator &op, + const StaticVerifierFunctionEmitter &staticVerifierEmitter) : def(op.getDef()), op(op), - opClass(op.getCppClassName(), op.getExtraClassDeclaration()) { + opClass(op.getCppClassName(), op.getExtraClassDeclaration()), + staticVerifierEmitter(staticVerifierEmitter) { verifyCtx.withOp("(*this->getOperation())"); genTraits(); @@ -465,12 +613,16 @@ genSideEffectInterfaceMethods(); } -void OpEmitter::emitDecl(const Operator &op, raw_ostream &os) { - OpEmitter(op).emitDecl(os); +void OpEmitter::emitDecl( + const Operator &op, raw_ostream &os, + const StaticVerifierFunctionEmitter &staticVerifierEmitter) { + OpEmitter(op, staticVerifierEmitter).emitDecl(os); } -void OpEmitter::emitDef(const Operator &op, raw_ostream &os) { - OpEmitter(op).emitDef(os); +void OpEmitter::emitDef( + const Operator &op, raw_ostream &os, + const StaticVerifierFunctionEmitter &staticVerifierEmitter) { + OpEmitter(op, staticVerifierEmitter).emitDef(os); } void OpEmitter::emitDecl(raw_ostream &os) { opClass.writeDeclTo(os); } @@ -1893,23 +2045,16 @@ // Otherwise, if there is no predicate there is nothing left to do. if (!hasPredicate) continue; - // Emit a loop to check all the dynamic values in the pack. + StringRef constraintFn = staticVerifierEmitter.getTypeConstraintFn( + staticValue.value().constraint); body << " for (::mlir::Value v : valueGroup" << staticValue.index() - << ") {\n"; - - auto constraint = staticValue.value().constraint; - body << " (void)v;\n" - << " if (!(" - << tgfmt(constraint.getConditionTemplate(), - &fctx.withSelf("v.getType()")) - << ")) {\n" - << formatv(" return emitOpError(\"{0} #\") << index " - "<< \" must be {1}, but got \" << v.getType();\n", - valueKind, constraint.getDescription()) - << " }\n" // if + << ") {\n" + << " if (::mlir::failed(" << constraintFn + << "(getOperation(), v.getType(), \"" << valueKind << "\", index)))\n" + << " return ::mlir::failure();\n" << " ++index;\n" - << " }\n"; // for + << " }\n"; } body << " }\n"; @@ -2250,7 +2395,8 @@ } // Emits the opcode enum and op classes. -static void emitOpClasses(const std::vector &defs, raw_ostream &os, +static void emitOpClasses(const RecordKeeper &recordKeeper, + const std::vector &defs, raw_ostream &os, bool emitDecl) { // First emit forward declaration for each class, this allows them to refer // to each others in traits for example. @@ -2266,17 +2412,23 @@ } IfDefScope scope("GET_OP_CLASSES", os); + if (defs.empty()) + return; + + // Generate all of the locally instantiated methods first. + StaticVerifierFunctionEmitter staticVerifierEmitter(recordKeeper, defs, os, + emitDecl); for (auto *def : defs) { Operator op(*def); NamespaceEmitter emitter(os, op.getDialect()); if (emitDecl) { os << formatv(opCommentHeader, op.getQualCppClassName(), "declarations"); OpOperandAdaptorEmitter::emitDecl(op, os); - OpEmitter::emitDecl(op, os); + OpEmitter::emitDecl(op, os, staticVerifierEmitter); } else { os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions"); OpOperandAdaptorEmitter::emitDef(op, os); - OpEmitter::emitDef(op, os); + OpEmitter::emitDef(op, os, staticVerifierEmitter); } } } @@ -2331,7 +2483,7 @@ emitSourceFileHeader("Op Declarations", os); const auto &defs = getAllDerivedDefinitions(recordKeeper, "Op"); - emitOpClasses(defs, os, /*emitDecl=*/true); + emitOpClasses(recordKeeper, defs, os, /*emitDecl=*/true); return false; } @@ -2341,7 +2493,7 @@ const auto &defs = getAllDerivedDefinitions(recordKeeper, "Op"); emitOpList(defs, os); - emitOpClasses(defs, os, /*emitDecl=*/false); + emitOpClasses(recordKeeper, defs, os, /*emitDecl=*/false); return false; }