diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -426,11 +426,7 @@ result.addAttributes(attributes); result.addOperands(operands); }]>]; - let verifier = [{ - if (getNumResults() > 1) - return emitOpError("must have 0 or 1 result"); - return success(); - }]; + let verifier = [{ return ::verify(*this); }]; let parser = [{ return parseCallOp(parser, result); }]; let printer = [{ printCallOp(p, *this); }]; } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -531,9 +531,65 @@ } //===----------------------------------------------------------------------===// -// Printing/parsing for LLVM::CallOp. +// Verifying/Printing/parsing for LLVM::CallOp. //===----------------------------------------------------------------------===// +static LogicalResult verify(CallOp &op) { + if (op.getNumResults() > 1) + return op.emitOpError("must have 0 or 1 result"); + + // Check if there is a callee attribute specified or if it an indirect call. + auto fnAttr = op.getAttrOfType("callee"); + if (!fnAttr) { + // FIXME: we should verify the type of the callee for indirect calls as + // well. + return success(); + } + + // Check caller/callee matching operands/result types. + + Operation *callee = + SymbolTable::lookupNearestSymbolFrom(op, fnAttr.getValue()); + if (!callee) + return op.emitOpError() + << "'" << fnAttr.getValue() + << "' does not reference a symbol in the current scope"; + auto fn = dyn_cast(callee); + if (!fn) + return op.emitOpError() << "'" << fnAttr.getValue() + << "' does not reference a valid LLVM function"; + + // Verify that the operand and result types match the callee. + LLVMType fnType = fn.getType(); + + if (!fnType.isFunctionVarArg() && + fnType.getFunctionNumParams() != op.getNumOperands()) + return op.emitOpError() + << "incorrect number of operands (" << op.getNumOperands() + << ") for callee (expecting: " << fnType.getFunctionNumParams() + << ")"; + + if (fnType.getFunctionNumParams() > op.getNumOperands()) + return op.emitOpError() + << "incorrect number of operands (" << op.getNumOperands() + << ") for varargs callee (expecting at least: " + << fnType.getFunctionNumParams() << ")"; + + for (unsigned i = 0, e = fnType.getFunctionNumParams(); i != e; ++i) + if (op.getOperand(i).getType() != fnType.getFunctionParamType(i)) + return op.emitOpError() << "operand type mismatch for operand " << i + << ": " << op.getOperand(i).getType() + << " != " << fnType.getFunctionParamType(i); + + if (op.getNumResults() && + op.getResult(0).getType() != fnType.getFunctionResultType()) + return op.emitOpError() + << "result type mismatch: " << op.getResult(0).getType() + << " != " << fnType.getFunctionResultType(); + + return success(); +} + static void printCallOp(OpAsmPrinter &p, CallOp &op) { auto callee = op.callee(); bool isDirect = callee.hasValue();