diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -426,11 +426,7 @@ result.addAttributes(attributes); result.addOperands(operands); }]>]; - let verifier = [{ - if (getNumResults() > 1) - return emitOpError("must have 0 or 1 result"); - return success(); - }]; + let verifier = [{ return ::verify(*this); }]; let parser = [{ return parseCallOp(parser, result); }]; let printer = [{ printCallOp(p, *this); }]; } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -531,9 +531,83 @@ } //===----------------------------------------------------------------------===// -// Printing/parsing for LLVM::CallOp. +// Verifying/Printing/parsing for LLVM::CallOp. //===----------------------------------------------------------------------===// +static LogicalResult verify(CallOp &op) { + if (op.getNumResults() > 1) + return op.emitOpError("must have 0 or 1 result"); + + // Type for the callee, we'll get it differently depending if it is a direct + // or indirect call. + LLVMType fnType; + + bool isIndirect = false; + + // If this is an indirect call, the callee attribute is missing. + Optional calleeName = op.callee(); + if (!calleeName) { + isIndirect = true; + if (!op.getNumOperands()) + return op.emitOpError( + "must have either a `callee` attribute or at least an operand"); + fnType = op.getOperand(0).getType().dyn_cast(); + if (!fnType) + return op.emitOpError("indirect call to a non-llvm type: ") + << op.getOperand(0).getType(); + auto ptrType = fnType.dyn_cast(); + if (!ptrType) + return op.emitOpError("indirect call expects a pointer as callee: ") + << fnType; + fnType = ptrType.getElementType(); + } else { + Operation *callee = SymbolTable::lookupNearestSymbolFrom(op, *calleeName); + if (!callee) + return op.emitOpError() + << "'" << *calleeName + << "' does not reference a symbol in the current scope"; + auto fn = dyn_cast(callee); + if (!fn) + return op.emitOpError() << "'" << *calleeName + << "' does not reference a valid LLVM function"; + + fnType = fn.getType(); + } + if (!fnType.isFunctionTy()) + return op.emitOpError("callee does not have a functional type: ") << fnType; + + // Verify that the operand and result types match the callee. + + if (!fnType.isFunctionVarArg() && + fnType.getFunctionNumParams() != (op.getNumOperands() - isIndirect)) + return op.emitOpError() + << "incorrect number of operands (" + << (op.getNumOperands() - isIndirect) + << ") for callee (expecting: " << fnType.getFunctionNumParams() + << ")"; + + if (fnType.getFunctionNumParams() > (op.getNumOperands() - isIndirect)) + return op.emitOpError() << "incorrect number of operands (" + << (op.getNumOperands() - isIndirect) + << ") for varargs callee (expecting at least: " + << fnType.getFunctionNumParams() << ")"; + + for (unsigned i = 0, e = fnType.getFunctionNumParams(); i != e; ++i) + if (op.getOperand(i + isIndirect).getType() != + fnType.getFunctionParamType(i)) + return op.emitOpError() << "operand type mismatch for operand " << i + << ": " << op.getOperand(i + isIndirect).getType() + << " != " << fnType.getFunctionParamType(i); + + if (op.getNumResults() && + op.getResult(0).getType() != fnType.getFunctionResultType()) + return op.emitOpError() + << "result type mismatch: " << op.getResult(0).getType() + << " != " << fnType.getFunctionResultType(); + + return success(); +} + static void printCallOp(OpAsmPrinter &p, CallOp &op) { auto callee = op.callee(); bool isDirect = callee.hasValue(); diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -125,6 +125,75 @@ // ----- +func @invalid_call() { + // expected-error@+1 {{'llvm.call' op must have either a `callee` attribute or at least an operand}} + "llvm.call"() : () -> () +} + +// ----- + +func @call_non_function_type(%callee : !llvm.func, %arg : !llvm.i8) { + // expected-error@+1 {{expected function type}} + llvm.call %callee(%arg) : !llvm.func +} + +// ----- + +func @call_unknown_symbol() { + // expected-error@+1 {{'llvm.call' op 'missing_callee' does not reference a symbol in the current scope}} + llvm.call @missing_callee() : () -> () +} + +// ----- + +func @standard_func_callee() + +func @call_non_llvm() { + // expected-error@+1 {{'llvm.call' op 'standard_func_callee' does not reference a valid LLVM function}} + llvm.call @standard_func_callee() : () -> () +} + +// ----- + +func @call_non_llvm_indirect(%arg0 : i32) { + // expected-error@+1 {{'llvm.call' op operand #0 must be LLVM dialect type, but got 'i32'}} + "llvm.call"(%arg0) : (i32) -> () +} + +// ----- + +llvm.func @callee_func(!llvm.i8) -> () + +func @callee_arg_mismatch(%arg0 : !llvm.i32) { + // expected-error@+1 {{'llvm.call' op operand type mismatch for operand 0: '!llvm.i32' != '!llvm.i8'}} + llvm.call @callee_func(%arg0) : (!llvm.i32) -> () +} + +// ----- + +func @indirect_callee_arg_mismatch(%arg0 : !llvm.i32, %callee : !llvm.ptr>) { + // expected-error@+1 {{'llvm.call' op operand type mismatch for operand 0: '!llvm.i32' != '!llvm.i8'}} + "llvm.call"(%callee, %arg0) : (!llvm.ptr>, !llvm.i32) -> () +} + +// ----- + +llvm.func @callee_func() -> (!llvm.i8) + +func @callee_return_mismatch() { + // expected-error@+1 {{'llvm.call' op result type mismatch: '!llvm.i32' != '!llvm.i8'}} + %res = llvm.call @callee_func() : () -> (!llvm.i32) +} + +// ----- + +func @indirect_callee_return_mismatch(%callee : !llvm.ptr>) { + // expected-error@+1 {{'llvm.call' op result type mismatch: '!llvm.i32' != '!llvm.i8'}} + "llvm.call"(%callee) : (!llvm.ptr>) -> (!llvm.i32) +} + +// ----- + func @call_too_many_results(%callee : () -> (i32,i32)) { // expected-error@+1 {{expected function with 0 or 1 result}} llvm.call %callee() : () -> (i32, i32)