diff --git a/mlir/include/mlir/IR/Identifier.h b/mlir/include/mlir/IR/Identifier.h --- a/mlir/include/mlir/IR/Identifier.h +++ b/mlir/include/mlir/IR/Identifier.h @@ -67,6 +67,9 @@ return Identifier(static_cast(entry)); } + /// Compare the underlying StringRef. + int compare(Identifier rhs) const { return strref().compare(rhs.strref()); } + private: /// This contains the bytes of the string, which is guaranteed to be nul /// terminated. diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -82,7 +82,7 @@ using OperationProperties = uint32_t; /// This is the name of the operation. - const StringRef name; + const Identifier name; /// This is the dialect that this operation belongs to. Dialect &dialect; @@ -171,13 +171,7 @@ SmallVectorImpl &results), void (&getCanonicalizationPatterns)(OwningRewritePatternList &results, MLIRContext *context), - detail::InterfaceMap &&interfaceMap, bool (&hasTrait)(TypeID traitID)) - : name(name), dialect(dialect), typeID(typeID), - parseAssembly(parseAssembly), printAssembly(printAssembly), - verifyInvariants(verifyInvariants), foldHook(foldHook), - getCanonicalizationPatterns(getCanonicalizationPatterns), - opProperties(opProperties), interfaceMap(std::move(interfaceMap)), - hasRawTrait(hasTrait) {} + detail::InterfaceMap &&interfaceMap, bool (&hasTrait)(TypeID traitID)); /// The properties of the operation. const OperationProperties opProperties; @@ -302,9 +296,12 @@ /// Return the operation name with dialect name stripped, if it has one. StringRef stripDialect() const; - /// Return the name of this operation. This always succeeds. + /// Return the name of this operation. This always succeeds. StringRef getStringRef() const; + /// Return the name of this operation as an identifier. This always succeeds. + Identifier getIdentifier() const; + /// If this operation has a registered operation description, return it. /// Otherwise return null. const AbstractOperation *getAbstractOperation() const; diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -668,6 +668,25 @@ return nullptr; } +AbstractOperation::AbstractOperation( + StringRef name, Dialect &dialect, OperationProperties opProperties, + TypeID typeID, + ParseResult (&parseAssembly)(OpAsmParser &parser, OperationState &result), + void (&printAssembly)(Operation *op, OpAsmPrinter &p), + LogicalResult (&verifyInvariants)(Operation *op), + LogicalResult (&foldHook)(Operation *op, ArrayRef operands, + SmallVectorImpl &results), + void (&getCanonicalizationPatterns)(OwningRewritePatternList &results, + MLIRContext *context), + detail::InterfaceMap &&interfaceMap, bool (&hasTrait)(TypeID traitID)) + : name(Identifier::get(name, dialect.getContext())), dialect(dialect), + typeID(typeID), parseAssembly(parseAssembly), + printAssembly(printAssembly), verifyInvariants(verifyInvariants), + foldHook(foldHook), + getCanonicalizationPatterns(getCanonicalizationPatterns), + opProperties(opProperties), interfaceMap(std::move(interfaceMap)), + hasRawTrait(hasTrait) {} + /// Get the dialect that registered the type with the provided typeid. const AbstractType &AbstractType::lookup(TypeID typeID, MLIRContext *context) { auto &impl = context->getImpl(); diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -45,11 +45,16 @@ return splitName.second.empty() ? splitName.first : splitName.second; } -/// Return the name of this operation. This always succeeds. +/// Return the name of this operation. This always succeeds. StringRef OperationName::getStringRef() const { + return getIdentifier().strref(); +} + +/// Return the name of this operation as an identifier. This always succeeds. +Identifier OperationName::getIdentifier() const { if (auto *op = representation.dyn_cast()) return op->name; - return representation.get().strref(); + return representation.get(); } const AbstractOperation *OperationName::getAbstractOperation() const { diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -863,8 +863,8 @@ /// Emit a diagnostic at the specified location and return failure. InFlightDiagnostic emitError(llvm::SMLoc loc, const Twine &message) override { emittedError = true; - return parser.emitError(loc, "custom op '" + opDefinition->name + "' " + - message); + return parser.emitError(loc, "custom op '" + opDefinition->name.strref() + + "' " + message); } llvm::SMLoc getCurrentLocation() override {