diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp --- a/flang/lib/Lower/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP.cpp @@ -15,6 +15,7 @@ #include "flang/Lower/FIRBuilder.h" #include "flang/Lower/PFTBuilder.h" #include "flang/Parser/parse-tree.h" +#include "flang/Semantics/tools.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "llvm/Frontend/OpenMP/OMPConstants.h" @@ -87,19 +88,34 @@ auto &firOpBuilder = absConv.getFirOpBuilder(); auto currentLocation = absConv.getCurrentLocation(); auto insertPt = firOpBuilder.saveInsertionPoint(); + + // Clauses. + // FIXME: Add support for other clauses. + mlir::Value numThreads; + + const auto ¶llelOpClauseList = + std::get(blockDirective.t); + for (const auto &clause : parallelOpClauseList.v) { + if (const auto &numThreadsClause = + std::get_if(&clause.u)) { + // OMPIRBuilder expects `NUM_THREAD` clause as a `Value`. + numThreads = absConv.genExprValue( + *Fortran::semantics::GetExpr(numThreadsClause->v)); + } + } llvm::ArrayRef argTy; - mlir::ValueRange range; - llvm::SmallVector operandSegmentSizes(6 /*Size=*/, - 0 /*Value=*/); - // create and insert the operation. + Attribute defaultValue, procBindValue; + // Create and insert the operation. + // Create the Op with empty ranges for clauses that are yet to be lowered. auto parallelOp = firOpBuilder.create( - currentLocation, argTy, range); - parallelOp.setAttr(mlir::omp::ParallelOp::getOperandSegmentSizeAttr(), - firOpBuilder.getI32VectorAttr(operandSegmentSizes)); - parallelOp.getRegion().push_back(new Block{}); + currentLocation, argTy, Value(), numThreads, + defaultValue.dyn_cast_or_null(), ValueRange(), ValueRange(), + ValueRange(), ValueRange(), + procBindValue.dyn_cast_or_null()); + firOpBuilder.createBlock(¶llelOp.getRegion()); auto &block = parallelOp.getRegion().back(); firOpBuilder.setInsertionPointToStart(&block); - // ensure the block is well-formed. + // Ensure the block is well-formed. firOpBuilder.create(currentLocation); firOpBuilder.restoreInsertionPoint(insertPt); } diff --git a/flang/unittests/Lower/OpenMPLoweringTest.cpp b/flang/unittests/Lower/OpenMPLoweringTest.cpp --- a/flang/unittests/Lower/OpenMPLoweringTest.cpp +++ b/flang/unittests/Lower/OpenMPLoweringTest.cpp @@ -8,6 +8,7 @@ #include "gtest/gtest.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" +#include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/InitAllDialects.h" #include "flang/Parser/parse-tree.h" @@ -16,6 +17,7 @@ class OpenMPLoweringTest : public testing::Test { protected: void SetUp() override { + ctx.loadDialect(); ctx.loadDialect(); mlir::registerAllDialects(ctx.getDialectRegistry()); mlirOpBuilder.reset(new mlir::OpBuilder(&ctx)); @@ -81,22 +83,31 @@ // OpenMPDialect. EXPECT_EQ(parallelDirective.v, llvm::omp::Directive::OMPD_parallel); auto insertPt = mlirOpBuilder->saveInsertionPoint(); + mlir::Value numThreads; + mlir::Attribute defaultValue, procBindValue; + + // Construct a dummy value for `NUM_THREAD` clause. + numThreads = mlirOpBuilder->create( + mlirOpBuilder->getUnknownLoc(), mlirOpBuilder->getIntegerType(32), + mlirOpBuilder->getIntegerAttr(mlirOpBuilder->getIntegerType(32), 4)); + llvm::ArrayRef argTy; - mlir::ValueRange range; - llvm::SmallVector operandSegmentSizes(6 /*Size=*/, 0 /*Value=*/); - // create and insert the operation. + // Create and insert the operation. auto parallelOp = mlirOpBuilder->create( - mlirOpBuilder->getUnknownLoc(), argTy, range); - parallelOp.setAttr(mlir::omp::ParallelOp::getOperandSegmentSizeAttr(), - mlirOpBuilder->getI32VectorAttr(operandSegmentSizes)); + mlirOpBuilder->getUnknownLoc(), argTy, mlir::Value(), numThreads, + defaultValue.dyn_cast_or_null(), mlir::ValueRange(), + mlir::ValueRange(), mlir::ValueRange(), mlir::ValueRange(), + procBindValue.dyn_cast_or_null()); parallelOp.getRegion().push_back(new mlir::Block{}); auto &block = parallelOp.getRegion().back(); mlirOpBuilder->setInsertionPointToStart(&block); - // ensure the block is well-formed. + // Ensure the block is well-formed. mlirOpBuilder->create( mlirOpBuilder->getUnknownLoc()); mlirOpBuilder->restoreInsertionPoint(insertPt); + EXPECT_EQ(succeeded(parallelOp.verify()), true); + EXPECT_EQ(parallelOp.num_threads_var(), numThreads); } // main() from gtest_main