diff --git a/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp b/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp --- a/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp +++ b/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp @@ -15,6 +15,7 @@ #include "mlir/TableGen/GenInfo.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/MachineValueType.h" #include "llvm/Support/PrettyStackTrace.h" #include "llvm/Support/Signals.h" #include "llvm/TableGen/Error.h" @@ -31,6 +32,39 @@ llvm::cl::cat(IntrinsicGenCat)); namespace { + +using IndicesTy = llvm::SmallVector; + +// Static free helper functions for accessing tablegen definitions. +// Return a CodeGen value type entry from a type record. +static llvm::MVT::SimpleValueType getValueType(const llvm::Record *Rec) { + return (llvm::MVT::SimpleValueType)Rec->getValueAsDef("VT")->getValueAsInt( + "Value"); +} + +// Return the indices of the definitions in a list of definitions that +// represent overloadable types +static IndicesTy getOverloadableTypeIdxs(const llvm::Record &record, + const char *listName) { + IndicesTy OverloadedOps; + auto results = record.getValueAsListOfDefs(listName); + for (auto R : llvm::enumerate(results)) { + llvm::MVT::SimpleValueType VT = getValueType(R.value()); + switch (VT) { + case llvm::MVT::iAny: + case llvm::MVT::fAny: + case llvm::MVT::Any: + case llvm::MVT::iPTRAny: + case llvm::MVT::vAny: + OverloadedOps.push_back(R.index()); + break; + default: + continue; + } + } + return OverloadedOps; +} + /// A wrapper for LLVM's Tablegen class `Intrinsic` that provides accessors to /// the fields of the record. class LLVMIntrinsic { @@ -108,6 +142,14 @@ return false; } + IndicesTy getOverloadableOperandsIdxs() const { + return getOverloadableTypeIdxs(record, fieldOperands); + } + + IndicesTy getOverloadableResultsIdxs() const { + return getOverloadableTypeIdxs(record, fieldResults); + } + private: /// Names of the fileds in the Intrinsic LLVM Tablegen class. const char *fieldName = "LLVMName"; @@ -122,10 +164,23 @@ /// Emits C++ code constructing an LLVM IR intrinsic given the generated MLIR /// operation. In LLVM IR, intrinsics are constructed as function calls. static void emitBuilder(const LLVMIntrinsic &intr, llvm::raw_ostream &os) { + auto overloadedOps = intr.getOverloadableOperandsIdxs(); + auto overloadedRes = intr.getOverloadableResultsIdxs(); os << " llvm::Module *module = builder.GetInsertBlock()->getModule();\n"; os << " llvm::Function *fn = llvm::Intrinsic::getDeclaration(\n"; os << " module, llvm::Intrinsic::" << intr.getProperRecordName() - << ");\n"; + << ", {"; + for (int idx : overloadedRes) { + os << "\n opInst.getResult(" << idx << ").getType()" + << ".dyn_cast().getUnderlyingType(),"; + } + for (int idx : overloadedOps) { + os << "\n opInst.getOpOperand(" << idx << ").get().getType()" + << ".dyn_cast().getUnderlyingType(),"; + } + if (!overloadedRes.empty() || !overloadedOps.empty()) + os << "\n "; + os << "});\n"; os << " auto operands = llvm::to_vector<8, Value *>(\n"; os << " opInst.operand_begin(), opInst.operand_end());\n"; os << " " << (intr.getNumResults() > 0 ? "$res = " : "")