diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td @@ -155,6 +155,16 @@ return {begin, end}; } + // Adds a new block argument that corresponds to buffers located in + // workgroup memory. + BlockArgument* addWorkgroupAttribution(Type type) { + auto attrName = getNumWorkgroupAttributionsAttrName(); + auto attr = getAttrOfType(attrName); + setAttr(attrName, IntegerAttr::get(attr.getType(), attr.getValue()+1)); + return getBody().front().insertArgument( + getType().getNumInputs() + attr.getInt(), type); + } + /// Returns a list of block arguments that correspond to buffers located in /// the private memory. ArrayRef getPrivateAttributions() { diff --git a/mlir/include/mlir/Dialect/GPU/Passes.h b/mlir/include/mlir/Dialect/GPU/Passes.h --- a/mlir/include/mlir/Dialect/GPU/Passes.h +++ b/mlir/include/mlir/Dialect/GPU/Passes.h @@ -17,11 +17,17 @@ namespace mlir { +class MLIRContext; class ModuleOp; template class OpPassBase; +class OwningRewritePatternList; std::unique_ptr> createGpuKernelOutliningPass(); +/// Collect a set of patterns to rewrite ops within the GPU dialect. +void populateGpuRewritePatterns(MLIRContext *context, + OwningRewritePatternList &patterns); + } // namespace mlir #endif // MLIR_DIALECT_GPU_PASSES_H_ diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h --- a/mlir/include/mlir/IR/Block.h +++ b/mlir/include/mlir/IR/Block.h @@ -82,6 +82,9 @@ /// Add one argument to the argument list for each type specified in the list. iterator_range addArguments(ArrayRef types); + // Add one value to the argument list at the specified position. + BlockArgument *insertArgument(unsigned index, Type type); + /// Erase the argument at 'index' and remove it from the argument list. If /// 'updatePredTerms' is set to true, this argument is also removed from the /// terminators of each predecessor to this block. diff --git a/mlir/lib/Dialect/GPU/CMakeLists.txt b/mlir/lib/Dialect/GPU/CMakeLists.txt --- a/mlir/lib/Dialect/GPU/CMakeLists.txt +++ b/mlir/lib/Dialect/GPU/CMakeLists.txt @@ -2,6 +2,7 @@ IR/GPUDialect.cpp IR/DialectRegistration.cpp Transforms/KernelOutlining.cpp + Transforms/AllReduceLowering.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/GPU diff --git a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp @@ -0,0 +1,362 @@ +//===- AllReduceLowering.cpp - Implementation of all-reduce lowering ------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements in-dialect lowering of the all-reduce op to a block of +// simpler instructions. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/GPU/Passes.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace { + +struct GpuAllReduceRewriter { + using AccumulatorFactory = std::function; + + explicit GpuAllReduceRewriter(gpu::GPUFuncOp funcOp_, + gpu::AllReduceOp reduceOp_, + PatternRewriter &rewriter_) + : funcOp(funcOp_), reduceOp(reduceOp_), rewriter(rewriter_), + loc(reduceOp.getLoc()), valueType(reduceOp.value()->getType()), + indexType(IndexType::get(reduceOp.getContext())), + int32Type(IntegerType::get(32, reduceOp.getContext())) {} + + /// Creates an all_reduce across the block. + /// + /// First reduce the elements within a warp. The first thread of each warp + /// writes the intermediate result to shared memory. After synchronizing the + /// block, the first warp reduces the values from shared memory. The result + /// is broadcasted to all threads through shared memory. + /// + /// %warp_reduce = `createWarpReduce(%operand)` + /// cond_br %is_first_lane, ^then1, ^continue1 + /// ^then1: + /// store %warp_reduce, %workgroup_buffer, %warp_id + /// br ^continue1 + /// ^continue1: + /// gpu.barrier + /// %is_valid_warp = cmpi "slt" %thread_idx, %num_warps + /// cond_br %is_first_lane, ^then2, ^continue2 + /// ^then2: + /// %partial_reduce = load %workgroup_buffer, %thread_idx + /// %all_reduce = `createWarpReduce(%partial_reduce)` + /// store %all_reduce, %workgroup_buffer, %zero + /// llvm.br ^continue2 + /// ^continue2: + /// gpu.barrier + /// %result = load %workgroup_buffer, %zero + /// return %result + /// + void rewrite() { + rewriter.setInsertionPoint(reduceOp); + + // Compute linear thread index and block size. + Value *dimX = getDimOp("x"); + Value *dimY = getDimOp("y"); + Value *dimZ = getDimOp("z"); + Value *tidX = getDimOp("x"); + Value *tidY = getDimOp("y"); + Value *tidZ = getDimOp("z"); + Value *tmp1 = create(loc, int32Type, tidZ, dimY); + Value *tmp2 = create(loc, int32Type, tmp1, tidY); + Value *tmp3 = create(loc, int32Type, tmp2, dimX); + Value *tmp4 = create(loc, int32Type, dimX, dimY); + Value *threadIdx = create(loc, int32Type, tmp3, tidX); + Value *blockSize = create(loc, int32Type, tmp4, dimZ); + + // Compute lane id (invocation id withing the subgroup). + Value *warpMask = create(loc, kWarpSize - 1, 32); + Value *laneId = create(loc, threadIdx, warpMask); + Value *isFirstLane = create(loc, CmpIPredicate::eq, laneId, + create(loc, 0, 32)); + + Value *numThreadsWithSmallerWarpId = create(loc, threadIdx, laneId); + // The number of active threads in the warp, not clamped to 32. + Value *activeWidth = + create(loc, blockSize, numThreadsWithSmallerWarpId); + + // Create factory for op which accumulates to values. + AccumulatorFactory accumFactory = getFactory(); + assert(accumFactory && "failed to create accumulator factory"); + + // Reduce elements within each warp to produce the intermediate results. + Value *warpReduce = + createWarpReduce(activeWidth, laneId, reduceOp.value(), accumFactory); + + // Add workgroup buffer to parent function for intermediate result. + Value *buffer = createWorkgroupBuffer(); + + // Write the intermediate results to shared memory, using the first lane of + // each warp. + createPredicatedBlock(isFirstLane, [&] { + Value *warpId = getDivideByWarpSize(threadIdx); + Value *index = create(loc, indexType, warpId); + create(loc, warpReduce, buffer, index); + }); + create(loc); + + // Compute number of active warps. + Value *biasedBlockSize = + create(loc, int32Type, blockSize, warpMask); + Value *numWarps = getDivideByWarpSize(biasedBlockSize); + Value *isValidWarp = + create(loc, CmpIPredicate::slt, threadIdx, numWarps); + + // Use the first numWarps threads to reduce the intermediate results from + // shared memory. The final result is written to shared memory again. + Value *zero = create(loc, 0); + createPredicatedBlock(isValidWarp, [&] { + Value *index = create(loc, indexType, threadIdx); + Value *value = create(loc, valueType, buffer, index); + Value *result = createWarpReduce(numWarps, laneId, value, accumFactory); + create(loc, result, buffer, zero); + }); + + // Synchronize workgroup and load result from shared memory. + create(loc); + Value *result = create(loc, valueType, buffer, zero); + + rewriter.replaceOp(reduceOp, result); + } + +private: + // Shortcut to create op using rewriter. + template T create(Args... args) { + return rewriter.create(std::forward(args)...); + } + + // Creates dimension op of type T, with the result casted to int32. + template Value *getDimOp(StringRef dimension) { + Value *dim = create(loc, indexType, rewriter.getStringAttr(dimension)); + return create(loc, int32Type, dim); + } + + /// Adds type to funcOp's workgroup attributions. + Value *createWorkgroupBuffer() { + auto bufferType = + MemRefType::get({kWarpSize}, valueType, ArrayRef{}, 3); + return funcOp.addWorkgroupAttribution(bufferType); + } + + /// Returns an accumulator factory using either the op attribute or the body + /// region. + AccumulatorFactory getFactory() { + if (!reduceOp.body().empty()) + return getFactory(reduceOp.body()); + if (reduceOp.op()) + return getFactory(*reduceOp.op()); + return AccumulatorFactory(); + } + + /// Returns an accumulator factory that clones the body. The body's entry + /// block is expected to have 2 arguments. The gpu.yield return the + /// accumulated value of the same type. + AccumulatorFactory getFactory(Region &body) { + return AccumulatorFactory([&](Value *lhs, Value *rhs) { + Block *block = rewriter.getInsertionBlock(); + Block *split = rewriter.splitBlock(block, rewriter.getInsertionPoint()); + + // Insert accumulator body between split block. + BlockAndValueMapping mapping; + mapping.map(body.front().getArgument(0), lhs); + mapping.map(body.front().getArgument(1), rhs); + rewriter.cloneRegionBefore(body, *split->getParent(), + split->getIterator(), mapping); + + // Add branch before inserted body, into body. + block = block->getNextNode(); + create(loc, block, ValueRange()); + + // Replace all gpu.yield ops with branch out of body. + for (; block != split; block = block->getNextNode()) { + Operation *terminator = block->getTerminator(); + if (!llvm::isa(terminator)) + continue; + rewriter.setInsertionPointToEnd(block); + rewriter.replaceOpWithNewOp( + terminator, split, ValueRange(terminator->getOperand(0))); + } + + // Return accumulator result. + rewriter.setInsertionPointToStart(split); + return split->addArgument(lhs->getType()); + }); + } + + /// Returns an accumulator factory that creates an op specified by opName. + AccumulatorFactory getFactory(StringRef opName) { + bool isFloatingPoint = valueType.isF32(); + if (opName == "add") + return isFloatingPoint ? getFactory() : getFactory(); + if (opName == "mul") + return isFloatingPoint ? getFactory() : getFactory(); + return AccumulatorFactory(); + } + + /// Returns an accumulator factory that creates an op of type T. + template AccumulatorFactory getFactory() { + return [&](Value *lhs, Value *rhs) { + return create(loc, lhs->getType(), lhs, rhs); + }; + } + + /// Creates an if-block skeleton and calls the two factories to generate the + /// ops in the `then` and `else` block.. + /// + /// llvm.cond_br %condition, ^then, ^continue + /// ^then: + /// %then_operands = `thenOpsFactory()` + /// llvm.br ^continue(%then_operands) + /// ^else: + /// %else_operands = `elseOpsFactory()` + /// llvm.br ^continue(%else_operands) + /// ^continue(%block_operands): + /// + template + void createIf(Value *condition, ThenOpsFactory &&thenOpsFactory, + ElseOpsFactory &&elseOpsFactory) { + Block *currentBlock = rewriter.getInsertionBlock(); + auto currentPoint = rewriter.getInsertionPoint(); + + Block *thenBlock = rewriter.splitBlock(currentBlock, currentPoint); + Block *elseBlock = rewriter.splitBlock(thenBlock, thenBlock->begin()); + Block *continueBlock = rewriter.splitBlock(elseBlock, elseBlock->begin()); + + rewriter.setInsertionPointToEnd(currentBlock); + create(loc, condition, thenBlock, + /*trueOperands=*/ArrayRef(), elseBlock, + /*falseOperands=*/ArrayRef()); + + auto addBranch = [&](ValueRange operands) { + create(loc, continueBlock, operands); + }; + + rewriter.setInsertionPointToStart(thenBlock); + auto thenOperands = thenOpsFactory(); + addBranch(thenOperands); + + rewriter.setInsertionPointToStart(elseBlock); + auto elseOperands = elseOpsFactory(); + addBranch(elseOperands); + + assert(thenOperands.size() == elseOperands.size()); + rewriter.setInsertionPointToStart(continueBlock); + for (auto *operand : thenOperands) + continueBlock->addArgument(operand->getType()); + } + + /// Shortcut for createIf with empty else block and no block operands. + template + void createPredicatedBlock(Value *condition, Factory &&predicatedOpsFactory) { + createIf( + condition, + [&] { + predicatedOpsFactory(); + return ArrayRef(); + }, + [&] { return ArrayRef(); }); + } + + /// Creates a reduction across the first activeWidth lanes of a warp. + /// The first lane returns the result, all others return values are undefined. + Value *createWarpReduce(Value *activeWidth, Value *laneId, Value *operand, + AccumulatorFactory &accumFactory) { + Value *warpSize = create(loc, kWarpSize, 32); + Value *isPartialWarp = + create(loc, CmpIPredicate::slt, activeWidth, warpSize); + SmallVector shuffleType = {valueType, rewriter.getI1Type()}; + auto xorAttr = rewriter.getStringAttr("xor"); + + createIf( + isPartialWarp, + // Generate reduction over a (potentially) partial warp. + [&] { + Value *value = operand; + // Repeatedly shuffle value from 'laneId ^ i' and accumulate if source + // lane is within the active range. The accumulated value is available + // in the first lane. + for (int i = 1; i < kWarpSize; i <<= 1) { + Value *offset = create(loc, i, 32); + auto shuffleOp = create( + loc, shuffleType, value, offset, activeWidth, xorAttr); + // Skip the accumulation if the shuffle op read from a lane outside + // of the active range. + createIf( + shuffleOp.getResult(1), + [&] { + return llvm::SmallVector{ + accumFactory(value, shuffleOp.getResult(0))}; + }, + [&] { return llvm::makeArrayRef(value); }); + value = rewriter.getInsertionBlock()->getArgument(0); + } + return llvm::SmallVector{value}; + }, + // Generate a reduction over the entire warp. This is a specialization + // of the above reduction with unconditional accumulation. + [&] { + Value *value = operand; + for (int i = 1; i < kWarpSize; i <<= 1) { + Value *offset = create(loc, i, 32); + auto shuffleOp = create(loc, shuffleType, value, + offset, warpSize, xorAttr); + value = accumFactory(value, shuffleOp.getResult(0)); + } + return llvm::SmallVector{value}; + }); + return rewriter.getInsertionBlock()->getArgument(0); + } + + /// Returns value divided by the warp size (i.e. 32). + Value *getDivideByWarpSize(Value *value) { + Value *warpSize = create(loc, kWarpSize, 32); + return create(loc, int32Type, value, warpSize); + } + + gpu::GPUFuncOp funcOp; + gpu::AllReduceOp reduceOp; + PatternRewriter &rewriter; + + Location loc; + Type valueType; + Type indexType; + Type int32Type; + + static constexpr int kWarpSize = 32; +}; + +struct GpuAllReduceConversion : public RewritePattern { + explicit GpuAllReduceConversion(MLIRContext *context) + : RewritePattern(gpu::GPUFuncOp::getOperationName(), 1, context) {} + + PatternMatchResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + auto funcOp = llvm::cast(op); + auto callback = [&](gpu::AllReduceOp reduceOp) { + GpuAllReduceRewriter(funcOp, reduceOp, rewriter).rewrite(); + return WalkResult::interrupt(); + }; + while (funcOp.walk(callback).wasInterrupted()) { + } + return matchSuccess(); + } +}; +} // namespace + +void mlir::populateGpuRewritePatterns(MLIRContext *context, + OwningRewritePatternList &patterns) { + patterns.insert(context); +} diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp --- a/mlir/lib/IR/Block.cpp +++ b/mlir/lib/IR/Block.cpp @@ -160,6 +160,13 @@ return {arguments.data() + initialSize, arguments.data() + arguments.size()}; } +BlockArgument *Block::insertArgument(unsigned index, Type type) { + auto *arg = new BlockArgument(type, this); + assert(index <= arguments.size()); + arguments.insert(arguments.begin() + index, arg); + return arg; +} + void Block::eraseArgument(unsigned index, bool updatePredTerms) { assert(index < arguments.size()); diff --git a/mlir/test/Dialect/GPU/all-reduce.mlir b/mlir/test/Dialect/GPU/all-reduce.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/GPU/all-reduce.mlir @@ -0,0 +1,180 @@ +// RUN: mlir-opt -test-all-reduce-lowering %s | FileCheck %s + +module @kernels attributes {gpu.kernel_module} { + + // CHECK: gpu.func @kernel(%[[varg0:[a-z_0-9]+]]: f32) workgroup(%[[varg1:[a-z_0-9]+]] : memref<32xf32, 3>) kernel { + gpu.func @kernel(%arg0 : f32) attributes { gpu.kernel } { + // CHECK: %[[vc0_i32:[a-z_0-9]+]] = constant 0 : i32 + // CHECK: %[[vc31_i32:[a-z_0-9]+]] = constant 31 : i32 + // CHECK: %[[vc0:[a-z_0-9]+]] = constant 0 : index + // CHECK: %[[vc32_i32:[a-z_0-9]+]] = constant 32 : i32 + // CHECK: %[[vc1_i32:[a-z_0-9]+]] = constant 1 : i32 + // CHECK: %[[vc2_i32:[a-z_0-9]+]] = constant 2 : i32 + // CHECK: %[[vc4_i32:[a-z_0-9]+]] = constant 4 : i32 + // CHECK: %[[vc8_i32:[a-z_0-9]+]] = constant 8 : i32 + // CHECK: %[[vc16_i32:[a-z_0-9]+]] = constant 16 : i32 + // CHECK: %[[v0:[a-z_0-9]+]] = "gpu.block_dim"() {dimension = "x"} : () -> index + // CHECK: %[[v1:[a-z_0-9]+]] = index_cast %[[v0]] : index to i32 + // CHECK: %[[v2:[a-z_0-9]+]] = "gpu.block_dim"() {dimension = "y"} : () -> index + // CHECK: %[[v3:[a-z_0-9]+]] = index_cast %[[v2]] : index to i32 + // CHECK: %[[v4:[a-z_0-9]+]] = "gpu.block_dim"() {dimension = "z"} : () -> index + // CHECK: %[[v5:[a-z_0-9]+]] = index_cast %[[v4]] : index to i32 + // CHECK: %[[v6:[a-z_0-9]+]] = "gpu.thread_id"() {dimension = "x"} : () -> index + // CHECK: %[[v7:[a-z_0-9]+]] = index_cast %[[v6]] : index to i32 + // CHECK: %[[v8:[a-z_0-9]+]] = "gpu.thread_id"() {dimension = "y"} : () -> index + // CHECK: %[[v9:[a-z_0-9]+]] = index_cast %[[v8]] : index to i32 + // CHECK: %[[v10:[a-z_0-9]+]] = "gpu.thread_id"() {dimension = "z"} : () -> index + // CHECK: %[[v11:[a-z_0-9]+]] = index_cast %[[v10]] : index to i32 + // CHECK: %[[v12:[a-z_0-9]+]] = muli %[[v11]], %[[v3]] : i32 + // CHECK: %[[v13:[a-z_0-9]+]] = addi %[[v12]], %[[v9]] : i32 + // CHECK: %[[v14:[a-z_0-9]+]] = muli %[[v13]], %[[v1]] : i32 + // CHECK: %[[v15:[a-z_0-9]+]] = muli %[[v1]], %[[v3]] : i32 + // CHECK: %[[v16:[a-z_0-9]+]] = addi %[[v14]], %[[v7]] : i32 + // CHECK: %[[v17:[a-z_0-9]+]] = muli %[[v15]], %[[v5]] : i32 + // CHECK: %[[v18:[a-z_0-9]+]] = and %[[v16]], %[[vc31_i32]] : i32 + // CHECK: %[[v19:[a-z_0-9]+]] = cmpi "eq", %[[v18]], %[[vc0_i32]] : i32 + // CHECK: %[[v20:[a-z_0-9]+]] = subi %[[v16]], %[[v18]] : i32 + // CHECK: %[[v21:[a-z_0-9]+]] = subi %[[v17]], %[[v20]] : i32 + // CHECK: %[[v22:[a-z_0-9]+]] = cmpi "slt", %[[v21]], %[[vc32_i32]] : i32 + // CHECK: cond_br %[[v22]], ^bb[[#b1:]], ^bb[[#b17:]] + // CHECK: ^bb[[#b1]]: + // CHECK: %[[vresult:[a-z_0-9]+]], %[[vvalid:[a-z_0-9]+]] = gpu.shuffle %[[varg0]], %[[vc1_i32]], %[[v21]] xor : f32 + // CHECK: cond_br %[[vvalid]], ^bb[[#b2:]], ^bb[[#b3:]] + // CHECK: ^bb[[#b2]]: + // CHECK: %[[v23:[a-z_0-9]+]] = addf %[[varg0]], %[[vresult]] : f32 + // CHECK: br ^bb[[#b4:]](%[[v23]] : f32) + // CHECK: ^bb[[#b3]]: + // CHECK: br ^bb[[#b4]](%[[varg0]] : f32) + // CHECK: ^bb[[#b4]](%[[v24:[a-z_0-9]+]]: f32): + // CHECK: %[[vresult_0:[a-z_0-9]+]], %[[vvalid_1:[a-z_0-9]+]] = gpu.shuffle %[[v24]], %[[vc2_i32]], %[[v21]] xor : f32 + // CHECK: cond_br %[[vvalid_1]], ^bb[[#b5:]], ^bb[[#b6:]] + // CHECK: ^bb[[#b5]]: + // CHECK: %[[v25:[a-z_0-9]+]] = addf %[[v24]], %[[vresult_0]] : f32 + // CHECK: br ^bb[[#b7:]](%[[v25]] : f32) + // CHECK: ^bb[[#b6]]: + // CHECK: br ^bb[[#b7]](%[[v24]] : f32) + // CHECK: ^bb[[#b7]](%[[v26:[a-z_0-9]+]]: f32): + // CHECK: %[[vresult_2:[a-z_0-9]+]], %[[vvalid_3:[a-z_0-9]+]] = gpu.shuffle %[[v26]], %[[vc4_i32]], %[[v21]] xor : f32 + // CHECK: cond_br %[[vvalid_3]], ^bb[[#b8:]], ^bb[[#b9:]] + // CHECK: ^bb[[#b8]]: + // CHECK: %[[v27:[a-z_0-9]+]] = addf %[[v26]], %[[vresult_2]] : f32 + // CHECK: br ^bb[[#b10:]](%[[v27]] : f32) + // CHECK: ^bb[[#b9]]: + // CHECK: br ^bb[[#b10]](%[[v26]] : f32) + // CHECK: ^bb[[#b10]](%[[v28:[a-z_0-9]+]]: f32): + // CHECK: %[[vresult_4:[a-z_0-9]+]], %[[vvalid_5:[a-z_0-9]+]] = gpu.shuffle %[[v28]], %[[vc8_i32]], %[[v21]] xor : f32 + // CHECK: cond_br %[[vvalid_5]], ^bb[[#b11:]], ^bb[[#b12:]] + // CHECK: ^bb[[#b11]]: + // CHECK: %[[v29:[a-z_0-9]+]] = addf %[[v28]], %[[vresult_4]] : f32 + // CHECK: br ^bb[[#b13:]](%[[v29]] : f32) + // CHECK: ^bb[[#b12]]: + // CHECK: br ^bb[[#b13]](%[[v28]] : f32) + // CHECK: ^bb[[#b13]](%[[v30:[a-z_0-9]+]]: f32): + // CHECK: %[[vresult_6:[a-z_0-9]+]], %[[vvalid_7:[a-z_0-9]+]] = gpu.shuffle %[[v30]], %[[vc16_i32]], %[[v21]] xor : f32 + // CHECK: cond_br %[[vvalid_7]], ^bb[[#b14:]], ^bb[[#b15:]] + // CHECK: ^bb[[#b14]]: + // CHECK: %[[v31:[a-z_0-9]+]] = addf %[[v30]], %[[vresult_6]] : f32 + // CHECK: br ^bb[[#b16:]](%[[v31]] : f32) + // CHECK: ^bb[[#b15]]: + // CHECK: br ^bb[[#b16]](%[[v30]] : f32) + // CHECK: ^bb[[#b16]](%[[v32:[a-z_0-9]+]]: f32): + // CHECK: br ^bb[[#b18:]](%[[v32]] : f32) + // CHECK: ^bb[[#b17]]: + // CHECK: %[[vresult_8:[a-z_0-9]+]], %[[vvalid_9:[a-z_0-9]+]] = gpu.shuffle %[[varg0]], %[[vc1_i32]], %[[vc32_i32]] xor : f32 + // CHECK: %[[v33:[a-z_0-9]+]] = addf %[[varg0]], %[[vresult_8]] : f32 + // CHECK: %[[vresult_10:[a-z_0-9]+]], %[[vvalid_11:[a-z_0-9]+]] = gpu.shuffle %[[v33]], %[[vc2_i32]], %[[vc32_i32]] xor : f32 + // CHECK: %[[v34:[a-z_0-9]+]] = addf %[[v33]], %[[vresult_10]] : f32 + // CHECK: %[[vresult_12:[a-z_0-9]+]], %[[vvalid_13:[a-z_0-9]+]] = gpu.shuffle %[[v34]], %[[vc4_i32]], %[[vc32_i32]] xor : f32 + // CHECK: %[[v35:[a-z_0-9]+]] = addf %[[v34]], %[[vresult_12]] : f32 + // CHECK: %[[vresult_14:[a-z_0-9]+]], %[[vvalid_15:[a-z_0-9]+]] = gpu.shuffle %[[v35]], %[[vc8_i32]], %[[vc32_i32]] xor : f32 + // CHECK: %[[v36:[a-z_0-9]+]] = addf %[[v35]], %[[vresult_14]] : f32 + // CHECK: %[[vresult_16:[a-z_0-9]+]], %[[vvalid_17:[a-z_0-9]+]] = gpu.shuffle %[[v36]], %[[vc16_i32]], %[[vc32_i32]] xor : f32 + // CHECK: %[[v37:[a-z_0-9]+]] = addf %[[v36]], %[[vresult_16]] : f32 + // CHECK: br ^bb[[#b18]](%[[v37]] : f32) + // CHECK: ^bb[[#b18]](%[[v38:[a-z_0-9]+]]: f32): + // CHECK: cond_br %[[v19]], ^bb[[#b19:]], ^bb[[#b20:]] + // CHECK: ^bb[[#b19]]: + // CHECK: %[[v39:[a-z_0-9]+]] = divis %[[v16]], %[[vc32_i32]] : i32 + // CHECK: %[[v40:[a-z_0-9]+]] = index_cast %[[v39]] : i32 to index + // CHECK: store %[[v38]], %[[varg1]][%[[v40]]] : memref<32xf32, 3> + // CHECK: br ^bb[[#b21:]] + // CHECK: ^bb[[#b20]]: + // CHECK: br ^bb[[#b21]] + // CHECK: ^bb[[#b21]]: + // CHECK: gpu.barrier + // CHECK: %[[v41:[a-z_0-9]+]] = addi %[[v17]], %[[vc31_i32]] : i32 + // CHECK: %[[v42:[a-z_0-9]+]] = divis %[[v41]], %[[vc32_i32]] : i32 + // CHECK: %[[v43:[a-z_0-9]+]] = cmpi "slt", %[[v16]], %[[v42]] : i32 + // CHECK: cond_br %[[v43]], ^bb[[#b22:]], ^bb[[#b41:]] + // CHECK: ^bb[[#b22]]: + // CHECK: %[[v44:[a-z_0-9]+]] = index_cast %[[v16]] : i32 to index + // CHECK: %[[v45:[a-z_0-9]+]] = load %[[varg1]][%[[v44]]] : memref<32xf32, 3> + // CHECK: %[[v46:[a-z_0-9]+]] = cmpi "slt", %[[v42]], %[[vc32_i32]] : i32 + // CHECK: cond_br %[[v46]], ^bb[[#b23:]], ^bb[[#b39:]] + // CHECK: ^bb[[#b23]]: + // CHECK: %[[vresult_18:[a-z_0-9]+]], %[[vvalid_19:[a-z_0-9]+]] = gpu.shuffle %[[v45]], %[[vc1_i32]], %[[v42]] xor : f32 + // CHECK: cond_br %[[vvalid_19]], ^bb[[#b24:]], ^bb[[#b25:]] + // CHECK: ^bb[[#b24]]: + // CHECK: %[[v47:[a-z_0-9]+]] = addf %[[v45]], %[[vresult_18]] : f32 + // CHECK: br ^bb[[#b26:]](%[[v47]] : f32) + // CHECK: ^bb[[#b25]]: + // CHECK: br ^bb[[#b26]](%[[v45]] : f32) + // CHECK: ^bb[[#b26]](%[[v48:[a-z_0-9]+]]: f32): + // CHECK: %[[vresult_20:[a-z_0-9]+]], %[[vvalid_21:[a-z_0-9]+]] = gpu.shuffle %[[v48]], %[[vc2_i32]], %[[v42]] xor : f32 + // CHECK: cond_br %[[vvalid_21]], ^bb[[#b27:]], ^bb[[#b28:]] + // CHECK: ^bb[[#b27]]: + // CHECK: %[[v49:[a-z_0-9]+]] = addf %[[v48]], %[[vresult_20]] : f32 + // CHECK: br ^bb[[#b29:]](%[[v49]] : f32) + // CHECK: ^bb[[#b28]]: + // CHECK: br ^bb[[#b29]](%[[v48]] : f32) + // CHECK: ^bb[[#b29]](%[[v50:[a-z_0-9]+]]: f32): + // CHECK: %[[vresult_22:[a-z_0-9]+]], %[[vvalid_23:[a-z_0-9]+]] = gpu.shuffle %[[v50]], %[[vc4_i32]], %[[v42]] xor : f32 + // CHECK: cond_br %[[vvalid_23]], ^bb[[#b30:]], ^bb[[#b31:]] + // CHECK: ^bb[[#b30]]: + // CHECK: %[[v51:[a-z_0-9]+]] = addf %[[v50]], %[[vresult_22]] : f32 + // CHECK: br ^bb[[#b32:]](%[[v51]] : f32) + // CHECK: ^bb[[#b31]]: + // CHECK: br ^bb[[#b32]](%[[v50]] : f32) + // CHECK: ^bb[[#b32]](%[[v52:[a-z_0-9]+]]: f32): + // CHECK: %[[vresult_24:[a-z_0-9]+]], %[[vvalid_25:[a-z_0-9]+]] = gpu.shuffle %[[v52]], %[[vc8_i32]], %[[v42]] xor : f32 + // CHECK: cond_br %[[vvalid_25]], ^bb[[#b33:]], ^bb[[#b34:]] + // CHECK: ^bb[[#b33]]: + // CHECK: %[[v53:[a-z_0-9]+]] = addf %[[v52]], %[[vresult_24]] : f32 + // CHECK: br ^bb[[#b35:]](%[[v53]] : f32) + // CHECK: ^bb[[#b34]]: + // CHECK: br ^bb[[#b35]](%[[v52]] : f32) + // CHECK: ^bb[[#b35]](%[[v54:[a-z_0-9]+]]: f32): + // CHECK: %[[vresult_26:[a-z_0-9]+]], %[[vvalid_27:[a-z_0-9]+]] = gpu.shuffle %[[v54]], %[[vc16_i32]], %[[v42]] xor : f32 + // CHECK: cond_br %[[vvalid_27]], ^bb[[#b36:]], ^bb[[#b37:]] + // CHECK: ^bb[[#b36]]: + // CHECK: %[[v55:[a-z_0-9]+]] = addf %[[v54]], %[[vresult_26]] : f32 + // CHECK: br ^bb[[#b38:]](%[[v55]] : f32) + // CHECK: ^bb[[#b37]]: + // CHECK: br ^bb[[#b38]](%[[v54]] : f32) + // CHECK: ^bb[[#b38]](%[[v56:[a-z_0-9]+]]: f32): + // CHECK: br ^bb[[#b40:]](%[[v56]] : f32) + // CHECK: ^bb[[#b39]]: + // CHECK: %[[vresult_28:[a-z_0-9]+]], %[[vvalid_29:[a-z_0-9]+]] = gpu.shuffle %[[v45]], %[[vc1_i32]], %[[vc32_i32]] xor : f32 + // CHECK: %[[v57:[a-z_0-9]+]] = addf %[[v45]], %[[vresult_28]] : f32 + // CHECK: %[[vresult_30:[a-z_0-9]+]], %[[vvalid_31:[a-z_0-9]+]] = gpu.shuffle %[[v57]], %[[vc2_i32]], %[[vc32_i32]] xor : f32 + // CHECK: %[[v58:[a-z_0-9]+]] = addf %[[v57]], %[[vresult_30]] : f32 + // CHECK: %[[vresult_32:[a-z_0-9]+]], %[[vvalid_33:[a-z_0-9]+]] = gpu.shuffle %[[v58]], %[[vc4_i32]], %[[vc32_i32]] xor : f32 + // CHECK: %[[v59:[a-z_0-9]+]] = addf %[[v58]], %[[vresult_32]] : f32 + // CHECK: %[[vresult_34:[a-z_0-9]+]], %[[vvalid_35:[a-z_0-9]+]] = gpu.shuffle %[[v59]], %[[vc8_i32]], %[[vc32_i32]] xor : f32 + // CHECK: %[[v60:[a-z_0-9]+]] = addf %[[v59]], %[[vresult_34]] : f32 + // CHECK: %[[vresult_36:[a-z_0-9]+]], %[[vvalid_37:[a-z_0-9]+]] = gpu.shuffle %[[v60]], %[[vc16_i32]], %[[vc32_i32]] xor : f32 + // CHECK: %[[v61:[a-z_0-9]+]] = addf %[[v60]], %[[vresult_36]] : f32 + // CHECK: br ^bb[[#b40]](%[[v61]] : f32) + // CHECK: ^bb[[#b40]](%[[v62:[a-z_0-9]+]]: f32): + // CHECK: store %[[v62]], %[[varg1]][%[[vc0]]] : memref<32xf32, 3> + // CHECK: br ^bb[[#b42:]] + // CHECK: ^bb[[#b41]]: + // CHECK: br ^bb[[#b42]] + // CHECK: ^bb[[#b42]]: + // CHECK: gpu.barrier + // CHECK: %[[v63:[a-z_0-9]+]] = load %[[varg1]][%[[vc0]]] : memref<32xf32, 3> + %sum = "gpu.all_reduce"(%arg0) ({}) {op = "add"} : (f32) -> (f32) + gpu.return + } + +} diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt --- a/mlir/test/lib/Transforms/CMakeLists.txt +++ b/mlir/test/lib/Transforms/CMakeLists.txt @@ -1,4 +1,5 @@ add_llvm_library(MLIRTestTransforms + TestAllReduceLowering.cpp TestCallGraph.cpp TestConstantFold.cpp TestLoopFusion.cpp @@ -27,6 +28,7 @@ MLIRAffineOps MLIRAnalysis MLIRLoopOps + MLIRGPU MLIRPass MLIRTestDialect MLIRVectorOps diff --git a/mlir/test/lib/Transforms/TestAllReduceLowering.cpp b/mlir/test/lib/Transforms/TestAllReduceLowering.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Transforms/TestAllReduceLowering.cpp @@ -0,0 +1,32 @@ +//===- TestAllReduceLowering.cpp - Test gpu.all_reduce lowering -----------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains test passes for lowering the gpu.all_reduce op. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/GPU/Passes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace { +struct TestAllReduceLoweringPass + : public ModulePass { + void runOnModule() override { + OwningRewritePatternList patterns; + populateGpuRewritePatterns(&getContext(), patterns); + applyPatternsGreedily(getModule(), patterns); + } +}; +} // namespace + +static PassRegistration + pass("test-all-reduce-lowering", + "Lowers gpu.all-reduce ops within the GPU dialect.");