diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -984,7 +984,7 @@ /// Return true if this "op class" can match against the specified operation. static bool classof(Operation *op) { if (auto *abstractOp = op->getAbstractOperation()) - return &classof == abstractOp->classof; + return ClassID::getID() == abstractOp->classID; return op->getName().getStringRef() == ConcreteType::getOperationName(); } diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -90,8 +90,8 @@ /// This is the dialect that this operation belongs to. Dialect &dialect; - /// Return true if this "op class" can match against the specified operation. - bool (&classof)(Operation *op); + /// The unique identifier of the derived Op class. + ClassID *classID; /// Use the specified object to parse this ops custom assembly format. ParseResult (&parseAssembly)(OpAsmParser &parser, OperationState &result); @@ -158,15 +158,16 @@ /// operations they contain. template static AbstractOperation get(Dialect &dialect) { return AbstractOperation( - T::getOperationName(), dialect, T::getOperationProperties(), T::classof, - T::parseAssembly, T::printAssembly, T::verifyInvariants, T::foldHook, - T::getCanonicalizationPatterns, T::getRawInterface, T::hasTrait); + T::getOperationName(), dialect, T::getOperationProperties(), + ClassID::getID(), T::parseAssembly, T::printAssembly, + T::verifyInvariants, T::foldHook, T::getCanonicalizationPatterns, + T::getRawInterface, T::hasTrait); } private: AbstractOperation( StringRef name, Dialect &dialect, OperationProperties opProperties, - bool (&classof)(Operation *op), + ClassID *classID, ParseResult (&parseAssembly)(OpAsmParser &parser, OperationState &result), void (&printAssembly)(Operation *op, OpAsmPrinter &p), LogicalResult (&verifyInvariants)(Operation *op), @@ -176,7 +177,7 @@ MLIRContext *context), void *(&getRawInterface)(ClassID *interfaceID), bool (&hasTrait)(ClassID *traitID)) - : name(name), dialect(dialect), classof(classof), + : name(name), dialect(dialect), classID(classID), parseAssembly(parseAssembly), printAssembly(printAssembly), verifyInvariants(verifyInvariants), foldHook(foldHook), getCanonicalizationPatterns(getCanonicalizationPatterns),