diff --git a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h --- a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h +++ b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h @@ -37,6 +37,11 @@ void populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns); +/// Populate GpuSubgroupReduce pattern to NVVM. It generates a specific nvvm +/// op that is not available on every GPU. +void populateGpuSubgroupReduceOpLoweringPattern(LLVMTypeConverter &converter, + RewritePatternSet &patterns); + /// Collect a set of patterns to convert WMMA ops from GPU dialect to NVVM. void populateGpuWMMAToNVVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns); @@ -45,7 +50,8 @@ /// index bitwidth used for the lowering of the device side index computations /// is configurable. std::unique_ptr> createLowerGpuOpsToNVVMOpsPass( - unsigned indexBitwidth = kDeriveIndexBitwidthFromDataLayout); + unsigned indexBitwidth = kDeriveIndexBitwidthFromDataLayout, + bool hasRedux = false); } // namespace mlir diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -359,7 +359,9 @@ let options = [ Option<"indexBitwidth", "index-bitwidth", "unsigned", /*default=kDeriveIndexBitwidthFromDataLayout*/"0", - "Bitwidth of the index type, 0 to use size of machine word"> + "Bitwidth of the index type, 0 to use size of machine word">, + Option<"hasRedux", "has-redux", "bool", /*default=*/"false", + "Target gpu supports redux">, ]; } diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -58,6 +58,55 @@ llvm_unreachable("unknown shuffle mode"); } +static NVVM::ReduxKind convertReduxKind(gpu::AllReduceOperation mode) { + switch (mode) { + case gpu::AllReduceOperation::ADD: + return NVVM::ReduxKind::ADD; + case gpu::AllReduceOperation::AND: + return NVVM::ReduxKind::AND; + case gpu::AllReduceOperation::MAX: + return NVVM::ReduxKind::MAX; + case gpu::AllReduceOperation::MIN: + return NVVM::ReduxKind::MIN; + case gpu::AllReduceOperation::OR: + return NVVM::ReduxKind::OR; + case gpu::AllReduceOperation::XOR: + return NVVM::ReduxKind::XOR; + case gpu::AllReduceOperation::MUL: + llvm_unreachable("MUL is not supported by redux"); + break; + } + llvm_unreachable("unknown redux mode"); +} + +/// This pass lowers gpu.subgroup_reduce op into to the nvvm.redux op. The op +/// must be run by the entire subgroup, otherwise it is undefined behaviour. +struct GPUSubgroupReduceOpLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + + matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!op.getUniform()) + op->emitError("cannot be lowered to redux as the op must be run " + "uniformly (entire subgroup)."); + if (!op.getValue().getType().isInteger(32)) + op->emitError("unsupported data type"); + + Location loc = op->getLoc(); + auto int32Type = IntegerType::get(rewriter.getContext(), 32); + Value offset = rewriter.create(loc, int32Type, -1); + NVVM::ReduxKind mode = convertReduxKind(op.getOp()); + + auto reduxOp = rewriter.create(loc, int32Type, op.getValue(), + mode, offset); + + rewriter.replaceOp(op, reduxOp->getResult(0)); + return success(); + } +}; + struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -155,8 +204,9 @@ struct LowerGpuOpsToNVVMOpsPass : public impl::ConvertGpuOpsToNVVMOpsBase { LowerGpuOpsToNVVMOpsPass() = default; - LowerGpuOpsToNVVMOpsPass(unsigned indexBitwidth) { + LowerGpuOpsToNVVMOpsPass(unsigned indexBitwidth, bool hasRedux = false) { this->indexBitwidth = indexBitwidth; + this->hasRedux = hasRedux; } void runOnOperation() override { @@ -229,6 +279,8 @@ populateMemRefToLLVMConversionPatterns(converter, llvmPatterns); populateGpuToNVVMConversionPatterns(converter, llvmPatterns); populateGpuWMMAToNVVMConversionPatterns(converter, llvmPatterns); + if (this->hasRedux) + populateGpuSubgroupReduceOpLoweringPattern(converter, llvmPatterns); LLVMConversionTarget target(getContext()); configureGpuToNVVMConversionLegality(target); if (failed(applyPartialConversion(m, target, std::move(llvmPatterns)))) @@ -259,6 +311,11 @@ patterns.add>(converter, f32Func, f64Func); } +void mlir::populateGpuSubgroupReduceOpLoweringPattern( + LLVMTypeConverter &converter, RewritePatternSet &patterns) { + patterns.add(converter); +} + void mlir::populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns) { populateWithGenerated(patterns); @@ -323,6 +380,6 @@ } std::unique_ptr> -mlir::createLowerGpuOpsToNVVMOpsPass(unsigned indexBitwidth) { - return std::make_unique(indexBitwidth); +mlir::createLowerGpuOpsToNVVMOpsPass(unsigned indexBitwidth, bool hasRedux) { + return std::make_unique(indexBitwidth, hasRedux); } diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir --- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir +++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir @@ -1,5 +1,5 @@ -// RUN: mlir-opt %s -convert-gpu-to-nvvm -split-input-file | FileCheck %s -// RUN: mlir-opt %s -convert-gpu-to-nvvm='index-bitwidth=32' -split-input-file | FileCheck --check-prefix=CHECK32 %s +// RUN: mlir-opt %s -convert-gpu-to-nvvm='has-redux=1' -split-input-file | FileCheck %s +// RUN: mlir-opt %s -convert-gpu-to-nvvm='has-redux=1 index-bitwidth=32' -split-input-file | FileCheck --check-prefix=CHECK32 %s gpu.module @test_module { // CHECK-LABEL: func @gpu_index_ops() @@ -574,3 +574,44 @@ } } +// ----- + +gpu.module @test_module { + // CHECK-LABEL: func @subgroup_reduce_add + gpu.func @subgroup_reduce_add(%arg0 : i32) { + // CHECK: nvvm.redux.sync add {{.*}} + %result = gpu.subgroup_reduce add %arg0 uniform {} : (i32) -> (i32) + gpu.return + } + // CHECK-LABEL: func @subgroup_reduce_and + gpu.func @subgroup_reduce_and(%arg0 : i32) { + // CHECK: nvvm.redux.sync and {{.*}} + %result = gpu.subgroup_reduce and %arg0 uniform {} : (i32) -> (i32) + gpu.return + } + // CHECK-LABEL: @subgroup_reduce_max + gpu.func @subgroup_reduce_max(%arg0 : i32) { + // CHECK: nvvm.redux.sync max {{.*}} + %result = gpu.subgroup_reduce max %arg0 uniform {} : (i32) -> (i32) + gpu.return + } + // CHECK-LABEL: @subgroup_reduce_min + gpu.func @subgroup_reduce_min(%arg0 : i32) { + // CHECK: nvvm.redux.sync min {{.*}} + %result = gpu.subgroup_reduce min %arg0 uniform {} : (i32) -> (i32) + gpu.return + } + // CHECK-LABEL: @subgroup_reduce_or + gpu.func @subgroup_reduce_or(%arg0 : i32) { + // CHECK: nvvm.redux.sync or {{.*}} + %result = gpu.subgroup_reduce or %arg0 uniform {} : (i32) -> (i32) + gpu.return + } + // CHECK-LABEL: @subgroup_reduce_xor + gpu.func @subgroup_reduce_xor(%arg0 : i32) { + // CHECK nvvm.redux.sync xor {{.*}} + %result = gpu.subgroup_reduce xor %arg0 uniform {} : (i32) -> (i32) + gpu.return + } +} +