diff --git a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h --- a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h +++ b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h @@ -14,6 +14,7 @@ namespace mlir { class LLVMTypeConverter; class OwningRewritePatternList; +class ConversionTarget; template class OperationPass; @@ -21,6 +22,9 @@ class GPUModuleOp; } +/// Configure target to convert from to convert from the GPU dialect to NVVM. +void configureGpuToNVVMConversionLegality(ConversionTarget &target); + /// Collect a set of patterns to convert from the GPU dialect to NVVM. void populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter, OwningRewritePatternList &patterns); diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -135,14 +135,7 @@ populateStdToLLVMConversionPatterns(converter, llvmPatterns); populateGpuToNVVMConversionPatterns(converter, llvmPatterns); LLVMConversionTarget target(getContext()); - target.addIllegalDialect(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addLegalDialect(); - // TODO: Remove once we support replacing non-root ops. - target.addLegalOp(); + configureGpuToNVVMConversionLegality(target); if (failed(applyPartialConversion(m, target, std::move(llvmPatterns)))) signalPassFailure(); } @@ -150,6 +143,19 @@ } // anonymous namespace +void mlir::configureGpuToNVVMConversionLegality(ConversionTarget &target) { + target.addIllegalOp(); + target.addLegalDialect<::mlir::LLVM::LLVMDialect>(); + target.addLegalDialect<::mlir::NVVM::NVVMDialect>(); + target.addIllegalDialect(); + target.addIllegalOp(); + + // TODO: Remove once we support replacing non-root ops. + target.addLegalOp(); +} + void mlir::populateGpuToNVVMConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { populateWithGenerated(converter.getDialect()->getContext(), patterns);