diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -15,6 +15,7 @@ #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/GPU/Passes.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/Pass/Pass.h" @@ -49,421 +50,6 @@ } }; -/// Converts all_reduce op to LLVM/NVVM ops. -struct GPUAllReduceOpLowering : public LLVMOpLowering { - using AccumulatorFactory = - std::function; - - explicit GPUAllReduceOpLowering(LLVMTypeConverter &lowering_) - : LLVMOpLowering(gpu::AllReduceOp::getOperationName(), - lowering_.getDialect()->getContext(), lowering_), - int32Type(LLVM::LLVMType::getInt32Ty(lowering_.getDialect())) {} - - PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - Location loc = op->getLoc(); - Value operand = operands.front(); - - // TODO(csigg): Generalize to other types of accumulation. - assert(op->getOperand(0).getType().isIntOrFloat()); - - // Create the reduction using an accumulator factory. - AccumulatorFactory factory = - getFactory(cast(op), operand); - assert(factory && "failed to create accumulator factory"); - Value result = createBlockReduce(loc, operand, factory, rewriter); - - rewriter.replaceOp(op, {result}); - return matchSuccess(); - } - -private: - /// Returns an accumulator factory using either the op attribute or the body - /// region. - AccumulatorFactory getFactory(gpu::AllReduceOp allReduce, - Value operand) const { - if (!allReduce.body().empty()) { - return getFactory(allReduce.body()); - } - if (allReduce.op()) { - auto type = operand.getType().cast(); - return getFactory(*allReduce.op(), type.getUnderlyingType()); - } - return AccumulatorFactory(); - } - - /// Returns an accumulator factory that clones the body. The body's entry - /// block is expected to have 2 arguments. The gpu.yield return the - /// accumulated value of the same type. - AccumulatorFactory getFactory(Region &body) const { - return AccumulatorFactory([&](Location loc, Value lhs, Value rhs, - ConversionPatternRewriter &rewriter) { - Block *block = rewriter.getInsertionBlock(); - Block *split = rewriter.splitBlock(block, rewriter.getInsertionPoint()); - - // Insert accumulator body between split block. - BlockAndValueMapping mapping; - mapping.map(body.front().getArgument(0), lhs); - mapping.map(body.front().getArgument(1), rhs); - rewriter.cloneRegionBefore(body, *split->getParent(), - split->getIterator(), mapping); - - // Add branch before inserted body, into body. - block = block->getNextNode(); - rewriter.create(loc, ArrayRef{}, - llvm::makeArrayRef(block), ValueRange()); - - // Replace all gpu.yield ops with branch out of body. - for (; block != split; block = block->getNextNode()) { - Operation *terminator = block->getTerminator(); - if (!llvm::isa(terminator)) - continue; - rewriter.setInsertionPointToEnd(block); - rewriter.replaceOpWithNewOp( - terminator, ArrayRef{}, llvm::makeArrayRef(split), - ValueRange(terminator->getOperand(0))); - } - - // Return accumulator result. - rewriter.setInsertionPointToStart(split); - return split->addArgument(lhs.getType()); - }); - } - - /// Returns an accumulator factory that creates an op specified by opName. - AccumulatorFactory getFactory(StringRef opName, llvm::Type *type) const { - if (type->isVectorTy() || type->isArrayTy()) - return getFactory(opName, type->getSequentialElementType()); - - bool isFloatingPoint = type->isFloatingPointTy(); - - if (opName == "add") { - return isFloatingPoint ? getFactory() - : getFactory(); - } - if (opName == "mul") { - return isFloatingPoint ? getFactory() - : getFactory(); - } - - return AccumulatorFactory(); - } - - /// Returns an accumulator factory that creates an op of type T. - template AccumulatorFactory getFactory() const { - return [](Location loc, Value lhs, Value rhs, - ConversionPatternRewriter &rewriter) { - return rewriter.create(loc, lhs.getType(), lhs, rhs); - }; - } - - /// Creates an all_reduce across the block. - /// - /// First reduce the elements within a warp. The first thread of each warp - /// writes the intermediate result to shared memory. After synchronizing the - /// block, the first warp reduces the values from shared memory. The result - /// is broadcasted to all threads through shared memory. - /// - /// %warp_reduce = `createWarpReduce(%operand)` - /// %shared_mem_ptr = llvm.mlir.addressof @reduce_buffer - /// %zero = llvm.mlir.constant(0 : i32) : !llvm.i32 - /// %lane_id = nvvm.read.ptx.sreg.laneid : !llvm.i32 - /// %is_first_lane = llvm.icmp "eq" %lane_id, %zero : !llvm.i1 - /// %thread_idx = `getLinearThreadIndex()` : !llvm.i32 - /// llvm.cond_br %is_first_lane, ^then1, ^continue1 - /// ^then1: - /// %warp_id = `getWarpId()` - /// %store_dst = llvm.getelementptr %shared_mem_ptr[%zero, %warp_id] - /// llvm.store %store_dst, %warp_reduce - /// llvm.br ^continue1 - /// ^continue1: - /// nvvm.barrier0 - /// %num_warps = `getNumWarps()` : !llvm.i32 - /// %is_valid_warp = llvm.icmp "slt" %thread_idx, %num_warps - /// %result_ptr = llvm.getelementptr %shared_mem_ptr[%zero, %zero] - /// llvm.cond_br %is_first_lane, ^then2, ^continue2 - /// ^then2: - /// %load_src = llvm.getelementptr %shared_mem_ptr[%zero, %thread_idx] - /// %value = llvm.load %load_src - /// %result = `createWarpReduce(%value)` - /// llvm.store %result_ptr, %result - /// llvm.br ^continue2 - /// ^continue2: - /// nvvm.barrier0 - /// %result = llvm.load %result_ptr - /// return %result - /// - Value createBlockReduce(Location loc, Value operand, - AccumulatorFactory &accumFactory, - ConversionPatternRewriter &rewriter) const { - auto type = operand.getType().cast(); - - // Create shared memory array to store the warp reduction. - auto module = operand.getDefiningOp()->getParentOfType(); - assert(module && "op must belong to a module"); - Value sharedMemPtr = - createSharedMemoryArray(loc, module, type, kWarpSize, rewriter); - - Value zero = rewriter.create( - loc, int32Type, rewriter.getI32IntegerAttr(0u)); - Value laneId = rewriter.create(loc, int32Type); - Value isFirstLane = rewriter.create( - loc, LLVM::ICmpPredicate::eq, laneId, zero); - Value threadIdx = getLinearThreadIndex(loc, rewriter); - Value blockSize = getBlockSize(loc, rewriter); - Value activeWidth = getActiveWidth(loc, threadIdx, blockSize, rewriter); - - // Reduce elements within each warp to produce the intermediate results. - Value warpReduce = createWarpReduce(loc, activeWidth, laneId, operand, - accumFactory, rewriter); - - // Write the intermediate results to shared memory, using the first lane of - // each warp. - createPredicatedBlock(loc, rewriter, isFirstLane, [&] { - Value warpId = getDivideByWarpSize(threadIdx, rewriter); - Value storeDst = rewriter.create( - loc, type, sharedMemPtr, ArrayRef({zero, warpId})); - rewriter.create(loc, warpReduce, storeDst); - }); - rewriter.create(loc); - - Value numWarps = getNumWarps(loc, blockSize, rewriter); - Value isValidWarp = rewriter.create( - loc, LLVM::ICmpPredicate::slt, threadIdx, numWarps); - Value resultPtr = rewriter.create( - loc, type, sharedMemPtr, ArrayRef({zero, zero})); - - // Use the first numWarps threads to reduce the intermediate results from - // shared memory. The final result is written to shared memory again. - createPredicatedBlock(loc, rewriter, isValidWarp, [&] { - Value loadSrc = rewriter.create( - loc, type, sharedMemPtr, ArrayRef({zero, threadIdx})); - Value value = rewriter.create(loc, type, loadSrc); - Value result = createWarpReduce(loc, numWarps, laneId, value, - accumFactory, rewriter); - rewriter.create(loc, result, resultPtr); - }); - rewriter.create(loc); - - // Load and return result from shared memory. - Value result = rewriter.create(loc, type, resultPtr); - return result; - } - - /// Creates an if-block skeleton and calls the two factories to generate the - /// ops in the `then` and `else` block.. - /// - /// llvm.cond_br %condition, ^then, ^continue - /// ^then: - /// %then_operands = `thenOpsFactory()` - /// llvm.br ^continue(%then_operands) - /// ^else: - /// %else_operands = `elseOpsFactory()` - /// llvm.br ^continue(%else_operands) - /// ^continue(%block_operands): - /// - template - void createIf(Location loc, ConversionPatternRewriter &rewriter, - Value condition, ThenOpsFactory &&thenOpsFactory, - ElseOpsFactory &&elseOpsFactory) const { - Block *currentBlock = rewriter.getInsertionBlock(); - auto currentPoint = rewriter.getInsertionPoint(); - - Block *thenBlock = rewriter.splitBlock(currentBlock, currentPoint); - Block *elseBlock = rewriter.splitBlock(thenBlock, thenBlock->begin()); - Block *continueBlock = rewriter.splitBlock(elseBlock, elseBlock->begin()); - - rewriter.setInsertionPointToEnd(currentBlock); - rewriter.create(loc, llvm::makeArrayRef(condition), - ArrayRef{thenBlock, elseBlock}); - - auto addBranch = [&](ValueRange operands) { - rewriter.create(loc, ArrayRef{}, - llvm::makeArrayRef(continueBlock), - llvm::makeArrayRef(operands)); - }; - - rewriter.setInsertionPointToStart(thenBlock); - auto thenOperands = thenOpsFactory(); - addBranch(thenOperands); - - rewriter.setInsertionPointToStart(elseBlock); - auto elseOperands = elseOpsFactory(); - addBranch(elseOperands); - - assert(thenOperands.size() == elseOperands.size()); - rewriter.setInsertionPointToStart(continueBlock); - for (auto operand : thenOperands) - continueBlock->addArgument(operand.getType()); - } - - /// Shortcut for createIf with empty else block and no block operands. - template - void createPredicatedBlock(Location loc, ConversionPatternRewriter &rewriter, - Value condition, - Factory &&predicatedOpsFactory) const { - createIf( - loc, rewriter, condition, - [&] { - predicatedOpsFactory(); - return ArrayRef(); - }, - [&] { return ArrayRef(); }); - } - - /// Creates a reduction across the first activeWidth lanes of a warp. - /// The first lane returns the result, all others return values are undefined. - Value createWarpReduce(Location loc, Value activeWidth, Value laneId, - Value operand, AccumulatorFactory accumFactory, - ConversionPatternRewriter &rewriter) const { - Value warpSize = rewriter.create( - loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize)); - Value isPartialWarp = rewriter.create( - loc, LLVM::ICmpPredicate::slt, activeWidth, warpSize); - auto type = operand.getType().cast(); - - createIf( - loc, rewriter, isPartialWarp, - // Generate reduction over a (potentially) partial warp. - [&] { - Value value = operand; - Value one = rewriter.create( - loc, int32Type, rewriter.getI32IntegerAttr(1)); - // Bit mask of active lanes: `(1 << activeWidth) - 1`. - Value activeMask = rewriter.create( - loc, int32Type, - rewriter.create(loc, int32Type, one, activeWidth), - one); - // Clamp lane: `activeWidth - 1` - Value maskAndClamp = - rewriter.create(loc, int32Type, activeWidth, one); - auto dialect = lowering.getDialect(); - auto predTy = LLVM::LLVMType::getInt1Ty(dialect); - auto shflTy = LLVM::LLVMType::getStructTy(dialect, {type, predTy}); - auto returnValueAndIsValidAttr = rewriter.getUnitAttr(); - - // Repeatedly shuffle value from 'laneId ^ i' and accumulate if source - // lane is within the active range. All lanes contain the final - // result, but only the first lane's result is used. - for (int i = 1; i < kWarpSize; i <<= 1) { - Value offset = rewriter.create( - loc, int32Type, rewriter.getI32IntegerAttr(i)); - Value shfl = rewriter.create( - loc, shflTy, activeMask, value, offset, maskAndClamp, - returnValueAndIsValidAttr); - Value isActiveSrcLane = rewriter.create( - loc, predTy, shfl, rewriter.getIndexArrayAttr(1)); - // Skip the accumulation if the shuffle op read from a lane outside - // of the active range. - createIf( - loc, rewriter, isActiveSrcLane, - [&] { - Value shflValue = rewriter.create( - loc, type, shfl, rewriter.getIndexArrayAttr(0)); - return SmallVector{ - accumFactory(loc, value, shflValue, rewriter)}; - }, - [&] { return llvm::makeArrayRef(value); }); - value = rewriter.getInsertionBlock()->getArgument(0); - } - return SmallVector{value}; - }, - // Generate a reduction over the entire warp. This is a specialization - // of the above reduction with unconditional accumulation. - [&] { - Value value = operand; - Value activeMask = rewriter.create( - loc, int32Type, rewriter.getI32IntegerAttr(~0u)); - Value maskAndClamp = rewriter.create( - loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize - 1)); - for (int i = 1; i < kWarpSize; i <<= 1) { - Value offset = rewriter.create( - loc, int32Type, rewriter.getI32IntegerAttr(i)); - Value shflValue = rewriter.create( - loc, type, activeMask, value, offset, maskAndClamp, - /*return_value_and_is_valid=*/UnitAttr()); - value = accumFactory(loc, value, shflValue, rewriter); - } - return SmallVector{value}; - }); - return rewriter.getInsertionBlock()->getArgument(0); - } - - /// Creates a global array stored in shared memory. - Value createSharedMemoryArray(Location loc, gpu::GPUModuleOp module, - LLVM::LLVMType elementType, int numElements, - ConversionPatternRewriter &rewriter) const { - OpBuilder builder(module.body()); - - auto arrayType = LLVM::LLVMType::getArrayTy(elementType, numElements); - StringRef name = "reduce_buffer"; - auto globalOp = builder.create( - loc, arrayType.cast(), - /*isConstant=*/false, LLVM::Linkage::Internal, name, - /*value=*/Attribute(), gpu::GPUDialect::getWorkgroupAddressSpace()); - - return rewriter.create(loc, globalOp); - } - - /// Returns the index of the thread within the block. - Value getLinearThreadIndex(Location loc, - ConversionPatternRewriter &rewriter) const { - Value dimX = rewriter.create(loc, int32Type); - Value dimY = rewriter.create(loc, int32Type); - Value idX = rewriter.create(loc, int32Type); - Value idY = rewriter.create(loc, int32Type); - Value idZ = rewriter.create(loc, int32Type); - Value tmp1 = rewriter.create(loc, int32Type, idZ, dimY); - Value tmp2 = rewriter.create(loc, int32Type, tmp1, idY); - Value tmp3 = rewriter.create(loc, int32Type, tmp2, dimX); - return rewriter.create(loc, int32Type, tmp3, idX); - } - - /// Returns the number of threads in the block. - Value getBlockSize(Location loc, ConversionPatternRewriter &rewriter) const { - Value dimX = rewriter.create(loc, int32Type); - Value dimY = rewriter.create(loc, int32Type); - Value dimZ = rewriter.create(loc, int32Type); - Value dimXY = rewriter.create(loc, int32Type, dimX, dimY); - return rewriter.create(loc, int32Type, dimXY, dimZ); - } - - /// Returns the number of warps in the block. - Value getNumWarps(Location loc, Value blockSize, - ConversionPatternRewriter &rewriter) const { - auto warpSizeMinusOne = rewriter.create( - loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize - 1)); - auto biasedBlockSize = rewriter.create( - loc, int32Type, blockSize, warpSizeMinusOne); - return getDivideByWarpSize(biasedBlockSize, rewriter); - } - - /// Returns the number of active threads in the warp, not clamped to 32. - Value getActiveWidth(Location loc, Value threadIdx, Value blockSize, - ConversionPatternRewriter &rewriter) const { - Value threadIdxMask = rewriter.create( - loc, int32Type, rewriter.getI32IntegerAttr(~(kWarpSize - 1))); - Value numThreadsWithSmallerWarpId = - rewriter.create(loc, threadIdx, threadIdxMask); - return rewriter.create(loc, blockSize, - numThreadsWithSmallerWarpId); - } - - /// Returns value divided by the warp size (i.e. 32). - Value getDivideByWarpSize(Value value, - ConversionPatternRewriter &rewriter) const { - auto loc = value.getLoc(); - auto warpSize = rewriter.create( - loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize)); - return rewriter.create(loc, int32Type, value, warpSize); - } - - LLVM::LLVMType int32Type; - - static constexpr int kWarpSize = 32; -}; - struct GPUShuffleOpLowering : public LLVMOpLowering { explicit GPUShuffleOpLowering(LLVMTypeConverter &lowering_) : LLVMOpLowering(gpu::ShuffleOp::getOperationName(), @@ -705,6 +291,14 @@ void runOnOperation() override { gpu::GPUModuleOp m = getOperation(); OwningRewritePatternList patterns; + + // Apply in-dialect lowering first. In-dialect lowering will replace ops + // which need to be lowered further, which is not supported by a single + // conversion pass. + populateGpuRewritePatterns(m.getContext(), patterns); + applyPatternsGreedily(m, patterns); + patterns.clear(); + NVVMTypeConverter converter(m.getContext()); populateStdToLLVMConversionPatterns(converter, patterns); populateGpuToNVVMConversionPatterns(converter, patterns); @@ -736,8 +330,8 @@ NVVM::BlockIdYOp, NVVM::BlockIdZOp>, GPUIndexIntrinsicOpLowering, - GPUAllReduceOpLowering, GPUShuffleOpLowering, GPUFuncOpLowering, - GPUReturnOpLowering>(converter); + GPUShuffleOpLowering, GPUFuncOpLowering, GPUReturnOpLowering>( + converter); patterns.insert>(converter, "__nv_fabsf", "__nv_fabs"); patterns.insert>(converter, "__nv_ceilf",