diff --git a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h --- a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h +++ b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h @@ -37,6 +37,11 @@ void populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns); +/// Populate GpuSubgroupReduce pattern to NVVM. It generates a specific nvvm +/// op that is not available on every GPU. +void populateGpuSubgroupReduceOpLoweringPattern(LLVMTypeConverter &converter, + RewritePatternSet &patterns); + /// Collect a set of patterns to convert WMMA ops from GPU dialect to NVVM. void populateGpuWMMAToNVVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns); @@ -45,7 +50,8 @@ /// index bitwidth used for the lowering of the device side index computations /// is configurable. std::unique_ptr> createLowerGpuOpsToNVVMOpsPass( - unsigned indexBitwidth = kDeriveIndexBitwidthFromDataLayout); + unsigned indexBitwidth = kDeriveIndexBitwidthFromDataLayout, + bool hasRedux = false); } // namespace mlir diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -359,7 +359,9 @@ let options = [ Option<"indexBitwidth", "index-bitwidth", "unsigned", /*default=kDeriveIndexBitwidthFromDataLayout*/"0", - "Bitwidth of the index type, 0 to use size of machine word"> + "Bitwidth of the index type, 0 to use size of machine word">, + Option<"hasRedux", "has-redux", "bool", /*default=*/"false", + "Target gpu supports redux">, ]; } diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -135,6 +135,45 @@ let assemblyFormat = "$arg attr-dict `:` type($res)"; } +//===----------------------------------------------------------------------===// +// NVVM redux op definitions +//===----------------------------------------------------------------------===// + +def ReduxKindNone : I32EnumAttrCase<"NONE", 0, "none">; +def ReduxKindAdd : I32EnumAttrCase<"ADD", 1, "add">; +def ReduxKindAnd : I32EnumAttrCase<"AND", 2, "and">; +def ReduxKindMax : I32EnumAttrCase<"MAX", 3, "max">; +def ReduxKindMin : I32EnumAttrCase<"MIN", 4, "min">; +def ReduxKindOr : I32EnumAttrCase<"OR", 5, "or">; +def ReduxKindUmax : I32EnumAttrCase<"UMAX", 6, "umax">; +def ReduxKindUmin : I32EnumAttrCase<"UMIN", 7, "umin">; +def ReduxKindXor : I32EnumAttrCase<"XOR", 8, "xor">; + +/// Enum attribute of the different kinds. +def ReduxKind : I32EnumAttr<"ReduxKind", "NVVM redux kind", + [ReduxKindAdd, ReduxKindAnd, ReduxKindMax, ReduxKindMin, ReduxKindOr, + ReduxKindUmax, ReduxKindUmin, ReduxKindXor]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::NVVM"; +} + +def ReduxKindAttr : EnumAttr; + +def NVVM_ReduxOp : + NVVM_Op<"redux.sync">, + Results<(outs LLVM_Type:$res)>, + Arguments<(ins LLVM_Type:$val, + ReduxKindAttr:$kind, + I32:$mask_and_clamp)> { + string llvmBuilder = [{ + auto intId = getReduxIntrinsicId($_resultType, $kind); + $res = createIntrinsicCall(builder, intId, {$val, $mask_and_clamp}); + }]; + let assemblyFormat = [{ + $kind $val `,` $mask_and_clamp attr-dict `:` type($val) `->` type($res) + }]; +} + //===----------------------------------------------------------------------===// // NVVM synchronization op definitions //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -58,6 +58,55 @@ llvm_unreachable("unknown shuffle mode"); } +static NVVM::ReduxKind convertReduxKind(gpu::AllReduceOperation mode) { + switch (mode) { + case gpu::AllReduceOperation::ADD: + return NVVM::ReduxKind::ADD; + case gpu::AllReduceOperation::AND: + return NVVM::ReduxKind::AND; + case gpu::AllReduceOperation::MAX: + return NVVM::ReduxKind::MAX; + case gpu::AllReduceOperation::MIN: + return NVVM::ReduxKind::MIN; + case gpu::AllReduceOperation::OR: + return NVVM::ReduxKind::OR; + case gpu::AllReduceOperation::XOR: + return NVVM::ReduxKind::XOR; + case gpu::AllReduceOperation::MUL: + llvm_unreachable("MUL is not supported by redux"); + break; + } + llvm_unreachable("unknown redux mode"); +} + +/// This pass lowers gpu.subgroup_reduce op into to the nvvm.redux op. The op +/// must be run by the entire subgroup, otherwise it is undefined behaviour. +struct GPUSubgroupReduceOpLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + + matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!op.getUniform()) + op->emitError("cannot be lowered to redux as the op must be run " + "uniformly (entire subgroup)."); + if (!op.getValue().getType().isInteger(32)) + op->emitError("unsupported data type"); + + Location loc = op->getLoc(); + auto int32Type = IntegerType::get(rewriter.getContext(), 32); + Value offset = rewriter.create(loc, int32Type, -1); + NVVM::ReduxKind mode = convertReduxKind(op.getOp()); + + auto reduxOp = rewriter.create(loc, int32Type, op.getValue(), + mode, offset); + + rewriter.replaceOp(op, reduxOp->getResult(0)); + return success(); + } +}; + struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -155,8 +204,9 @@ struct LowerGpuOpsToNVVMOpsPass : public impl::ConvertGpuOpsToNVVMOpsBase { LowerGpuOpsToNVVMOpsPass() = default; - LowerGpuOpsToNVVMOpsPass(unsigned indexBitwidth) { + LowerGpuOpsToNVVMOpsPass(unsigned indexBitwidth, bool hasRedux = false) { this->indexBitwidth = indexBitwidth; + this->hasRedux = hasRedux; } void runOnOperation() override { @@ -229,6 +279,8 @@ populateMemRefToLLVMConversionPatterns(converter, llvmPatterns); populateGpuToNVVMConversionPatterns(converter, llvmPatterns); populateGpuWMMAToNVVMConversionPatterns(converter, llvmPatterns); + if (this->hasRedux) + populateGpuSubgroupReduceOpLoweringPattern(converter, llvmPatterns); LLVMConversionTarget target(getContext()); configureGpuToNVVMConversionLegality(target); if (failed(applyPartialConversion(m, target, std::move(llvmPatterns)))) @@ -259,6 +311,11 @@ patterns.add>(converter, f32Func, f64Func); } +void mlir::populateGpuSubgroupReduceOpLoweringPattern( + LLVMTypeConverter &converter, RewritePatternSet &patterns) { + patterns.add(converter); +} + void mlir::populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns) { populateWithGenerated(patterns); @@ -323,6 +380,6 @@ } std::unique_ptr> -mlir::createLowerGpuOpsToNVVMOpsPass(unsigned indexBitwidth) { - return std::make_unique(indexBitwidth); +mlir::createLowerGpuOpsToNVVMOpsPass(unsigned indexBitwidth, bool hasRedux) { + return std::make_unique(indexBitwidth, hasRedux); } diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp @@ -25,6 +25,32 @@ using namespace mlir::LLVM; using mlir::LLVM::detail::createIntrinsicCall; +static llvm::Intrinsic::ID getReduxIntrinsicId(llvm::Type *resultType, + NVVM::ReduxKind kind) { + if (!resultType->isIntegerTy(32)) + llvm_unreachable("unsupported data type for redux"); + + switch (kind) { + case NVVM::ReduxKind::ADD: + return llvm::Intrinsic::nvvm_redux_sync_add; + case NVVM::ReduxKind::UMAX: + return llvm::Intrinsic::nvvm_redux_sync_umax; + case NVVM::ReduxKind::UMIN: + return llvm::Intrinsic::nvvm_redux_sync_umin; + case NVVM::ReduxKind::AND: + return llvm::Intrinsic::nvvm_redux_sync_and; + case NVVM::ReduxKind::OR: + return llvm::Intrinsic::nvvm_redux_sync_or; + case NVVM::ReduxKind::XOR: + return llvm::Intrinsic::nvvm_redux_sync_xor; + case NVVM::ReduxKind::MAX: + return llvm::Intrinsic::nvvm_redux_sync_max; + case NVVM::ReduxKind::MIN: + return llvm::Intrinsic::nvvm_redux_sync_min; + } + llvm_unreachable("unknown redux kind"); +} + static llvm::Intrinsic::ID getShflIntrinsicId(llvm::Type *resultType, NVVM::ShflKind kind, bool withPredicate) { diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir --- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir +++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir @@ -1,5 +1,5 @@ -// RUN: mlir-opt %s -convert-gpu-to-nvvm -split-input-file | FileCheck %s -// RUN: mlir-opt %s -convert-gpu-to-nvvm='index-bitwidth=32' -split-input-file | FileCheck --check-prefix=CHECK32 %s +// RUN: mlir-opt %s -convert-gpu-to-nvvm='has-redux=1' -split-input-file | FileCheck %s +// RUN: mlir-opt %s -convert-gpu-to-nvvm='has-redux=1 index-bitwidth=32' -split-input-file | FileCheck --check-prefix=CHECK32 %s gpu.module @test_module { // CHECK-LABEL: func @gpu_index_ops() @@ -574,3 +574,44 @@ } } +// ----- + +gpu.module @test_module { + // CHECK-LABEL: func @subgroup_reduce_add + gpu.func @subgroup_reduce_add(%arg0 : i32) { + // CHECK: nvvm.redux.sync add {{.*}} + %result = gpu.subgroup_reduce add %arg0 uniform {} : (i32) -> (i32) + gpu.return + } + // CHECK-LABEL: func @subgroup_reduce_and + gpu.func @subgroup_reduce_and(%arg0 : i32) { + // CHECK: nvvm.redux.sync and {{.*}} + %result = gpu.subgroup_reduce and %arg0 uniform {} : (i32) -> (i32) + gpu.return + } + // CHECK-LABEL: @subgroup_reduce_max + gpu.func @subgroup_reduce_max(%arg0 : i32) { + // CHECK: nvvm.redux.sync max {{.*}} + %result = gpu.subgroup_reduce max %arg0 uniform {} : (i32) -> (i32) + gpu.return + } + // CHECK-LABEL: @subgroup_reduce_min + gpu.func @subgroup_reduce_min(%arg0 : i32) { + // CHECK: nvvm.redux.sync min {{.*}} + %result = gpu.subgroup_reduce min %arg0 uniform {} : (i32) -> (i32) + gpu.return + } + // CHECK-LABEL: @subgroup_reduce_or + gpu.func @subgroup_reduce_or(%arg0 : i32) { + // CHECK: nvvm.redux.sync or {{.*}} + %result = gpu.subgroup_reduce or %arg0 uniform {} : (i32) -> (i32) + gpu.return + } + // CHECK-LABEL: @subgroup_reduce_xor + gpu.func @subgroup_reduce_xor(%arg0 : i32) { + // CHECK nvvm.redux.sync xor {{.*}} + %result = gpu.subgroup_reduce xor %arg0 uniform {} : (i32) -> (i32) + gpu.return + } +} + diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir --- a/mlir/test/Dialect/LLVMIR/nvvm.mlir +++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir @@ -310,6 +310,29 @@ %l4 = nvvm.ldmatrix %arg0 {num = 4 : i32, layout = #nvvm.mma_layout} : (!llvm.ptr) -> !llvm.struct<(i32, i32, i32, i32)> llvm.return } + +// CHECK-LABEL: llvm.func @redux_sync +llvm.func @redux_sync(%value : i32, %offset : i32) -> i32 { + // CHECK: nvvm.redux.sync add %{{.*}} + %r1 = nvvm.redux.sync add %value, %offset : i32 -> i32 + // CHECK: nvvm.redux.sync max %{{.*}} + %r2 = nvvm.redux.sync max %value, %offset : i32 -> i32 + // CHECK: nvvm.redux.sync min %{{.*}} + %r3 = nvvm.redux.sync min %value, %offset : i32 -> i32 + // CHECK: nvvm.redux.sync umax %{{.*}} + %r5 = nvvm.redux.sync umax %value, %offset : i32 -> i32 + // CHECK: nvvm.redux.sync umin %{{.*}} + %r6 = nvvm.redux.sync umin %value, %offset : i32 -> i32 + // CHECK: nvvm.redux.sync and %{{.*}} + %r7 = nvvm.redux.sync and %value, %offset : i32 -> i32 + // CHECK: nvvm.redux.sync or %{{.*}} + %r8 = nvvm.redux.sync or %value, %offset : i32 -> i32 + // CHECK: nvvm.redux.sync xor %{{.*}} + %r9 = nvvm.redux.sync xor %value, %offset : i32 -> i32 + llvm.return %r1 : i32 +} + + // ----- // expected-error@below {{attribute attached to unexpected op}}