diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -51,21 +51,21 @@ // NVVM intrinsic operations //===----------------------------------------------------------------------===// -class NVVM_IntrOp overloadedResults, - list overloadedOperands, list traits, +class NVVM_IntrOp traits, int numResults> : LLVM_IntrOpBase; + /*list overloadedResults=*/[], + /*list overloadedOperands=*/[], + traits, numResults>; //===----------------------------------------------------------------------===// // NVVM special register op definitions //===----------------------------------------------------------------------===// -class NVVM_SpecialRegisterOp traits = []> : - NVVM_IntrOp, - Arguments<(ins)> { +class NVVM_SpecialRegisterOp traits = []> : + NVVM_IntrOp { + let arguments = (ins); let assemblyFormat = "attr-dict `:` type($res)"; } @@ -92,6 +92,16 @@ def NVVM_GridDimYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.y">; def NVVM_GridDimZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.z">; +//===----------------------------------------------------------------------===// +// NVVM approximate op definitions +//===----------------------------------------------------------------------===// + +def NVVM_RcpApproxFtzF32Op : NVVM_IntrOp<"rcp.approx.ftz.f", [NoSideEffect], 1> { + let arguments = (ins F32:$arg); + let results = (outs F32:$res); + let assemblyFormat = "$arg attr-dict `:` type($res)"; +} + //===----------------------------------------------------------------------===// // NVVM synchronization op definitions //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -148,6 +148,62 @@ } }; +// Replaces fdiv on fp16 with fp32 multiplication with reciprocal plus one +// (conditional) Newton iteration. +// +// This as accurate as promoting the division to fp32 in the NVPTX backend, but +// faster because it performs less Newton iterations, avoids the slow path +// for e.g. denormals, and allows reuse of the reciprocal for multiple divisions +// by the same divisor. +struct ExpandDivF16 : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + +private: + LogicalResult + matchAndRewrite(LLVM::FDivOp op, LLVM::FDivOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!op.getType().isF16()) + return rewriter.notifyMatchFailure(op, "not f16"); + Location loc = op.getLoc(); + + Type f32Type = rewriter.getF32Type(); + Type i32Type = rewriter.getI32Type(); + + // Extend lhs and rhs to fp32. + Value lhs = rewriter.create(loc, f32Type, adaptor.getLhs()); + Value rhs = rewriter.create(loc, f32Type, adaptor.getRhs()); + + // float rcp = rcp.approx.ftz.f32(rhs), approx = lhs * rcp. + Value rcp = rewriter.create(loc, f32Type, rhs); + Value approx = rewriter.create(loc, lhs, rcp); + + // Refine the approximation with one Newton iteration: + // float refined = approx + (lhs - approx * rhs) * rcp; + Value err = rewriter.create( + loc, approx, rewriter.create(loc, rhs), lhs); + Value refined = rewriter.create(loc, err, rcp, approx); + + // Use refined value if approx is normal (exponent neither all 0 or all 1). + Value mask = rewriter.create( + loc, i32Type, rewriter.getUI32IntegerAttr(0x7f800000)); + Value cast = rewriter.create(loc, i32Type, approx); + Value exp = rewriter.create(loc, i32Type, cast, mask); + Value zero = rewriter.create( + loc, i32Type, rewriter.getUI32IntegerAttr(0)); + Value pred = rewriter.create( + loc, + rewriter.create(loc, LLVM::ICmpPredicate::eq, exp, zero), + rewriter.create(loc, LLVM::ICmpPredicate::eq, exp, mask)); + Value result = + rewriter.create(loc, f32Type, pred, approx, refined); + + // Replace with trucation back to fp16. + rewriter.replaceOpWithNewOp(op, op.getType(), result); + + return success(); + } +}; + /// Import the GPU Ops to NVVM Patterns. #include "GPUToNVVM.cpp.inc" @@ -222,6 +278,10 @@ LLVM::FCeilOp, LLVM::FFloorOp, LLVM::LogOp, LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp, LLVM::SqrtOp>(); + // Expand fdiv on fp16 to faster code than NVPTX backend's fp32 promotion. + target.addDynamicallyLegalOp( + [&](LLVM::FDivOp op) { return !op.getType().isF16(); }); + // TODO: Remove once we support replacing non-root ops. target.addLegalOp(); } @@ -241,6 +301,8 @@ GPULaneIdOpToNVVM, GPUShuffleOpLowering, GPUReturnOpLowering>( converter); + patterns.add(converter); + // Explicitly drop memory space when lowering private memory // attributions since NVVM models it as `alloca`s in the default // memory space and does not support `alloca`s with addrspace(5). diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir --- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir +++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir @@ -488,3 +488,30 @@ } } +// ----- + +gpu.module @test_module { + // CHECK-LABEL: func @gpu_divf_fp16 + func.func @gpu_divf_fp16(%arg0 : f16, %arg1 : f16) -> f16 { + // CHECK: %[[lhs:.*]] = llvm.fpext %arg0 : f16 to f32 + // CHECK: %[[rhs:.*]] = llvm.fpext %arg1 : f16 to f32 + // CHECK: %[[rcp:.*]] = nvvm.rcp.approx.ftz.f %1 : f32 + // CHECK: %[[approx:.*]] = llvm.fmul %[[lhs]], %[[rcp]] : f32 + // CHECK: %[[neg:.*]] = llvm.fneg %[[rhs]] : f32 + // CHECK: %[[err:.*]] = "llvm.intr.fma"(%[[approx]], %[[neg]], %[[lhs]]) : (f32, f32, f32) -> f32 + // CHECK: %[[refined:.*]] = "llvm.intr.fma"(%[[err]], %[[rcp]], %[[approx]]) : (f32, f32, f32) -> f32 + // CHECK: %[[mask:.*]] = llvm.mlir.constant(2139095040 : ui32) : i32 + // CHECK: %[[cast:.*]] = llvm.bitcast %[[approx]] : f32 to i32 + // CHECK: %[[exp:.*]] = llvm.and %[[cast]], %[[mask]] : i32 + // CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : ui32) : i32 + // CHECK: %[[is_zero:.*]] = llvm.icmp "eq" %[[exp]], %[[c0]] : i32 + // CHECK: %[[is_mask:.*]] = llvm.icmp "eq" %[[exp]], %[[mask]] : i32 + // CHECK: %[[pred:.*]] = llvm.or %[[is_zero]], %[[is_mask]] : i1 + // CHECK: %[[select:.*]] = llvm.select %[[pred]], %[[approx]], %[[refined]] : i1, f32 + // CHECK: %[[result:.*]] = llvm.fptrunc %[[select]] : f32 to f16 + %result = arith.divf %arg0, %arg1 : f16 + // CHECK: llvm.return %[[result]] : f16 + func.return %result : f16 + } +} + diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir --- a/mlir/test/Dialect/LLVMIR/nvvm.mlir +++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir @@ -29,6 +29,13 @@ llvm.return %0 : i32 } +// CHECK-LABEL: @nvvm_rcp +func.func @nvvm_rcp(%arg0: f32) -> f32 { + // CHECK: nvvm.rcp.approx.ftz.f %arg0 : f32 + %0 = nvvm.rcp.approx.ftz.f %arg0 : f32 + llvm.return %0 : f32 +} + // CHECK-LABEL: @llvm_nvvm_barrier0 func.func @llvm_nvvm_barrier0() { // CHECK: nvvm.barrier0 diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir --- a/mlir/test/Target/LLVMIR/nvvmir.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir.mlir @@ -33,6 +33,13 @@ llvm.return %1 : i32 } +// CHECK-LABEL: @nvvm_rcp +llvm.func @nvvm_rcp(%0: f32) -> f32 { + // CHECK: call float @llvm.nvvm.rcp.approx.ftz.f + %1 = nvvm.rcp.approx.ftz.f %0 : f32 + llvm.return %1 : f32 +} + // CHECK-LABEL: @llvm_nvvm_barrier0 llvm.func @llvm_nvvm_barrier0() { // CHECK: call void @llvm.nvvm.barrier0()