Index: mlir/include/mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h =================================================================== --- /dev/null +++ mlir/include/mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h @@ -0,0 +1,31 @@ +//===- OpenMPToLLVM.h - Utils to convert from the OpenMP dialect ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_CONVERSION_OPENMPTOLLVM_OPENMPTOLLVM_H_ +#define MLIR_CONVERSION_OPENMPTOLLVM_OPENMPTOLLVM_H_ + +#include + +namespace mlir { +class LLVMTypeConverter; +class MLIRContext; +class ModuleOp; +template +class OperationPass; +class OwningRewritePatternList; + +/// Populate the given list with patterns that convert from OpenMP to LLVM. +void populateOpenMPToLLVMConversionPatterns(MLIRContext *context, + LLVMTypeConverter &converter, + OwningRewritePatternList &patterns); + +/// Create a pass to convert OpenMP operations to the LLVMIR dialect. +std::unique_ptr> createConvertOpenMPToLLVMPass(); + +} // namespace mlir + +#endif // MLIR_CONVERSION_OPENMPTOLLVM_OPENMPTOLLVM_H_ Index: mlir/include/mlir/Conversion/Passes.h =================================================================== --- mlir/include/mlir/Conversion/Passes.h +++ mlir/include/mlir/Conversion/Passes.h @@ -19,6 +19,7 @@ #include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" #include "mlir/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.h" #include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h" +#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h" #include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h" #include "mlir/Conversion/SCFToStandard/SCFToStandard.h" #include "mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.h" Index: mlir/include/mlir/Conversion/Passes.td =================================================================== --- mlir/include/mlir/Conversion/Passes.td +++ mlir/include/mlir/Conversion/Passes.td @@ -189,6 +189,16 @@ } //===----------------------------------------------------------------------===// +// OpenMPToLLVM +//===----------------------------------------------------------------------===// + +def ConvertOpenMPToLLVM : Pass<"convert-openmp-to-llvm", "ModuleOp"> { + let summary = "Convert the OpenMP ops to OpenMP ops with LLVM dialect"; + let constructor = "mlir::createConvertOpenMPToLLVMPass()"; + let dependentDialects = ["LLVM::LLVMDialect"]; +} + +//===----------------------------------------------------------------------===// // SCFToStandard //===----------------------------------------------------------------------===// Index: mlir/lib/Conversion/CMakeLists.txt =================================================================== --- mlir/lib/Conversion/CMakeLists.txt +++ mlir/lib/Conversion/CMakeLists.txt @@ -8,6 +8,7 @@ add_subdirectory(LinalgToLLVM) add_subdirectory(LinalgToSPIRV) add_subdirectory(LinalgToStandard) +add_subdirectory(OpenMPToLLVM) add_subdirectory(SCFToGPU) add_subdirectory(SCFToSPIRV) add_subdirectory(SCFToStandard) Index: mlir/lib/Conversion/OpenMPToLLVM/CMakeLists.txt =================================================================== --- /dev/null +++ mlir/lib/Conversion/OpenMPToLLVM/CMakeLists.txt @@ -0,0 +1,20 @@ +add_mlir_conversion_library(MLIROpenMPToLLVM + OpenMPToLLVM.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/OpenMPToLLVM + + DEPENDS + MLIRConversionPassIncGen + intrinsics_gen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRIR + MLIRLLVMIR + MLIROpenMP + MLIRStandardToLLVM + MLIRTransforms + ) Index: mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp =================================================================== --- /dev/null +++ mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp @@ -0,0 +1,76 @@ +//===- OpenMPToLLVM.cpp - conversion from OpenMP to LLVM dialect ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h" + +#include "../PassDetail.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" + +using namespace mlir; + +namespace { +struct ParallelOpConversion : public ConvertToLLVMPattern { + explicit ParallelOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : ConvertToLLVMPattern(omp::ParallelOp::getOperationName(), context, + typeConverter) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto curOp = cast(op); + auto newOp = rewriter.create( + curOp.getLoc(), ArrayRef(), operands, curOp.getAttrs()); + rewriter.inlineRegionBefore(curOp.region(), newOp.region(), + newOp.region().end()); + if (failed(rewriter.convertRegionTypes(&newOp.region(), typeConverter))) + return failure(); + + rewriter.eraseOp(op); + return success(); + } +}; +} // namespace + +void mlir::populateOpenMPToLLVMConversionPatterns( + MLIRContext *context, LLVMTypeConverter &converter, + OwningRewritePatternList &patterns) { + patterns.insert(context, converter); +} + +namespace { +struct ConvertOpenMPToLLVMPass + : public ConvertOpenMPToLLVMBase { + void runOnOperation() override; +}; +} // namespace + +void ConvertOpenMPToLLVMPass::runOnOperation() { + auto module = getOperation(); + MLIRContext *context = &getContext(); + + // Convert to OpenMP operations with LLVM IR dialect + OwningRewritePatternList patterns; + LLVMTypeConverter converter(&getContext()); + populateStdToLLVMConversionPatterns(converter, patterns); + populateOpenMPToLLVMConversionPatterns(context, converter, patterns); + + LLVMConversionTarget target(getContext()); + target.addDynamicallyLegalOp( + [&](omp::ParallelOp op) { return converter.isLegal(&op.getRegion()); }); + target.addLegalOp(); + if (failed(applyPartialConversion(module, target, patterns))) + signalPassFailure(); +} + +std::unique_ptr> mlir::createConvertOpenMPToLLVMPass() { + return std::make_unique(); +} Index: mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir =================================================================== --- /dev/null +++ mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir @@ -0,0 +1,30 @@ +// RUN: mlir-opt -convert-openmp-to-llvm %s -split-input-file | FileCheck %s + +// CHECK-LABEL: llvm.func @branch_loop +func @branch_loop() { + %start = constant 0 : index + %end = constant 0 : index + // CHECK: omp.parallel + omp.parallel { + // CHECK-NEXT: llvm.br ^[[BB1:.*]](%{{[0-9]+}}, %{{[0-9]+}} : !llvm.i64, !llvm.i64 + br ^bb1(%start, %end : index, index) + // CHECK-NEXT: ^[[BB1]](%[[ARG1:[0-9]+]]: !llvm.i64, %[[ARG2:[0-9]+]]: !llvm.i64):{{.*}} + ^bb1(%0: index, %1: index): + // CHECK-NEXT: %[[CMP:[0-9]+]] = llvm.icmp "slt" %[[ARG1]], %[[ARG2]] : !llvm.i64 + %2 = cmpi "slt", %0, %1 : index + // CHECK-NEXT: llvm.cond_br %[[CMP]], ^[[BB2:.*]](%{{[0-9]+}}, %{{[0-9]+}} : !llvm.i64, !llvm.i64), ^[[BB3:.*]] + cond_br %2, ^bb2(%end, %end : index, index), ^bb3 + // CHECK-NEXT: ^[[BB2]](%[[ARG3:[0-9]+]]: !llvm.i64, %[[ARG4:[0-9]+]]: !llvm.i64): + ^bb2(%3: index, %4: index): + // CHECK-NEXT: llvm.br ^[[BB1]](%[[ARG3]], %[[ARG4]] : !llvm.i64, !llvm.i64) + br ^bb1(%3, %4 : index, index) + // CHECK-NEXT: ^[[BB3]]: + ^bb3: + omp.flush + omp.barrier + omp.taskwait + omp.taskyield + omp.terminator + } + return +}