Index: mlir/include/mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h =================================================================== --- /dev/null +++ mlir/include/mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h @@ -0,0 +1,30 @@ +//===- OpenMPToLLVM.h - Utils to convert from the OpenMP dialect ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_CONVERSION_OPENMPTOLLVM_OPENMPTOLLVM_H_ +#define MLIR_CONVERSION_OPENMPTOLLVM_OPENMPTOLLVM_H_ + +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +class MLIRContext; +class ModuleOp; +template +class OperationPass; + +/// Populate the given list with patterns that convert from OpenMP to LLVM. +void populateOpenMPToLLVMConversionPatterns(MLIRContext *context, + LLVMTypeConverter &converter, + OwningRewritePatternList &patterns); + +/// Create a pass to convert OpenMP operations to the LLVMIR dialect. +std::unique_ptr> createConvertOpenMPToLLVMPass(); + +} // namespace mlir + +#endif // MLIR_CONVERSION_OPENMPTOLLVM_OPENMPTOLLVM_H_ Index: mlir/include/mlir/Conversion/Passes.h =================================================================== --- mlir/include/mlir/Conversion/Passes.h +++ mlir/include/mlir/Conversion/Passes.h @@ -19,6 +19,7 @@ #include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" #include "mlir/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.h" #include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h" +#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h" #include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h" #include "mlir/Conversion/SCFToStandard/SCFToStandard.h" #include "mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.h" Index: mlir/include/mlir/Conversion/Passes.td =================================================================== --- mlir/include/mlir/Conversion/Passes.td +++ mlir/include/mlir/Conversion/Passes.td @@ -189,6 +189,15 @@ } //===----------------------------------------------------------------------===// +// OpenMPToLLVM +//===----------------------------------------------------------------------===// + +def ConvertOpenMPToLLVM : Pass<"convert-openmp-to-llvm", "ModuleOp"> { + let summary = "Convert the OpenMP ops to OpenMP ops with LLVM dialect"; + let constructor = "mlir::createConvertOpenMPToLLVMPass()"; +} + +//===----------------------------------------------------------------------===// // SCFToStandard //===----------------------------------------------------------------------===// Index: mlir/lib/Conversion/CMakeLists.txt =================================================================== --- mlir/lib/Conversion/CMakeLists.txt +++ mlir/lib/Conversion/CMakeLists.txt @@ -8,6 +8,7 @@ add_subdirectory(LinalgToLLVM) add_subdirectory(LinalgToSPIRV) add_subdirectory(LinalgToStandard) +add_subdirectory(OpenMPToLLVM) add_subdirectory(SCFToGPU) add_subdirectory(SCFToSPIRV) add_subdirectory(SCFToStandard) Index: mlir/lib/Conversion/OpenMPToLLVM/CMakeLists.txt =================================================================== --- /dev/null +++ mlir/lib/Conversion/OpenMPToLLVM/CMakeLists.txt @@ -0,0 +1,20 @@ +add_mlir_conversion_library(MLIROpenMPToLLVM + OpenMPToLLVM.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/OpenMPToLLVM + + DEPENDS + MLIRConversionPassIncGen + intrinsics_gen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRIR + MLIRLLVMIR + MLIROpenMP + MLIRStandardToLLVM + MLIRTransforms + ) Index: mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp =================================================================== --- /dev/null +++ mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp @@ -0,0 +1,83 @@ +//===- OpenMPToLLVM.cpp - conversion from OpenMP to LLVM dialect ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h" + +#include "../PassDetail.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" + +using namespace mlir; + +struct ParallelOpConversion : public ConvertToLLVMPattern { + explicit ParallelOpConversion(MLIRContext *context, + LLVMTypeConverter &typeConverter) + : ConvertToLLVMPattern(omp::ParallelOp::getOperationName(), context, + typeConverter) {} + +protected: + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto curOp = cast(op); + auto newOp = rewriter.create( + curOp.getLoc(), curOp.getOperation()->getResultTypes(), + curOp.getOperation()->getOperands(), curOp.getOperation()->getAttrs()); + rewriter.inlineRegionBefore(curOp.region(), newOp.region(), + newOp.region().end()); + if (failed(rewriter.convertRegionTypes(&newOp.region(), typeConverter))) { + return failure(); + } + rewriter.eraseOp(op); + return success(); + } +}; + +void mlir::populateOpenMPToLLVMConversionPatterns( + MLIRContext *context, LLVMTypeConverter &converter, + OwningRewritePatternList &patterns) { + patterns.insert(context, converter); +} + +namespace { +struct ConvertOpenMPToLLVMPass + : public ConvertOpenMPToLLVMBase { + void runOnOperation() override; +}; +} // namespace + +void ConvertOpenMPToLLVMPass::runOnOperation() { + auto module = getOperation(); + MLIRContext *context = &getContext(); + + // Convert to OpenMP operations with LLVM IR dialect + OwningRewritePatternList patterns; + LLVMTypeConverter converter(&getContext()); + populateStdToLLVMConversionPatterns(converter, patterns); + populateOpenMPToLLVMConversionPatterns(context, converter, patterns); + + LLVMConversionTarget target(getContext()); + target.addDynamicallyLegalOp( + [&](omp::ParallelOp op) { return converter.isLegal(&op.getRegion()); }); + target.addLegalOp(); + if (failed(applyPartialConversion(module, target, patterns))) + signalPassFailure(); +} + +std::unique_ptr> mlir::createConvertOpenMPToLLVMPass() { + return std::make_unique(); +} Index: mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir =================================================================== --- /dev/null +++ mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir @@ -0,0 +1,79 @@ +// RUN: mlir-opt -convert-openmp-to-llvm %s -split-input-file | FileCheck %s + +// CHECK-LABEL: func @empty() { +// CHECK-NEXT: llvm.return +// CHECK-NEXT: } +func @empty() { +^bb0: + return +} + +// CHECK-LABEL: llvm.func @branch_loop() { +func @branch_loop() { +// CHECK-NEXT: %[[VAL0:[0-9]+]] = llvm.mlir.constant(0 : index) : !llvm.i64 + %start = constant 0 : index +// CHECK-NEXT: %[[VAL1:[0-9]+]] = llvm.mlir.constant(1 : index) : !llvm.i64 + %step = constant 1 : index +// CHECK-NEXT: %[[VAL2:[0-9]+]] = llvm.mlir.constant(100 : index) : !llvm.i64 + %end = constant 100 : index +// CHECK-NEXT: %[[VAL3:[0-9]+]] = llvm.mlir.constant(4 : index) : !llvm.i64 + %nthreads = constant 4 : index +// CHECK-NEXT: omp.parallel num_threads(%[[VAL3]] : !llvm.i64) private(%[[VAL0]] : !llvm.i64, %[[VAL2]] : !llvm.i64) proc_bind(spread) { + omp.parallel num_threads(%nthreads : index) private(%start : index, %end : index) proc_bind(spread) { +// CHECK-NEXT: llvm.br ^[[BB1:.*]](%[[VAL0]] : !llvm.i64) + br ^bb1(%start : index) +// CHECK-NEXT: ^[[BB1]](%[[VAL4:[0-9]+]]: !llvm.i64): + ^bb1(%0: index): +// CHECK-NEXT: %[[VAL5:[0-9]+]] = llvm.icmp "slt" %[[VAL4]], %[[VAL2]] : !llvm.i64 + %1 = cmpi "slt", %0, %end : index +// CHECK-NEXT: llvm.cond_br %[[VAL5]], ^[[BB2:.*]], ^[[BB3:.*]] + cond_br %1, ^bb2, ^bb3 +// CHECK-NEXT: ^[[BB2]]: + ^bb2: +// CHECK-NEXT: %[[VAL6:[0-9]+]] = llvm.trunc %[[VAL4]] : !llvm.i64 to !llvm.i32 + %2 = index_cast %0 : index to i32 +// CHECK-NEXT: llvm.call @print_i32(%[[VAL6]]) : (!llvm.i32) -> () + call @print_i32(%2) : (i32) -> () +// CHECK-NEXT: %[[VAL7:[0-9]+]] = llvm.add %[[VAL4]], %[[VAL1]] : !llvm.i64 + %3 = addi %0, %step : index +// CHECK-NEXT: llvm.br ^bb1(%[[VAL7]] : !llvm.i64) + br ^bb1(%3 : index) +// CHECK-NEXT: ^[[BB3]]: + ^bb3: +// CHECK-NEXT: omp.terminator + omp.terminator +// CHECK-NEXT: } + } +// CHECK-NEXT: llvm.return + return +// CHECK-NEXT: } +} +// CHECK-NEXT: llvm.func @print_i32(!llvm.i32) +func @print_i32(i32) + +// CHECK: llvm.func @test_constructs(%[[FN_ARG0:.*]]: !llvm.i32, %[[FN_ARG1:.*]]: !llvm.i32) +func @test_constructs(%arg0: i32, %arg1: i32) { +// CHECK-NEXT: omp.parallel shared(%[[FN_ARG0]] : !llvm.i32, %[[FN_ARG1]] : !llvm.i32) { + omp.parallel shared(%arg0: i32, %arg1: i32) { +// CHECK-NEXT: omp.flush + omp.flush +// CHECK-NEXT: omp.barrier + omp.barrier +// CHECK-NEXT: llvm.br ^[[BB1:.*]](%[[FN_ARG0]], %[[FN_ARG1]] : !llvm.i32, !llvm.i32) + br ^bb1(%arg0, %arg1 : i32, i32) +// CHECK-NEXT: ^[[BB1]](%[[BB_ARG0:.*]]: !llvm.i32, %[[BB_ARG1:.*]]: !llvm.i32): // pred: ^bb0 + ^bb1(%0: i32, %1: i32): +// CHECK-NEXT: omp.flush %[[BB_ARG0]], %[[BB_ARG1]] : !llvm.i32, !llvm.i32 + omp.flush %0, %1 : i32, i32 +// CHECK-NEXT: omp.taskyield + omp.taskyield +// CHECK-NEXT: omp.taskwait + omp.taskwait +// CHECK-NEXT: omp.terminator + omp.terminator +// CHECK-NEXT: } + } +// CHECK-NEXT: llvm.return + return +// CHECK-NEXT: } +}