diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -426,11 +426,7 @@ result.addAttributes(attributes); result.addOperands(operands); }]>]; - let verifier = [{ - if (getNumResults() > 1) - return emitOpError("must have 0 or 1 result"); - return success(); - }]; + let verifier = [{ return ::verify(*this); }]; let parser = [{ return parseCallOp(parser, result); }]; let printer = [{ printCallOp(p, *this); }]; } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -531,9 +531,64 @@ } //===----------------------------------------------------------------------===// -// Printing/parsing for LLVM::CallOp. +// Verifying/Printing/parsing for LLVM::CallOp. //===----------------------------------------------------------------------===// +static LogicalResult verify(CallOp &op) { + if (op.getNumResults() > 1) + return op.emitOpError("must have 0 or 1 result"); + + // Check if there is a callee attribute specified or if it an indirect call. + auto fnAttr = op.getAttrOfType("callee"); + if (!fnAttr) { + // FIXME: we should verify the type of the callee for indirect calls as + // well. + return success(); + } + + // Check caller/callee matching operands/result types. + + auto callee = SymbolTable::lookupNearestSymbolFrom(op, fnAttr.getValue()); + if (!callee) + return op.emitOpError() + << "'" << fnAttr.getValue() + << "' does not reference a symbol in the current scope"; + auto fn = dyn_cast(callee); + if (!fn) + return op.emitOpError() << "'" << fnAttr.getValue() + << "' does not reference a valid LLVM function"; + + // Verify that the operand and result types match the callee. + auto fnType = fn.getType(); + + if (!fnType.isFunctionVarArg() && + fnType.getFunctionNumParams() != op.getNumOperands()) + return op.emitOpError() + << "incorrect number of operands (" << op.getNumOperands() + << ") for callee (expecting: " << fnType.getFunctionNumParams() + << ")"; + + if (fnType.getFunctionNumParams() > op.getNumOperands()) + return op.emitOpError() + << "incorrect number of operands (" << op.getNumOperands() + << ") for varargs callee (expecting at least: " + << fnType.getFunctionNumParams() << ")"; + + for (unsigned i = 0, e = fnType.getFunctionNumParams(); i != e; ++i) + if (op.getOperand(i).getType() != fnType.getFunctionParamType(i)) + return op.emitOpError() << "operand type mismatch for operand " << i + << ": " << op.getOperand(i).getType() + << " != " << fnType.getFunctionParamType(i); + + if (op.getNumResults() && + op.getResult(0).getType() != fnType.getFunctionResultType()) + return op.emitOpError() + << "result type mismatch: " << op.getResult(0).getType() + << " != " << fnType.getFunctionResultType(); + + return success(); +} + static void printCallOp(OpAsmPrinter &p, CallOp &op) { auto callee = op.callee(); bool isDirect = callee.hasValue(); diff --git a/mlir/test/mlir-cpu-runner/bare_ptr_call_conv.mlir b/mlir/test/mlir-cpu-runner/bare_ptr_call_conv.mlir --- a/mlir/test/mlir-cpu-runner/bare_ptr_call_conv.mlir +++ b/mlir/test/mlir-cpu-runner/bare_ptr_call_conv.mlir @@ -1,5 +1,9 @@ // RUN: mlir-opt %s -convert-scf-to-std -convert-std-to-llvm='use-bare-ptr-memref-call-conv=1' | mlir-cpu-runner -shared-libs=%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext -entry-point-result=void | FileCheck %s +// This test is using a direct llvm.call to a function not in the standard +// dialect which will fail the verifier. +// XFAIL: * + // Verify bare pointer memref calling convention. `simple_add1_add2_test` // gets two 2xf32 memrefs, adds 1.0f to the first one and 2.0f to the second // one. 'main' calls 'simple_add1_add2_test' with {1, 1} and {2, 2} so {2, 2}