diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -75,10 +75,6 @@ class LLVM_Op traits = []> : LLVM_OpBase; -// Compatibility class for LLVM intrinsic operations. -class LLVM_IntrOp traits = []> : - LLVM_Op<"intr."#mnemonic, traits>; - // Case of the LLVM enum attribute backed by I64Attr with customized string // representation that corresponds to what is visible in the textual IR form. // The parameters are as follows: @@ -163,6 +159,14 @@ }]; } +// Base class for LLVM intrinsic operations, should not be used directly. Places +// the intrinsic into the LLVM dialect and prefixes its name with "intr.". +class LLVM_IntrOp overloadedResults, + list overloadedOperands, list traits, + bit hasResult> + : LLVM_IntrOpBase; + // Base class for LLVM intrinsic operations returning no results. Places the // intrinsic into the LLVM dialect and prefixes its name with "intr.". // @@ -179,8 +183,7 @@ // The Op has no results. class LLVM_ZeroResultIntrOp overloadedOperands = [], list traits = []> - : LLVM_IntrOpBase; + : LLVM_IntrOp; // Base class for LLVM intrinsic operations returning one result. Places the // intrinsic into the LLVM dialect and prefixes its name with "intr.". This is @@ -191,8 +194,7 @@ class LLVM_OneResultIntrOp overloadedResults = [], list overloadedOperands = [], list traits = []> - : LLVM_IntrOpBase; + : LLVM_IntrOp; // LLVM vector reduction over a single vector. class LLVM_VectorReduction diff --git a/mlir/test/mlir-tblgen/llvm-intrinsics.td b/mlir/test/mlir-tblgen/llvm-intrinsics.td --- a/mlir/test/mlir-tblgen/llvm-intrinsics.td +++ b/mlir/test/mlir-tblgen/llvm-intrinsics.td @@ -15,15 +15,15 @@ // CHECK-LABEL: def LLVM_ptrmask // CHECK: LLVM_IntrOp<"ptrmask -// CHECK: Arguments<(ins -// CHECK: Results<(outs -// CHECK: llvm::Function *fn = llvm::Intrinsic::getDeclaration( -// CHECK: module, llvm::Intrinsic::ptrmask, { -// CHECK: opInst.getResult(0).getType().cast().getUnderlyingType(), -// CHECK: opInst.getOperand(0).getType().cast().getUnderlyingType(), -// CHECK: opInst.getOperand(1).getType().cast().getUnderlyingType(), -// CHECK: }); -// CHECK: lookupValues(opInst.getOperands()); +// The result of this intrinsic result is overloadable. +// CHECK: [0] +// Both its operands are overloadable. +// CHECK: [0, 1] +// It has no additional traits. +// CHECK: [] +// It has a result. +// CHECK: 1> +// CHECK: Arguments<(ins LLVM_Type, LLVM_Type //---------------------------------------------------------------------------// diff --git a/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp b/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp --- a/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp +++ b/mlir/tools/mlir-tblgen/LLVMIRIntrinsicGen.cpp @@ -179,31 +179,12 @@ }; } // namespace -/// Emits C++ code constructing an LLVM IR intrinsic given the generated MLIR -/// operation. In LLVM IR, intrinsics are constructed as function calls. -static void emitBuilder(const LLVMIntrinsic &intr, llvm::raw_ostream &os) { - auto overloadedRes = intr.getOverloadableResultsIdxs(); - auto overloadedOps = intr.getOverloadableOperandsIdxs(); - os << " llvm::Module *module = builder.GetInsertBlock()->getModule();\n"; - os << " llvm::Function *fn = llvm::Intrinsic::getDeclaration(\n"; - os << " module, llvm::Intrinsic::" << intr.getProperRecordName() - << ", {"; - for (unsigned idx : overloadedRes.set_bits()) { - os << "\n opInst.getResult(" << idx << ").getType()" - << ".cast().getUnderlyingType(),"; - } - for (unsigned idx : overloadedOps.set_bits()) { - os << "\n opInst.getOperand(" << idx << ").getType()" - << ".cast().getUnderlyingType(),"; - } - if (overloadedRes.any() || overloadedOps.any()) - os << "\n "; - os << "});\n"; - os << " auto operands =\n"; - os << " lookupValues(opInst.getOperands());\n"; - os << " " << (intr.getNumResults() > 0 ? "$res = " : "") - << "builder.CreateCall(fn, operands);\n"; - os << " "; +/// Prints the elements in "range" separated by commas and surrounded by "[]". +template +void printBracketedRange(const Range &range, llvm::raw_ostream &os) { + os << '['; + mlir::interleaveComma(range, os); + os << ']'; } /// Emits ODS (TableGen-based) code for `record` representing an LLVM intrinsic. @@ -224,16 +205,16 @@ // Emit the definition. os << "def LLVM_" << intr.getProperRecordName() << " : " << opBaseClass - << "<\"" << intr.getOperationName() << "\", ["; - mlir::interleaveComma(traits, os); - os << "]>, Arguments<(ins" << (operands.empty() ? "" : " "); + << "<\"" << intr.getOperationName() << "\", "; + printBracketedRange(intr.getOverloadableResultsIdxs().set_bits(), os); + os << ", "; + printBracketedRange(intr.getOverloadableOperandsIdxs().set_bits(), os); + os << ", "; + printBracketedRange(traits, os); + os << ", " << (intr.getNumResults() == 0 ? 0 : 1) << ">, Arguments<(ins" + << (operands.empty() ? "" : " "); mlir::interleaveComma(operands, os); - os << ")>, Results<(outs" - << (intr.getNumResults() == 0 ? "" : " LLVM_Type:$res") << ")> {\n" - << " let llvmBuilder = [{\n"; - emitBuilder(intr, os); - os << "}];\n"; - os << "}\n\n"; + os << ")>;\n\n"; return false; }