diff --git a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h --- a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h +++ b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h @@ -22,7 +22,7 @@ class GPUModuleOp; } -/// Configure target to convert from to convert from the GPU dialect to NVVM. +/// Configure target to convert from the GPU dialect to NVVM. void configureGpuToNVVMConversionLegality(ConversionTarget &target); /// Collect a set of patterns to convert from the GPU dialect to NVVM. diff --git a/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h b/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h --- a/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h +++ b/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h @@ -14,6 +14,7 @@ namespace mlir { class LLVMTypeConverter; class OwningRewritePatternList; +class ConversionTarget; template class OperationPass; @@ -26,6 +27,9 @@ void populateGpuToROCDLConversionPatterns(LLVMTypeConverter &converter, OwningRewritePatternList &patterns); +/// Configure target to convert from the GPU dialect to ROCDL. +void configureGpuToROCDLConversionLegality(ConversionTarget &target); + /// Creates a pass that lowers GPU dialect operations to ROCDL counterparts. The /// index bitwidth used for the lowering of the device side index computations /// is configurable. diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp --- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp +++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp @@ -69,14 +69,7 @@ populateStdToLLVMConversionPatterns(converter, llvmPatterns); populateGpuToROCDLConversionPatterns(converter, llvmPatterns); LLVMConversionTarget target(getContext()); - target.addIllegalDialect(); - target.addIllegalOp(); - target.addIllegalOp(); - target.addLegalDialect(); - // TODO: Remove once we support replacing non-root ops. - target.addLegalOp(); + configureGpuToROCDLConversionLegality(target); if (failed(applyPartialConversion(m, target, std::move(llvmPatterns)))) signalPassFailure(); } @@ -84,6 +77,19 @@ } // anonymous namespace +void mlir::configureGpuToROCDLConversionLegality(ConversionTarget &target) { + target.addIllegalOp(); + target.addLegalDialect<::mlir::LLVM::LLVMDialect>(); + target.addLegalDialect(); + target.addIllegalDialect(); + target.addIllegalOp(); + + // TODO: Remove once we support replacing non-root ops. + target.addLegalOp(); +} + void mlir::populateGpuToROCDLConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { populateWithGenerated(converter.getDialect()->getContext(), patterns);