Index: mlir/include/mlir/Dialect/GPU/GPUOps.td =================================================================== --- mlir/include/mlir/Dialect/GPU/GPUOps.td +++ mlir/include/mlir/Dialect/GPU/GPUOps.td @@ -482,15 +482,25 @@ }]; } -// These mirror the XLA ComparisonDirection enum. +// add, mul mirror the XLA ComparisonDirection enum. def GPU_AllReduceOpAdd : StrEnumAttrCase<"add">; +def GPU_AllReduceOpAnd : StrEnumAttrCase<"and">; +def GPU_AllReduceOpMax : StrEnumAttrCase<"max">; +def GPU_AllReduceOpMin : StrEnumAttrCase<"min">; def GPU_AllReduceOpMul : StrEnumAttrCase<"mul">; +def GPU_AllReduceOpOr : StrEnumAttrCase<"or">; +def GPU_AllReduceOpXor : StrEnumAttrCase<"xor">; def GPU_AllReduceOperationAttr : StrEnumAttr<"AllReduceOperationAttr", "built-in reduction operations supported by gpu.allreduce.", [ GPU_AllReduceOpAdd, + GPU_AllReduceOpAnd, + GPU_AllReduceOpMax, + GPU_AllReduceOpMin, GPU_AllReduceOpMul, + GPU_AllReduceOpOr, + GPU_AllReduceOpXor ]>; def GPU_AllReduceOp : GPU_Op<"all_reduce", @@ -514,8 +524,8 @@ ``` compute the sum of each work item's %0 value. The first version specifies the accumulation as operation, whereas the second version specifies the - accumulation as code region. The accumulation operation must either be - `add` or `mul`. + accumulation as code region. The accumulation operation must be one of: + `add`, `and`, `max`, `min`, `mul`, `or`, `xor`. Either none or all work items of a workgroup need to execute this op in convergence. Index: mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp =================================================================== --- mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -123,6 +123,33 @@ return isFloatingPoint ? getFactory() : getFactory(); } + if (opName == "and") { + assert(!isFloatingPoint && + " accumulator is not compatible with Floating Point type"); + return getFactory(); + } + if (opName == "or") { + assert(!isFloatingPoint && + " accumulator is not compatible with Floating Point type"); + return getFactory(); + } + if (opName == "xor") { + assert(!isFloatingPoint && + " accumulator is not compatible with Floating Point type"); + return getFactory(); + } + if (opName == "max") { + return isFloatingPoint ? getCmpFactory( + LLVM::FCmpPredicate::ugt) + : getCmpFactory( + LLVM::ICmpPredicate::ugt); + } + if (opName == "min") { + return isFloatingPoint ? getCmpFactory( + LLVM::FCmpPredicate::ult) + : getCmpFactory( + LLVM::ICmpPredicate::ult); + } return AccumulatorFactory(); } @@ -135,6 +162,17 @@ }; } + /// Returns an accumulator for comparaison such as min, max. T is the type + /// of the compare op and P is the type of the predicate. + template + AccumulatorFactory getCmpFactory(P predicate) const { + return [predicate](Location loc, Value lhs, Value rhs, + ConversionPatternRewriter &rewriter) { + Value cmp = rewriter.create(loc, predicate, lhs, rhs); + return rewriter.create(loc, cmp, lhs, rhs); + }; + } + /// Creates an all_reduce across the block. /// /// First reduce the elements within a warp. The first thread of each warp