diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -15,6 +15,7 @@ #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/GPU/Passes.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/Pass/Pass.h" @@ -28,415 +29,6 @@ namespace { -/// Converts all_reduce op to LLVM/NVVM ops. -struct GPUAllReduceOpLowering : public ConvertToLLVMPattern { - using AccumulatorFactory = - std::function; - - explicit GPUAllReduceOpLowering(LLVMTypeConverter &lowering_) - : ConvertToLLVMPattern(gpu::AllReduceOp::getOperationName(), - lowering_.getDialect()->getContext(), lowering_), - int32Type(LLVM::LLVMType::getInt32Ty(lowering_.getDialect())) {} - - PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - Location loc = op->getLoc(); - Value operand = operands.front(); - - // TODO(csigg): Generalize to other types of accumulation. - assert(op->getOperand(0).getType().isSignlessIntOrFloat()); - - // Create the reduction using an accumulator factory. - AccumulatorFactory factory = - getFactory(cast(op), operand); - assert(factory && "failed to create accumulator factory"); - Value result = createBlockReduce(loc, operand, factory, rewriter); - - rewriter.replaceOp(op, {result}); - return matchSuccess(); - } - -private: - /// Returns an accumulator factory using either the op attribute or the body - /// region. - AccumulatorFactory getFactory(gpu::AllReduceOp allReduce, - Value operand) const { - if (!allReduce.body().empty()) { - return getFactory(allReduce.body()); - } - if (allReduce.op()) { - auto type = operand.getType().cast(); - return getFactory(*allReduce.op(), type.getUnderlyingType()); - } - return AccumulatorFactory(); - } - - /// Returns an accumulator factory that clones the body. The body's entry - /// block is expected to have 2 arguments. The gpu.yield return the - /// accumulated value of the same type. - AccumulatorFactory getFactory(Region &body) const { - return AccumulatorFactory([&](Location loc, Value lhs, Value rhs, - ConversionPatternRewriter &rewriter) { - Block *block = rewriter.getInsertionBlock(); - Block *split = rewriter.splitBlock(block, rewriter.getInsertionPoint()); - - // Insert accumulator body between split block. - BlockAndValueMapping mapping; - mapping.map(body.front().getArgument(0), lhs); - mapping.map(body.front().getArgument(1), rhs); - rewriter.cloneRegionBefore(body, *split->getParent(), - split->getIterator(), mapping); - - // Add branch before inserted body, into body. - block = block->getNextNode(); - rewriter.create(loc, ValueRange(), block); - - // Replace all gpu.yield ops with branch out of body. - for (; block != split; block = block->getNextNode()) { - Operation *terminator = block->getTerminator(); - if (!llvm::isa(terminator)) - continue; - rewriter.setInsertionPointToEnd(block); - rewriter.replaceOpWithNewOp( - terminator, terminator->getOperand(0), split); - } - - // Return accumulator result. - rewriter.setInsertionPointToStart(split); - return split->addArgument(lhs.getType()); - }); - } - - /// Returns an accumulator factory that creates an op specified by opName. - AccumulatorFactory getFactory(StringRef opName, llvm::Type *type) const { - if (type->isVectorTy() || type->isArrayTy()) - return getFactory(opName, type->getSequentialElementType()); - - bool isFloatingPoint = type->isFloatingPointTy(); - - if (opName == "add") { - return isFloatingPoint ? getFactory() - : getFactory(); - } - if (opName == "mul") { - return isFloatingPoint ? getFactory() - : getFactory(); - } - - return AccumulatorFactory(); - } - - /// Returns an accumulator factory that creates an op of type T. - template AccumulatorFactory getFactory() const { - return [](Location loc, Value lhs, Value rhs, - ConversionPatternRewriter &rewriter) { - return rewriter.create(loc, lhs.getType(), lhs, rhs); - }; - } - - /// Creates an all_reduce across the block. - /// - /// First reduce the elements within a warp. The first thread of each warp - /// writes the intermediate result to shared memory. After synchronizing the - /// block, the first warp reduces the values from shared memory. The result - /// is broadcasted to all threads through shared memory. - /// - /// %warp_reduce = `createWarpReduce(%operand)` - /// %shared_mem_ptr = llvm.mlir.addressof @reduce_buffer - /// %zero = llvm.mlir.constant(0 : i32) : !llvm.i32 - /// %lane_id = nvvm.read.ptx.sreg.laneid : !llvm.i32 - /// %is_first_lane = llvm.icmp "eq" %lane_id, %zero : !llvm.i1 - /// %thread_idx = `getLinearThreadIndex()` : !llvm.i32 - /// llvm.cond_br %is_first_lane, ^then1, ^continue1 - /// ^then1: - /// %warp_id = `getWarpId()` - /// %store_dst = llvm.getelementptr %shared_mem_ptr[%zero, %warp_id] - /// llvm.store %store_dst, %warp_reduce - /// llvm.br ^continue1 - /// ^continue1: - /// nvvm.barrier0 - /// %num_warps = `getNumWarps()` : !llvm.i32 - /// %is_valid_warp = llvm.icmp "slt" %thread_idx, %num_warps - /// %result_ptr = llvm.getelementptr %shared_mem_ptr[%zero, %zero] - /// llvm.cond_br %is_first_lane, ^then2, ^continue2 - /// ^then2: - /// %load_src = llvm.getelementptr %shared_mem_ptr[%zero, %thread_idx] - /// %value = llvm.load %load_src - /// %result = `createWarpReduce(%value)` - /// llvm.store %result_ptr, %result - /// llvm.br ^continue2 - /// ^continue2: - /// nvvm.barrier0 - /// %result = llvm.load %result_ptr - /// return %result - /// - Value createBlockReduce(Location loc, Value operand, - AccumulatorFactory &accumFactory, - ConversionPatternRewriter &rewriter) const { - auto type = operand.getType().cast(); - - // Create shared memory array to store the warp reduction. - auto module = operand.getDefiningOp()->getParentOfType(); - assert(module && "op must belong to a module"); - Value sharedMemPtr = - createSharedMemoryArray(loc, module, type, kWarpSize, rewriter); - - Value zero = rewriter.create( - loc, int32Type, rewriter.getI32IntegerAttr(0u)); - Value laneId = rewriter.create(loc, int32Type); - Value isFirstLane = rewriter.create( - loc, LLVM::ICmpPredicate::eq, laneId, zero); - Value threadIdx = getLinearThreadIndex(loc, rewriter); - Value blockSize = getBlockSize(loc, rewriter); - Value activeWidth = getActiveWidth(loc, threadIdx, blockSize, rewriter); - - // Reduce elements within each warp to produce the intermediate results. - Value warpReduce = createWarpReduce(loc, activeWidth, laneId, operand, - accumFactory, rewriter); - - // Write the intermediate results to shared memory, using the first lane of - // each warp. - createPredicatedBlock(loc, rewriter, isFirstLane, [&] { - Value warpId = getDivideByWarpSize(threadIdx, rewriter); - Value storeDst = rewriter.create( - loc, type, sharedMemPtr, ArrayRef({zero, warpId})); - rewriter.create(loc, warpReduce, storeDst); - }); - rewriter.create(loc); - - Value numWarps = getNumWarps(loc, blockSize, rewriter); - Value isValidWarp = rewriter.create( - loc, LLVM::ICmpPredicate::slt, threadIdx, numWarps); - Value resultPtr = rewriter.create( - loc, type, sharedMemPtr, ArrayRef({zero, zero})); - - // Use the first numWarps threads to reduce the intermediate results from - // shared memory. The final result is written to shared memory again. - createPredicatedBlock(loc, rewriter, isValidWarp, [&] { - Value loadSrc = rewriter.create( - loc, type, sharedMemPtr, ArrayRef({zero, threadIdx})); - Value value = rewriter.create(loc, type, loadSrc); - Value result = createWarpReduce(loc, numWarps, laneId, value, - accumFactory, rewriter); - rewriter.create(loc, result, resultPtr); - }); - rewriter.create(loc); - - // Load and return result from shared memory. - Value result = rewriter.create(loc, type, resultPtr); - return result; - } - - /// Creates an if-block skeleton and calls the two factories to generate the - /// ops in the `then` and `else` block.. - /// - /// llvm.cond_br %condition, ^then, ^continue - /// ^then: - /// %then_operands = `thenOpsFactory()` - /// llvm.br ^continue(%then_operands) - /// ^else: - /// %else_operands = `elseOpsFactory()` - /// llvm.br ^continue(%else_operands) - /// ^continue(%block_operands): - /// - template - void createIf(Location loc, ConversionPatternRewriter &rewriter, - Value condition, ThenOpsFactory &&thenOpsFactory, - ElseOpsFactory &&elseOpsFactory) const { - Block *currentBlock = rewriter.getInsertionBlock(); - auto currentPoint = rewriter.getInsertionPoint(); - - Block *thenBlock = rewriter.splitBlock(currentBlock, currentPoint); - Block *elseBlock = rewriter.splitBlock(thenBlock, thenBlock->begin()); - Block *continueBlock = rewriter.splitBlock(elseBlock, elseBlock->begin()); - - rewriter.setInsertionPointToEnd(currentBlock); - rewriter.create(loc, condition, thenBlock, elseBlock); - - auto addBranch = [&](ValueRange operands) { - rewriter.create(loc, operands, continueBlock); - }; - - rewriter.setInsertionPointToStart(thenBlock); - auto thenOperands = thenOpsFactory(); - addBranch(thenOperands); - - rewriter.setInsertionPointToStart(elseBlock); - auto elseOperands = elseOpsFactory(); - addBranch(elseOperands); - - assert(thenOperands.size() == elseOperands.size()); - rewriter.setInsertionPointToStart(continueBlock); - for (auto operand : thenOperands) - continueBlock->addArgument(operand.getType()); - } - - /// Shortcut for createIf with empty else block and no block operands. - template - void createPredicatedBlock(Location loc, ConversionPatternRewriter &rewriter, - Value condition, - Factory &&predicatedOpsFactory) const { - createIf( - loc, rewriter, condition, - [&] { - predicatedOpsFactory(); - return ArrayRef(); - }, - [&] { return ArrayRef(); }); - } - - /// Creates a reduction across the first activeWidth lanes of a warp. - /// The first lane returns the result, all others return values are undefined. - Value createWarpReduce(Location loc, Value activeWidth, Value laneId, - Value operand, AccumulatorFactory accumFactory, - ConversionPatternRewriter &rewriter) const { - Value warpSize = rewriter.create( - loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize)); - Value isPartialWarp = rewriter.create( - loc, LLVM::ICmpPredicate::slt, activeWidth, warpSize); - auto type = operand.getType().cast(); - - createIf( - loc, rewriter, isPartialWarp, - // Generate reduction over a (potentially) partial warp. - [&] { - Value value = operand; - Value one = rewriter.create( - loc, int32Type, rewriter.getI32IntegerAttr(1)); - // Bit mask of active lanes: `(1 << activeWidth) - 1`. - Value activeMask = rewriter.create( - loc, int32Type, - rewriter.create(loc, int32Type, one, activeWidth), - one); - // Clamp lane: `activeWidth - 1` - Value maskAndClamp = - rewriter.create(loc, int32Type, activeWidth, one); - auto dialect = typeConverter.getDialect(); - auto predTy = LLVM::LLVMType::getInt1Ty(dialect); - auto shflTy = LLVM::LLVMType::getStructTy(dialect, {type, predTy}); - auto returnValueAndIsValidAttr = rewriter.getUnitAttr(); - - // Repeatedly shuffle value from 'laneId ^ i' and accumulate if source - // lane is within the active range. All lanes contain the final - // result, but only the first lane's result is used. - for (int i = 1; i < kWarpSize; i <<= 1) { - Value offset = rewriter.create( - loc, int32Type, rewriter.getI32IntegerAttr(i)); - Value shfl = rewriter.create( - loc, shflTy, activeMask, value, offset, maskAndClamp, - returnValueAndIsValidAttr); - Value isActiveSrcLane = rewriter.create( - loc, predTy, shfl, rewriter.getIndexArrayAttr(1)); - // Skip the accumulation if the shuffle op read from a lane outside - // of the active range. - createIf( - loc, rewriter, isActiveSrcLane, - [&] { - Value shflValue = rewriter.create( - loc, type, shfl, rewriter.getIndexArrayAttr(0)); - return SmallVector{ - accumFactory(loc, value, shflValue, rewriter)}; - }, - [&] { return llvm::makeArrayRef(value); }); - value = rewriter.getInsertionBlock()->getArgument(0); - } - return SmallVector{value}; - }, - // Generate a reduction over the entire warp. This is a specialization - // of the above reduction with unconditional accumulation. - [&] { - Value value = operand; - Value activeMask = rewriter.create( - loc, int32Type, rewriter.getI32IntegerAttr(~0u)); - Value maskAndClamp = rewriter.create( - loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize - 1)); - for (int i = 1; i < kWarpSize; i <<= 1) { - Value offset = rewriter.create( - loc, int32Type, rewriter.getI32IntegerAttr(i)); - Value shflValue = rewriter.create( - loc, type, activeMask, value, offset, maskAndClamp, - /*return_value_and_is_valid=*/UnitAttr()); - value = accumFactory(loc, value, shflValue, rewriter); - } - return SmallVector{value}; - }); - return rewriter.getInsertionBlock()->getArgument(0); - } - - /// Creates a global array stored in shared memory. - Value createSharedMemoryArray(Location loc, gpu::GPUModuleOp module, - LLVM::LLVMType elementType, int numElements, - ConversionPatternRewriter &rewriter) const { - OpBuilder builder(module.body()); - - auto arrayType = LLVM::LLVMType::getArrayTy(elementType, numElements); - StringRef name = "reduce_buffer"; - auto globalOp = builder.create( - loc, arrayType.cast(), - /*isConstant=*/false, LLVM::Linkage::Internal, name, - /*value=*/Attribute(), gpu::GPUDialect::getWorkgroupAddressSpace()); - - return rewriter.create(loc, globalOp); - } - - /// Returns the index of the thread within the block. - Value getLinearThreadIndex(Location loc, - ConversionPatternRewriter &rewriter) const { - Value dimX = rewriter.create(loc, int32Type); - Value dimY = rewriter.create(loc, int32Type); - Value idX = rewriter.create(loc, int32Type); - Value idY = rewriter.create(loc, int32Type); - Value idZ = rewriter.create(loc, int32Type); - Value tmp1 = rewriter.create(loc, int32Type, idZ, dimY); - Value tmp2 = rewriter.create(loc, int32Type, tmp1, idY); - Value tmp3 = rewriter.create(loc, int32Type, tmp2, dimX); - return rewriter.create(loc, int32Type, tmp3, idX); - } - - /// Returns the number of threads in the block. - Value getBlockSize(Location loc, ConversionPatternRewriter &rewriter) const { - Value dimX = rewriter.create(loc, int32Type); - Value dimY = rewriter.create(loc, int32Type); - Value dimZ = rewriter.create(loc, int32Type); - Value dimXY = rewriter.create(loc, int32Type, dimX, dimY); - return rewriter.create(loc, int32Type, dimXY, dimZ); - } - - /// Returns the number of warps in the block. - Value getNumWarps(Location loc, Value blockSize, - ConversionPatternRewriter &rewriter) const { - auto warpSizeMinusOne = rewriter.create( - loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize - 1)); - auto biasedBlockSize = rewriter.create( - loc, int32Type, blockSize, warpSizeMinusOne); - return getDivideByWarpSize(biasedBlockSize, rewriter); - } - - /// Returns the number of active threads in the warp, not clamped to 32. - Value getActiveWidth(Location loc, Value threadIdx, Value blockSize, - ConversionPatternRewriter &rewriter) const { - Value threadIdxMask = rewriter.create( - loc, int32Type, rewriter.getI32IntegerAttr(~(kWarpSize - 1))); - Value numThreadsWithSmallerWarpId = - rewriter.create(loc, threadIdx, threadIdxMask); - return rewriter.create(loc, blockSize, - numThreadsWithSmallerWarpId); - } - - /// Returns value divided by the warp size (i.e. 32). - Value getDivideByWarpSize(Value value, - ConversionPatternRewriter &rewriter) const { - auto loc = value.getLoc(); - auto warpSize = rewriter.create( - loc, int32Type, rewriter.getI32IntegerAttr(kWarpSize)); - return rewriter.create(loc, int32Type, value, warpSize); - } - - LLVM::LLVMType int32Type; - - static constexpr int kWarpSize = 32; -}; struct GPUShuffleOpLowering : public ConvertToLLVMPattern { explicit GPUShuffleOpLowering(LLVMTypeConverter &lowering_) @@ -671,6 +263,14 @@ }); OwningRewritePatternList patterns; + + // Apply in-dialect lowering first. In-dialect lowering will replace ops + // which need to be lowered further, which is not supported by a single + // conversion pass. + populateGpuRewritePatterns(m.getContext(), patterns); + applyPatternsGreedily(m, patterns); + patterns.clear(); + populateStdToLLVMConversionPatterns(converter, patterns); populateGpuToNVVMConversionPatterns(converter, patterns); LLVMConversionTarget target(getContext()); @@ -702,8 +302,8 @@ NVVM::BlockIdYOp, NVVM::BlockIdZOp>, GPUIndexIntrinsicOpLowering, - GPUAllReduceOpLowering, GPUShuffleOpLowering, GPUFuncOpLowering, - GPUReturnOpLowering>(converter); + GPUShuffleOpLowering, GPUFuncOpLowering, GPUReturnOpLowering>( + converter); patterns.insert>(converter, "__nv_fabsf", "__nv_fabs"); patterns.insert>(converter, "__nv_ceilf", diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir --- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir +++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir @@ -3,7 +3,8 @@ gpu.module @test_module { // CHECK-LABEL: func @gpu_index_ops() func @gpu_index_ops() - attributes { gpu.kernel } { + -> (index, index, index, index, index, index, + index, index, index, index, index, index) { // CHECK: = nvvm.read.ptx.sreg.tid.x : !llvm.i32 %tIdX = "gpu.thread_id"() {dimension = "x"} : () -> (index) // CHECK: = nvvm.read.ptx.sreg.tid.y : !llvm.i32 @@ -32,7 +33,10 @@ // CHECK: = nvvm.read.ptx.sreg.nctaid.z : !llvm.i32 %gDimZ = "gpu.grid_dim"() {dimension = "z"} : () -> (index) - std.return + std.return %tIdX, %tIdY, %tIdZ, %bDimX, %bDimY, %bDimZ, + %bIdX, %bIdY, %bIdZ, %gDimX, %gDimY, %gDimZ + : index, index, index, index, index, index, + index, index, index, index, index, index } } @@ -40,8 +44,7 @@ gpu.module @test_module { // CHECK-LABEL: func @gpu_all_reduce_op() - func @gpu_all_reduce_op() - attributes { gpu.kernel } { + gpu.func @gpu_all_reduce_op() { %arg0 = constant 1.0 : f32 // TODO(csigg): Check full IR expansion once lowering has settled. // CHECK: nvvm.shfl.sync.bfly @@ -49,7 +52,7 @@ // CHECK: llvm.fadd %result = "gpu.all_reduce"(%arg0) ({}) {op = "add"} : (f32) -> (f32) - std.return + gpu.return } } @@ -57,8 +60,7 @@ gpu.module @test_module { // CHECK-LABEL: func @gpu_all_reduce_region() - func @gpu_all_reduce_region() - attributes { gpu.kernel } { + gpu.func @gpu_all_reduce_region() { %arg0 = constant 1 : i32 // TODO(csigg): Check full IR expansion once lowering has settled. // CHECK: nvvm.shfl.sync.bfly @@ -68,7 +70,7 @@ %xor = xor %lhs, %rhs : i32 "gpu.yield"(%xor) : (i32) -> () }) : (i32) -> (i32) - std.return + gpu.return } } @@ -76,8 +78,7 @@ gpu.module @test_module { // CHECK-LABEL: func @gpu_shuffle() - func @gpu_shuffle() - attributes { gpu.kernel } { + func @gpu_shuffle() -> (f32) { // CHECK: %[[#VALUE:]] = llvm.mlir.constant(1.000000e+00 : f32) : !llvm.float %arg0 = constant 1.0 : f32 // CHECK: %[[#OFFSET:]] = llvm.mlir.constant(4 : i32) : !llvm.i32 @@ -93,7 +94,7 @@ // CHECK: llvm.extractvalue %[[#SHFL]][1 : index] : !llvm<"{ float, i1 }"> %shfl, %pred = "gpu.shuffle"(%arg0, %arg1, %arg2) { mode = "xor" } : (f32, i32, i32) -> (f32, i1) - std.return + std.return %shfl : f32 } } @@ -101,8 +102,7 @@ gpu.module @test_module { // CHECK-LABEL: func @gpu_sync() - func @gpu_sync() - attributes { gpu.kernel } { + func @gpu_sync() { // CHECK: nvvm.barrier0 gpu.barrier std.return @@ -115,12 +115,12 @@ // CHECK: llvm.func @__nv_fabsf(!llvm.float) -> !llvm.float // CHECK: llvm.func @__nv_fabs(!llvm.double) -> !llvm.double // CHECK-LABEL: func @gpu_fabs - func @gpu_fabs(%arg_f32 : f32, %arg_f64 : f64) { + func @gpu_fabs(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { %result32 = std.absf %arg_f32 : f32 // CHECK: llvm.call @__nv_fabsf(%{{.*}}) : (!llvm.float) -> !llvm.float %result64 = std.absf %arg_f64 : f64 // CHECK: llvm.call @__nv_fabs(%{{.*}}) : (!llvm.double) -> !llvm.double - std.return + std.return %result32, %result64 : f32, f64 } } @@ -130,12 +130,12 @@ // CHECK: llvm.func @__nv_ceilf(!llvm.float) -> !llvm.float // CHECK: llvm.func @__nv_ceil(!llvm.double) -> !llvm.double // CHECK-LABEL: func @gpu_ceil - func @gpu_ceil(%arg_f32 : f32, %arg_f64 : f64) { + func @gpu_ceil(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { %result32 = std.ceilf %arg_f32 : f32 // CHECK: llvm.call @__nv_ceilf(%{{.*}}) : (!llvm.float) -> !llvm.float %result64 = std.ceilf %arg_f64 : f64 // CHECK: llvm.call @__nv_ceil(%{{.*}}) : (!llvm.double) -> !llvm.double - std.return + std.return %result32, %result64 : f32, f64 } } @@ -145,12 +145,12 @@ // CHECK: llvm.func @__nv_cosf(!llvm.float) -> !llvm.float // CHECK: llvm.func @__nv_cos(!llvm.double) -> !llvm.double // CHECK-LABEL: func @gpu_cos - func @gpu_cos(%arg_f32 : f32, %arg_f64 : f64) { + func @gpu_cos(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { %result32 = std.cos %arg_f32 : f32 // CHECK: llvm.call @__nv_cosf(%{{.*}}) : (!llvm.float) -> !llvm.float %result64 = std.cos %arg_f64 : f64 // CHECK: llvm.call @__nv_cos(%{{.*}}) : (!llvm.double) -> !llvm.double - std.return + std.return %result32, %result64 : f32, f64 } } @@ -159,14 +159,12 @@ // CHECK: llvm.func @__nv_expf(!llvm.float) -> !llvm.float // CHECK: llvm.func @__nv_exp(!llvm.double) -> !llvm.double // CHECK-LABEL: func @gpu_exp - func @gpu_exp(%arg_f32 : f32, %arg_f64 : f64) { - %exp_f32 = std.exp %arg_f32 : f32 - // CHECK: llvm.call @__nv_expf(%{{.*}}) : (!llvm.float) -> !llvm.float - %result_f32 = std.exp %exp_f32 : f32 + func @gpu_exp(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + %result32 = std.exp %arg_f32 : f32 // CHECK: llvm.call @__nv_expf(%{{.*}}) : (!llvm.float) -> !llvm.float %result64 = std.exp %arg_f64 : f64 // CHECK: llvm.call @__nv_exp(%{{.*}}) : (!llvm.double) -> !llvm.double - std.return + std.return %result32, %result64 : f32, f64 } } @@ -176,12 +174,12 @@ // CHECK: llvm.func @__nv_logf(!llvm.float) -> !llvm.float // CHECK: llvm.func @__nv_log(!llvm.double) -> !llvm.double // CHECK-LABEL: func @gpu_log - func @gpu_log(%arg_f32 : f32, %arg_f64 : f64) { + func @gpu_log(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { %result32 = std.log %arg_f32 : f32 // CHECK: llvm.call @__nv_logf(%{{.*}}) : (!llvm.float) -> !llvm.float %result64 = std.log %arg_f64 : f64 // CHECK: llvm.call @__nv_log(%{{.*}}) : (!llvm.double) -> !llvm.double - std.return + std.return %result32, %result64 : f32, f64 } } @@ -191,12 +189,12 @@ // CHECK: llvm.func @__nv_log10f(!llvm.float) -> !llvm.float // CHECK: llvm.func @__nv_log10(!llvm.double) -> !llvm.double // CHECK-LABEL: func @gpu_log10 - func @gpu_log10(%arg_f32 : f32, %arg_f64 : f64) { + func @gpu_log10(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { %result32 = std.log10 %arg_f32 : f32 // CHECK: llvm.call @__nv_log10f(%{{.*}}) : (!llvm.float) -> !llvm.float %result64 = std.log10 %arg_f64 : f64 // CHECK: llvm.call @__nv_log10(%{{.*}}) : (!llvm.double) -> !llvm.double - std.return + std.return %result32, %result64 : f32, f64 } } @@ -206,12 +204,12 @@ // CHECK: llvm.func @__nv_log2f(!llvm.float) -> !llvm.float // CHECK: llvm.func @__nv_log2(!llvm.double) -> !llvm.double // CHECK-LABEL: func @gpu_log2 - func @gpu_log2(%arg_f32 : f32, %arg_f64 : f64) { + func @gpu_log2(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { %result32 = std.log2 %arg_f32 : f32 // CHECK: llvm.call @__nv_log2f(%{{.*}}) : (!llvm.float) -> !llvm.float %result64 = std.log2 %arg_f64 : f64 // CHECK: llvm.call @__nv_log2(%{{.*}}) : (!llvm.double) -> !llvm.double - std.return + std.return %result32, %result64 : f32, f64 } } @@ -221,12 +219,12 @@ // CHECK: llvm.func @__nv_tanhf(!llvm.float) -> !llvm.float // CHECK: llvm.func @__nv_tanh(!llvm.double) -> !llvm.double // CHECK-LABEL: func @gpu_tanh - func @gpu_tanh(%arg_f32 : f32, %arg_f64 : f64) { + func @gpu_tanh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { %result32 = std.tanh %arg_f32 : f32 // CHECK: llvm.call @__nv_tanhf(%{{.*}}) : (!llvm.float) -> !llvm.float %result64 = std.tanh %arg_f64 : f64 // CHECK: llvm.call @__nv_tanh(%{{.*}}) : (!llvm.double) -> !llvm.double - std.return + std.return %result32, %result64 : f32, f64 } } @@ -239,14 +237,12 @@ // CHECK: llvm.func @__nv_expf(!llvm.float) -> !llvm.float // CHECK: llvm.func @__nv_exp(!llvm.double) -> !llvm.double // CHECK-LABEL: func @gpu_exp - func @gpu_exp(%arg_f32 : f32, %arg_f64 : f64) { - %exp_f32 = std.exp %arg_f32 : f32 - // CHECK: llvm.call @__nv_expf(%{{.*}}) : (!llvm.float) -> !llvm.float - %result_f32 = std.exp %exp_f32 : f32 + func @gpu_exp(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + %result32 = std.exp %arg_f32 : f32 // CHECK: llvm.call @__nv_expf(%{{.*}}) : (!llvm.float) -> !llvm.float %result64 = std.exp %arg_f64 : f64 // CHECK: llvm.call @__nv_exp(%{{.*}}) : (!llvm.double) -> !llvm.double - std.return + std.return %result32, %result64 : f32, f64 } "test.finish" () : () -> () }) : () -> ()