diff --git a/mlir/include/mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h b/mlir/include/mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h --- a/mlir/include/mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h +++ b/mlir/include/mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h @@ -26,6 +26,10 @@ /// Create a pass to convert OpenMP operations to the LLVMIR dialect. std::unique_ptr> createConvertOpenMPToLLVMPass(); +/// Create a pass to capture "above-defined" omp::ParallelOp parameters. +std::unique_ptr> +createCaptureOpenMPParallelParametersPass(); + } // namespace mlir #endif // MLIR_CONVERSION_OPENMPTOLLVM_OPENMPTOLLVM_H_ diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -220,6 +220,19 @@ let dependentDialects = ["LLVM::LLVMDialect"]; } +//===----------------------------------------------------------------------===// +// CaptureOpenMPParallelParameters +//===----------------------------------------------------------------------===// + +def CaptureOpenMPParallelParameters : + Pass<"capture-openmp-parallel-parameters", "ModuleOp"> { + let summary = "Captures the omp::ParallelOp 'above-defined' parameters." + "This pass operates on the output of convert-openmp-to-llvm" + "pass - mlir LLVM dialect IR."; + let constructor = "mlir::createCaptureOpenMPParallelParametersPass()"; + let dependentDialects = ["LLVM::LLVMDialect"]; +} + //===----------------------------------------------------------------------===// // PDLToPDLInterp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp --- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp +++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp @@ -12,10 +12,80 @@ #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" +#include "mlir/Transforms/RegionUtils.h" + +#include "llvm/ADT/SetVector.h" using namespace mlir; namespace { +/// Pass to capture the parameters to omp::ParallelOp +struct CaptureOpenMPParameters + : public CaptureOpenMPParallelParametersBase { + void runOnOperation() override; + CaptureOpenMPParameters() = default; + +private: + // Captures "above-defined" parameters passed to omp::ParaqllelOp and wraps + // them in a struct to make sure the varargs are passed properly to the + // synthetically generated fork function. + void captureOmpParallelParams(omp::ParallelOp parOp) const { + OpBuilder builder{parOp.getContext()}; + llvm::SetVector values; + + visitUsedValuesDefinedAbove({parOp.region()}, [&](OpOperand *opOperand) { + Value value = opOperand->get(); + + // If it's not an LLVM type, or if it's an LLVM pointer type, we + // don't need or want to capture this value in via a struct. + LLVM::LLVMType llvmType = value.getType().dyn_cast(); + if (!llvmType || llvmType.isPointerTy()) + return; + + // Otherwise, we need to capture the value through an alloca'd + // struct. + values.insert(value); + }); + + if (!values.size()) + return; + + // Build the structure. + builder.setInsertionPoint(parOp); + LLVM::LLVMType structTy; + { + SmallVector types; + for (auto val : values) + types.push_back(val.getType().cast()); + structTy = LLVM::LLVMType::getStructTy(parOp.getContext(), types); + } + LLVM::LLVMType structPtrTy = structTy.getPointerTo(); + auto numElements = builder.create( + parOp.getLoc(), LLVM::LLVMType::getInt64Ty(parOp.getContext()), + builder.getIndexAttr(1)); + auto structPtr = builder.create(parOp.getLoc(), structPtrTy, + numElements, 0); + Value srcStructVal = + builder.create(parOp.getLoc(), structTy); + for (auto srcIdx : llvm::enumerate(values)) { + srcStructVal = builder.create( + parOp.getLoc(), srcStructVal, srcIdx.value(), + builder.getI64ArrayAttr(srcIdx.index())); + } + builder.create(parOp.getLoc(), srcStructVal, structPtr); + + // Unpack the structure, rewriting the affected values. + builder.setInsertionPointToStart(&parOp.region().front()); + auto dstStructVal = builder.create(parOp.getLoc(), structPtr); + for (auto srcIdx : llvm::enumerate(values)) { + auto capturedValue = builder.create( + parOp.getLoc(), srcIdx.value().getType(), dstStructVal, + builder.getI64ArrayAttr(srcIdx.index())); + replaceAllUsesInRegionWith(srcIdx.value(), capturedValue, parOp.region()); + } + } +}; + struct ParallelOpConversion : public ConvertToLLVMPattern { explicit ParallelOpConversion(MLIRContext *context, LLVMTypeConverter &typeConverter) @@ -39,6 +109,11 @@ }; } // namespace +void CaptureOpenMPParameters::runOnOperation() { + ModuleOp module = getOperation(); + module.walk([&](omp::ParallelOp op) { captureOmpParallelParams(op); }); +} + void mlir::populateOpenMPToLLVMConversionPatterns( MLIRContext *context, LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { @@ -71,6 +146,12 @@ signalPassFailure(); } +/// Create a pass to convert OpenMP operations to the LLVMIR dialect. +std::unique_ptr> +mlir::createCaptureOpenMPParallelParametersPass() { + return std::make_unique(); +} + std::unique_ptr> mlir::createConvertOpenMPToLLVMPass() { return std::make_unique(); } diff --git a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir --- a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir +++ b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -convert-openmp-to-llvm %s -split-input-file | FileCheck %s +// RUN: mlir-opt -convert-openmp-to-llvm -capture-openmp-parallel-parameters %s -split-input-file | FileCheck %s // CHECK-LABEL: llvm.func @branch_loop func @branch_loop() { @@ -6,7 +6,7 @@ %end = constant 0 : index // CHECK: omp.parallel omp.parallel { - // CHECK-NEXT: llvm.br ^[[BB1:.*]](%{{[0-9]+}}, %{{[0-9]+}} : !llvm.i64, !llvm.i64 + // CHECK: llvm.br ^[[BB1:.*]](%{{[0-9]+}}, %{{[0-9]+}} : !llvm.i64, !llvm.i64 br ^bb1(%start, %end : index, index) // CHECK-NEXT: ^[[BB1]](%[[ARG1:[0-9]+]]: !llvm.i64, %[[ARG2:[0-9]+]]: !llvm.i64):{{.*}} ^bb1(%0: index, %1: index): diff --git a/mlir/test/Conversion/OpenMPToLLVM/openmp_float-parallel_param.mlir b/mlir/test/Conversion/OpenMPToLLVM/openmp_float-parallel_param.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/OpenMPToLLVM/openmp_float-parallel_param.mlir @@ -0,0 +1,26 @@ +// RUN: mlir-opt %s -convert-openmp-to-llvm -capture-openmp-parallel-parameters | FileCheck %s + +func @print_memref_f32(memref<*xf32>) +func @plaidml_rt_thread_num() -> index + +func @main() { + %num_threads = constant 4 : index + %B = alloc() : memref<4xf32> + %cf0 = constant 42.0 : f32 + // CHECK: %{{.*}} = llvm.alloca %{{.*}} x !llvm.struct<(struct<({{.*}})>, float)> : (!llvm.i64) -> !llvm.ptr, float)>> + // CHECK: %{{.*}} = llvm.mlir.undef : !llvm.struct<(struct<({{.*}}>)>, float)> + // CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(struct<({{.*}})>, float)> + // CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(struct<({{.*}})>, float)> + // CHECK: llvm.store %{{.*}}, %{{.*}} : !llvm.ptr, float)>> + omp.parallel num_threads(%num_threads : index) { + // CHECK: %{{.*}} = llvm.load %{{.*}} : !llvm.ptr, float)>> + // CHECK: %{{.*}} = llvm.extractvalue %{{.*}}[0] : !llvm.struct<(struct<({{.*}})>, float)> + // CHECK: %{{.*}} = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(struct<({{.*}})>, float)> + %tid = call @plaidml_rt_thread_num() : () -> index + store %cf0, %B[%tid] : memref<4xf32> + omp.terminator + } + %B_u = memref_cast %B : memref<4xf32> to memref<*xf32> + call @print_memref_f32(%B_u) : (memref<*xf32>) -> () + return +}