diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h --- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h +++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h @@ -20,6 +20,9 @@ /// depending on the element type that Op operates upon. The function /// declaration is added in case it was not added before. /// +/// If the input values are of f16 type, the value is first casted to f32, the +/// function called and then the result casted back. +/// /// Example with NVVM: /// %exp_f32 = std.exp %arg_f32 : f32 /// @@ -44,21 +47,48 @@ std::is_base_of, SourceOp>::value, "expected single result op"); - LLVMType resultType = typeConverter.convertType(op->getResult(0).getType()) - .template cast(); - LLVMType funcType = getFunctionType(resultType, operands); - StringRef funcName = getFunctionName(resultType); + static_assert(std::is_base_of, + SourceOp>::value, + "expected op with same operand and result types"); + + SmallVector castedOperands; + for (Value operand : operands) + castedOperands.push_back(maybeCast(operand, rewriter)); + + LLVMType resultType = + castedOperands.front().getType().cast(); + LLVMType funcType = getFunctionType(resultType, castedOperands); + StringRef funcName = getFunctionName(funcType.getFunctionResultType()); if (funcName.empty()) return failure(); LLVMFuncOp funcOp = appendOrGetFuncOp(funcName, funcType, op); auto callOp = rewriter.create( - op->getLoc(), resultType, rewriter.getSymbolRefAttr(funcOp), operands); - rewriter.replaceOp(op, {callOp.getResult(0)}); + op->getLoc(), resultType, rewriter.getSymbolRefAttr(funcOp), + castedOperands); + + if (resultType == operands.front().getType()) { + rewriter.replaceOp(op, {callOp.getResult(0)}); + return success(); + } + + Value truncated = rewriter.create( + op->getLoc(), operands.front().getType(), callOp.getResult(0)); + rewriter.replaceOp(op, {truncated}); return success(); } private: + Value maybeCast(Value operand, PatternRewriter &rewriter) const { + LLVM::LLVMType type = operand.getType().cast(); + if (!type.isHalfTy()) + return operand; + + return rewriter.create( + operand.getLoc(), LLVM::LLVMType::getFloatTy(&type.getDialect()), + operand); + } + LLVM::LLVMType getFunctionType(LLVM::LLVMType resultType, ArrayRef operands) const { using LLVM::LLVMType; diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir --- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir +++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir @@ -219,12 +219,16 @@ // CHECK: llvm.func @__nv_tanhf(!llvm.float) -> !llvm.float // CHECK: llvm.func @__nv_tanh(!llvm.double) -> !llvm.double // CHECK-LABEL: func @gpu_tanh - func @gpu_tanh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + func @gpu_tanh(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) { + %result16 = std.tanh %arg_f16 : f16 + // CHECK: llvm.fpext %{{.*}} : !llvm.half to !llvm.float + // CHECK-NEXT: llvm.call @__nv_tanhf(%{{.*}}) : (!llvm.float) -> !llvm.float + // CHECK-NEXT: llvm.fptrunc %{{.*}} : !llvm.float to !llvm.half %result32 = std.tanh %arg_f32 : f32 // CHECK: llvm.call @__nv_tanhf(%{{.*}}) : (!llvm.float) -> !llvm.float %result64 = std.tanh %arg_f64 : f64 // CHECK: llvm.call @__nv_tanh(%{{.*}}) : (!llvm.double) -> !llvm.double - std.return %result32, %result64 : f32, f64 + std.return %result16, %result32, %result64 : f16, f32, f64 } }