diff --git a/mlir/include/mlir/Conversion/SCFToGPU/SCFToGPU.h b/mlir/include/mlir/Conversion/SCFToGPU/SCFToGPU.h --- a/mlir/include/mlir/Conversion/SCFToGPU/SCFToGPU.h +++ b/mlir/include/mlir/Conversion/SCFToGPU/SCFToGPU.h @@ -12,9 +12,10 @@ namespace mlir { class AffineForOp; +class ConversionTarget; +struct LogicalResult; class MLIRContext; class OwningRewritePatternList; -struct LogicalResult; class Value; namespace scf { @@ -44,6 +45,10 @@ void populateParallelLoopToGPUPatterns(OwningRewritePatternList &patterns, MLIRContext *ctx); +/// Configures the rewrite target such that only `scf.parallel` operations that +/// are not rewritten by the provided patterns are legal. +void configureParallelLoopToGPULegality(ConversionTarget &target); + } // namespace mlir #endif // MLIR_CONVERSION_SCFTOGPU_SCFTOGPU_H_ diff --git a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp --- a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp +++ b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp @@ -640,3 +640,9 @@ MLIRContext *ctx) { patterns.insert(ctx); } + +void mlir::configureParallelLoopToGPULegality(ConversionTarget &target) { + target.addDynamicallyLegalOp([](scf::ParallelOp parallelOp) { + return !parallelOp.getAttr(gpu::getMappingAttrName()); + }); +} diff --git a/mlir/lib/Conversion/SCFToGPU/SCFToGPUPass.cpp b/mlir/lib/Conversion/SCFToGPU/SCFToGPUPass.cpp --- a/mlir/lib/Conversion/SCFToGPU/SCFToGPUPass.cpp +++ b/mlir/lib/Conversion/SCFToGPU/SCFToGPUPass.cpp @@ -53,7 +53,7 @@ target.addLegalDialect(); target.addLegalDialect(); target.addLegalDialect(); - target.addIllegalOp(); + configureParallelLoopToGPULegality(target); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) signalPassFailure();