diff --git a/mlir/test/mlir-tblgen/llvm-intrinsics.td b/mlir/test/mlir-tblgen/llvm-intrinsics.td --- a/mlir/test/mlir-tblgen/llvm-intrinsics.td +++ b/mlir/test/mlir-tblgen/llvm-intrinsics.td @@ -6,16 +6,21 @@ // includes from the main file to avoid unnecessary dependencies and decrease // the test cost. The command-line flags further ensure a specific intrinsic is // processed and we only check the ouptut below. +// We also verify emission of type specialization for overloadable intrinsics. // // RUN: cat %S/../../../llvm/include/llvm/IR/Intrinsics.td \ // RUN: | grep -v "llvm/IR/Intrinsics" \ -// RUN: | mlir-tblgen -gen-llvmir-intrinsics -I %S/../../../llvm/include/ --llvmir-intrinsics-filter=vastart \ +// RUN: | mlir-tblgen -gen-llvmir-intrinsics -I %S/../../../llvm/include/ --llvmir-intrinsics-filter=is_constant \ // RUN: | FileCheck %s -// CHECK-LABEL: def LLVM_vastart +// CHECK-LABEL: def LLVM_is_constant // CHECK: LLVM_Op<"intr // CHECK: Arguments<(ins // CHECK: Results<(outs +// CHECK: llvm::Function *fn = llvm::Intrinsic::getDeclaration( +// CHECK: module, llvm::Intrinsic::is_constant, { +// CHECK: opInst.getOperand(0).getType().cast().getUnderlyingType(), +// CHECK: }); //---------------------------------------------------------------------------// diff --git a/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp b/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp --- a/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp +++ b/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp @@ -14,7 +14,9 @@ #include "mlir/Support/STLExtras.h" #include "mlir/TableGen/GenInfo.h" +#include "llvm/ADT/SmallBitVector.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/MachineValueType.h" #include "llvm/Support/PrettyStackTrace.h" #include "llvm/Support/Signals.h" #include "llvm/TableGen/Error.h" @@ -30,6 +32,38 @@ "substring in their record name"), llvm::cl::cat(IntrinsicGenCat)); +// Used to represent the indices of overloadable operands/results. +using IndicesTy = llvm::SmallBitVector; + +/// Return a CodeGen value type entry from a type record. +static llvm::MVT::SimpleValueType getValueType(const llvm::Record *rec) { + return (llvm::MVT::SimpleValueType)rec->getValueAsDef("VT")->getValueAsInt( + "Value"); +} + +/// Return the indices of the definitions in a list of definitions that +/// represent overloadable types +static IndicesTy getOverloadableTypeIdxs(const llvm::Record &record, + const char *listName) { + auto results = record.getValueAsListOfDefs(listName); + IndicesTy overloadedOps(results.size()); + for (auto r : llvm::enumerate(results)) { + llvm::MVT::SimpleValueType vt = getValueType(r.value()); + switch (vt) { + case llvm::MVT::iAny: + case llvm::MVT::fAny: + case llvm::MVT::Any: + case llvm::MVT::iPTRAny: + case llvm::MVT::vAny: + overloadedOps.set(r.index()); + break; + default: + continue; + } + } + return overloadedOps; +} + namespace { /// A wrapper for LLVM's Tablegen class `Intrinsic` that provides accessors to /// the fields of the record. @@ -108,6 +142,14 @@ return false; } + IndicesTy getOverloadableOperandsIdxs() const { + return getOverloadableTypeIdxs(record, fieldOperands); + } + + IndicesTy getOverloadableResultsIdxs() const { + return getOverloadableTypeIdxs(record, fieldResults); + } + private: /// Names of the fields in the Intrinsic LLVM Tablegen class. const char *fieldName = "LLVMName"; @@ -122,10 +164,23 @@ /// Emits C++ code constructing an LLVM IR intrinsic given the generated MLIR /// operation. In LLVM IR, intrinsics are constructed as function calls. static void emitBuilder(const LLVMIntrinsic &intr, llvm::raw_ostream &os) { + auto overloadedRes = intr.getOverloadableResultsIdxs(); + auto overloadedOps = intr.getOverloadableOperandsIdxs(); os << " llvm::Module *module = builder.GetInsertBlock()->getModule();\n"; os << " llvm::Function *fn = llvm::Intrinsic::getDeclaration(\n"; os << " module, llvm::Intrinsic::" << intr.getProperRecordName() - << ");\n"; + << ", {"; + for (unsigned idx : overloadedRes.set_bits()) { + os << "\n opInst.getResult(" << idx << ").getType()" + << ".cast().getUnderlyingType(),"; + } + for (unsigned idx : overloadedOps.set_bits()) { + os << "\n opInst.getOperand(" << idx << ").getType()" + << ".cast().getUnderlyingType(),"; + } + if (overloadedRes.any() || overloadedOps.any()) + os << "\n "; + os << "});\n"; os << " auto operands = llvm::to_vector<8, Value *>(\n"; os << " opInst.operand_begin(), opInst.operand_end());\n"; os << " " << (intr.getNumResults() > 0 ? "$res = " : "")