diff --git a/flang/test/Lower/OpenMP/flush.f90 b/flang/test/Lower/OpenMP/flush.f90 --- a/flang/test/Lower/OpenMP/flush.f90 +++ b/flang/test/Lower/OpenMP/flush.f90 @@ -1,14 +1,16 @@ ! This test checks lowering of OpenMP Flush Directive. !RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s --check-prefixes="FIRDialect,OMPDialect" -!RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | fir-opt --cfg-conversion | fir-opt --fir-to-llvm-ir | FileCheck %s --check-prefixes="OMPDialect" +!RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | fir-opt --cfg-conversion | fir-opt --fir-to-llvm-ir | FileCheck %s --check-prefixes="LLVMIRDialect,OMPDialect" subroutine flush_standalone(a, b, c) integer, intent(inout) :: a, b, c !$omp flush(a,b,c) !$omp flush -!OMPDialect: omp.flush(%{{.*}}, %{{.*}}, %{{.*}} : !fir.ref, !fir.ref, !fir.ref) +!OMPDialect: omp.flush(%{{.*}}, %{{.*}}, %{{.*}} : +!FIRDialect: !fir.ref, !fir.ref, !fir.ref) +!LLVMIRDialect: !llvm.ptr, !llvm.ptr, !llvm.ptr) !OMPDialect: omp.flush end subroutine flush_standalone @@ -19,7 +21,9 @@ !$omp parallel !OMPDialect: omp.parallel { -!OMPDialect: omp.flush(%{{.*}}, %{{.*}}, %{{.*}} : !fir.ref, !fir.ref, !fir.ref) +!OMPDialect: omp.flush(%{{.*}}, %{{.*}}, %{{.*}} : +!FIRDialect: !fir.ref, !fir.ref, !fir.ref) +!LLVMIRDialect: !llvm.ptr, !llvm.ptr, !llvm.ptr) !OMPDialect: omp.flush !$omp flush(a,b,c) !$omp flush diff --git a/flang/test/Lower/OpenMP/parallel-sections.f90 b/flang/test/Lower/OpenMP/parallel-sections.f90 --- a/flang/test/Lower/OpenMP/parallel-sections.f90 +++ b/flang/test/Lower/OpenMP/parallel-sections.f90 @@ -41,7 +41,9 @@ !FIRDialect: %[[allocator:.*]] = arith.constant 1 : i32 !LLVMDialect: %[[allocator:.*]] = llvm.mlir.constant(1 : i32) : i32 !OMPDialect: omp.parallel { - !OMPDialect: omp.sections allocate(%[[allocator]] : i32 -> %{{.*}} : !fir.ref) { + !OMPDialect: omp.sections allocate( + !FIRDialect: %[[allocator]] : i32 -> %{{.*}} : !fir.ref) { + !LLVMDialect: %[[allocator]] : i32 -> %{{.*}} : !llvm.ptr) { !$omp parallel sections allocate(omp_high_bw_mem_alloc: x) !OMPDialect: omp.section { !$omp section diff --git a/flang/test/Lower/OpenMP/parallel.f90 b/flang/test/Lower/OpenMP/parallel.f90 --- a/flang/test/Lower/OpenMP/parallel.f90 +++ b/flang/test/Lower/OpenMP/parallel.f90 @@ -1,5 +1,5 @@ !RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s --check-prefixes="FIRDialect,OMPDialect" -!RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | fir-opt --fir-to-llvm-ir | FileCheck %s --check-prefixes="OMPDialect" +!RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | fir-opt --fir-to-llvm-ir | FileCheck %s --check-prefixes="LLVMDialect,OMPDialect" !FIRDialect-LABEL: func @_QPparallel_simple subroutine parallel_simple() @@ -152,7 +152,10 @@ subroutine parallel_allocate() use omp_lib integer :: x - !OMPDialect: omp.parallel allocate(%{{.+}} : i32 -> %{{.+}} : !fir.ref) { + !OMPDialect: omp.parallel allocate( + !FIRDialect: %{{.+}} : i32 -> %{{.+}} : !fir.ref + !LLVMDialect: %{{.+}} : i32 -> %{{.+}} : !llvm.ptr + !OMPDialect: ) { !$omp parallel allocate(omp_high_bw_mem_alloc: x) private(x) !FIRDialect: arith.addi x = x + 12 @@ -191,7 +194,10 @@ !OMPDialect: omp.terminator !$omp end parallel - !OMPDialect: omp.parallel if({{.*}} : i1) num_threads({{.*}} : i32) allocate(%{{.+}} : i32 -> %{{.+}} : !fir.ref) { + !OMPDialect: omp.parallel if({{.*}} : i1) num_threads({{.*}} : i32) allocate( + !FIRDialect: %{{.+}} : i32 -> %{{.+}} : !fir.ref + !LLVMDialect: %{{.+}} : i32 -> %{{.+}} : !llvm.ptr + !OMPDialect: ) { !$omp parallel num_threads(num_threads) if(alpha .le. 0) allocate(omp_high_bw_mem_alloc: alpha) private(alpha) !FIRDialect: fir.call call f3() diff --git a/flang/test/Lower/OpenMP/single.f90 b/flang/test/Lower/OpenMP/single.f90 --- a/flang/test/Lower/OpenMP/single.f90 +++ b/flang/test/Lower/OpenMP/single.f90 @@ -1,5 +1,5 @@ !RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s --check-prefixes="FIRDialect,OMPDialect" -!RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | fir-opt --cfg-conversion | fir-opt --fir-to-llvm-ir | FileCheck %s --check-prefixes="OMPDialect" +!RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | fir-opt --cfg-conversion | fir-opt --fir-to-llvm-ir | FileCheck %s --check-prefixes="LLVMDialect,OMPDialect" !=============================================================================== ! Single construct @@ -55,7 +55,10 @@ integer :: x !OMPDialect: omp.parallel { !$omp parallel - !OMPDialect: omp.single allocate(%{{.+}} : i32 -> %{{.+}} : !fir.ref) { + !OMPDialect: omp.single allocate( + !FIRDialect: %{{.+}} : i32 -> %{{.+}} : !fir.ref + !LLVMDialect: %{{.+}} : i32 -> %{{.+}} : !llvm.ptr + !OMPDialect: ) { !$omp single allocate(omp_high_bw_mem_alloc: x) private(x) !FIRDialect: arith.addi x = x + 12 diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -574,9 +574,19 @@ specified or implied. }]; - let arguments = (ins Variadic:$varList); + let arguments = (ins Variadic:$varList); let assemblyFormat = [{ ( `(` $varList^ `:` type($varList) `)` )? attr-dict}]; + let extraClassDeclaration = [{ + /// The number of variable operands. + unsigned getNumVariableOperands() { + return getOperation()->getNumOperands(); + } + /// The i-th variable operand passed. + Value getVariableOperand(unsigned i) { + return getOperand(i); + } + }]; } //===----------------------------------------------------------------------===// // 2.14.5 target construct diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp --- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp +++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp @@ -47,7 +47,8 @@ }; template -struct RegionLessOpConversion : public ConvertOpToLLVMPattern { +struct RegionLessOpWithVarOperandsConversion + : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(T curOp, typename T::Adaptor adaptor, @@ -57,6 +58,9 @@ if (failed(converter->convertTypes(curOp->getResultTypes(), resTypes))) return failure(); SmallVector convertedOperands; + assert(curOp.getNumVariableOperands() == + curOp.getOperation()->getNumOperands() && + "unexpected non-variable operands"); for (unsigned idx = 0; idx < curOp.getNumVariableOperands(); ++idx) { Value originalVariableOperand = curOp.getVariableOperand(idx); if (!originalVariableOperand) @@ -78,23 +82,31 @@ void mlir::configureOpenMPToLLVMConversionLegality( ConversionTarget &target, LLVMTypeConverter &typeConverter) { target.addDynamicallyLegalOp( - [&](Operation *op) { return typeConverter.isLegal(&op->getRegion(0)); }); + mlir::omp::MasterOp, mlir::omp::SectionsOp, + mlir::omp::SingleOp>([&](Operation *op) { + return typeConverter.isLegal(&op->getRegion(0)) && + typeConverter.isLegal(op->getOperandTypes()) && + typeConverter.isLegal(op->getResultTypes()); + }); target .addDynamicallyLegalOp([&](Operation *op) { - return typeConverter.isLegal(op->getOperandTypes()); - }); + mlir::omp::FlushOp, mlir::omp::ThreadprivateOp>( + [&](Operation *op) { + return typeConverter.isLegal(op->getOperandTypes()) && + typeConverter.isLegal(op->getResultTypes()); + }); } void mlir::populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns) { - patterns.add, - RegionOpConversion, - RegionOpConversion, - RegionLessOpConversion, - RegionLessOpConversion, - RegionLessOpConversion>(converter); + patterns.add< + RegionOpConversion, RegionOpConversion, + RegionOpConversion, RegionOpConversion, + RegionOpConversion, + RegionLessOpWithVarOperandsConversion, + RegionLessOpWithVarOperandsConversion, + RegionLessOpWithVarOperandsConversion, + RegionLessOpWithVarOperandsConversion>(converter); } namespace {