diff --git a/mlir/lib/Dialect/GPU/CMakeLists.txt b/mlir/lib/Dialect/GPU/CMakeLists.txt --- a/mlir/lib/Dialect/GPU/CMakeLists.txt +++ b/mlir/lib/Dialect/GPU/CMakeLists.txt @@ -38,7 +38,6 @@ MLIRMemRef MLIRSideEffectInterfaces MLIRSupport - MLIRLLVMIR ) add_mlir_dialect_library(MLIRGPUTransforms diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -13,7 +13,6 @@ #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" @@ -226,21 +225,28 @@ // Check that `launch_func` refers to a well-formed kernel function. Operation *kernelFunc = module.lookupSymbol(launchOp.kernelAttr()); - auto kernelGPUFunction = dyn_cast_or_null(kernelFunc); - auto kernelLLVMFunction = dyn_cast_or_null(kernelFunc); - if (!kernelGPUFunction && !kernelLLVMFunction) + if (!kernelFunc) return launchOp.emitOpError("kernel function '") << launchOp.kernel() << "' is undefined"; + auto kernelConvertedFunction = dyn_cast(kernelFunc); + if (!kernelConvertedFunction) { + InFlightDiagnostic diag = launchOp.emitOpError() + << "referenced kernel '" << launchOp.kernel() + << "' is not a function"; + diag.attachNote(kernelFunc->getLoc()) << "see the kernel definition here"; + return diag; + } + if (!kernelFunc->getAttrOfType( GPUDialect::getKernelFuncAttrName())) return launchOp.emitOpError("kernel function is missing the '") << GPUDialect::getKernelFuncAttrName() << "' attribute"; - // TODO: if the kernel function has been converted to - // the LLVM dialect but the caller hasn't (which happens during the - // separate compilation), do not check type correspondence as it would - // require the verifier to be aware of the LLVM type conversion. - if (kernelLLVMFunction) + // TODO: If the kernel isn't a GPU function (which happens during separate + // compilation), do not check type correspondence as it would require the + // verifier to be aware of the type conversion. + auto kernelGPUFunction = dyn_cast(kernelFunc); + if (!kernelGPUFunction) return success(); unsigned actualNumArguments = launchOp.getNumKernelOperands(); diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir --- a/mlir/test/Dialect/GPU/invalid.mlir +++ b/mlir/test/Dialect/GPU/invalid.mlir @@ -120,6 +120,21 @@ // ----- +module attributes {gpu.container_module} { + gpu.module @kernels { + // expected-note@+1 {{see the kernel definition here}} + memref.global "private" @kernel_1 : memref<4xi32> + } + + func @launch_func_undefined_function(%sz : index) { + // expected-error@+1 {{referenced kernel '@kernels::@kernel_1' is not a function}} + gpu.launch_func @kernels::@kernel_1 blocks in (%sz, %sz, %sz) threads in (%sz, %sz, %sz) + return + } +} + +// ----- + module attributes {gpu.container_module} { module @kernels { gpu.func @kernel_1(%arg1 : !llvm.ptr) kernel {