diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp --- a/flang/lib/Lower/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP.cpp @@ -205,7 +205,8 @@ // Create and insert the operation. auto parallelOp = firOpBuilder.create( currentLocation, argTy, ifClauseOperand, numThreadsClauseOperand, - ValueRange(), ValueRange(), + /*allocate_vars=*/ValueRange(), /*allocators_vars=*/ValueRange(), + /*reduction_vars=*/ValueRange(), /*reductions=*/nullptr, procBindClauseOperand.dyn_cast_or_null()); // Handle attribute based clauses. for (const auto &clause : parallelOpClauseList.v) { diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -66,7 +66,7 @@ def ParallelOp : OpenMP_Op<"parallel", [ AutomaticAllocationScope, AttrSizedOperandSegments, DeclareOpInterfaceMethods, - RecursiveSideEffects]> { + RecursiveSideEffects, ReductionClauseInterface]> { let summary = "parallel construct"; let description = [{ The parallel construct includes a region of code which is to be executed @@ -83,6 +83,18 @@ The $allocators_vars and $allocate_vars parameters are a variadic list of values that specify the memory allocator to be used to obtain storage for private values. + Reductions can be performed in a parallel construct by specifying reduction + accumulator variables in `reduction_vars` and symbols referring to reduction + declarations in the `reductions` attribute. Each reduction is identified + by the accumulator it uses and accumulators must not be repeated in the same + reduction. The `omp.reduction` operation accepts the accumulator and a + partial value which is considered to be produced by the thread for the + given reduction. If multiple values are produced for the same accumulator, + i.e. there are multiple `omp.reduction`s, the last value is taken. The + reduction declaration specifies how to combine the values from each thread + into the final value, which is available in the accumulator after all the + threads complete. + The optional $proc_bind_val attribute controls the thread affinity for the execution of the parallel region. }]; @@ -91,6 +103,8 @@ Optional:$num_threads_var, Variadic:$allocate_vars, Variadic:$allocators_vars, + Variadic:$reduction_vars, + OptionalAttr:$reductions, OptionalAttr:$proc_bind_val); let regions = (region AnyRegion:$region); @@ -99,7 +113,11 @@ OpBuilder<(ins CArg<"ArrayRef", "{}">:$attributes)> ]; let assemblyFormat = [{ - oilist( `if` `(` $if_expr_var `:` type($if_expr_var) `)` + oilist( `reduction` `(` + custom( + $reduction_vars, type($reduction_vars), $reductions + ) `)` + | `if` `(` $if_expr_var `:` type($if_expr_var) `)` | `num_threads` `(` $num_threads_var `:` type($num_threads_var) `)` | `allocate` `(` custom( @@ -110,6 +128,12 @@ ) $region attr-dict }]; let hasVerifier = 1; + let extraClassDeclaration = [{ + // TODO: remove this once emitAccessorPrefix is set to + // kEmitAccessorPrefix_Prefixed for the dialect. + /// Returns the reduction variables + operand_range getReductionVars() { return reduction_vars(); } + }]; } def TerminatorOp : OpenMP_Op<"terminator", [Terminator]> { @@ -156,7 +180,8 @@ let assemblyFormat = "$region attr-dict"; } -def SectionsOp : OpenMP_Op<"sections", [AttrSizedOperandSegments]> { +def SectionsOp : OpenMP_Op<"sections", [AttrSizedOperandSegments, + ReductionClauseInterface]> { let summary = "sections construct"; let description = [{ The sections construct is a non-iterative worksharing construct that @@ -207,6 +232,13 @@ let hasVerifier = 1; let hasRegionVerifier = 1; + + let extraClassDeclaration = [{ + // TODO: remove this once emitAccessorPrefix is set to + // kEmitAccessorPrefix_Prefixed for the dialect. + /// Returns the reduction variables + operand_range getReductionVars() { return reduction_vars(); } + }]; } //===----------------------------------------------------------------------===// @@ -247,7 +279,7 @@ def WsLoopOp : OpenMP_Op<"wsloop", [AttrSizedOperandSegments, AllTypesMatch<["lowerBound", "upperBound", "step"]>, - RecursiveSideEffects]> { + RecursiveSideEffects, ReductionClauseInterface]> { let summary = "workshare loop construct"; let description = [{ The workshare loop construct specifies that the iterations of the loop(s) @@ -338,6 +370,11 @@ /// Returns the number of reduction variables. unsigned getNumReductionVars() { return reduction_vars().size(); } + + // TODO: remove this once emitAccessorPrefix is set to + // kEmitAccessorPrefix_Prefixed for the dialect. + /// Returns the reduction variables + operand_range getReductionVars() { return reduction_vars(); } }]; let hasCustomAssemblyFormat = 1; let assemblyFormat = [{ diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td @@ -31,4 +31,18 @@ ]; } +def ReductionClauseInterface : OpInterface<"ReductionClauseInterface"> { + let description = [{ + OpenMP operations that support reduction clause have this interface. + }]; + + let cppNamespace = "::mlir::omp"; + + let methods = [ + InterfaceMethod< + "Get reduction vars", "::mlir::Operation::operand_range", + "getReductionVars">, + ]; +} + #endif // OpenMP_OPS_INTERFACES diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -27,6 +27,7 @@ #include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc" #include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc" +#include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc" #include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc" using namespace mlir; @@ -58,19 +59,6 @@ MemRefType::attachInterface>(*getContext()); } -//===----------------------------------------------------------------------===// -// ParallelOp -//===----------------------------------------------------------------------===// - -void ParallelOp::build(OpBuilder &builder, OperationState &state, - ArrayRef attributes) { - ParallelOp::build( - builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr, - /*allocate_vars=*/ValueRange(), /*allocators_vars=*/ValueRange(), - /*proc_bind_val=*/nullptr); - state.addAttributes(attributes); -} - //===----------------------------------------------------------------------===// // Parser and printer for Allocate Clause //===----------------------------------------------------------------------===// @@ -142,13 +130,6 @@ p << stringifyEnum(attr.getValue()); } -LogicalResult ParallelOp::verify() { - if (allocate_vars().size() != allocators_vars().size()) - return emitError( - "expected equal sizes for allocate and allocator variables"); - return success(); -} - //===----------------------------------------------------------------------===// // Parser and printer for Linear Clause //===----------------------------------------------------------------------===// @@ -469,6 +450,27 @@ return success(); } +//===----------------------------------------------------------------------===// +// ParallelOp +//===----------------------------------------------------------------------===// + +void ParallelOp::build(OpBuilder &builder, OperationState &state, + ArrayRef attributes) { + ParallelOp::build( + builder, state, /*if_expr_var=*/nullptr, /*num_threads_var=*/nullptr, + /*allocate_vars=*/ValueRange(), /*allocators_vars=*/ValueRange(), + /*reduction_vars=*/ValueRange(), /*reductions=*/nullptr, + /*proc_bind_val=*/nullptr); + state.addAttributes(attributes); +} + +LogicalResult ParallelOp::verify() { + if (allocate_vars().size() != allocators_vars().size()) + return emitError( + "expected equal sizes for allocate and allocator variables"); + return verifyReductionVarList(*this, reductions(), reduction_vars()); +} + //===----------------------------------------------------------------------===// // Verifier for SectionsOp //===----------------------------------------------------------------------===// @@ -709,13 +711,17 @@ } LogicalResult ReductionOp::verify() { - // TODO: generalize this to an op interface when there is more than one op - // that supports reductions. - auto container = (*this)->getParentOfType(); - for (unsigned i = 0, e = container.getNumReductionVars(); i < e; ++i) - if (container.reduction_vars()[i] == accumulator()) - return success(); - + auto *op = (*this)->getParentWithTrait(); + if (!op) + return emitOpError() << "must be used within an operation supporting " + "reduction clause interface"; + while (op) { + for (const auto &var : + cast(op).getReductionVars()) + if (var == accumulator()) + return success(); + op = op->getParentWithTrait(); + } return emitOpError() << "the accumulator is not used by the parent"; } diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -59,7 +59,7 @@ // CHECK: omp.parallel num_threads(%{{.*}} : si32) allocate(%{{.*}} : memref -> %{{.*}} : memref) "omp.parallel"(%num_threads, %data_var, %data_var) ({ omp.terminator - }) {operand_segment_sizes = dense<[0,1,1,1]>: vector<4xi32>} : (si32, memref, memref) -> () + }) {operand_segment_sizes = dense<[0,1,1,1,0]> : vector<5xi32>} : (si32, memref, memref) -> () // CHECK: omp.barrier omp.barrier @@ -68,22 +68,22 @@ // CHECK: omp.parallel if(%{{.*}}) allocate(%{{.*}} : memref -> %{{.*}} : memref) "omp.parallel"(%if_cond, %data_var, %data_var) ({ omp.terminator - }) {operand_segment_sizes = dense<[1,0,1,1]> : vector<4xi32>} : (i1, memref, memref) -> () + }) {operand_segment_sizes = dense<[1,0,1,1,0]> : vector<5xi32>} : (i1, memref, memref) -> () // test without allocate // CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : si32) "omp.parallel"(%if_cond, %num_threads) ({ omp.terminator - }) {operand_segment_sizes = dense<[1,1,0,0]> : vector<4xi32>} : (i1, si32) -> () + }) {operand_segment_sizes = dense<[1,1,0,0,0]> : vector<5xi32>} : (i1, si32) -> () omp.terminator - }) {operand_segment_sizes = dense<[1,1,1,1]> : vector<4xi32>, proc_bind_val = #omp<"procbindkind spread">} : (i1, si32, memref, memref) -> () + }) {operand_segment_sizes = dense<[1,1,1,1,0]> : vector<5xi32>, proc_bind_val = #omp<"procbindkind spread">} : (i1, si32, memref, memref) -> () // test with multiple parameters for single variadic argument // CHECK: omp.parallel allocate(%{{.*}} : memref -> %{{.*}} : memref) "omp.parallel" (%data_var, %data_var) ({ omp.terminator - }) {operand_segment_sizes = dense<[0,0,1,1]> : vector<4xi32>} : (memref, memref) -> () + }) {operand_segment_sizes = dense<[0,0,1,1,0]> : vector<5xi32>} : (memref, memref) -> () return } @@ -407,7 +407,8 @@ omp.yield } -func @reduction(%lb : index, %ub : index, %step : index) { +// CHECK-LABEL: func @wsloop_reduction +func @wsloop_reduction(%lb : index, %ub : index, %step : index) { %c1 = arith.constant 1 : i32 %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr // CHECK: reduction(@add_f32 -> %{{.+}} : !llvm.ptr) @@ -421,6 +422,65 @@ return } +// CHECK-LABEL: func @parallel_reduction +func @parallel_reduction() { + %c1 = arith.constant 1 : i32 + %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr + // CHECK: omp.parallel reduction(@add_f32 -> {{.+}} : !llvm.ptr) + omp.parallel reduction(@add_f32 -> %0 : !llvm.ptr) { + %1 = arith.constant 2.0 : f32 + // CHECK: omp.reduction %{{.+}}, %{{.+}} + omp.reduction %1, %0 : !llvm.ptr + omp.terminator + } + return +} + +// CHECK: func @parallel_wsloop_reduction +func @parallel_wsloop_reduction(%lb : index, %ub : index, %step : index) { + %c1 = arith.constant 1 : i32 + %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr + // CHECK: omp.parallel reduction(@add_f32 -> %{{.+}} : !llvm.ptr) { + omp.parallel reduction(@add_f32 -> %0 : !llvm.ptr) { + // CHECK: omp.wsloop for (%{{.+}}) : index = (%{{.+}}) to (%{{.+}}) step (%{{.+}}) + omp.wsloop for (%iv) : index = (%lb) to (%ub) step (%step) { + %1 = arith.constant 2.0 : f32 + // CHECK: omp.reduction %{{.+}}, %{{.+}} : !llvm.ptr + omp.reduction %1, %0 : !llvm.ptr + // CHECK: omp.yield + omp.yield + } + // CHECK: omp.terminator + omp.terminator + } + return +} + +// CHECK-LABEL: func @sections_reduction +func @sections_reduction() { + %c1 = arith.constant 1 : i32 + %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr + // CHECK: omp.sections reduction(@add_f32 -> {{.+}} : !llvm.ptr) + omp.sections reduction(@add_f32 -> %0 : !llvm.ptr) { + // CHECK: omp.section + omp.section { + %1 = arith.constant 2.0 : f32 + // CHECK: omp.reduction %{{.+}}, %{{.+}} + omp.reduction %1, %0 : !llvm.ptr + omp.terminator + } + // CHECK: omp.section + omp.section { + %1 = arith.constant 3.0 : f32 + // CHECK: omp.reduction %{{.+}}, %{{.+}} + omp.reduction %1, %0 : !llvm.ptr + omp.terminator + } + omp.terminator + } + return +} + // CHECK: omp.reduction.declare // CHECK-LABEL: @add2_f32 omp.reduction.declare @add2_f32 : f32 @@ -438,9 +498,10 @@ } // CHECK-NOT: atomic -func @reduction2(%lb : index, %ub : index, %step : index) { +// CHECK-LABEL: func @wsloop_reduction2 +func @wsloop_reduction2(%lb : index, %ub : index, %step : index) { %0 = memref.alloca() : memref<1xf32> - // CHECK: reduction + // CHECK: omp.wsloop reduction(@add2_f32 -> %{{.+}} : memref<1xf32>) omp.wsloop reduction(@add2_f32 -> %0 : memref<1xf32>) for (%iv) : index = (%lb) to (%ub) step (%step) { %1 = arith.constant 2.0 : f32 @@ -451,6 +512,61 @@ return } +// CHECK-LABEL: func @parallel_reduction2 +func @parallel_reduction2() { + %0 = memref.alloca() : memref<1xf32> + // CHECK: omp.parallel reduction(@add2_f32 -> %{{.+}} : memref<1xf32>) + omp.parallel reduction(@add2_f32 -> %0 : memref<1xf32>) { + %1 = arith.constant 2.0 : f32 + // CHECK: omp.reduction + omp.reduction %1, %0 : memref<1xf32> + omp.terminator + } + return +} + +// CHECK: func @parallel_wsloop_reduction2 +func @parallel_wsloop_reduction2(%lb : index, %ub : index, %step : index) { + %c1 = arith.constant 1 : i32 + %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr + // CHECK: omp.parallel reduction(@add2_f32 -> %{{.+}} : !llvm.ptr) { + omp.parallel reduction(@add2_f32 -> %0 : !llvm.ptr) { + // CHECK: omp.wsloop for (%{{.+}}) : index = (%{{.+}}) to (%{{.+}}) step (%{{.+}}) + omp.wsloop for (%iv) : index = (%lb) to (%ub) step (%step) { + %1 = arith.constant 2.0 : f32 + // CHECK: omp.reduction %{{.+}}, %{{.+}} : !llvm.ptr + omp.reduction %1, %0 : !llvm.ptr + // CHECK: omp.yield + omp.yield + } + // CHECK: omp.terminator + omp.terminator + } + return +} + +// CHECK-LABEL: func @sections_reduction2 +func @sections_reduction2() { + %0 = memref.alloca() : memref<1xf32> + // CHECK: omp.sections reduction(@add2_f32 -> %{{.+}} : memref<1xf32>) + omp.sections reduction(@add2_f32 -> %0 : memref<1xf32>) { + omp.section { + %1 = arith.constant 2.0 : f32 + // CHECK: omp.reduction + omp.reduction %1, %0 : memref<1xf32> + omp.terminator + } + omp.section { + %1 = arith.constant 2.0 : f32 + // CHECK: omp.reduction + omp.reduction %1, %0 : memref<1xf32> + omp.terminator + } + omp.terminator + } + return +} + // CHECK: omp.critical.declare @mutex1 hint(uncontended) omp.critical.declare @mutex1 hint(uncontended) // CHECK: omp.critical.declare @mutex2 hint(contended)