diff --git a/mlir/lib/Dialect/GPU/CMakeLists.txt b/mlir/lib/Dialect/GPU/CMakeLists.txt --- a/mlir/lib/Dialect/GPU/CMakeLists.txt +++ b/mlir/lib/Dialect/GPU/CMakeLists.txt @@ -38,7 +38,6 @@ MLIRMemRef MLIRSideEffectInterfaces MLIRSupport - MLIRLLVMIR ) add_mlir_dialect_library(MLIRGPUTransforms diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -13,7 +13,6 @@ #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" @@ -227,8 +226,14 @@ // Check that `launch_func` refers to a well-formed kernel function. Operation *kernelFunc = module.lookupSymbol(launchOp.kernelAttr()); auto kernelGPUFunction = dyn_cast_or_null(kernelFunc); - auto kernelLLVMFunction = dyn_cast_or_null(kernelFunc); - if (!kernelGPUFunction && !kernelLLVMFunction) + + // If the kernel isn't a GPU function, check to see that it is at least a + // function of a some kind. This allows for handling when the kernel + // function is in another form mid-conversion. + auto kernelConvertedFunction = + kernelGPUFunction ? nullptr + : dyn_cast_or_null(kernelFunc); + if (!kernelGPUFunction && !kernelConvertedFunction) return launchOp.emitOpError("kernel function '") << launchOp.kernel() << "' is undefined"; if (!kernelFunc->getAttrOfType( @@ -236,11 +241,11 @@ return launchOp.emitOpError("kernel function is missing the '") << GPUDialect::getKernelFuncAttrName() << "' attribute"; - // TODO: if the kernel function has been converted to - // the LLVM dialect but the caller hasn't (which happens during the - // separate compilation), do not check type correspondence as it would - // require the verifier to be aware of the LLVM type conversion. - if (kernelLLVMFunction) + // TODO: if the kernel function has been converted already but the caller + // hasn't (which happens during separate compilation), do not check type + // correspondence as it would require the verifier to be aware of the type + // conversion. + if (kernelConvertedFunction) return success(); unsigned actualNumArguments = launchOp.getNumKernelOperands();