diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td @@ -680,6 +680,47 @@ }]; } +//===--------------------------------------------------------------------===// +// CallIntrinsicOp +//===--------------------------------------------------------------------===// +def LLVM_CallIntrinsicOp : LLVM_Op<"call_intrinsic", [Pure]> { + let arguments = (ins StrAttr:$intrin, Variadic:$args); + let results = (outs Variadic:$results); + let summary = [{Calls an intrinsic.}]; + let description = [{ + Call the specified llvm intrinsic. If the intrinsic is overloaded, use + the MLIR function type of this op to determine which intrinsic to call.}]; + let llvmBuilder = [{ + llvm::Module *module = builder.GetInsertBlock()->getModule(); + llvm::Intrinsic::ID id = op.getIntrinsicEnum(); + if (!id) { + assert(false && "Couldn't find the intrinsic!\n"); + } + + llvm::Function *fn; + if (llvm::Intrinsic::isOverloaded(id)) { + fn = getOverloadedDeclaration(opInst, id, module, moduleTranslation); + } else { + fn = llvm::Intrinsic::getDeclaration(module, id, {}); + } + + auto *inst = builder.CreateCall(fn, + moduleTranslation.lookupValues(opInst.getOperands())); + if (opInst.getNumResults() == 1) { + moduleTranslation.mapValue(op->getResults().front()) = inst; + } + }]; + let assemblyFormat = [{ + $intrin `(` $args `)` `:` functional-type($args, $results) attr-dict + }]; + + let extraClassDeclaration = [{ + llvm::Intrinsic::ID getIntrinsicEnum() { + return llvm::Function::lookupIntrinsicID(getIntrinAttr()); + } + }]; +} + // // LLVM Vector Predication operations. // diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp @@ -258,6 +258,40 @@ return position; } +static llvm::Function * +getOverloadedDeclaration(Operation &opInst, llvm::Intrinsic::ID id, + llvm::Module *module, + LLVM::ModuleTranslation &moduleTranslation) { + SmallVector allArgTys; + for (size_t i = 0; i < opInst.getNumOperands(); i++) { + allArgTys.push_back( + moduleTranslation.convertType(opInst.getOperand(i).getType())); + } + llvm::Type *resTy; + if (opInst.getNumResults() == 0) { + resTy = llvm::Type::getVoidTy(module->getContext()); + } else { + resTy = moduleTranslation.convertType(opInst.getResult(0).getType()); + } + llvm::FunctionType *FT = llvm::FunctionType::get(resTy, allArgTys, false); + + SmallVector Table; + getIntrinsicInfoTableEntries(id, Table); + ArrayRef TableRef = Table; + + SmallVector overloadedArgTys; + if (llvm::Intrinsic::matchIntrinsicSignature(FT, TableRef, + overloadedArgTys) != + llvm::Intrinsic::MatchIntrinsicTypesResult::MatchIntrinsicTypes_Match) + assert(false && "Intrinsic type is not a match!\n"); + + if (llvm::Intrinsic::matchIntrinsicVarArg(FT->isVarArg(), TableRef)) + assert(false && "Intrinsic variadic-ness is not a match!\n"); + ArrayRef OverloadedArgTys = overloadedArgTys; + + return llvm::Intrinsic::getDeclaration(module, id, OverloadedArgTys); +} + static LogicalResult convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) { @@ -272,8 +306,8 @@ // Emit function calls. If the "callee" attribute is present, this is a // direct function call and we also need to look up the remapped function // itself. Otherwise, this is an indirect call and the callee is the first - // operand, look it up as a normal value. Return the llvm::Value representing - // the function result, which may be of llvm::VoidTy type. + // operand, look it up as a normal value. Return the llvm::Value + // representing the function result, which may be of llvm::VoidTy type. auto convertCall = [&](Operation &op) -> llvm::Value * { auto operands = moduleTranslation.lookupValues(op.getOperands()); ArrayRef operandsRef(operands); @@ -404,8 +438,8 @@ return success(); } - // Emit branches. We need to look up the remapped blocks and ignore the block - // arguments that were transformed into PHI nodes. + // Emit branches. We need to look up the remapped blocks and ignore the + // block arguments that were transformed into PHI nodes. if (auto brOp = dyn_cast(opInst)) { llvm::BranchInst *branch = builder.CreateBr(moduleTranslation.lookupBlock(brOp.getSuccessor())); diff --git a/mlir/test/Dialect/LLVMIR/call-intrin.mlir b/mlir/test/Dialect/LLVMIR/call-intrin.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/LLVMIR/call-intrin.mlir @@ -0,0 +1,54 @@ +// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s + +// CHECK: ; ModuleID = 'LLVMDialectModule' +// CHECK: source_filename = "LLVMDialectModule" +// CHECK: declare ptr @malloc(i64) +// CHECK: declare void @free(ptr) +// CHECK: define <4 x float> @round_sse41() { +// CHECK: %1 = call <4 x float> @llvm.x86.sse41.round.ss(<4 x float> , <4 x float> , i32 1) +// CHECK: ret <4 x float> %1 +// CHECK: } +llvm.func @round_sse41() -> vector<4xf32> { + %0 = llvm.mlir.constant(1 : i32) : i32 + %1 = llvm.mlir.constant(dense<0.2> : vector<4xf32>) : vector<4xf32> + %res = llvm.call_intrinsic "llvm.x86.sse41.round.ss"(%1, %1, %0) : (vector<4xf32>, vector<4xf32>, i32) -> vector<4xf32> {} + llvm.return %res: vector<4xf32> +} + +// ----- + +// CHECK: ; ModuleID = 'LLVMDialectModule' +// CHECK: source_filename = "LLVMDialectModule" + +// CHECK: declare ptr @malloc(i64) + +// CHECK: declare void @free(ptr) + +// CHECK: define float @round_overloaded() { +// CHECK: %1 = call float @llvm.round.f32(float 1.000000e+00) +// CHECK: ret float %1 +// CHECK: } +llvm.func @round_overloaded() -> f32 { + %0 = llvm.mlir.constant(1.0 : f32) : f32 + %res = llvm.call_intrinsic "llvm.round"(%0) : (f32) -> f32 {} + llvm.return %res: f32 +} + +// ----- + +// CHECK: ; ModuleID = 'LLVMDialectModule' +// CHECK: source_filename = "LLVMDialectModule" +// CHECK: declare ptr @malloc(i64) +// CHECK: declare void @free(ptr) +// CHECK: define void @lifetime_start() { +// CHECK: %1 = alloca float, i8 1, align 4 +// CHECK: call void @llvm.lifetime.start.p0(i64 4, ptr %1) +// CHECK: ret void +// CHECK: } +llvm.func @lifetime_start() { + %0 = llvm.mlir.constant(4 : i64) : i64 + %1 = llvm.mlir.constant(1 : i8) : i8 + %2 = llvm.alloca %1 x f32 : (i8) -> !llvm.ptr + llvm.call_intrinsic "llvm.lifetime.start"(%0, %2) : (i64, !llvm.ptr) -> () {} + llvm.return +}