diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td @@ -159,6 +159,16 @@ return {begin, end}; } + // Adds a new block argument that corresponds to buffers located in + // workgroup memory. + BlockArgument addWorkgroupAttribution(Type type) { + auto attrName = getNumWorkgroupAttributionsAttrName(); + auto attr = getAttrOfType(attrName); + setAttr(attrName, IntegerAttr::get(attr.getType(), attr.getValue() + 1)); + return getBody().front().insertArgument( + getType().getNumInputs() + attr.getInt(), type); + } + /// Returns a list of block arguments that correspond to buffers located in /// the private memory. ArrayRef getPrivateAttributions() { diff --git a/mlir/include/mlir/Dialect/GPU/Passes.h b/mlir/include/mlir/Dialect/GPU/Passes.h --- a/mlir/include/mlir/Dialect/GPU/Passes.h +++ b/mlir/include/mlir/Dialect/GPU/Passes.h @@ -17,11 +17,17 @@ namespace mlir { +class MLIRContext; class ModuleOp; template class OpPassBase; +class OwningRewritePatternList; std::unique_ptr> createGpuKernelOutliningPass(); +/// Collect a set of patterns to rewrite ops within the GPU dialect. +void populateGpuRewritePatterns(MLIRContext *context, + OwningRewritePatternList &patterns); + } // namespace mlir #endif // MLIR_DIALECT_GPU_PASSES_H_ diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h --- a/mlir/include/mlir/IR/Block.h +++ b/mlir/include/mlir/IR/Block.h @@ -87,6 +87,9 @@ /// Add one argument to the argument list for each type specified in the list. iterator_range addArguments(ArrayRef types); + // Add one value to the argument list at the specified position. + BlockArgument insertArgument(unsigned index, Type type); + /// Erase the argument at 'index' and remove it from the argument list. If /// 'updatePredTerms' is set to true, this argument is also removed from the /// terminators of each predecessor to this block. diff --git a/mlir/lib/Dialect/GPU/CMakeLists.txt b/mlir/lib/Dialect/GPU/CMakeLists.txt --- a/mlir/lib/Dialect/GPU/CMakeLists.txt +++ b/mlir/lib/Dialect/GPU/CMakeLists.txt @@ -1,6 +1,7 @@ add_llvm_library(MLIRGPU IR/GPUDialect.cpp IR/DialectRegistration.cpp + Transforms/AllReduceLowering.cpp Transforms/KernelOutlining.cpp Transforms/MemoryPromotion.cpp diff --git a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp @@ -0,0 +1,376 @@ +//===- AllReduceLowering.cpp - Implementation of all-reduce lowering ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements in-dialect lowering of the all-reduce op to a block of +// simpler instructions. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/GPU/Passes.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace { + +struct GpuAllReduceRewriter { + using AccumulatorFactory = std::function; + + GpuAllReduceRewriter(gpu::GPUFuncOp funcOp_, gpu::AllReduceOp reduceOp_, + PatternRewriter &rewriter_) + : funcOp(funcOp_), reduceOp(reduceOp_), rewriter(rewriter_), + loc(reduceOp.getLoc()), valueType(reduceOp.value().getType()), + indexType(IndexType::get(reduceOp.getContext())), + int32Type(IntegerType::get(/*width=*/32, reduceOp.getContext())) {} + + /// Creates an all_reduce across the workgroup. + /// + /// First reduce the elements within a subgroup. The first invocation of each + /// subgroup writes the intermediate result to workgroup memory. After + /// synchronizing the workgroup, the first subgroup reduces the values from + /// workgroup memory. The result is broadcasted to all invocations through + /// workgroup memory. + /// + /// %subgroup_reduce = `createSubgroupReduce(%operand)` + /// cond_br %is_first_lane, ^then1, ^continue1 + /// ^then1: + /// store %subgroup_reduce, %workgroup_buffer[%subgroup_id] + /// br ^continue1 + /// ^continue1: + /// gpu.barrier + /// %is_valid_subgroup = cmpi "slt" %invocation_idx, %num_subgroups + /// cond_br %is_valid_subgroup, ^then2, ^continue2 + /// ^then2: + /// %partial_reduce = load %workgroup_buffer[%invocation_idx] + /// %all_reduce = `createSubgroupReduce(%partial_reduce)` + /// store %all_reduce, %workgroup_buffer[%zero] + /// llvm.br ^continue2 + /// ^continue2: + /// gpu.barrier + /// %result = load %workgroup_buffer[%zero] + /// return %result + /// + void rewrite() { + rewriter.setInsertionPoint(reduceOp); + + // Compute linear invocation index and workgroup size. + Value dimX = getDimOp("x"); + Value dimY = getDimOp("y"); + Value dimZ = getDimOp("z"); + Value tidX = getDimOp("x"); + Value tidY = getDimOp("y"); + Value tidZ = getDimOp("z"); + Value tmp1 = create(int32Type, tidZ, dimY); + Value tmp2 = create(int32Type, tmp1, tidY); + Value tmp3 = create(int32Type, tmp2, dimX); + Value tmp4 = create(int32Type, dimX, dimY); + Value invocationIdx = create(int32Type, tmp3, tidX); + Value workgroupSize = create(int32Type, tmp4, dimZ); + + // Compute lane id (invocation id withing the subgroup). + Value subgroupMask = create(kSubgroupSize - 1, int32Type); + Value laneId = create(invocationIdx, subgroupMask); + Value isFirstLane = create(CmpIPredicate::eq, laneId, + create(0, int32Type)); + + Value numThreadsWithSmallerSubgroupId = + create(invocationIdx, laneId); + // The number of active invocations starting from the current subgroup. + // The consumers do not require the value to be clamped to the size of the + // subgroup. + Value activeWidth = + create(workgroupSize, numThreadsWithSmallerSubgroupId); + + // Create factory for op which accumulates to values. + AccumulatorFactory accumFactory = getFactory(); + assert(accumFactory && "failed to create accumulator factory"); + + // Reduce elements within each subgroup to produce the intermediate results. + Value subgroupReduce = createSubgroupReduce(activeWidth, laneId, + reduceOp.value(), accumFactory); + + // Add workgroup buffer to parent function for intermediate result. + Value buffer = createWorkgroupBuffer(); + + // Write the intermediate results to workgroup memory, using the first lane + // of each subgroup. + createPredicatedBlock(isFirstLane, [&] { + Value subgroupId = getDivideBySubgroupSize(invocationIdx); + Value index = create(indexType, subgroupId); + create(subgroupReduce, buffer, index); + }); + create(); + + // Compute number of active subgroups. + Value biasedBlockSize = + create(int32Type, workgroupSize, subgroupMask); + Value numSubgroups = getDivideBySubgroupSize(biasedBlockSize); + Value isValidSubgroup = + create(CmpIPredicate::slt, invocationIdx, numSubgroups); + + // Use the first numSubgroups invocations to reduce the intermediate results + // from workgroup memory. The final result is written to workgroup memory + // again. + Value zero = create(0); + createPredicatedBlock(isValidSubgroup, [&] { + Value index = create(indexType, invocationIdx); + Value value = create(valueType, buffer, index); + Value result = + createSubgroupReduce(numSubgroups, laneId, value, accumFactory); + create(result, buffer, zero); + }); + + // Synchronize workgroup and load result from workgroup memory. + create(); + Value result = create(valueType, buffer, zero); + + rewriter.replaceOp(reduceOp, result); + } + +private: + // Shortcut to create an op from rewriter using loc as the first argument. + template + T create(Args... args) { + return rewriter.create(loc, std::forward(args)...); + } + + // Creates dimension op of type T, with the result casted to int32. + template + Value getDimOp(StringRef dimension) { + Value dim = create(indexType, rewriter.getStringAttr(dimension)); + return create(int32Type, dim); + } + + /// Adds type to funcOp's workgroup attributions. + Value createWorkgroupBuffer() { + int workgroupMemoryAddressSpace = 3; + auto bufferType = + MemRefType::get({kSubgroupSize}, valueType, ArrayRef{}, + workgroupMemoryAddressSpace); + return funcOp.addWorkgroupAttribution(bufferType); + } + + /// Returns an accumulator factory using either the op attribute or the body + /// region. + AccumulatorFactory getFactory() { + auto &body = reduceOp.body(); + if (!body.empty()) + return getFactory(body); + auto opAttr = reduceOp.op(); + if (opAttr) + return getFactory(*opAttr); + return AccumulatorFactory(); + } + + /// Returns an accumulator factory that clones the body. The body's entry + /// block is expected to have 2 arguments. The gpu.yield return the + /// accumulated value of the same type. + AccumulatorFactory getFactory(Region &body) { + return AccumulatorFactory([&](Value lhs, Value rhs) { + Block *block = rewriter.getInsertionBlock(); + Block *split = rewriter.splitBlock(block, rewriter.getInsertionPoint()); + + // Insert accumulator body between split block. + BlockAndValueMapping mapping; + mapping.map(body.front().getArgument(0), lhs); + mapping.map(body.front().getArgument(1), rhs); + rewriter.cloneRegionBefore(body, *split->getParent(), + split->getIterator(), mapping); + + // Add branch before inserted body, into body. + block = block->getNextNode(); + create(block, ValueRange()); + + // Replace all gpu.yield ops with branch out of body. + for (; block != split; block = block->getNextNode()) { + Operation *terminator = block->getTerminator(); + if (!isa(terminator)) + continue; + rewriter.setInsertionPointToEnd(block); + rewriter.replaceOpWithNewOp( + terminator, split, ValueRange(terminator->getOperand(0))); + } + + // Return accumulator result. + rewriter.setInsertionPointToStart(split); + return split->addArgument(lhs.getType()); + }); + } + + /// Returns an accumulator factory that creates an op specified by opName. + AccumulatorFactory getFactory(StringRef opName) { + bool isFloatingPoint = valueType.isa(); + if (opName == "add") + return isFloatingPoint ? getFactory() : getFactory(); + if (opName == "mul") + return isFloatingPoint ? getFactory() : getFactory(); + return AccumulatorFactory(); + } + + /// Returns an accumulator factory that creates an op of type T. + template + AccumulatorFactory getFactory() { + return [&](Value lhs, Value rhs) { + return create(lhs.getType(), lhs, rhs); + }; + } + + /// Creates an if-block skeleton and calls the two factories to generate the + /// ops in the `then` and `else` block.. + /// + /// llvm.cond_br %condition, ^then, ^continue + /// ^then: + /// %then_operands = `thenOpsFactory()` + /// llvm.br ^continue(%then_operands) + /// ^else: + /// %else_operands = `elseOpsFactory()` + /// llvm.br ^continue(%else_operands) + /// ^continue(%block_operands): + /// + template + void createIf(Value condition, ThenOpsFactory &&thenOpsFactory, + ElseOpsFactory &&elseOpsFactory) { + Block *currentBlock = rewriter.getInsertionBlock(); + auto currentPoint = rewriter.getInsertionPoint(); + + Block *thenBlock = rewriter.splitBlock(currentBlock, currentPoint); + Block *elseBlock = rewriter.splitBlock(thenBlock, thenBlock->begin()); + Block *continueBlock = rewriter.splitBlock(elseBlock, elseBlock->begin()); + + rewriter.setInsertionPointToEnd(currentBlock); + create(condition, thenBlock, + /*trueOperands=*/ArrayRef(), elseBlock, + /*falseOperands=*/ArrayRef()); + + rewriter.setInsertionPointToStart(thenBlock); + ValueRange thenOperands = thenOpsFactory(); + create(continueBlock, thenOperands); + + rewriter.setInsertionPointToStart(elseBlock); + ValueRange elseOperands = elseOpsFactory(); + create(continueBlock, elseOperands); + + assert(thenOperands.size() == elseOperands.size()); + rewriter.setInsertionPointToStart(continueBlock); + for (auto operand : thenOperands) + continueBlock->addArgument(operand.getType()); + } + + /// Shortcut for createIf with empty else block and no block operands. + template + void createPredicatedBlock(Value condition, Factory &&predicatedOpsFactory) { + static_assert(std::is_same::value, + "predicatedOpsFactory should not return any value"); + createIf( + condition, + [&] { + predicatedOpsFactory(); + return ArrayRef(); + }, + [&] { return ArrayRef(); }); + } + + /// Creates a reduction across the first activeWidth lanes of a subgroup, or + /// the entire subgroup if activeWidth is larger than the subgroup width. + /// The first lane returns the result, all others return values are undefined. + Value createSubgroupReduce(Value activeWidth, Value laneId, Value operand, + AccumulatorFactory &accumFactory) { + Value subgroupSize = create(kSubgroupSize, int32Type); + Value isPartialSubgroup = + create(CmpIPredicate::slt, activeWidth, subgroupSize); + SmallVector shuffleType = {valueType, rewriter.getI1Type()}; + auto xorAttr = rewriter.getStringAttr("xor"); + + createIf( + isPartialSubgroup, + // Generate reduction over a (potentially) partial subgroup. + [&] { + Value value = operand; + // Repeatedly shuffle value from 'laneId ^ i' and accumulate if source + // lane is within the active range. The accumulated value is available + // in the first lane. + for (int i = 1; i < kSubgroupSize; i <<= 1) { + Value offset = create(i, int32Type); + auto shuffleOp = create(shuffleType, value, offset, + activeWidth, xorAttr); + // Skip the accumulation if the shuffle op read from a lane outside + // of the active range. + createIf( + shuffleOp.getResult(1), + [&] { + return SmallVector{ + accumFactory(value, shuffleOp.getResult(0))}; + }, + [&] { return llvm::makeArrayRef(value); }); + value = rewriter.getInsertionBlock()->getArgument(0); + } + return SmallVector{value}; + }, + // Generate a reduction over the entire subgroup. This is a + // specialization of the above reduction with unconditional + // accumulation. + [&] { + Value value = operand; + for (int i = 1; i < kSubgroupSize; i <<= 1) { + Value offset = create(i, int32Type); + auto shuffleOp = create(shuffleType, value, offset, + subgroupSize, xorAttr); + value = accumFactory(value, shuffleOp.getResult(0)); + } + return SmallVector{value}; + }); + return rewriter.getInsertionBlock()->getArgument(0); + } + + /// Returns value divided by the subgroup size (i.e. 32). + Value getDivideBySubgroupSize(Value value) { + Value subgroupSize = create(kSubgroupSize, int32Type); + return create(int32Type, value, subgroupSize); + } + + gpu::GPUFuncOp funcOp; + gpu::AllReduceOp reduceOp; + PatternRewriter &rewriter; + + Location loc; + Type valueType; + Type indexType; + Type int32Type; + + static constexpr int kSubgroupSize = 32; +}; + +struct GpuAllReduceConversion : public RewritePattern { + explicit GpuAllReduceConversion(MLIRContext *context) + : RewritePattern(gpu::GPUFuncOp::getOperationName(), 1, context) {} + + PatternMatchResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + auto funcOp = cast(op); + auto callback = [&](gpu::AllReduceOp reduceOp) { + GpuAllReduceRewriter(funcOp, reduceOp, rewriter).rewrite(); + // Performing a rewrite invalidates the walk iterator. Report interrupt + // so that we can start a new walk until all all_reduce ops are replaced. + return WalkResult::interrupt(); + }; + while (funcOp.walk(callback).wasInterrupted()) { + } + return matchSuccess(); + } +}; +} // namespace + +void mlir::populateGpuRewritePatterns(MLIRContext *context, + OwningRewritePatternList &patterns) { + patterns.insert(context); +} diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp --- a/mlir/lib/IR/Block.cpp +++ b/mlir/lib/IR/Block.cpp @@ -160,6 +160,13 @@ return {arguments.data() + initialSize, arguments.data() + arguments.size()}; } +BlockArgument Block::insertArgument(unsigned index, Type type) { + auto arg = BlockArgument::create(type, this); + assert(index <= arguments.size()); + arguments.insert(arguments.begin() + index, arg); + return arg; +} + void Block::eraseArgument(unsigned index, bool updatePredTerms) { assert(index < arguments.size()); diff --git a/mlir/test/Dialect/GPU/all-reduce.mlir b/mlir/test/Dialect/GPU/all-reduce.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/GPU/all-reduce.mlir @@ -0,0 +1,183 @@ +// RUN: mlir-opt -test-all-reduce-lowering %s | FileCheck %s + +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py +// CHECK: module @kernels attributes {gpu.kernel_module} { +module @kernels attributes {gpu.kernel_module} { + + // CHECK-LABEL: gpu.func @kernel( + // CHECK-SAME: [[VAL_0:%.*]]: f32) workgroup([[VAL_1:%.*]] : memref<32xf32, 3>) kernel { + gpu.func @kernel(%arg0 : f32) attributes { gpu.kernel } { + // CHECK: [[VAL_2:%.*]] = constant 31 : i32 + // CHECK: [[VAL_3:%.*]] = constant 0 : i32 + // CHECK: [[VAL_4:%.*]] = constant 0 : index + // CHECK: [[VAL_5:%.*]] = constant 32 : i32 + // CHECK: [[VAL_6:%.*]] = constant 1 : i32 + // CHECK: [[VAL_7:%.*]] = constant 2 : i32 + // CHECK: [[VAL_8:%.*]] = constant 4 : i32 + // CHECK: [[VAL_9:%.*]] = constant 8 : i32 + // CHECK: [[VAL_10:%.*]] = constant 16 : i32 + // CHECK: [[VAL_11:%.*]] = "gpu.block_dim"() {dimension = "x"} : () -> index + // CHECK: [[VAL_12:%.*]] = index_cast [[VAL_11]] : index to i32 + // CHECK: [[VAL_13:%.*]] = "gpu.block_dim"() {dimension = "y"} : () -> index + // CHECK: [[VAL_14:%.*]] = index_cast [[VAL_13]] : index to i32 + // CHECK: [[VAL_15:%.*]] = "gpu.block_dim"() {dimension = "z"} : () -> index + // CHECK: [[VAL_16:%.*]] = index_cast [[VAL_15]] : index to i32 + // CHECK: [[VAL_17:%.*]] = "gpu.thread_id"() {dimension = "x"} : () -> index + // CHECK: [[VAL_18:%.*]] = index_cast [[VAL_17]] : index to i32 + // CHECK: [[VAL_19:%.*]] = "gpu.thread_id"() {dimension = "y"} : () -> index + // CHECK: [[VAL_20:%.*]] = index_cast [[VAL_19]] : index to i32 + // CHECK: [[VAL_21:%.*]] = "gpu.thread_id"() {dimension = "z"} : () -> index + // CHECK: [[VAL_22:%.*]] = index_cast [[VAL_21]] : index to i32 + // CHECK: [[VAL_23:%.*]] = muli [[VAL_22]], [[VAL_14]] : i32 + // CHECK: [[VAL_24:%.*]] = addi [[VAL_23]], [[VAL_20]] : i32 + // CHECK: [[VAL_25:%.*]] = muli [[VAL_24]], [[VAL_12]] : i32 + // CHECK: [[VAL_26:%.*]] = muli [[VAL_12]], [[VAL_14]] : i32 + // CHECK: [[VAL_27:%.*]] = addi [[VAL_25]], [[VAL_18]] : i32 + // CHECK: [[VAL_28:%.*]] = muli [[VAL_26]], [[VAL_16]] : i32 + // CHECK: [[VAL_29:%.*]] = and [[VAL_27]], [[VAL_2]] : i32 + // CHECK: [[VAL_30:%.*]] = cmpi "eq", [[VAL_29]], [[VAL_3]] : i32 + // CHECK: [[VAL_31:%.*]] = subi [[VAL_27]], [[VAL_29]] : i32 + // CHECK: [[VAL_32:%.*]] = subi [[VAL_28]], [[VAL_31]] : i32 + // CHECK: [[VAL_33:%.*]] = cmpi "slt", [[VAL_32]], [[VAL_5]] : i32 + // CHECK: cond_br [[VAL_33]], ^bb1, ^bb17 + // CHECK: ^bb1: + // CHECK: [[VAL_34:%.*]], [[VAL_35:%.*]] = gpu.shuffle [[VAL_0]], [[VAL_6]], [[VAL_32]] xor : f32 + // CHECK: cond_br [[VAL_35]], ^bb2, ^bb3 + // CHECK: ^bb2: + // CHECK: [[VAL_36:%.*]] = addf [[VAL_0]], [[VAL_34]] : f32 + // CHECK: br ^bb4([[VAL_36]] : f32) + // CHECK: ^bb3: + // CHECK: br ^bb4([[VAL_0]] : f32) + // CHECK: ^bb4([[VAL_37:%.*]]: f32): + // CHECK: [[VAL_38:%.*]], [[VAL_39:%.*]] = gpu.shuffle [[VAL_37]], [[VAL_7]], [[VAL_32]] xor : f32 + // CHECK: cond_br [[VAL_39]], ^bb5, ^bb6 + // CHECK: ^bb5: + // CHECK: [[VAL_40:%.*]] = addf [[VAL_37]], [[VAL_38]] : f32 + // CHECK: br ^bb7([[VAL_40]] : f32) + // CHECK: ^bb6: + // CHECK: br ^bb7([[VAL_37]] : f32) + // CHECK: ^bb7([[VAL_41:%.*]]: f32): + // CHECK: [[VAL_42:%.*]], [[VAL_43:%.*]] = gpu.shuffle [[VAL_41]], [[VAL_8]], [[VAL_32]] xor : f32 + // CHECK: cond_br [[VAL_43]], ^bb8, ^bb9 + // CHECK: ^bb8: + // CHECK: [[VAL_44:%.*]] = addf [[VAL_41]], [[VAL_42]] : f32 + // CHECK: br ^bb10([[VAL_44]] : f32) + // CHECK: ^bb9: + // CHECK: br ^bb10([[VAL_41]] : f32) + // CHECK: ^bb10([[VAL_45:%.*]]: f32): + // CHECK: [[VAL_46:%.*]], [[VAL_47:%.*]] = gpu.shuffle [[VAL_45]], [[VAL_9]], [[VAL_32]] xor : f32 + // CHECK: cond_br [[VAL_47]], ^bb11, ^bb12 + // CHECK: ^bb11: + // CHECK: [[VAL_48:%.*]] = addf [[VAL_45]], [[VAL_46]] : f32 + // CHECK: br ^bb13([[VAL_48]] : f32) + // CHECK: ^bb12: + // CHECK: br ^bb13([[VAL_45]] : f32) + // CHECK: ^bb13([[VAL_49:%.*]]: f32): + // CHECK: [[VAL_50:%.*]], [[VAL_51:%.*]] = gpu.shuffle [[VAL_49]], [[VAL_10]], [[VAL_32]] xor : f32 + // CHECK: cond_br [[VAL_51]], ^bb14, ^bb15 + // CHECK: ^bb14: + // CHECK: [[VAL_52:%.*]] = addf [[VAL_49]], [[VAL_50]] : f32 + // CHECK: br ^bb16([[VAL_52]] : f32) + // CHECK: ^bb15: + // CHECK: br ^bb16([[VAL_49]] : f32) + // CHECK: ^bb16([[VAL_53:%.*]]: f32): + // CHECK: br ^bb18([[VAL_53]] : f32) + // CHECK: ^bb17: + // CHECK: [[VAL_54:%.*]], [[VAL_55:%.*]] = gpu.shuffle [[VAL_0]], [[VAL_6]], [[VAL_5]] xor : f32 + // CHECK: [[VAL_56:%.*]] = addf [[VAL_0]], [[VAL_54]] : f32 + // CHECK: [[VAL_57:%.*]], [[VAL_58:%.*]] = gpu.shuffle [[VAL_56]], [[VAL_7]], [[VAL_5]] xor : f32 + // CHECK: [[VAL_59:%.*]] = addf [[VAL_56]], [[VAL_57]] : f32 + // CHECK: [[VAL_60:%.*]], [[VAL_61:%.*]] = gpu.shuffle [[VAL_59]], [[VAL_8]], [[VAL_5]] xor : f32 + // CHECK: [[VAL_62:%.*]] = addf [[VAL_59]], [[VAL_60]] : f32 + // CHECK: [[VAL_63:%.*]], [[VAL_64:%.*]] = gpu.shuffle [[VAL_62]], [[VAL_9]], [[VAL_5]] xor : f32 + // CHECK: [[VAL_65:%.*]] = addf [[VAL_62]], [[VAL_63]] : f32 + // CHECK: [[VAL_66:%.*]], [[VAL_67:%.*]] = gpu.shuffle [[VAL_65]], [[VAL_10]], [[VAL_5]] xor : f32 + // CHECK: [[VAL_68:%.*]] = addf [[VAL_65]], [[VAL_66]] : f32 + // CHECK: br ^bb18([[VAL_68]] : f32) + // CHECK: ^bb18([[VAL_69:%.*]]: f32): + // CHECK: cond_br [[VAL_30]], ^bb19, ^bb20 + // CHECK: ^bb19: + // CHECK: [[VAL_70:%.*]] = divi_signed [[VAL_27]], [[VAL_5]] : i32 + // CHECK: [[VAL_71:%.*]] = index_cast [[VAL_70]] : i32 to index + // CHECK: store [[VAL_69]], [[VAL_1]]{{\[}}[[VAL_71]]] : memref<32xf32, 3> + // CHECK: br ^bb21 + // CHECK: ^bb20: + // CHECK: br ^bb21 + // CHECK: ^bb21: + // CHECK: gpu.barrier + // CHECK: [[VAL_72:%.*]] = addi [[VAL_28]], [[VAL_2]] : i32 + // CHECK: [[VAL_73:%.*]] = divi_signed [[VAL_72]], [[VAL_5]] : i32 + // CHECK: [[VAL_74:%.*]] = cmpi "slt", [[VAL_27]], [[VAL_73]] : i32 + // CHECK: cond_br [[VAL_74]], ^bb22, ^bb41 + // CHECK: ^bb22: + // CHECK: [[VAL_75:%.*]] = index_cast [[VAL_27]] : i32 to index + // CHECK: [[VAL_76:%.*]] = load [[VAL_1]]{{\[}}[[VAL_75]]] : memref<32xf32, 3> + // CHECK: [[VAL_77:%.*]] = cmpi "slt", [[VAL_73]], [[VAL_5]] : i32 + // CHECK: cond_br [[VAL_77]], ^bb23, ^bb39 + // CHECK: ^bb23: + // CHECK: [[VAL_78:%.*]], [[VAL_79:%.*]] = gpu.shuffle [[VAL_76]], [[VAL_6]], [[VAL_73]] xor : f32 + // CHECK: cond_br [[VAL_79]], ^bb24, ^bb25 + // CHECK: ^bb24: + // CHECK: [[VAL_80:%.*]] = addf [[VAL_76]], [[VAL_78]] : f32 + // CHECK: br ^bb26([[VAL_80]] : f32) + // CHECK: ^bb25: + // CHECK: br ^bb26([[VAL_76]] : f32) + // CHECK: ^bb26([[VAL_81:%.*]]: f32): + // CHECK: [[VAL_82:%.*]], [[VAL_83:%.*]] = gpu.shuffle [[VAL_81]], [[VAL_7]], [[VAL_73]] xor : f32 + // CHECK: cond_br [[VAL_83]], ^bb27, ^bb28 + // CHECK: ^bb27: + // CHECK: [[VAL_84:%.*]] = addf [[VAL_81]], [[VAL_82]] : f32 + // CHECK: br ^bb29([[VAL_84]] : f32) + // CHECK: ^bb28: + // CHECK: br ^bb29([[VAL_81]] : f32) + // CHECK: ^bb29([[VAL_85:%.*]]: f32): + // CHECK: [[VAL_86:%.*]], [[VAL_87:%.*]] = gpu.shuffle [[VAL_85]], [[VAL_8]], [[VAL_73]] xor : f32 + // CHECK: cond_br [[VAL_87]], ^bb30, ^bb31 + // CHECK: ^bb30: + // CHECK: [[VAL_88:%.*]] = addf [[VAL_85]], [[VAL_86]] : f32 + // CHECK: br ^bb32([[VAL_88]] : f32) + // CHECK: ^bb31: + // CHECK: br ^bb32([[VAL_85]] : f32) + // CHECK: ^bb32([[VAL_89:%.*]]: f32): + // CHECK: [[VAL_90:%.*]], [[VAL_91:%.*]] = gpu.shuffle [[VAL_89]], [[VAL_9]], [[VAL_73]] xor : f32 + // CHECK: cond_br [[VAL_91]], ^bb33, ^bb34 + // CHECK: ^bb33: + // CHECK: [[VAL_92:%.*]] = addf [[VAL_89]], [[VAL_90]] : f32 + // CHECK: br ^bb35([[VAL_92]] : f32) + // CHECK: ^bb34: + // CHECK: br ^bb35([[VAL_89]] : f32) + // CHECK: ^bb35([[VAL_93:%.*]]: f32): + // CHECK: [[VAL_94:%.*]], [[VAL_95:%.*]] = gpu.shuffle [[VAL_93]], [[VAL_10]], [[VAL_73]] xor : f32 + // CHECK: cond_br [[VAL_95]], ^bb36, ^bb37 + // CHECK: ^bb36: + // CHECK: [[VAL_96:%.*]] = addf [[VAL_93]], [[VAL_94]] : f32 + // CHECK: br ^bb38([[VAL_96]] : f32) + // CHECK: ^bb37: + // CHECK: br ^bb38([[VAL_93]] : f32) + // CHECK: ^bb38([[VAL_97:%.*]]: f32): + // CHECK: br ^bb40([[VAL_97]] : f32) + // CHECK: ^bb39: + // CHECK: [[VAL_98:%.*]], [[VAL_99:%.*]] = gpu.shuffle [[VAL_76]], [[VAL_6]], [[VAL_5]] xor : f32 + // CHECK: [[VAL_100:%.*]] = addf [[VAL_76]], [[VAL_98]] : f32 + // CHECK: [[VAL_101:%.*]], [[VAL_102:%.*]] = gpu.shuffle [[VAL_100]], [[VAL_7]], [[VAL_5]] xor : f32 + // CHECK: [[VAL_103:%.*]] = addf [[VAL_100]], [[VAL_101]] : f32 + // CHECK: [[VAL_104:%.*]], [[VAL_105:%.*]] = gpu.shuffle [[VAL_103]], [[VAL_8]], [[VAL_5]] xor : f32 + // CHECK: [[VAL_106:%.*]] = addf [[VAL_103]], [[VAL_104]] : f32 + // CHECK: [[VAL_107:%.*]], [[VAL_108:%.*]] = gpu.shuffle [[VAL_106]], [[VAL_9]], [[VAL_5]] xor : f32 + // CHECK: [[VAL_109:%.*]] = addf [[VAL_106]], [[VAL_107]] : f32 + // CHECK: [[VAL_110:%.*]], [[VAL_111:%.*]] = gpu.shuffle [[VAL_109]], [[VAL_10]], [[VAL_5]] xor : f32 + // CHECK: [[VAL_112:%.*]] = addf [[VAL_109]], [[VAL_110]] : f32 + // CHECK: br ^bb40([[VAL_112]] : f32) + // CHECK: ^bb40([[VAL_113:%.*]]: f32): + // CHECK: store [[VAL_113]], [[VAL_1]]{{\[}}[[VAL_4]]] : memref<32xf32, 3> + // CHECK: br ^bb42 + // CHECK: ^bb41: + // CHECK: br ^bb42 + // CHECK: ^bb42: + // CHECK: gpu.barrier + // CHECK: [[VAL_114:%.*]] = load [[VAL_1]]{{\[}}[[VAL_4]]] : memref<32xf32, 3> + %sum = "gpu.all_reduce"(%arg0) ({}) {op = "add"} : (f32) -> (f32) + gpu.return + } + +} diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt --- a/mlir/test/lib/Transforms/CMakeLists.txt +++ b/mlir/test/lib/Transforms/CMakeLists.txt @@ -1,4 +1,5 @@ add_llvm_library(MLIRTestTransforms + TestAllReduceLowering.cpp TestCallGraph.cpp TestConstantFold.cpp TestLoopFusion.cpp @@ -30,6 +31,7 @@ MLIREDSC MLIRGPU MLIRLoopOps + MLIRGPU MLIRPass MLIRTestDialect MLIRVectorOps diff --git a/mlir/test/lib/Transforms/TestAllReduceLowering.cpp b/mlir/test/lib/Transforms/TestAllReduceLowering.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Transforms/TestAllReduceLowering.cpp @@ -0,0 +1,32 @@ +//===- TestAllReduceLowering.cpp - Test gpu.all_reduce lowering -----------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains test passes for lowering the gpu.all_reduce op. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/GPU/Passes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace { +struct TestAllReduceLoweringPass + : public ModulePass { + void runOnModule() override { + OwningRewritePatternList patterns; + populateGpuRewritePatterns(&getContext(), patterns); + applyPatternsGreedily(getModule(), patterns); + } +}; +} // namespace + +static PassRegistration + pass("test-all-reduce-lowering", + "Lowers gpu.all-reduce ops within the GPU dialect.");