Index: mlir/include/mlir/Dialect/GPU/GPUOps.td =================================================================== --- mlir/include/mlir/Dialect/GPU/GPUOps.td +++ mlir/include/mlir/Dialect/GPU/GPUOps.td @@ -482,15 +482,25 @@ }]; } -// These mirror the XLA ComparisonDirection enum. +// add, mul mirror the XLA ComparisonDirection enum. def GPU_AllReduceOpAdd : StrEnumAttrCase<"add">; +def GPU_AllReduceOpAnd : StrEnumAttrCase<"and">; +def GPU_AllReduceOpMax : StrEnumAttrCase<"max">; +def GPU_AllReduceOpMin : StrEnumAttrCase<"min">; def GPU_AllReduceOpMul : StrEnumAttrCase<"mul">; +def GPU_AllReduceOpOr : StrEnumAttrCase<"or">; +def GPU_AllReduceOpXor : StrEnumAttrCase<"xor">; def GPU_AllReduceOperationAttr : StrEnumAttr<"AllReduceOperationAttr", "built-in reduction operations supported by gpu.allreduce.", [ GPU_AllReduceOpAdd, + GPU_AllReduceOpAnd, + GPU_AllReduceOpMax, + GPU_AllReduceOpMin, GPU_AllReduceOpMul, + GPU_AllReduceOpOr, + GPU_AllReduceOpXor ]>; def GPU_AllReduceOp : GPU_Op<"all_reduce", @@ -514,8 +524,8 @@ ``` compute the sum of each work item's %0 value. The first version specifies the accumulation as operation, whereas the second version specifies the - accumulation as code region. The accumulation operation must either be - `add` or `mul`. + accumulation as code region. The accumulation operation must be one of: + `add`, `and`, `max`, `min`, `mul`, `or`, `xor`. Either none or all work items of a workgroup need to execute this op in convergence. Index: mlir/include/mlir/ExecutionEngine/RunnerUtils.h =================================================================== --- mlir/include/mlir/ExecutionEngine/RunnerUtils.h +++ mlir/include/mlir/ExecutionEngine/RunnerUtils.h @@ -211,6 +211,8 @@ extern "C" MLIR_RUNNERUTILS_EXPORT void _mlir_ciface_print_memref_f32(UnrankedMemRefType *M); +extern "C" MLIR_RUNNERUTILS_EXPORT void print_memref_i32(int64_t rank, + void *ptr); extern "C" MLIR_RUNNERUTILS_EXPORT void print_memref_f32(int64_t rank, void *ptr); Index: mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp =================================================================== --- mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -123,18 +123,51 @@ return isFloatingPoint ? getFactory() : getFactory(); } + if (opName == "and") { + return getFactory(); + } + if (opName == "or") { + return getFactory(); + } + if (opName == "xor") { + return getFactory(); + } + if (opName == "max") { + return isFloatingPoint ? getCmpFactory() + : getCmpFactory(); + } + if (opName == "min") { + return isFloatingPoint ? getCmpFactory() + : getCmpFactory(); + } return AccumulatorFactory(); } /// Returns an accumulator factory that creates an op of type T. - template AccumulatorFactory getFactory() const { + template + AccumulatorFactory getFactory() const { return [](Location loc, Value lhs, Value rhs, ConversionPatternRewriter &rewriter) { return rewriter.create(loc, lhs.getType(), lhs, rhs); }; } + /// Returns an accumulator for comparaison such as min, max. T is the type + /// of the compare op. + template + AccumulatorFactory getCmpFactory() const { + return [](Location loc, Value lhs, Value rhs, + ConversionPatternRewriter &rewriter) { + Value cmp = rewriter.create(loc, predicate, lhs, rhs); + return rewriter.create(loc, cmp, lhs, rhs); + }; + } + /// Creates an all_reduce across the block. /// /// First reduce the elements within a warp. The first thread of each warp @@ -705,9 +738,9 @@ GPUAllReduceOpLowering, GPUShuffleOpLowering, GPUFuncOpLowering, GPUReturnOpLowering>(converter); patterns.insert>(converter, "__nv_fabsf", - "__nv_fabs"); + "__nv_fabs"); patterns.insert>(converter, "__nv_ceilf", - "__nv_ceil"); + "__nv_ceil"); patterns.insert>(converter, "__nv_cosf", "__nv_cos"); patterns.insert>(converter, "__nv_expf", Index: mlir/lib/Dialect/GPU/IR/GPUDialect.cpp =================================================================== --- mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -148,6 +148,14 @@ } if (yieldCount == 0) return allReduce.emitError("expected gpu.yield op in region"); + } else { + StringRef opName = *allReduce.op(); + if ((opName == "and" || opName == "or" || opName == "xor") && + !allReduce.getType().isa()) { + return allReduce.emitError() + << "`" << opName << "`" + << " accumulator is only compatible with Integer type"; + } } return success(); } Index: mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp =================================================================== --- mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp +++ mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp @@ -212,6 +212,25 @@ return isFloatingPoint ? getFactory() : getFactory(); if (opName == "mul") return isFloatingPoint ? getFactory() : getFactory(); + if (opName == "and") { + return getFactory(); + } + if (opName == "or") { + return getFactory(); + } + if (opName == "xor") { + return getFactory(); + } + if (opName == "max") { + return isFloatingPoint + ? getCmpFactory() + : getCmpFactory(); + } + if (opName == "min") { + return isFloatingPoint + ? getCmpFactory() + : getCmpFactory(); + } return AccumulatorFactory(); } @@ -222,6 +241,16 @@ }; } + /// Returns an accumulator for comparaison such as min, max. T is the type + /// of the compare op. + template + AccumulatorFactory getCmpFactory() const { + return [&](Value lhs, Value rhs) { + Value cmp = rewriter.create(loc, predicate, lhs, rhs); + return rewriter.create(loc, cmp, lhs, rhs); + }; + } + /// Creates an if-block skeleton and calls the two factories to generate the /// ops in the `then` and `else` block.. /// Index: mlir/lib/ExecutionEngine/RunnerUtils.cpp =================================================================== --- mlir/lib/ExecutionEngine/RunnerUtils.cpp +++ mlir/lib/ExecutionEngine/RunnerUtils.cpp @@ -41,6 +41,22 @@ } } +extern "C" void _mlir_ciface_print_memref_i32(UnrankedMemRefType *M) { + printUnrankedMemRefMetaData(std::cout, *M); + int rank = M->rank; + void *ptr = M->descriptor; + + switch (rank) { + MEMREF_CASE(int32_t, 0); + MEMREF_CASE(int32_t, 1); + MEMREF_CASE(int32_t, 2); + MEMREF_CASE(int32_t, 3); + MEMREF_CASE(int32_t, 4); + default: + assert(0 && "Unsupported rank to print"); + } +} + extern "C" void _mlir_ciface_print_memref_f32(UnrankedMemRefType *M) { printUnrankedMemRefMetaData(std::cout, *M); int rank = M->rank; @@ -57,6 +73,13 @@ } } +extern "C" void print_memref_i32(int64_t rank, void *ptr) { + UnrankedMemRefType descriptor; + descriptor.rank = rank; + descriptor.descriptor = ptr; + _mlir_ciface_print_memref_i32(&descriptor); +} + extern "C" void print_memref_f32(int64_t rank, void *ptr) { UnrankedMemRefType descriptor; descriptor.rank = rank; Index: mlir/test/Dialect/GPU/all-reduce-max.mlir =================================================================== --- /dev/null +++ mlir/test/Dialect/GPU/all-reduce-max.mlir @@ -0,0 +1,203 @@ +// RUN: mlir-opt -test-all-reduce-lowering %s | FileCheck %s + +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py +// CHECK: module @kernels attributes {gpu.kernel_module} { +module @kernels attributes {gpu.kernel_module} { + + // CHECK-LABEL: gpu.func @kernel( + // CHECK-SAME: [[VAL_0:%.*]]: f32) workgroup([[VAL_1:%.*]] : memref<32xf32, 3>) kernel { + gpu.func @kernel(%arg0 : f32) attributes { gpu.kernel } { + // CHECK: [[VAL_2:%.*]] = constant 31 : i32 + // CHECK: [[VAL_3:%.*]] = constant 0 : i32 + // CHECK: [[VAL_4:%.*]] = constant 0 : index + // CHECK: [[VAL_5:%.*]] = constant 32 : i32 + // CHECK: [[VAL_6:%.*]] = constant 1 : i32 + // CHECK: [[VAL_7:%.*]] = constant 2 : i32 + // CHECK: [[VAL_8:%.*]] = constant 4 : i32 + // CHECK: [[VAL_9:%.*]] = constant 8 : i32 + // CHECK: [[VAL_10:%.*]] = constant 16 : i32 + // CHECK: [[VAL_11:%.*]] = "gpu.block_dim"() {dimension = "x"} : () -> index + // CHECK: [[VAL_12:%.*]] = index_cast [[VAL_11]] : index to i32 + // CHECK: [[VAL_13:%.*]] = "gpu.block_dim"() {dimension = "y"} : () -> index + // CHECK: [[VAL_14:%.*]] = index_cast [[VAL_13]] : index to i32 + // CHECK: [[VAL_15:%.*]] = "gpu.block_dim"() {dimension = "z"} : () -> index + // CHECK: [[VAL_16:%.*]] = index_cast [[VAL_15]] : index to i32 + // CHECK: [[VAL_17:%.*]] = "gpu.thread_id"() {dimension = "x"} : () -> index + // CHECK: [[VAL_18:%.*]] = index_cast [[VAL_17]] : index to i32 + // CHECK: [[VAL_19:%.*]] = "gpu.thread_id"() {dimension = "y"} : () -> index + // CHECK: [[VAL_20:%.*]] = index_cast [[VAL_19]] : index to i32 + // CHECK: [[VAL_21:%.*]] = "gpu.thread_id"() {dimension = "z"} : () -> index + // CHECK: [[VAL_22:%.*]] = index_cast [[VAL_21]] : index to i32 + // CHECK: [[VAL_23:%.*]] = muli [[VAL_22]], [[VAL_14]] : i32 + // CHECK: [[VAL_24:%.*]] = addi [[VAL_23]], [[VAL_20]] : i32 + // CHECK: [[VAL_25:%.*]] = muli [[VAL_24]], [[VAL_12]] : i32 + // CHECK: [[VAL_26:%.*]] = muli [[VAL_12]], [[VAL_14]] : i32 + // CHECK: [[VAL_27:%.*]] = addi [[VAL_25]], [[VAL_18]] : i32 + // CHECK: [[VAL_28:%.*]] = muli [[VAL_26]], [[VAL_16]] : i32 + // CHECK: [[VAL_29:%.*]] = and [[VAL_27]], [[VAL_2]] : i32 + // CHECK: [[VAL_30:%.*]] = cmpi "eq", [[VAL_29]], [[VAL_3]] : i32 + // CHECK: [[VAL_31:%.*]] = subi [[VAL_27]], [[VAL_29]] : i32 + // CHECK: [[VAL_32:%.*]] = subi [[VAL_28]], [[VAL_31]] : i32 + // CHECK: [[VAL_33:%.*]] = cmpi "slt", [[VAL_32]], [[VAL_5]] : i32 + // CHECK: cond_br [[VAL_33]], ^bb1, ^bb17 + // CHECK: ^bb1: + // CHECK: [[VAL_34:%.*]], [[VAL_35:%.*]] = gpu.shuffle [[VAL_0]], [[VAL_6]], [[VAL_32]] xor : f32 + // CHECK: cond_br [[VAL_35]], ^bb2, ^bb3 + // CHECK: ^bb2: + // CHECK: [[VAL_36:%.*]] = cmpf "ugt", [[VAL_0]], [[VAL_34]] : f32 + // CHECK: [[VAL_37:%.*]] = select [[VAL_36]], [[VAL_0]], [[VAL_34]] : f32 + // CHECK: br ^bb4([[VAL_37]] : f32) + // CHECK: ^bb3: + // CHECK: br ^bb4([[VAL_0]] : f32) + // CHECK: ^bb4([[VAL_38:%.*]]: f32): + // CHECK: [[VAL_39:%.*]], [[VAL_40:%.*]] = gpu.shuffle [[VAL_38]], [[VAL_7]], [[VAL_32]] xor : f32 + // CHECK: cond_br [[VAL_40]], ^bb5, ^bb6 + // CHECK: ^bb5: + // CHECK: [[VAL_41:%.*]] = cmpf "ugt", [[VAL_38]], [[VAL_39]] : f32 + // CHECK: [[VAL_42:%.*]] = select [[VAL_41]], [[VAL_38]], [[VAL_39]] : f32 + // CHECK: br ^bb7([[VAL_42]] : f32) + // CHECK: ^bb6: + // CHECK: br ^bb7([[VAL_38]] : f32) + // CHECK: ^bb7([[VAL_43:%.*]]: f32): + // CHECK: [[VAL_44:%.*]], [[VAL_45:%.*]] = gpu.shuffle [[VAL_43]], [[VAL_8]], [[VAL_32]] xor : f32 + // CHECK: cond_br [[VAL_45]], ^bb8, ^bb9 + // CHECK: ^bb8: + // CHECK: [[VAL_46:%.*]] = cmpf "ugt", [[VAL_43]], [[VAL_44]] : f32 + // CHECK: [[VAL_47:%.*]] = select [[VAL_46]], [[VAL_43]], [[VAL_44]] : f32 + // CHECK: br ^bb10([[VAL_47]] : f32) + // CHECK: ^bb9: + // CHECK: br ^bb10([[VAL_43]] : f32) + // CHECK: ^bb10([[VAL_48:%.*]]: f32): + // CHECK: [[VAL_49:%.*]], [[VAL_50:%.*]] = gpu.shuffle [[VAL_48]], [[VAL_9]], [[VAL_32]] xor : f32 + // CHECK: cond_br [[VAL_50]], ^bb11, ^bb12 + // CHECK: ^bb11: + // CHECK: [[VAL_51:%.*]] = cmpf "ugt", [[VAL_48]], [[VAL_49]] : f32 + // CHECK: [[VAL_52:%.*]] = select [[VAL_51]], [[VAL_48]], [[VAL_49]] : f32 + // CHECK: br ^bb13([[VAL_52]] : f32) + // CHECK: ^bb12: + // CHECK: br ^bb13([[VAL_48]] : f32) + // CHECK: ^bb13([[VAL_53:%.*]]: f32): + // CHECK: [[VAL_54:%.*]], [[VAL_55:%.*]] = gpu.shuffle [[VAL_53]], [[VAL_10]], [[VAL_32]] xor : f32 + // CHECK: cond_br [[VAL_55]], ^bb14, ^bb15 + // CHECK: ^bb14: + // CHECK: [[VAL_56:%.*]] = cmpf "ugt", [[VAL_53]], [[VAL_54]] : f32 + // CHECK: [[VAL_57:%.*]] = select [[VAL_56]], [[VAL_53]], [[VAL_54]] : f32 + // CHECK: br ^bb16([[VAL_57]] : f32) + // CHECK: ^bb15: + // CHECK: br ^bb16([[VAL_53]] : f32) + // CHECK: ^bb16([[VAL_58:%.*]]: f32): + // CHECK: br ^bb18([[VAL_58]] : f32) + // CHECK: ^bb17: + // CHECK: [[VAL_59:%.*]], [[VAL_60:%.*]] = gpu.shuffle [[VAL_0]], [[VAL_6]], [[VAL_5]] xor : f32 + // CHECK: [[VAL_61:%.*]] = cmpf "ugt", [[VAL_0]], [[VAL_59]] : f32 + // CHECK: [[VAL_62:%.*]] = select [[VAL_61]], [[VAL_0]], [[VAL_59]] : f32 + // CHECK: [[VAL_63:%.*]], [[VAL_64:%.*]] = gpu.shuffle [[VAL_62]], [[VAL_7]], [[VAL_5]] xor : f32 + // CHECK: [[VAL_65:%.*]] = cmpf "ugt", [[VAL_62]], [[VAL_63]] : f32 + // CHECK: [[VAL_66:%.*]] = select [[VAL_65]], [[VAL_62]], [[VAL_63]] : f32 + // CHECK: [[VAL_67:%.*]], [[VAL_68:%.*]] = gpu.shuffle [[VAL_66]], [[VAL_8]], [[VAL_5]] xor : f32 + // CHECK: [[VAL_69:%.*]] = cmpf "ugt", [[VAL_66]], [[VAL_67]] : f32 + // CHECK: [[VAL_70:%.*]] = select [[VAL_69]], [[VAL_66]], [[VAL_67]] : f32 + // CHECK: [[VAL_71:%.*]], [[VAL_72:%.*]] = gpu.shuffle [[VAL_70]], [[VAL_9]], [[VAL_5]] xor : f32 + // CHECK: [[VAL_73:%.*]] = cmpf "ugt", [[VAL_70]], [[VAL_71]] : f32 + // CHECK: [[VAL_74:%.*]] = select [[VAL_73]], [[VAL_70]], [[VAL_71]] : f32 + // CHECK: [[VAL_75:%.*]], [[VAL_76:%.*]] = gpu.shuffle [[VAL_74]], [[VAL_10]], [[VAL_5]] xor : f32 + // CHECK: [[VAL_77:%.*]] = cmpf "ugt", [[VAL_74]], [[VAL_75]] : f32 + // CHECK: [[VAL_78:%.*]] = select [[VAL_77]], [[VAL_74]], [[VAL_75]] : f32 + // CHECK: br ^bb18([[VAL_78]] : f32) + // CHECK: ^bb18([[VAL_79:%.*]]: f32): + // CHECK: cond_br [[VAL_30]], ^bb19, ^bb20 + // CHECK: ^bb19: + // CHECK: [[VAL_80:%.*]] = divi_signed [[VAL_27]], [[VAL_5]] : i32 + // CHECK: [[VAL_81:%.*]] = index_cast [[VAL_80]] : i32 to index + // CHECK: store [[VAL_79]], [[VAL_1]]{{\[}}[[VAL_81]]] : memref<32xf32, 3> + // CHECK: br ^bb21 + // CHECK: ^bb20: + // CHECK: br ^bb21 + // CHECK: ^bb21: + // CHECK: gpu.barrier + // CHECK: [[VAL_82:%.*]] = addi [[VAL_28]], [[VAL_2]] : i32 + // CHECK: [[VAL_83:%.*]] = divi_signed [[VAL_82]], [[VAL_5]] : i32 + // CHECK: [[VAL_84:%.*]] = cmpi "slt", [[VAL_27]], [[VAL_83]] : i32 + // CHECK: cond_br [[VAL_84]], ^bb22, ^bb41 + // CHECK: ^bb22: + // CHECK: [[VAL_85:%.*]] = index_cast [[VAL_27]] : i32 to index + // CHECK: [[VAL_86:%.*]] = load [[VAL_1]]{{\[}}[[VAL_85]]] : memref<32xf32, 3> + // CHECK: [[VAL_87:%.*]] = cmpi "slt", [[VAL_83]], [[VAL_5]] : i32 + // CHECK: cond_br [[VAL_87]], ^bb23, ^bb39 + // CHECK: ^bb23: + // CHECK: [[VAL_88:%.*]], [[VAL_89:%.*]] = gpu.shuffle [[VAL_86]], [[VAL_6]], [[VAL_83]] xor : f32 + // CHECK: cond_br [[VAL_89]], ^bb24, ^bb25 + // CHECK: ^bb24: + // CHECK: [[VAL_90:%.*]] = cmpf "ugt", [[VAL_86]], [[VAL_88]] : f32 + // CHECK: [[VAL_91:%.*]] = select [[VAL_90]], [[VAL_86]], [[VAL_88]] : f32 + // CHECK: br ^bb26([[VAL_91]] : f32) + // CHECK: ^bb25: + // CHECK: br ^bb26([[VAL_86]] : f32) + // CHECK: ^bb26([[VAL_92:%.*]]: f32): + // CHECK: [[VAL_93:%.*]], [[VAL_94:%.*]] = gpu.shuffle [[VAL_92]], [[VAL_7]], [[VAL_83]] xor : f32 + // CHECK: cond_br [[VAL_94]], ^bb27, ^bb28 + // CHECK: ^bb27: + // CHECK: [[VAL_95:%.*]] = cmpf "ugt", [[VAL_92]], [[VAL_93]] : f32 + // CHECK: [[VAL_96:%.*]] = select [[VAL_95]], [[VAL_92]], [[VAL_93]] : f32 + // CHECK: br ^bb29([[VAL_96]] : f32) + // CHECK: ^bb28: + // CHECK: br ^bb29([[VAL_92]] : f32) + // CHECK: ^bb29([[VAL_97:%.*]]: f32): + // CHECK: [[VAL_98:%.*]], [[VAL_99:%.*]] = gpu.shuffle [[VAL_97]], [[VAL_8]], [[VAL_83]] xor : f32 + // CHECK: cond_br [[VAL_99]], ^bb30, ^bb31 + // CHECK: ^bb30: + // CHECK: [[VAL_100:%.*]] = cmpf "ugt", [[VAL_97]], [[VAL_98]] : f32 + // CHECK: [[VAL_101:%.*]] = select [[VAL_100]], [[VAL_97]], [[VAL_98]] : f32 + // CHECK: br ^bb32([[VAL_101]] : f32) + // CHECK: ^bb31: + // CHECK: br ^bb32([[VAL_97]] : f32) + // CHECK: ^bb32([[VAL_102:%.*]]: f32): + // CHECK: [[VAL_103:%.*]], [[VAL_104:%.*]] = gpu.shuffle [[VAL_102]], [[VAL_9]], [[VAL_83]] xor : f32 + // CHECK: cond_br [[VAL_104]], ^bb33, ^bb34 + // CHECK: ^bb33: + // CHECK: [[VAL_105:%.*]] = cmpf "ugt", [[VAL_102]], [[VAL_103]] : f32 + // CHECK: [[VAL_106:%.*]] = select [[VAL_105]], [[VAL_102]], [[VAL_103]] : f32 + // CHECK: br ^bb35([[VAL_106]] : f32) + // CHECK: ^bb34: + // CHECK: br ^bb35([[VAL_102]] : f32) + // CHECK: ^bb35([[VAL_107:%.*]]: f32): + // CHECK: [[VAL_108:%.*]], [[VAL_109:%.*]] = gpu.shuffle [[VAL_107]], [[VAL_10]], [[VAL_83]] xor : f32 + // CHECK: cond_br [[VAL_109]], ^bb36, ^bb37 + // CHECK: ^bb36: + // CHECK: [[VAL_110:%.*]] = cmpf "ugt", [[VAL_107]], [[VAL_108]] : f32 + // CHECK: [[VAL_111:%.*]] = select [[VAL_110]], [[VAL_107]], [[VAL_108]] : f32 + // CHECK: br ^bb38([[VAL_111]] : f32) + // CHECK: ^bb37: + // CHECK: br ^bb38([[VAL_107]] : f32) + // CHECK: ^bb38([[VAL_112:%.*]]: f32): + // CHECK: br ^bb40([[VAL_112]] : f32) + // CHECK: ^bb39: + // CHECK: [[VAL_113:%.*]], [[VAL_114:%.*]] = gpu.shuffle [[VAL_86]], [[VAL_6]], [[VAL_5]] xor : f32 + // CHECK: [[VAL_115:%.*]] = cmpf "ugt", [[VAL_86]], [[VAL_113]] : f32 + // CHECK: [[VAL_116:%.*]] = select [[VAL_115]], [[VAL_86]], [[VAL_113]] : f32 + // CHECK: [[VAL_117:%.*]], [[VAL_118:%.*]] = gpu.shuffle [[VAL_116]], [[VAL_7]], [[VAL_5]] xor : f32 + // CHECK: [[VAL_119:%.*]] = cmpf "ugt", [[VAL_116]], [[VAL_117]] : f32 + // CHECK: [[VAL_120:%.*]] = select [[VAL_119]], [[VAL_116]], [[VAL_117]] : f32 + // CHECK: [[VAL_121:%.*]], [[VAL_122:%.*]] = gpu.shuffle [[VAL_120]], [[VAL_8]], [[VAL_5]] xor : f32 + // CHECK: [[VAL_123:%.*]] = cmpf "ugt", [[VAL_120]], [[VAL_121]] : f32 + // CHECK: [[VAL_124:%.*]] = select [[VAL_123]], [[VAL_120]], [[VAL_121]] : f32 + // CHECK: [[VAL_125:%.*]], [[VAL_126:%.*]] = gpu.shuffle [[VAL_124]], [[VAL_9]], [[VAL_5]] xor : f32 + // CHECK: [[VAL_127:%.*]] = cmpf "ugt", [[VAL_124]], [[VAL_125]] : f32 + // CHECK: [[VAL_128:%.*]] = select [[VAL_127]], [[VAL_124]], [[VAL_125]] : f32 + // CHECK: [[VAL_129:%.*]], [[VAL_130:%.*]] = gpu.shuffle [[VAL_128]], [[VAL_10]], [[VAL_5]] xor : f32 + // CHECK: [[VAL_131:%.*]] = cmpf "ugt", [[VAL_128]], [[VAL_129]] : f32 + // CHECK: [[VAL_132:%.*]] = select [[VAL_131]], [[VAL_128]], [[VAL_129]] : f32 + // CHECK: br ^bb40([[VAL_132]] : f32) + // CHECK: ^bb40([[VAL_133:%.*]]: f32): + // CHECK: store [[VAL_133]], [[VAL_1]]{{\[}}[[VAL_4]]] : memref<32xf32, 3> + // CHECK: br ^bb42 + // CHECK: ^bb41: + // CHECK: br ^bb42 + // CHECK: ^bb42: + // CHECK: gpu.barrier + // CHECK: [[VAL_134:%.*]] = load [[VAL_1]]{{\[}}[[VAL_4]]] : memref<32xf32, 3> + %sum = "gpu.all_reduce"(%arg0) ({}) {op = "max"} : (f32) -> (f32) + gpu.return + } + +} Index: mlir/test/Dialect/GPU/invalid.mlir =================================================================== --- mlir/test/Dialect/GPU/invalid.mlir +++ mlir/test/Dialect/GPU/invalid.mlir @@ -255,6 +255,14 @@ // ----- +func @reduce_invalid_op_type(%arg0 : f32) { + // expected-error@+1 {{`and` accumulator is only compatible with Integer type}} + %res = "gpu.all_reduce"(%arg0) ({}) {op = "and"} : (f32) -> (f32) + return +} + +// ----- + func @reduce_incorrect_region_arguments(%arg0 : f32) { // expected-error@+1 {{expected two region arguments}} %res = "gpu.all_reduce"(%arg0) ({ Index: mlir/test/mlir-cuda-runner/all-reduce-and.mlir =================================================================== --- /dev/null +++ mlir/test/mlir-cuda-runner/all-reduce-and.mlir @@ -0,0 +1,60 @@ +// RUN: mlir-cuda-runner %s --shared-libs=%cuda_wrapper_library_dir/libcuda-runtime-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext --entry-point-result=void | FileCheck %s + +func @main() { + %data = alloc() : memref<2x6xi32> + %sum_and = alloc() : memref<2xi32> + %sum_or = alloc() : memref<2xi32> + %sum_min = alloc() : memref<2xi32> + %cst0 = constant 0 : i32 + %cst1 = constant 1 : i32 + %cst2 = constant 2 : i32 + %cst4 = constant 4 : i32 + %cst8 = constant 8 : i32 + %cst16 = constant 16 : i32 + + %cst3 = constant 3 : i32 + %cst6 = constant 6 : i32 + %cst7 = constant 7 : i32 + %cst10 = constant 10 : i32 + %cst11 = constant 11 : i32 + + %c0 = constant 0 : index + %c1 = constant 1 : index + %c2 = constant 2 : index + %c3 = constant 3 : index + %c4 = constant 4 : index + %c5 = constant 5 : index + %c6 = constant 6 : index + + store %cst0, %data[%c0, %c0] : memref<2x6xi32> + store %cst1, %data[%c0, %c1] : memref<2x6xi32> + store %cst2, %data[%c0, %c2] : memref<2x6xi32> + store %cst4, %data[%c0, %c3] : memref<2x6xi32> + store %cst8, %data[%c0, %c4] : memref<2x6xi32> + store %cst16, %data[%c0, %c5] : memref<2x6xi32> + + store %cst2, %data[%c1, %c0] : memref<2x6xi32> + store %cst3, %data[%c1, %c1] : memref<2x6xi32> + store %cst6, %data[%c1, %c2] : memref<2x6xi32> + store %cst7, %data[%c1, %c3] : memref<2x6xi32> + store %cst10, %data[%c1, %c4] : memref<2x6xi32> + store %cst11, %data[%c1, %c5] : memref<2x6xi32> + + // AND + gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c2, %grid_y = %c1, %grid_z = %c1) + threads(%tx, %ty, %tz) in (%block_x = %c6, %block_y = %c1, %block_z = %c1) { + %val = load %data[%bx, %tx] : memref<2x6xi32> + %reduced_and = "gpu.all_reduce"(%val) ({}) { op = "and" } : (i32) -> (i32) + store %reduced_and, %sum_and[%bx] : memref<2xi32> + gpu.terminator + } + + %ptr_and = memref_cast %sum_and : memref<2xi32> to memref<*xi32> + call @print_memref_i32(%ptr_and) : (memref<*xi32>) -> () + // CHECK: [0, 2] + + return +} + +func @print_memref_i32(memref<*xi32>) + Index: mlir/test/mlir-cuda-runner/all-reduce-max.mlir =================================================================== --- /dev/null +++ mlir/test/mlir-cuda-runner/all-reduce-max.mlir @@ -0,0 +1,58 @@ +// RUN: mlir-cuda-runner %s --shared-libs=%cuda_wrapper_library_dir/libcuda-runtime-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext --entry-point-result=void | FileCheck %s + +func @main() { + %data = alloc() : memref<2x6xi32> + %sum = alloc() : memref<2xi32> + %cst0 = constant 0 : i32 + %cst1 = constant 1 : i32 + %cst2 = constant 2 : i32 + %cst4 = constant 4 : i32 + %cst8 = constant 8 : i32 + %cst16 = constant 16 : i32 + + %cst3 = constant 3 : i32 + %cst6 = constant 6 : i32 + %cst7 = constant 7 : i32 + %cst10 = constant 10 : i32 + %cst11 = constant 11 : i32 + + %c0 = constant 0 : index + %c1 = constant 1 : index + %c2 = constant 2 : index + %c3 = constant 3 : index + %c4 = constant 4 : index + %c5 = constant 5 : index + %c6 = constant 6 : index + + store %cst0, %data[%c0, %c0] : memref<2x6xi32> + store %cst1, %data[%c0, %c1] : memref<2x6xi32> + store %cst2, %data[%c0, %c2] : memref<2x6xi32> + store %cst4, %data[%c0, %c3] : memref<2x6xi32> + store %cst8, %data[%c0, %c4] : memref<2x6xi32> + store %cst16, %data[%c0, %c5] : memref<2x6xi32> + + store %cst2, %data[%c1, %c0] : memref<2x6xi32> + store %cst3, %data[%c1, %c1] : memref<2x6xi32> + store %cst6, %data[%c1, %c2] : memref<2x6xi32> + store %cst7, %data[%c1, %c3] : memref<2x6xi32> + store %cst10, %data[%c1, %c4] : memref<2x6xi32> + store %cst11, %data[%c1, %c5] : memref<2x6xi32> + + // MAX + gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c2, %grid_y = %c1, %grid_z = %c1) + threads(%tx, %ty, %tz) in (%block_x = %c6, %block_y = %c1, %block_z = %c1) { + %val = load %data[%bx, %tx] : memref<2x6xi32> + %reduced = "gpu.all_reduce"(%val) ({}) { op = "max" } : (i32) -> (i32) + store %reduced, %sum[%bx] : memref<2xi32> + gpu.terminator + } + + %ptr = memref_cast %sum : memref<2xi32> to memref<*xi32> + call @print_memref_i32(%ptr) : (memref<*xi32>) -> () + // CHECK: [16, 11] + + return +} + +func @print_memref_i32(memref<*xi32>) + Index: mlir/test/mlir-cuda-runner/all-reduce-min.mlir =================================================================== --- /dev/null +++ mlir/test/mlir-cuda-runner/all-reduce-min.mlir @@ -0,0 +1,58 @@ +// RUN: mlir-cuda-runner %s --shared-libs=%cuda_wrapper_library_dir/libcuda-runtime-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext --entry-point-result=void | FileCheck %s + +func @main() { + %data = alloc() : memref<2x6xi32> + %sum = alloc() : memref<2xi32> + %cst0 = constant 0 : i32 + %cst1 = constant 1 : i32 + %cst2 = constant 2 : i32 + %cst4 = constant 4 : i32 + %cst8 = constant 8 : i32 + %cst16 = constant 16 : i32 + + %cst3 = constant 3 : i32 + %cst6 = constant 6 : i32 + %cst7 = constant 7 : i32 + %cst10 = constant 10 : i32 + %cst11 = constant 11 : i32 + + %c0 = constant 0 : index + %c1 = constant 1 : index + %c2 = constant 2 : index + %c3 = constant 3 : index + %c4 = constant 4 : index + %c5 = constant 5 : index + %c6 = constant 6 : index + + store %cst0, %data[%c0, %c0] : memref<2x6xi32> + store %cst1, %data[%c0, %c1] : memref<2x6xi32> + store %cst2, %data[%c0, %c2] : memref<2x6xi32> + store %cst4, %data[%c0, %c3] : memref<2x6xi32> + store %cst8, %data[%c0, %c4] : memref<2x6xi32> + store %cst16, %data[%c0, %c5] : memref<2x6xi32> + + store %cst2, %data[%c1, %c0] : memref<2x6xi32> + store %cst3, %data[%c1, %c1] : memref<2x6xi32> + store %cst6, %data[%c1, %c2] : memref<2x6xi32> + store %cst7, %data[%c1, %c3] : memref<2x6xi32> + store %cst10, %data[%c1, %c4] : memref<2x6xi32> + store %cst11, %data[%c1, %c5] : memref<2x6xi32> + + // MIN + gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c2, %grid_y = %c1, %grid_z = %c1) + threads(%tx, %ty, %tz) in (%block_x = %c6, %block_y = %c1, %block_z = %c1) { + %val = load %data[%bx, %tx] : memref<2x6xi32> + %reduced = "gpu.all_reduce"(%val) ({}) { op = "min" } : (i32) -> (i32) + store %reduced, %sum[%bx] : memref<2xi32> + gpu.terminator + } + + %ptr = memref_cast %sum : memref<2xi32> to memref<*xi32> + call @print_memref_i32(%ptr) : (memref<*xi32>) -> () + // CHECK: [0, 2] + + return +} + +func @print_memref_i32(memref<*xi32>) + Index: mlir/test/mlir-cuda-runner/all-reduce-or.mlir =================================================================== --- /dev/null +++ mlir/test/mlir-cuda-runner/all-reduce-or.mlir @@ -0,0 +1,58 @@ +// RUN: mlir-cuda-runner %s --shared-libs=%cuda_wrapper_library_dir/libcuda-runtime-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext --entry-point-result=void | FileCheck %s + +func @main() { + %data = alloc() : memref<2x6xi32> + %sum = alloc() : memref<2xi32> + %cst0 = constant 0 : i32 + %cst1 = constant 1 : i32 + %cst2 = constant 2 : i32 + %cst4 = constant 4 : i32 + %cst8 = constant 8 : i32 + %cst16 = constant 16 : i32 + + %cst3 = constant 3 : i32 + %cst6 = constant 6 : i32 + %cst7 = constant 7 : i32 + %cst10 = constant 10 : i32 + %cst11 = constant 11 : i32 + + %c0 = constant 0 : index + %c1 = constant 1 : index + %c2 = constant 2 : index + %c3 = constant 3 : index + %c4 = constant 4 : index + %c5 = constant 5 : index + %c6 = constant 6 : index + + store %cst0, %data[%c0, %c0] : memref<2x6xi32> + store %cst1, %data[%c0, %c1] : memref<2x6xi32> + store %cst2, %data[%c0, %c2] : memref<2x6xi32> + store %cst4, %data[%c0, %c3] : memref<2x6xi32> + store %cst8, %data[%c0, %c4] : memref<2x6xi32> + store %cst16, %data[%c0, %c5] : memref<2x6xi32> + + store %cst2, %data[%c1, %c0] : memref<2x6xi32> + store %cst3, %data[%c1, %c1] : memref<2x6xi32> + store %cst6, %data[%c1, %c2] : memref<2x6xi32> + store %cst7, %data[%c1, %c3] : memref<2x6xi32> + store %cst10, %data[%c1, %c4] : memref<2x6xi32> + store %cst11, %data[%c1, %c5] : memref<2x6xi32> + + // OR + gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c2, %grid_y = %c1, %grid_z = %c1) + threads(%tx, %ty, %tz) in (%block_x = %c6, %block_y = %c1, %block_z = %c1) { + %val = load %data[%bx, %tx] : memref<2x6xi32> + %reduced = "gpu.all_reduce"(%val) ({}) { op = "or" } : (i32) -> (i32) + store %reduced, %sum[%bx] : memref<2xi32> + gpu.terminator + } + + %ptr = memref_cast %sum : memref<2xi32> to memref<*xi32> + call @print_memref_i32(%ptr) : (memref<*xi32>) -> () + // CHECK: [31, 15] + + return +} + +func @print_memref_i32(memref<*xi32>) + Index: mlir/test/mlir-cuda-runner/all-reduce-xor.mlir =================================================================== --- /dev/null +++ mlir/test/mlir-cuda-runner/all-reduce-xor.mlir @@ -0,0 +1,58 @@ +// RUN: mlir-cuda-runner %s --shared-libs=%cuda_wrapper_library_dir/libcuda-runtime-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext --entry-point-result=void | FileCheck %s + +func @main() { + %data = alloc() : memref<2x6xi32> + %sum = alloc() : memref<2xi32> + %cst0 = constant 0 : i32 + %cst1 = constant 1 : i32 + %cst2 = constant 2 : i32 + %cst4 = constant 4 : i32 + %cst8 = constant 8 : i32 + %cst16 = constant 16 : i32 + + %cst3 = constant 3 : i32 + %cst6 = constant 6 : i32 + %cst7 = constant 7 : i32 + %cst10 = constant 10 : i32 + %cst11 = constant 11 : i32 + + %c0 = constant 0 : index + %c1 = constant 1 : index + %c2 = constant 2 : index + %c3 = constant 3 : index + %c4 = constant 4 : index + %c5 = constant 5 : index + %c6 = constant 6 : index + + store %cst0, %data[%c0, %c0] : memref<2x6xi32> + store %cst1, %data[%c0, %c1] : memref<2x6xi32> + store %cst2, %data[%c0, %c2] : memref<2x6xi32> + store %cst4, %data[%c0, %c3] : memref<2x6xi32> + store %cst8, %data[%c0, %c4] : memref<2x6xi32> + store %cst16, %data[%c0, %c5] : memref<2x6xi32> + + store %cst2, %data[%c1, %c0] : memref<2x6xi32> + store %cst3, %data[%c1, %c1] : memref<2x6xi32> + store %cst6, %data[%c1, %c2] : memref<2x6xi32> + store %cst7, %data[%c1, %c3] : memref<2x6xi32> + store %cst10, %data[%c1, %c4] : memref<2x6xi32> + store %cst11, %data[%c1, %c5] : memref<2x6xi32> + + // XOR + gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c2, %grid_y = %c1, %grid_z = %c1) + threads(%tx, %ty, %tz) in (%block_x = %c6, %block_y = %c1, %block_z = %c1) { + %val = load %data[%bx, %tx] : memref<2x6xi32> + %reduced = "gpu.all_reduce"(%val) ({}) { op = "xor" } : (i32) -> (i32) + store %reduced, %sum[%bx] : memref<2xi32> + gpu.terminator + } + + %ptr = memref_cast %sum : memref<2xi32> to memref<*xi32> + call @print_memref_i32(%ptr) : (memref<*xi32>) -> () + // CHECK: [31, 1] + + return +} + +func @print_memref_i32(memref<*xi32>) +