Index: flang/lib/Lower/OpenMP.cpp =================================================================== --- flang/lib/Lower/OpenMP.cpp +++ flang/lib/Lower/OpenMP.cpp @@ -974,8 +974,8 @@ auto taskOp = firOpBuilder.create( currentLocation, ifClauseOperand, finalClauseOperand, untiedAttr, mergeableAttr, /*in_reduction_vars=*/ValueRange(), - /*in_reductions=*/nullptr, priorityClauseOperand, allocateOperands, - allocatorOperands); + /*in_reductions=*/nullptr, priorityClauseOperand, /*depends=*/nullptr, + /*depend_vars=*/ValueRange(), allocateOperands, allocatorOperands); createBodyOfOp(taskOp, converter, currentLocation, eval, &opClauseList); } else if (blockDirective.v == llvm::omp::OMPD_taskgroup) { // TODO: Add task_reduction support Index: mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td =================================================================== --- mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -479,6 +479,26 @@ // 2.10.1 task Construct //===----------------------------------------------------------------------===// +def ClauseTaskDependIn : I32EnumAttrCase<"taskdependin", 0>; +def ClauseTaskDependOut : I32EnumAttrCase<"taskdependout", 1>; +def ClauseTaskDependInOut : I32EnumAttrCase<"taskdependinout", 2>; + +def ClauseTaskDepend : I32EnumAttr< + "ClauseTaskDepend", + "task depend clause", + [ClauseTaskDependIn, ClauseTaskDependOut, ClauseTaskDependInOut]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::omp"; +} +def ClauseTaskDependAttr : + EnumAttr { + let assemblyFormat = "`(` $value `)`"; +} +def TaskDependArrayAttr : + TypedArrayAttrBase { + let constBuilderCall = ?; + } + def TaskOp : OpenMP_Op<"task", [AttrSizedOperandSegments, OutlineableOpenMPOpInterface, AutomaticAllocationScope, ReductionClauseInterface]> { @@ -519,6 +539,10 @@ default priority-value when no priority clause is specified should be assumed to be zero (the lowest priority). + The `depends` and `depend_vars` arguments are variadic lists of values + that specify the dependencies of this particular task in relation to + other tasks. + The `allocators_vars` and `allocate_vars` arguments are a variadic list of values that specify the memory allocator to be used to obtain storage for private values. @@ -533,6 +557,8 @@ Variadic:$in_reduction_vars, OptionalAttr:$in_reductions, Optional:$priority, + OptionalAttr:$depends, + Variadic:$depend_vars, Variadic:$allocate_vars, Variadic:$allocators_vars); let regions = (region AnyRegion:$region); @@ -551,6 +577,10 @@ $allocate_vars, type($allocate_vars), $allocators_vars, type($allocators_vars) ) `)` + |`depend` `(` + custom( + $depend_vars, type($depend_vars), $depends + ) `)` ) $region attr-dict }]; let extraClassDeclaration = [{ Index: mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp =================================================================== --- mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp +++ mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp @@ -113,10 +113,10 @@ void mlir::configureOpenMPToLLVMConversionLegality( ConversionTarget &target, LLVMTypeConverter &typeConverter) { - target.addDynamicallyLegalOp([&](Operation *op) { + target.addDynamicallyLegalOp< + mlir::omp::CriticalOp, mlir::omp::ParallelOp, mlir::omp::WsLoopOp, + mlir::omp::SimdLoopOp, mlir::omp::MasterOp, mlir::omp::SectionsOp, + mlir::omp::SingleOp, mlir::omp::TaskOp>([&](Operation *op) { return typeConverter.isLegal(&op->getRegion(0)) && typeConverter.isLegal(op->getOperandTypes()) && typeConverter.isLegal(op->getResultTypes()); @@ -142,6 +142,7 @@ RegionOpConversion, RegionOpConversion, RegionOpConversion, RegionOpConversion, RegionOpConversion, RegionOpConversion, + RegionOpConversion, RegionLessOpWithVarOperandsConversion, RegionLessOpWithVarOperandsConversion, RegionLessOpWithVarOperandsConversion, Index: mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp =================================================================== --- mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -464,6 +464,69 @@ return success(); } +//===----------------------------------------------------------------------===// +// Parser, printer and verifier for DependVarList +//===----------------------------------------------------------------------===// + +/// depend-entry-list ::= depend-entry +/// | depend-entry-list `,` depend-entry +/// depend-entry ::= depend-kind `->` ssa-id `:` type +static ParseResult +parseDependVarList(OpAsmParser &parser, + SmallVectorImpl &operands, + SmallVectorImpl &types, ArrayAttr &dependsArray) { + SmallVector dependVec; + if (failed(parser.parseCommaSeparatedList([&]() { + StringRef keyword; + if (parser.parseKeyword(&keyword) || parser.parseArrow() || + parser.parseOperand(operands.emplace_back()) || + parser.parseColonType(types.emplace_back())) + return failure(); + if (std::optional keywordDepend = + (symbolizeClauseTaskDepend(keyword))) + dependVec.emplace_back( + ClauseTaskDependAttr::get(parser.getContext(), *keywordDepend)); + else + return failure(); + return success(); + }))) + return failure(); + SmallVector depends(dependVec.begin(), dependVec.end()); + dependsArray = ArrayAttr::get(parser.getContext(), depends); + return success(); +} + +/// Print Depend clause +static void printDependVarList(OpAsmPrinter &p, Operation *op, + OperandRange dependVars, TypeRange dependTypes, + std::optional depends) { + + for (unsigned i = 0, e = depends->size(); i < e; ++i) { + if (i != 0) + p << ", "; + p << stringifyClauseTaskDepend( + (*depends)[i].cast().getValue()) + << " -> " << dependVars[i] << " : " << dependTypes[i]; + } +} + +/// Verifies Depend clause +static LogicalResult verifyDependVarList(Operation *op, + Optional depends, + OperandRange dependVars) { + if (!dependVars.empty()) { + if (!depends || depends->size() != dependVars.size()) + return op->emitOpError() << "expected as many depend values" + " as depend variables"; + } else { + if (depends) + return op->emitOpError() << "unexpected depend values"; + return success(); + } + + return success(); +} + //===----------------------------------------------------------------------===// // Parser, printer and verifier for Synchronization Hint (2.17.12) //===----------------------------------------------------------------------===// @@ -957,7 +1020,12 @@ // TaskOp //===----------------------------------------------------------------------===// LogicalResult TaskOp::verify() { - return verifyReductionVarList(*this, getInReductions(), getInReductionVars()); + LogicalResult verifyDependVars = + verifyDependVarList(*this, getDepends(), getDependVars()); + return failed(verifyDependVars) + ? verifyDependVars + : verifyReductionVarList(*this, getInReductions(), + getInReductionVars()); } //===----------------------------------------------------------------------===// Index: mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp =================================================================== --- mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -693,11 +693,37 @@ convertOmpOpRegions(taskOp.getRegion(), "omp.task.region", builder, moduleTranslation, bodyGenStatus); }; + + SmallVector dds; + if (!taskOp.getDependVars().empty() && taskOp.getDepends()) { + for (auto dep : + llvm::zip(taskOp.getDependVars(), taskOp.getDepends()->getValue())) { + llvm::omp::RTLDependenceKindTy type; + switch ( + std::get<1>(dep).cast().getValue()) { + case mlir::omp::ClauseTaskDepend::taskdependin: + type = llvm::omp::RTLDependenceKindTy::DepIn; + break; + // The OpenMP runtime requires that the codegen for 'depend' clause for + // 'out' dependency kind must be the same as codegen for 'depend' clause + // with 'inout' dependency. + case mlir::omp::ClauseTaskDepend::taskdependout: + case mlir::omp::ClauseTaskDepend::taskdependinout: + type = llvm::omp::RTLDependenceKindTy::DepInOut; + break; + }; + llvm::Value *depVal = moduleTranslation.lookupValue(std::get<0>(dep)); + llvm::OpenMPIRBuilder::DependData dd(type, depVal->getType(), depVal); + dds.emplace_back(dd); + } + } + llvm::OpenMPIRBuilder::InsertPointTy allocaIP = findAllocaInsertPoint(builder, moduleTranslation); llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createTask( - ompLoc, allocaIP, bodyCB, !taskOp.getUntied())); + ompLoc, allocaIP, bodyCB, !taskOp.getUntied(), /*Final*/ nullptr, + /*IfCondition*/ nullptr, dds)); return bodyGenStatus; } Index: mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir =================================================================== --- mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir +++ mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir @@ -148,6 +148,23 @@ // ----- +// CHECK-LABEL: @task_depend +// CHECK: (%[[ARG0:.*]]: !llvm.ptr) { +// CHECK: omp.task depend(taskdependin -> %[[ARG0]] : !llvm.ptr) { +// CHECK: omp.terminator +// CHECK: } +// CHECK: llvm.return +// CHECK: } + +func.func @task_depend(%arg0: !llvm.ptr) { + omp.task depend(taskdependin -> %arg0 : !llvm.ptr) { + omp.terminator + } + return +} + +// ----- + // CHECK-LABEL: @_QPomp_target_data // CHECK: (%[[ARG0:.*]]: !llvm.ptr, %[[ARG1:.*]]: !llvm.ptr, %[[ARG2:.*]]: !llvm.ptr, %[[ARG3:.*]]: !llvm.ptr) // CHECK: omp.target_enter_data map((to -> %[[ARG0]] : !llvm.ptr), (to -> %[[ARG1]] : !llvm.ptr), (always, alloc -> %[[ARG2]] : !llvm.ptr)) Index: mlir/test/Dialect/OpenMP/invalid.mlir =================================================================== --- mlir/test/Dialect/OpenMP/invalid.mlir +++ mlir/test/Dialect/OpenMP/invalid.mlir @@ -1238,6 +1238,16 @@ // ----- +func.func @omp_task_depend(%data_var: memref) { + // expected-error @below {{op expected as many depend values as depend variables}} + "omp.task"(%data_var) ({ + "omp.terminator"() : () -> () + }) {depends = [], operand_segment_sizes = array} : (memref) -> () + "func.return"() : () -> () +} + +// ----- + func.func @omp_task(%ptr: !llvm.ptr) { // expected-error @below {{op expected symbol reference @add_f32 to point to a reduction declaration}} omp.task in_reduction(@add_f32 -> %ptr : !llvm.ptr) { Index: mlir/test/Dialect/OpenMP/ops.mlir =================================================================== --- mlir/test/Dialect/OpenMP/ops.mlir +++ mlir/test/Dialect/OpenMP/ops.mlir @@ -1566,6 +1566,19 @@ return } +// CHECK-LABEL: @omp_task_depend +// CHECK-SAME: (%arg0: memref, %arg1: memref) { +func.func @omp_task_depend(%arg0: memref, %arg1: memref) { + // CHECK: omp.task depend(taskdependin -> %arg0 : memref, taskdependin -> %arg1 : memref, taskdependinout -> %arg0 : memref) { + omp.task depend(taskdependin -> %arg0 : memref, taskdependin -> %arg1 : memref, taskdependinout -> %arg0 : memref) { + // CHECK: "test.foo"() : () -> () + "test.foo"() : () -> () + // CHECK: omp.terminator + omp.terminator + } + return +} + func.func @omp_threadprivate() { %0 = arith.constant 1 : i32 %1 = arith.constant 2 : i32 Index: mlir/test/Target/LLVMIR/openmp-llvm.mlir =================================================================== --- mlir/test/Target/LLVMIR/openmp-llvm.mlir +++ mlir/test/Target/LLVMIR/openmp-llvm.mlir @@ -2237,6 +2237,55 @@ // CHECK: ret void +// CHECK: define i32 @[[wrapper_fn]](i32 %{{.+}}) { +// CHECK: call void @[[outlined_fn]]() +// CHECK: ret i32 0 +// CHECK: } + +// ----- + +// CHECK-LABEL: define void @omp_task_with_deps +// CHECK-SAME: (ptr %[[zaddr:.+]]) +// CHECK: %[[dep_arr_addr:.+]] = alloca [1 x %struct.kmp_dep_info], align 8 +// CHECK: %[[dep_arr_addr_0:.+]] = getelementptr inbounds [1 x %struct.kmp_dep_info], ptr %[[dep_arr_addr]], i64 0, i64 0 +// CHECK: %[[dep_arr_addr_0_val:.+]] = getelementptr inbounds %struct.kmp_dep_info, ptr %[[dep_arr_addr_0]], i32 0, i32 0 +// CHECK: %[[dep_arr_addr_0_val_int:.+]] = ptrtoint ptr %0 to i64 +// CHECK: store i64 %[[dep_arr_addr_0_val_int]], ptr %[[dep_arr_addr_0_val]], align 4 +// CHECK: %[[dep_arr_addr_0_size:.+]] = getelementptr inbounds %struct.kmp_dep_info, ptr %[[dep_arr_addr_0]], i32 0, i32 1 +// CHECK: store i64 8, ptr %[[dep_arr_addr_0_size]], align 4 +// CHECK: %[[dep_arr_addr_0_kind:.+]] = getelementptr inbounds %struct.kmp_dep_info, ptr %[[dep_arr_addr_0]], i32 0, i32 2 +// CHECK: store i8 1, ptr %[[dep_arr_addr_0_kind]], align 1 +llvm.func @omp_task_with_deps(%zaddr: !llvm.ptr) { + // CHECK: %[[omp_global_thread_num:.+]] = call i32 @__kmpc_global_thread_num({{.+}}) + // CHECK: %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc + // CHECK-SAME: (ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 1, i64 0, + // CHECK-SAME: i64 0, ptr @[[wrapper_fn:.+]]) + // CHECK: call i32 @__kmpc_omp_task_with_deps(ptr @{{.+}}, i32 %[[omp_global_thread_num]], ptr %[[task_data]], {{.*}}) + omp.task depend(taskdependin -> %zaddr : !llvm.ptr) { + %n = llvm.mlir.constant(1 : i64) : i64 + %valaddr = llvm.alloca %n x i32 : (i64) -> !llvm.ptr + %val = llvm.load %valaddr : !llvm.ptr + %double = llvm.add %val, %val : i32 + llvm.store %double, %valaddr : !llvm.ptr + omp.terminator + } + llvm.return +} + +// CHECK: define internal void @[[outlined_fn:.+]]() +// CHECK: task.alloca{{.*}}: +// CHECK: br label %[[task_body:[^, ]+]] +// CHECK: [[task_body]]: +// CHECK: br label %[[task_region:[^, ]+]] +// CHECK: [[task_region]]: +// CHECK: %[[alloca:.+]] = alloca i32, i64 1 +// CHECK: %[[val:.+]] = load i32, ptr %[[alloca]] +// CHECK: %[[newval:.+]] = add i32 %[[val]], %[[val]] +// CHECK: store i32 %[[newval]], ptr %{{[^, ]+}} +// CHECK: br label %[[exit_stub:[^, ]+]] +// CHECK: [[exit_stub]]: +// CHECK: ret void + // CHECK: define i32 @[[wrapper_fn]](i32 %{{.+}}) { // CHECK: call void @[[outlined_fn]]() // CHECK: ret i32 0