diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td @@ -1080,31 +1080,6 @@ }]; } -//===--------------------------------------------------------------------===// -// CallIntrinsicOp -//===--------------------------------------------------------------------===// - -def LLVM_CallIntrinsicOp - : LLVM_Op<"call_intrinsic", - [DeclareOpInterfaceMethods]> { - let summary = "Call to an LLVM intrinsic function."; - let description = [{ - Call the specified llvm intrinsic. If the intrinsic is overloaded, use - the MLIR function type of this op to determine which intrinsic to call. - }]; - let arguments = (ins StrAttr:$intrin, Variadic:$args, - DefaultValuedAttr:$fastmathFlags); - let results = (outs Variadic:$results); - let llvmBuilder = [{ - return convertCallLLVMIntrinsicOp(op, builder, moduleTranslation); - }]; - let assemblyFormat = [{ - $intrin `(` $args `)` `:` functional-type($args, $results) - custom(attr-dict) - }]; -} - // // LLVM Vector Predication operations. // diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -1759,4 +1759,30 @@ }]; } +//===--------------------------------------------------------------------===// +// CallIntrinsicOp +//===--------------------------------------------------------------------===// + +def LLVM_CallIntrinsicOp + : LLVM_Op<"call_intrinsic", + [DeclareOpInterfaceMethods]> { + let summary = "Call to an LLVM intrinsic function."; + let description = [{ + Call the specified llvm intrinsic. If the intrinsic is overloaded, use + the MLIR function type of this op to determine which intrinsic to call. + }]; + let arguments = (ins StrAttr:$intrin, Variadic:$args, + DefaultValuedAttr:$fastmathFlags); + let results = (outs Optional:$results); + let llvmBuilder = [{ + return convertCallLLVMIntrinsicOp(op, builder, moduleTranslation); + }]; + let assemblyFormat = [{ + $intrin `(` $args `)` `:` functional-type($args, $results) attr-dict + }]; + + let hasVerifier = 1; +} + #endif // LLVMIR_OPS diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -10,6 +10,7 @@ // MLIR, and the LLVM IR dialect. It also registers the dialect. // //===----------------------------------------------------------------------===// + #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "LLVMInlining.h" #include "TypeDetail.h" @@ -2785,6 +2786,59 @@ return IntegerAttr::get(getType(), lhs.getValue() | rhs.getValue()); } +//===----------------------------------------------------------------------===// +// Utilities for LLVM::MetadataOp +//===----------------------------------------------------------------------===// + +void MetadataOp::build(OpBuilder &builder, OperationState &result, + StringRef symName, bool createBodyBlock, + ArrayRef attributes) { + result.addAttribute(getSymNameAttrName(result.name), + builder.getStringAttr(symName)); + result.attributes.append(attributes.begin(), attributes.end()); + Region *body = result.addRegion(); + if (createBodyBlock) + body->emplaceBlock(); +} + +ParseResult MetadataOp::parse(OpAsmParser &parser, OperationState &result) { + StringAttr symName; + if (parser.parseSymbolName(symName, getSymNameAttrName(result.name), + result.attributes) || + parser.parseOptionalAttrDictWithKeyword(result.attributes)) + return failure(); + + Region *bodyRegion = result.addRegion(); + if (parser.parseRegion(*bodyRegion)) + return failure(); + + // If the region appeared to be empty to parseRegion(), + // add the body block explicitly. + if (bodyRegion->empty()) + bodyRegion->emplaceBlock(); + + return success(); +} + +void MetadataOp::print(OpAsmPrinter &printer) { + printer << ' '; + printer.printSymbolName(getSymName()); + printer.printOptionalAttrDictWithKeyword((*this)->getAttrs(), + {getSymNameAttrName().getValue()}); + printer << ' '; + printer.printRegion(getBody()); +} + +//===----------------------------------------------------------------------===// +// CallIntrinsicOp +//===----------------------------------------------------------------------===// + +LogicalResult CallIntrinsicOp::verify() { + if (!getIntrin().startswith("llvm.")) + return emitOpError() << "intrinsic name must start with 'llvm.'"; + return success(); +} + //===----------------------------------------------------------------------===// // OpAsmDialectInterface //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp @@ -58,11 +58,19 @@ return position; } +/// Convert an LLVM type to a string for printing in diagnostics. +static std::string diagStr(const llvm::Type *type) { + std::string str; + llvm::raw_string_ostream os(str); + type->print(os); + return os.str(); +} + /// Get the declaration of an overloaded llvm intrinsic. First we get the /// overloaded argument types and/or result type from the CallIntrinsicOp, and /// then use those to get the correct declaration of the overloaded intrinsic. static FailureOr -getOverloadedDeclaration(CallIntrinsicOp &op, llvm::Intrinsic::ID id, +getOverloadedDeclaration(CallIntrinsicOp op, llvm::Intrinsic::ID id, llvm::Module *module, LLVM::ModuleTranslation &moduleTranslation) { SmallVector allArgTys; @@ -86,7 +94,9 @@ if (llvm::Intrinsic::matchIntrinsicSignature(ft, tableRef, overloadedArgTys) != llvm::Intrinsic::MatchIntrinsicTypesResult::MatchIntrinsicTypes_Match) { - return op.emitOpError("intrinsic type is not a match"); + return mlir::emitError(op.getLoc(), "call intrinsic signature ") + << diagStr(ft) << " to overloaded intrinsic " << op.getIntrinAttr() + << " does not match any of the overloads"; } ArrayRef overloadedArgTysRef = overloadedArgTys; @@ -101,8 +111,8 @@ llvm::Intrinsic::ID id = llvm::Function::lookupIntrinsicID(op.getIntrinAttr()); if (!id) - return op.emitOpError() - << "couldn't find intrinsic: " << op.getIntrinAttr(); + return mlir::emitError(op.getLoc(), "could not find LLVM intrinsic: ") + << op.getIntrinAttr(); llvm::Function *fn = nullptr; if (llvm::Intrinsic::isOverloaded(id)) { @@ -114,6 +124,44 @@ } else { fn = llvm::Intrinsic::getDeclaration(module, id, {}); } + + // Check the result type of the call. + const llvm::Type *intrinType = + op.getNumResults() == 0 + ? llvm::Type::getVoidTy(module->getContext()) + : moduleTranslation.convertType(op.getResultTypes().front()); + if (intrinType != fn->getReturnType()) { + return mlir::emitError(op.getLoc(), "intrinsic call returns ") + << diagStr(intrinType) << " but " << op.getIntrinAttr() + << " actually returns " << diagStr(fn->getReturnType()); + } + + // Check the argument types of the call. If the function is variadic, check + // the subrange of required arguments. + if (!fn->getFunctionType()->isVarArg() && + op.getNumOperands() != fn->arg_size()) { + return mlir::emitError(op.getLoc(), "intrinsic call has ") + << op.getNumOperands() << " operands but " << op.getIntrinAttr() + << " expects " << fn->arg_size(); + } + if (fn->getFunctionType()->isVarArg() && + op.getNumOperands() < fn->arg_size()) { + return mlir::emitError(op.getLoc(), "intrinsic call has ") + << op.getNumOperands() << " operands but variadic " + << op.getIntrinAttr() << " expects at least " << fn->arg_size(); + } + // Check the arguments up to the number the function requires. + for (unsigned i = 0, e = fn->arg_size(); i != e; ++i) { + const llvm::Type *expected = fn->getArg(i)->getType(); + const llvm::Type *actual = + moduleTranslation.convertType(op.getOperandTypes()[i]); + if (actual != expected) { + return mlir::emitError(op.getLoc(), "intrinsic call operand #") + << i << " has type " << diagStr(actual) << " but " + << op.getIntrinAttr() << " expects " << diagStr(expected); + } + } + FastmathFlagsInterface itf = op; builder.setFastMathFlags(getFastmathFlags(itf)); diff --git a/mlir/test/Dialect/LLVMIR/call-intrin.mlir b/mlir/test/Dialect/LLVMIR/call-intrin.mlir --- a/mlir/test/Dialect/LLVMIR/call-intrin.mlir +++ b/mlir/test/Dialect/LLVMIR/call-intrin.mlir @@ -1,82 +1,107 @@ // RUN: mlir-translate -mlir-to-llvmir -split-input-file -verify-diagnostics %s | FileCheck %s -// CHECK: ; ModuleID = 'LLVMDialectModule' -// CHECK: source_filename = "LLVMDialectModule" -// CHECK: declare ptr @malloc(i64) -// CHECK: declare void @free(ptr) // CHECK: define <4 x float> @round_sse41() { -// CHECK: %1 = call reassoc <4 x float> @llvm.x86.sse41.round.ss(<4 x float> , <4 x float> , i32 1) +// CHECK: %1 = call reassoc <4 x float> @llvm.x86.sse41.round.ss(<4 x float> {{.*}}, <4 x float> {{.*}}, i32 1) // CHECK: ret <4 x float> %1 // CHECK: } llvm.func @round_sse41() -> vector<4xf32> { - %0 = llvm.mlir.constant(1 : i32) : i32 - %1 = llvm.mlir.constant(dense<0.2> : vector<4xf32>) : vector<4xf32> - %res = llvm.call_intrinsic "llvm.x86.sse41.round.ss"(%1, %1, %0) : (vector<4xf32>, vector<4xf32>, i32) -> vector<4xf32> {fastmathFlags = #llvm.fastmath} - llvm.return %res: vector<4xf32> + %0 = llvm.mlir.constant(1 : i32) : i32 + %1 = llvm.mlir.constant(dense<0.2> : vector<4xf32>) : vector<4xf32> + %res = llvm.call_intrinsic "llvm.x86.sse41.round.ss"(%1, %1, %0) : (vector<4xf32>, vector<4xf32>, i32) -> vector<4xf32> {fastmathFlags = #llvm.fastmath} + llvm.return %res: vector<4xf32> } // ----- -// CHECK: ; ModuleID = 'LLVMDialectModule' -// CHECK: source_filename = "LLVMDialectModule" - -// CHECK: declare ptr @malloc(i64) - -// CHECK: declare void @free(ptr) - // CHECK: define float @round_overloaded() { // CHECK: %1 = call float @llvm.round.f32(float 1.000000e+00) // CHECK: ret float %1 // CHECK: } llvm.func @round_overloaded() -> f32 { - %0 = llvm.mlir.constant(1.0 : f32) : f32 - %res = llvm.call_intrinsic "llvm.round"(%0) : (f32) -> f32 {} - llvm.return %res: f32 + %0 = llvm.mlir.constant(1.0 : f32) : f32 + %res = llvm.call_intrinsic "llvm.round"(%0) : (f32) -> f32 {} + llvm.return %res: f32 } // ----- -// CHECK: ; ModuleID = 'LLVMDialectModule' -// CHECK: source_filename = "LLVMDialectModule" -// CHECK: declare ptr @malloc(i64) -// CHECK: declare void @free(ptr) // CHECK: define void @lifetime_start() { // CHECK: %1 = alloca float, i8 1, align 4 // CHECK: call void @llvm.lifetime.start.p0(i64 4, ptr %1) // CHECK: ret void // CHECK: } llvm.func @lifetime_start() { - %0 = llvm.mlir.constant(4 : i64) : i64 - %1 = llvm.mlir.constant(1 : i8) : i8 - %2 = llvm.alloca %1 x f32 : (i8) -> !llvm.ptr - llvm.call_intrinsic "llvm.lifetime.start"(%0, %2) : (i64, !llvm.ptr) -> () {} - llvm.return + %0 = llvm.mlir.constant(4 : i64) : i64 + %1 = llvm.mlir.constant(1 : i8) : i8 + %2 = llvm.alloca %1 x f32 : (i8) -> !llvm.ptr + llvm.call_intrinsic "llvm.lifetime.start"(%0, %2) : (i64, !llvm.ptr) -> () {} + llvm.return } // ----- +// CHECK-LABEL: define void @variadic() llvm.func @variadic() { - %0 = llvm.mlir.constant(1 : i8) : i8 - %1 = llvm.alloca %0 x f32 : (i8) -> !llvm.ptr - llvm.call_intrinsic "llvm.localescape"(%1, %1) : (!llvm.ptr, !llvm.ptr) -> () - llvm.return + %0 = llvm.mlir.constant(1 : i8) : i8 + %1 = llvm.alloca %0 x f32 : (i8) -> !llvm.ptr + // CHECK: call void (...) @llvm.localescape(ptr %1, ptr %1) + llvm.call_intrinsic "llvm.localescape"(%1, %1) : (!llvm.ptr, !llvm.ptr) -> () + llvm.return } // ----- llvm.func @no_intrinsic() { - // expected-error@below {{'llvm.call_intrinsic' op couldn't find intrinsic: "llvm.does_not_exist"}} - // expected-error@below {{LLVM Translation failed for operation: llvm.call_intrinsic}} - llvm.call_intrinsic "llvm.does_not_exist"() : () -> () - llvm.return + // expected-error@below {{could not find LLVM intrinsic: "llvm.does_not_exist"}} + // expected-error@below {{LLVM Translation failed for operation: llvm.call_intrinsic}} + llvm.call_intrinsic "llvm.does_not_exist"() : () -> () + llvm.return } // ----- llvm.func @bad_types() { - %0 = llvm.mlir.constant(1 : i8) : i8 - // expected-error@below {{'llvm.call_intrinsic' op intrinsic type is not a match}} - // expected-error@below {{LLVM Translation failed for operation: llvm.call_intrinsic}} - llvm.call_intrinsic "llvm.round"(%0) : (i8) -> i8 {} - llvm.return + %0 = llvm.mlir.constant(1 : i8) : i8 + // expected-error@below {{call intrinsic signature i8 (i8) to overloaded intrinsic "llvm.round" does not match any of the overloads}} + // expected-error@below {{LLVM Translation failed for operation: llvm.call_intrinsic}} + llvm.call_intrinsic "llvm.round"(%0) : (i8) -> i8 {} + llvm.return +} + +// ----- + +llvm.func @bad_result() { + // expected-error @below {{intrinsic call returns void but "llvm.x86.sse41.round.ss" actually returns <4 x float>}} + // expected-error@below {{LLVM Translation failed for operation: llvm.call_intrinsic}} + llvm.call_intrinsic "llvm.x86.sse41.round.ss"() : () -> () + llvm.return +} + +// ----- + +llvm.func @bad_result() { + // expected-error @below {{intrinsic call returns <8 x float> but "llvm.x86.sse41.round.ss" actually returns <4 x float>}} + // expected-error@below {{LLVM Translation failed for operation: llvm.call_intrinsic}} + llvm.call_intrinsic "llvm.x86.sse41.round.ss"() : () -> (vector<8xf32>) + llvm.return +} + +// ----- + +llvm.func @bad_args() { + // expected-error @below {{intrinsic call has 0 operands but "llvm.x86.sse41.round.ss" expects 3}} + // expected-error@below {{LLVM Translation failed for operation: llvm.call_intrinsic}} + llvm.call_intrinsic "llvm.x86.sse41.round.ss"() : () -> (vector<4xf32>) + llvm.return +} + +// ----- + +llvm.func @bad_args() { + %0 = llvm.mlir.constant(1 : i64) : i64 + %1 = llvm.mlir.constant(dense<0.2> : vector<4xf32>) : vector<4xf32> + // expected-error @below {{intrinsic call operand #2 has type i64 but "llvm.x86.sse41.round.ss" expects i32}} + // expected-error@below {{LLVM Translation failed for operation: llvm.call_intrinsic}} + %res = llvm.call_intrinsic "llvm.x86.sse41.round.ss"(%1, %1, %0) : (vector<4xf32>, vector<4xf32>, i64) -> vector<4xf32> {fastmathFlags = #llvm.fastmath} + llvm.return }