diff --git a/mlir/include/mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h b/mlir/include/mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h --- a/mlir/include/mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h +++ b/mlir/include/mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h @@ -26,6 +26,10 @@ /// Create a pass to convert OpenMP operations to the LLVMIR dialect. std::unique_ptr> createConvertOpenMPToLLVMPass(); +/// Create a pass to capture "above-defined" omp::ParallelOp parameters. +std::unique_ptr> +createCaptureOpenMPParallelParametersPass(); + } // namespace mlir #endif // MLIR_CONVERSION_OPENMPTOLLVM_OPENMPTOLLVM_H_ diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -220,6 +220,16 @@ let dependentDialects = ["LLVM::LLVMDialect"]; } +//===----------------------------------------------------------------------===// +// CaptureOpenMPParallelParameters +//===----------------------------------------------------------------------===// + +def CaptureOpenMPParallelParameters : Pass<"capture-openmp-parallel-parameters", "ModuleOp"> { + let summary = "Captures the omp::ParallelOp \\\"above-defined\\\" parameters."; + let constructor = "mlir::createCaptureOpenMPParallelParametersPass()"; + let dependentDialects = ["LLVM::LLVMDialect"]; +} + //===----------------------------------------------------------------------===// // PDLToPDLInterp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp --- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp +++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp @@ -12,10 +12,82 @@ #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" +#include "mlir/Transforms/RegionUtils.h" + +#include "llvm/ADT/SetVector.h" using namespace mlir; namespace { +/// Pass to capture the parameters to omp::ParallelOp +struct CaptureOpenMPParametersPass + : public PassWrapper> { + void runOnOperation() override; + CaptureOpenMPParametersPass() = default; + CaptureOpenMPParametersPass(const CaptureOpenMPParametersPass &) {} + +private: + // Captures "above-defined" parameters passed to omp::ParaqllelOp and wraps + // them in a struct to make sure the varargs are passed properly to the + // synthetically generated fork function. + void captureOmpParallelParams(omp::ParallelOp parOp) const { + OpBuilder builder{parOp.getContext()}; + llvm::SetVector values; + + visitUsedValuesDefinedAbove({parOp.region()}, [&](OpOperand *opOperand) { + auto value = opOperand->get(); + + // If it's not an LLVM type, or if it's an LLVM pointer type, we + // don't need or want to capture this value in via a struct. + auto llvmType = value.getType().dyn_cast(); + if (!llvmType || llvmType.isPointerTy()) + return; + + // Otherwise, we need to capture the value through an alloca'd + // struct. + values.insert(value); + }); + + if (!values.size()) + return; + + // Build the structure. + builder.setInsertionPoint(parOp); + LLVM::LLVMType structTy; + { + SmallVector types; + for (auto val : values) { + types.push_back(val.getType().cast()); + } + structTy = LLVM::LLVMType::getStructTy(parOp.getContext(), types); + } + auto structPtrTy = structTy.getPointerTo(); + auto numElements = builder.create( + parOp.getLoc(), LLVM::LLVMType::getInt64Ty(parOp.getContext()), + builder.getIndexAttr(1)); + auto structPtr = builder.create(parOp.getLoc(), structPtrTy, + numElements, 0); + Value srcStructVal = + builder.create(parOp.getLoc(), structTy); + for (auto srcIdx : llvm::enumerate(values)) { + srcStructVal = builder.create( + parOp.getLoc(), srcStructVal, srcIdx.value(), + builder.getI64ArrayAttr(srcIdx.index())); + } + builder.create(parOp.getLoc(), srcStructVal, structPtr); + + // Unpack the structure, rewriting the affected values. + builder.setInsertionPointToStart(&parOp.region().front()); + auto dstStructVal = builder.create(parOp.getLoc(), structPtr); + for (auto srcIdx : llvm::enumerate(values)) { + auto capturedValue = builder.create( + parOp.getLoc(), srcIdx.value().getType(), dstStructVal, + builder.getI64ArrayAttr(srcIdx.index())); + replaceAllUsesInRegionWith(srcIdx.value(), capturedValue, parOp.region()); + } + } +}; + struct ParallelOpConversion : public ConvertToLLVMPattern { explicit ParallelOpConversion(MLIRContext *context, LLVMTypeConverter &typeConverter) @@ -39,6 +111,11 @@ }; } // namespace +void CaptureOpenMPParametersPass::runOnOperation() { + ModuleOp module = getOperation(); + module.walk([&](omp::ParallelOp op) { captureOmpParallelParams(op); }); +} + void mlir::populateOpenMPToLLVMConversionPatterns( MLIRContext *context, LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { @@ -71,6 +148,12 @@ signalPassFailure(); } +/// Create a pass to convert OpenMP operations to the LLVMIR dialect. +std::unique_ptr> +mlir::createCaptureOpenMPParallelParametersPass() { + return std::make_unique(); +} + std::unique_ptr> mlir::createConvertOpenMPToLLVMPass() { return std::make_unique(); } diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -17,11 +17,13 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" #include "mlir/IR/Module.h" #include "mlir/IR/RegionGraphTraits.h" #include "mlir/IR/StandardTypes.h" #include "mlir/Support/LLVM.h" #include "mlir/Target/LLVMIR/TypeTranslation.h" +#include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/ADT/PostOrderIterator.h" diff --git a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir --- a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir +++ b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -convert-openmp-to-llvm %s -split-input-file | FileCheck %s +// RUN: mlir-opt -convert-openmp-to-llvm -capture-openmp-parallel-parameters %s -split-input-file | FileCheck %s // CHECK-LABEL: llvm.func @branch_loop func @branch_loop() { @@ -6,7 +6,7 @@ %end = constant 0 : index // CHECK: omp.parallel omp.parallel { - // CHECK-NEXT: llvm.br ^[[BB1:.*]](%{{[0-9]+}}, %{{[0-9]+}} : !llvm.i64, !llvm.i64 + // CHECK: llvm.br ^[[BB1:.*]](%{{[0-9]+}}, %{{[0-9]+}} : !llvm.i64, !llvm.i64 br ^bb1(%start, %end : index, index) // CHECK-NEXT: ^[[BB1]](%[[ARG1:[0-9]+]]: !llvm.i64, %[[ARG2:[0-9]+]]: !llvm.i64):{{.*}} ^bb1(%0: index, %1: index): diff --git a/mlir/test/Conversion/OpenMPToLLVM/openmp_float-parallel_param.mlir b/mlir/test/Conversion/OpenMPToLLVM/openmp_float-parallel_param.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/OpenMPToLLVM/openmp_float-parallel_param.mlir @@ -0,0 +1,76 @@ +// RUN: mlir-opt %s -convert-openmp-to-llvm -capture-openmp-parallel-parameters | FileCheck %s + +module { + llvm.func @malloc(!llvm.i64) -> !llvm.ptr + llvm.func @print_memref_f32(%arg0: !llvm.i64, %arg1: !llvm.ptr) { + %0 = llvm.mlir.undef : !llvm.struct<(i64, ptr)> + %1 = llvm.insertvalue %arg0, %0[0] : !llvm.struct<(i64, ptr)> + %2 = llvm.insertvalue %arg1, %1[1] : !llvm.struct<(i64, ptr)> + %3 = llvm.mlir.constant(1 : index) : !llvm.i64 + %4 = llvm.alloca %3 x !llvm.struct<(i64, ptr)> : (!llvm.i64) -> !llvm.ptr)>> + llvm.store %2, %4 : !llvm.ptr)>> + llvm.call @_mlir_ciface_print_memref_f32(%4) : (!llvm.ptr)>>) -> () + llvm.return + } + llvm.func @_mlir_ciface_print_memref_f32(!llvm.ptr)>>) + llvm.func @plaidml_rt_thread_num() -> !llvm.i64 { + %0 = llvm.call @_mlir_ciface_plaidml_rt_thread_num() : () -> !llvm.i64 + llvm.return %0 : !llvm.i64 + } + llvm.func @_mlir_ciface_plaidml_rt_thread_num() -> !llvm.i64 + llvm.func @main() { + %0 = llvm.mlir.constant(4 : index) : !llvm.i64 + %1 = llvm.mlir.constant(4 : index) : !llvm.i64 + %2 = llvm.mlir.null : !llvm.ptr + %3 = llvm.mlir.constant(1 : index) : !llvm.i64 + %4 = llvm.getelementptr %2[%3] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr + %5 = llvm.ptrtoint %4 : !llvm.ptr to !llvm.i64 + %6 = llvm.mul %1, %5 : !llvm.i64 + %7 = llvm.call @malloc(%6) : (!llvm.i64) -> !llvm.ptr + %8 = llvm.bitcast %7 : !llvm.ptr to !llvm.ptr + %9 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %10 = llvm.insertvalue %8, %9[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %11 = llvm.insertvalue %8, %10[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %12 = llvm.mlir.constant(0 : index) : !llvm.i64 + %13 = llvm.insertvalue %12, %11[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %14 = llvm.mlir.constant(1 : index) : !llvm.i64 + %15 = llvm.insertvalue %1, %13[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %16 = llvm.insertvalue %14, %15[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %17 = llvm.mlir.constant(4.200000e+01 : f32) : !llvm.float + // CHECK: %{{.*}} = llvm.alloca %{{.*}} !llvm.struct<(struct<({{.*}})>, float)> : (!llvm.i64) -> !llvm.ptr, float)>> + // CHECK-NEXT %{{.*}} = insertvalue { {{.*}}, float } undef, {{.*}}, 0 + // CHECK-NEXT %{{.*}} = insertvalue { {{.*}}, float } %{{.*}}, float 4.200000e+01, 1 + // CHECK-NEXT store { {{.*}}, float } %{{.*}}, { {{.*}}, float }* %{{.*}} + omp.parallel num_threads(%0 : !llvm.i64) { + // CHECK: %{{.*}} = llvm.load %{{.*}} : !llvm.ptr, float)>> + // CHECK-NEXT: %{{.*}} = llvm.extractvalue %{{.*}}[0] : !llvm.struct<(struct<({{.*}})>, float)> + // CHECK-NEXT: %{{.*}} = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(struct<({{.*}})>, float)> + %27 = llvm.call @plaidml_rt_thread_num() : () -> !llvm.i64 + %28 = llvm.extractvalue %16[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %29 = llvm.mlir.constant(0 : index) : !llvm.i64 + %30 = llvm.mlir.constant(1 : index) : !llvm.i64 + %31 = llvm.mul %27, %30 : !llvm.i64 + %32 = llvm.add %29, %31 : !llvm.i64 + %33 = llvm.getelementptr %28[%32] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr + llvm.store %17, %33 : !llvm.ptr + omp.terminator + } + %18 = llvm.mlir.constant(1 : index) : !llvm.i64 + %19 = llvm.alloca %18 x !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> : (!llvm.i64) -> !llvm.ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>> + llvm.store %16, %19 : !llvm.ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>> + %20 = llvm.bitcast %19 : !llvm.ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>> to !llvm.ptr + %21 = llvm.mlir.constant(1 : i64) : !llvm.i64 + %22 = llvm.mlir.undef : !llvm.struct<(i64, ptr)> + %23 = llvm.insertvalue %21, %22[0] : !llvm.struct<(i64, ptr)> + %24 = llvm.insertvalue %20, %23[1] : !llvm.struct<(i64, ptr)> + %25 = llvm.extractvalue %24[0] : !llvm.struct<(i64, ptr)> + %26 = llvm.extractvalue %24[1] : !llvm.struct<(i64, ptr)> + llvm.call @print_memref_f32(%25, %26) : (!llvm.i64, !llvm.ptr) -> () + llvm.return + } + llvm.func @_mlir_ciface_main() { + llvm.call @main() : () -> () + llvm.return + } +} +