diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h --- a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h +++ b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h @@ -111,6 +111,7 @@ LogicalResult convertGlobals(); LogicalResult convertOneFunction(LLVMFuncOp func); LogicalResult convertBlock(Block &bb, bool ignoreArguments); + void captureOmpParallelParams(LLVMFuncOp func); llvm::Constant *getLLVMConstant(llvm::Type *llvmType, Attribute attr, Location loc); diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -17,11 +17,13 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" #include "mlir/IR/Module.h" #include "mlir/IR/RegionGraphTraits.h" #include "mlir/IR/StandardTypes.h" #include "mlir/Support/LLVM.h" #include "mlir/Target/LLVMIR/TypeTranslation.h" +#include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/ADT/PostOrderIterator.h" @@ -806,7 +808,74 @@ return success(); } +// Captures "above-defined" parameters passed to omp::ParaqllelOp and wraps +// them in a struct to make sure the varargs are passed properly to the +// synthetically generated fork function. +void ModuleTranslation::captureOmpParallelParams(LLVMFuncOp func) { + OpBuilder builder{func.getContext()}; + func.walk([&](omp::ParallelOp parOp) { + llvm::SetVector values; + + visitUsedValuesDefinedAbove({parOp.region()}, [&](OpOperand *opOperand) { + auto value = opOperand->get(); + + // If it's not an LLVM type, or if it's an LLVM pointer type, we + // don't need or want to smuggle this value in via a struct. + auto llvmType = value.getType().dyn_cast(); + if (!llvmType || llvmType.isPointerTy()) + return; + + // Otherwise, we need to smuggle the value through an alloca'd + // struct. + values.insert(value); + }); + + if (!values.size()) + return; + + // Build the structure. + builder.setInsertionPoint(parOp); + LLVM::LLVMType structTy; + { + SmallVector types; + for (auto val : values) { + types.push_back(val.getType().cast()); + } + structTy = LLVM::LLVMType::getStructTy(func.getContext(), types); + } + auto structPtrTy = structTy.getPointerTo(); + auto numElements = builder.create( + parOp.getLoc(), LLVM::LLVMType::getInt64Ty(func.getContext()), + builder.getIndexAttr(1)); + auto structPtr = builder.create(parOp.getLoc(), structPtrTy, + numElements, 0); + Value srcStructVal = + builder.create(parOp.getLoc(), structTy); + for (auto srcIdx : llvm::enumerate(values)) { + srcStructVal = builder.create( + parOp.getLoc(), srcStructVal, srcIdx.value(), + builder.getI64ArrayAttr(srcIdx.index())); + } + builder.create(parOp.getLoc(), srcStructVal, structPtr); + + // Unpack the structure, rewriting the affected values. + builder.setInsertionPointToStart(&parOp.region().front()); + auto dstStructVal = builder.create(parOp.getLoc(), structPtr); + for (auto srcIdx : llvm::enumerate(values)) { + auto capturedValue = builder.create( + parOp.getLoc(), srcIdx.value().getType(), dstStructVal, + builder.getI64ArrayAttr(srcIdx.index())); + replaceAllUsesInRegionWith(srcIdx.value(), capturedValue, parOp.region()); + } + }); +} + LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) { + // Capture and pass the parameters for the omp::ParallelOp. + // If we don't do this the SSE(UP) passed parameters will not be passed + // properly. + captureOmpParallelParams(func); + // Clear the block and value mappings, they are only relevant within one // function. blockMapping.clear(); diff --git a/mlir/test/Target/openmp_float-parallel_param.mlir b/mlir/test/Target/openmp_float-parallel_param.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Target/openmp_float-parallel_param.mlir @@ -0,0 +1,78 @@ +// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s + +module { + llvm.func @malloc(!llvm.i64) -> !llvm.ptr + llvm.func @print_memref_f32(%arg0: !llvm.i64, %arg1: !llvm.ptr) { + %0 = llvm.mlir.undef : !llvm.struct<(i64, ptr)> + %1 = llvm.insertvalue %arg0, %0[0] : !llvm.struct<(i64, ptr)> + %2 = llvm.insertvalue %arg1, %1[1] : !llvm.struct<(i64, ptr)> + %3 = llvm.mlir.constant(1 : index) : !llvm.i64 + %4 = llvm.alloca %3 x !llvm.struct<(i64, ptr)> : (!llvm.i64) -> !llvm.ptr)>> + llvm.store %2, %4 : !llvm.ptr)>> + llvm.call @_mlir_ciface_print_memref_f32(%4) : (!llvm.ptr)>>) -> () + llvm.return + } + llvm.func @_mlir_ciface_print_memref_f32(!llvm.ptr)>>) + llvm.func @plaidml_rt_thread_num() -> !llvm.i64 { + %0 = llvm.call @_mlir_ciface_plaidml_rt_thread_num() : () -> !llvm.i64 + llvm.return %0 : !llvm.i64 + } + llvm.func @_mlir_ciface_plaidml_rt_thread_num() -> !llvm.i64 + llvm.func @main() { + %0 = llvm.mlir.constant(4 : index) : !llvm.i64 + %1 = llvm.mlir.constant(4 : index) : !llvm.i64 + %2 = llvm.mlir.null : !llvm.ptr + %3 = llvm.mlir.constant(1 : index) : !llvm.i64 + %4 = llvm.getelementptr %2[%3] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr + %5 = llvm.ptrtoint %4 : !llvm.ptr to !llvm.i64 + %6 = llvm.mul %1, %5 : !llvm.i64 + %7 = llvm.call @malloc(%6) : (!llvm.i64) -> !llvm.ptr + %8 = llvm.bitcast %7 : !llvm.ptr to !llvm.ptr + %9 = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %10 = llvm.insertvalue %8, %9[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %11 = llvm.insertvalue %8, %10[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %12 = llvm.mlir.constant(0 : index) : !llvm.i64 + %13 = llvm.insertvalue %12, %11[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %14 = llvm.mlir.constant(1 : index) : !llvm.i64 + %15 = llvm.insertvalue %1, %13[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %16 = llvm.insertvalue %14, %15[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %17 = llvm.mlir.constant(4.200000e+01 : f32) : !llvm.float + // CHECK: %{{.*}} = alloca { {{.*}}, float }, i64 1 + // CHECK-NEXT %{{.*}} = insertvalue { {{.*}}, float } undef, {{.*}}, 0 + // CHECK-NEXT %{{.*}} = insertvalue { {{.*}}, float } %{{.*}}, float 4.200000e+01, 1 + // CHECK-NEXT store { {{.*}}, float } %{{.*}}, { {{.*}}, float }* %{{.*}} + omp.parallel num_threads(%0 : !llvm.i64) { + // CHECK: omp.par.region: + // CHECK-NEXT br label %omp.par.region1 + // CHECK-NEXT omp.par.region1: + // CHECK-NEXT %{{.*}} = load { {{.*}}, float }, { {{.*}}, float }* %{{.*}} + // CHECK-NEXT %{{.*}} = extractvalue { {{.*}}, float } %{{.*}}, 0 + // CHECK-NEXT %{{.*}} = extractvalue { {{.*}}, float } %{{.*}}, 1 + %27 = llvm.call @plaidml_rt_thread_num() : () -> !llvm.i64 + %28 = llvm.extractvalue %16[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %29 = llvm.mlir.constant(0 : index) : !llvm.i64 + %30 = llvm.mlir.constant(1 : index) : !llvm.i64 + %31 = llvm.mul %27, %30 : !llvm.i64 + %32 = llvm.add %29, %31 : !llvm.i64 + %33 = llvm.getelementptr %28[%32] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr + llvm.store %17, %33 : !llvm.ptr + omp.terminator + } + %18 = llvm.mlir.constant(1 : index) : !llvm.i64 + %19 = llvm.alloca %18 x !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> : (!llvm.i64) -> !llvm.ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>> + llvm.store %16, %19 : !llvm.ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>> + %20 = llvm.bitcast %19 : !llvm.ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>> to !llvm.ptr + %21 = llvm.mlir.constant(1 : i64) : !llvm.i64 + %22 = llvm.mlir.undef : !llvm.struct<(i64, ptr)> + %23 = llvm.insertvalue %21, %22[0] : !llvm.struct<(i64, ptr)> + %24 = llvm.insertvalue %20, %23[1] : !llvm.struct<(i64, ptr)> + %25 = llvm.extractvalue %24[0] : !llvm.struct<(i64, ptr)> + %26 = llvm.extractvalue %24[1] : !llvm.struct<(i64, ptr)> + llvm.call @print_memref_f32(%25, %26) : (!llvm.i64, !llvm.ptr) -> () + llvm.return + } + llvm.func @_mlir_ciface_main() { + llvm.call @main() : () -> () + llvm.return + } +}