diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -76,12 +76,19 @@ Variadic:$firstprivate_vars, Variadic:$shared_vars, Variadic:$copyin_vars, - OptionalAttr:$proc_bind_val); + UnitAttr:$procBindValueMaster, + UnitAttr:$procBindValueClose, + UnitAttr:$procBindValueSpread); let regions = (region AnyRegion:$region); let parser = [{ return parseParallelOp(parser, result); }]; let printer = [{ return printParallelOp(p, *this); }]; + let extraClassDeclaration = [{ + static StringRef getProcBindValueMasterAttrName() { return "master"; } + static StringRef getProcBindValueCloseAttrName() { return "close"; } + static StringRef getProcBindValueSpreadAttrName() { return "spread"; } + }]; } def TerminatorOp : OpenMP_Op<"terminator", [Terminator]> { diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -92,8 +92,12 @@ if (auto def = op.default_val()) p << " default(" << def->drop_front(3) << ")"; - if (auto bind = op.proc_bind_val()) - p << " proc_bind(" << bind << ")"; + if (op.procBindValueMaster()) + p << " proc_bind(" << op.getProcBindValueMasterAttrName() << ")"; + if (op.procBindValueClose()) + p << " proc_bind(" << op.getProcBindValueCloseAttrName() << ")"; + if (op.procBindValueSpread()) + p << " proc_bind(" << op.getProcBindValueSpreadAttrName() << ")"; p.printRegion(op.getRegion()); } @@ -217,8 +221,13 @@ if (parser.parseLParen() || parser.parseKeyword(&bind) || parser.parseRParen()) return failure(); - auto attr = parser.getBuilder().getStringAttr(bind); - result.addAttribute("proc_bind_val", attr); + auto attr = parser.getBuilder().getUnitAttr(); + if (bind == "master") + result.addAttribute("procBindValueMaster", attr); + if (bind == "close") + result.addAttribute("procBindValueClose", attr); + if (bind == "spread") + result.addAttribute("procBindValueSpread", attr); } else { return parser.emitError(parser.getNameLoc()) << keyword << " is not a valid clause for the " << opName diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -443,15 +443,23 @@ // called for variables which have destructors/finalizers. auto finiCB = [&](InsertPointTy codeGenIP) {}; + auto parallelOp = cast(opInst); llvm::Value *ifCond = nullptr; - if (auto ifExprVar = cast(opInst).if_expr_var()) - ifCond = valueMapping.lookup(ifExprVar); llvm::Value *numThreads = nullptr; - if (auto numThreadsVar = cast(opInst).num_threads_var()) + if (auto ifExprVar = parallelOp.if_expr_var()) + ifCond = valueMapping.lookup(ifExprVar); + if (auto numThreadsVar = parallelOp.num_threads_var()) numThreads = valueMapping.lookup(numThreadsVar); llvm::omp::ProcBindKind pbKind = llvm::omp::OMP_PROC_BIND_default; - if (auto bind = cast(opInst).proc_bind_val()) - pbKind = llvm::omp::getProcBindKind(bind.getValue()); + if (parallelOp.procBindValueMaster()) + pbKind = + llvm::omp::getProcBindKind(parallelOp.getProcBindValueMasterAttrName()); + if (parallelOp.procBindValueClose()) + pbKind = + llvm::omp::getProcBindKind(parallelOp.getProcBindValueCloseAttrName()); + if (parallelOp.procBindValueSpread()) + pbKind = + llvm::omp::getProcBindKind(parallelOp.getProcBindValueSpreadAttrName()); // TODO: Is the Parallel construct cancellable? bool isCancellable = false; // TODO: Determine the actual alloca insertion point, e.g., the function