diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -51,21 +51,21 @@ // NVVM intrinsic operations //===----------------------------------------------------------------------===// -class NVVM_IntrOp overloadedResults, - list overloadedOperands, list traits, +class NVVM_IntrOp traits, int numResults> : LLVM_IntrOpBase; + /*list overloadedResults=*/[], + /*list overloadedOperands=*/[], + traits, numResults>; //===----------------------------------------------------------------------===// // NVVM special register op definitions //===----------------------------------------------------------------------===// -class NVVM_SpecialRegisterOp traits = []> : - NVVM_IntrOp, - Arguments<(ins)> { +class NVVM_SpecialRegisterOp traits = []> : + NVVM_IntrOp { + let arguments = (ins); let assemblyFormat = "attr-dict `:` type($res)"; } @@ -92,6 +92,16 @@ def NVVM_GridDimYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.y">; def NVVM_GridDimZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.z">; +//===----------------------------------------------------------------------===// +// NVVM approximate op definitions +//===----------------------------------------------------------------------===// + +def NVVM_RcpApproxFtzF32Op : NVVM_IntrOp<"rcp.approx.ftz.f", [NoSideEffect], 1> { + let arguments = (ins F32:$arg); + let results = (outs F32:$res); + let assemblyFormat = "$arg attr-dict `:` type($res)"; +} + //===----------------------------------------------------------------------===// // NVVM synchronization op definitions //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -148,6 +148,62 @@ } }; +// Replaces fdiv on fp16 with fp32 multiplication with reciprocal plus one +// (conditional) Newton iteration. +// +// This as accurate as promoting the division to fp32 in the NVPTX backend, but +// faster because it performs less Newton iterations, avoids the slow path +// for e.g. denormals, and allows reuse of the reciprocal for multiple divisions +// by the same divisor. +struct ExpandDivF16 : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + +private: + LogicalResult + matchAndRewrite(LLVM::FDivOp op, LLVM::FDivOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!op.getType().isF16()) + return rewriter.notifyMatchFailure(op, "not f16"); + Location loc = op.getLoc(); + + Type f32Type = rewriter.getF32Type(); + Type i32Type = rewriter.getI32Type(); + + // Extend lhs and rhs to fp32. + Value lhs = rewriter.create(loc, f32Type, adaptor.getLhs()); + Value rhs = rewriter.create(loc, f32Type, adaptor.getRhs()); + + // float rcp = rcp.approx.ftz.f32(rhs), approx = lhs * rcp. + Value rcp = rewriter.create(loc, f32Type, rhs); + Value approx = rewriter.create(loc, lhs, rcp); + + // Refine the approximation with one Newton iteration: + // float refined = approx + (lhs - approx * rhs) * rcp; + Value err = rewriter.create( + loc, approx, rewriter.create(loc, rhs), lhs); + Value refined = rewriter.create(loc, err, rcp, approx); + + // Use refined value if approx is normal (exponent neither all 0 or all 1). + Value mask = rewriter.create( + loc, i32Type, rewriter.getUI32IntegerAttr(0x7f800000)); + Value cast = rewriter.create(loc, i32Type, approx); + Value exp = rewriter.create(loc, i32Type, cast, mask); + Value zero = rewriter.create( + loc, i32Type, rewriter.getUI32IntegerAttr(0)); + Value pred = rewriter.create( + loc, + rewriter.create(loc, LLVM::ICmpPredicate::eq, exp, zero), + rewriter.create(loc, LLVM::ICmpPredicate::eq, exp, mask)); + Value result = + rewriter.create(loc, f32Type, pred, approx, refined); + + // Replace with trucation back to fp16. + rewriter.replaceOpWithNewOp(op, op.getType(), result); + + return success(); + } +}; + /// Import the GPU Ops to NVVM Patterns. #include "GPUToNVVM.cpp.inc" @@ -222,6 +278,10 @@ LLVM::FCeilOp, LLVM::FFloorOp, LLVM::LogOp, LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp, LLVM::SqrtOp>(); + // Expand fdiv on fp16 to faster code than NVPTX backend's fp32 promotion. + target.addDynamicallyLegalOp( + [&](LLVM::FDivOp op) { return !op.getType().isF16(); }); + // TODO: Remove once we support replacing non-root ops. target.addLegalOp(); } @@ -241,6 +301,8 @@ GPULaneIdOpToNVVM, GPUShuffleOpLowering, GPUReturnOpLowering>( converter); + patterns.add(converter); + // Explicitly drop memory space when lowering private memory // attributions since NVVM models it as `alloca`s in the default // memory space and does not support `alloca`s with addrspace(5). diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir --- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir +++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir @@ -488,3 +488,30 @@ } } +// ----- + +gpu.module @test_module { + // CHECK-LABEL: func @gpu_divf_fp16 + func.func @gpu_divf_fp16(%arg0 : f16, %arg1 : f16) -> f16 { + // CHECK: %[[lhs:.*]] = llvm.fpext %arg0 : f16 to f32 + // CHECK: %[[rhs:.*]] = llvm.fpext %arg1 : f16 to f32 + // CHECK: %[[rcp:.*]] = nvvm.rcp.approx.ftz.f %1 : f32 + // CHECK: %[[approx:.*]] = llvm.fmul %[[lhs]], %[[rcp]] : f32 + // CHECK: %[[neg:.*]] = llvm.fneg %[[rhs]] : f32 + // CHECK: %[[err:.*]] = "llvm.intr.fma"(%[[approx]], %[[neg]], %[[lhs]]) : (f32, f32, f32) -> f32 + // CHECK: %[[refined:.*]] = "llvm.intr.fma"(%[[err]], %[[rcp]], %[[approx]]) : (f32, f32, f32) -> f32 + // CHECK: %[[mask:.*]] = llvm.mlir.constant(2139095040 : ui32) : i32 + // CHECK: %[[cast:.*]] = llvm.bitcast %[[approx]] : f32 to i32 + // CHECK: %[[exp:.*]] = llvm.and %[[cast]], %[[mask]] : i32 + // CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : ui32) : i32 + // CHECK: %[[is_zero:.*]] = llvm.icmp "eq" %[[exp]], %[[c0]] : i32 + // CHECK: %[[is_mask:.*]] = llvm.icmp "eq" %[[exp]], %[[mask]] : i32 + // CHECK: %[[pred:.*]] = llvm.or %[[is_zero]], %[[is_mask]] : i1 + // CHECK: %[[select:.*]] = llvm.select %[[pred]], %[[approx]], %[[refined]] : i1, f32 + // CHECK: %[[result:.*]] = llvm.fptrunc %[[select]] : f32 to f16 + %result = arith.divf %arg0, %arg1 : f16 + // CHECK: llvm.return %[[result]] : f16 + func.return %result : f16 + } +} + diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir --- a/mlir/test/Dialect/LLVMIR/nvvm.mlir +++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir @@ -1,5 +1,6 @@ // RUN: mlir-opt %s -split-input-file -verify-diagnostics | FileCheck %s +// CHECK-LABEL: @nvvm_special_regs func.func @nvvm_special_regs() -> i32 { // CHECK: nvvm.read.ptx.sreg.tid.x : i32 %0 = nvvm.read.ptx.sreg.tid.x : i32 @@ -28,12 +29,21 @@ llvm.return %0 : i32 } -func.func @llvm.nvvm.barrier0() { +// CHECK-LABEL: @nvvm_rcp +func.func @nvvm_rcp(%arg0: f32) -> f32 { + // CHECK: nvvm.rcp.approx.ftz.f %arg0 : f32 + %0 = nvvm.rcp.approx.ftz.f %arg0 : f32 + llvm.return %0 : f32 +} + +// CHECK-LABEL: @llvm_nvvm_barrier0 +func.func @llvm_nvvm_barrier0() { // CHECK: nvvm.barrier0 nvvm.barrier0 llvm.return } +// CHECK-LABEL: @nvvm_shfl func.func @nvvm_shfl( %arg0 : i32, %arg1 : i32, %arg2 : i32, %arg3 : i32, %arg4 : f32) -> i32 { @@ -50,6 +60,7 @@ llvm.return %0 : i32 } +// CHECK-LABEL: @nvvm_shfl_pred func.func @nvvm_shfl_pred( %arg0 : i32, %arg1 : i32, %arg2 : i32, %arg3 : i32, %arg4 : f32) -> !llvm.struct<(i32, i1)> { @@ -60,6 +71,7 @@ llvm.return %0 : !llvm.struct<(i32, i1)> } +// CHECK-LABEL: @nvvm_vote( func.func @nvvm_vote(%arg0 : i32, %arg1 : i1) -> i32 { // CHECK: nvvm.vote.ballot.sync %{{.*}}, %{{.*}} : i32 %0 = nvvm.vote.ballot.sync %arg0, %arg1 : i32 @@ -77,6 +89,7 @@ llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> } +// CHECK-LABEL: @nvvm_mma_m8n8k4_f16_f16 func.func @nvvm_mma_m8n8k4_f16_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>, %b0 : vector<2xf16>, %b1 : vector<2xf16>, %c0 : vector<2xf16>, %c1 : vector<2xf16>, %c2 : vector<2xf16>, %c3 : vector<2xf16>) { @@ -87,6 +100,7 @@ llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> } +// CHECK-LABEL: @nvvm_mma_m8n8k16_s8_s8 func.func @nvvm_mma_m8n8k16_s8_s8(%a0 : i32, %b0 : i32, %c0 : i32, %c1 : i32) { // CHECK: nvvm.mma.sync A[{{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow, layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, shape = {k = 16 : i32, m = 8 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32)> @@ -98,7 +112,8 @@ llvm.return %0 : !llvm.struct<(i32, i32)> } -func.func @nvvm_mma_m16n8k8_f16_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>, +// CHECK-LABEL: @nvvm_mma_m16n8k8_f16_f16 +func.func @nvvm_mma_m16n8k8_f16_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>, %b0 : vector<2xf16>, %c0 : vector<2xf16>, %c1 : vector<2xf16>) { // CHECK: nvvm.mma.sync A[%{{.*}}, %{{.*}}] B[%{{.*}}] C[%{{.*}}, %{{.*}}] {{{.*}}} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> @@ -108,6 +123,7 @@ llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)> } +// CHECK-LABEL: @nvvm_mma_m16n8k16_f16_f16 func.func @nvvm_mma_m16n8k16_f16_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>, %a2 : vector<2xf16>, %a3 : vector<2xf16>, %b0 : vector<2xf16>, %b1 : vector<2xf16>, @@ -119,6 +135,7 @@ llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)> } +// CHECK-LABEL: @nvvm_mma_m16n8k16_f32_f16 func.func @nvvm_mma_m16n8k16_f32_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>, %a2 : vector<2xf16>, %a3 : vector<2xf16>, %b0 : vector<2xf16>, %b1 : vector<2xf16>, @@ -130,6 +147,7 @@ llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)> } +// CHECK-LABEL: @nvvm_mma_m16n8k16_f16_f32 func.func @nvvm_mma_m16n8k16_f16_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>, %a2 : vector<2xf16>, %a3 : vector<2xf16>, %b0 : vector<2xf16>, %b1 : vector<2xf16>, @@ -141,6 +159,7 @@ llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)> } +// CHECK-LABEL: @nvvm_mma_m16n8k16_f32_f32 func.func @nvvm_mma_m16n8k16_f32_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>, %a2 : vector<2xf16>, %a3 : vector<2xf16>, %b0 : vector<2xf16>, %b1 : vector<2xf16>, @@ -152,7 +171,8 @@ llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)> } -func.func @nvvm_mma_m16n8k4_tf32_f32(%a0 : i32, %a1 : i32, +// CHECK-LABEL: @nvvm_mma_m16n8k4_tf32_f32 +func.func @nvvm_mma_m16n8k4_tf32_f32(%a0 : i32, %a1 : i32, %b0 : i32, %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) { // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, shape = {k = 4 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)> @@ -163,7 +183,8 @@ llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)> } -func.func @nvvm_mma_m16n8k16_s8_s8(%a0 : i32, %a1 : i32, %b0 : i32, +// CHECK-LABEL: @nvvm_mma_m16n8k16_s8_s8 +func.func @nvvm_mma_m16n8k16_s8_s8(%a0 : i32, %a1 : i32, %b0 : i32, %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) { // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow, layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3] @@ -174,7 +195,8 @@ llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)> } -func.func @nvvm_mma_m16n8k16_s8_u8(%a0 : i32, %a1 : i32, +// CHECK-LABEL: @nvvm_mma_m16n8k16_s8_u8 +func.func @nvvm_mma_m16n8k16_s8_u8(%a0 : i32, %a1 : i32, %b0 : i32, %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) { // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow, layoutA = #nvvm.mma_layout, layoutB = #nvvm.mma_layout, multiplicandAPtxType = #nvvm.mma_type, multiplicandBPtxType = #nvvm.mma_type, shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)> @@ -186,6 +208,7 @@ llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)> } +// CHECK-LABEL: @nvvm_mma_m16n8k256_b1_b1 func.func @nvvm_mma_m16n8k256_b1_b1(%a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32, %b0 : i32, %b1 : i32, %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) { @@ -197,6 +220,7 @@ llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)> } +// CHECK-LABEL: @nvvm_mma_m16n8k128_b1_b1 func.func @nvvm_mma_m16n8k128_b1_b1(%a0 : i32, %a1 : i32, %b0 : i32, %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) { @@ -243,6 +267,7 @@ llvm.return %0 : !llvm.struct<(i32, i32, i32, i32)> } +// CHECK-LABEL: @nvvm_wmma_mma func.func @nvvm_wmma_mma(%0 : i32, %1 : i32, %2 : i32, %3 : i32, %4 : i32, %5 : i32, %6 : i32, %7 : i32, %8 : f32, %9 : f32, %10 : f32, %11 : f32, %12 : f32, %13 : f32, %14 : f32, %15 : f32) @@ -255,6 +280,7 @@ llvm.return %r : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> } +// CHECK-LABEL: @cp_async llvm.func @cp_async(%arg0: !llvm.ptr, %arg1: !llvm.ptr) { // CHECK: nvvm.cp.async.shared.global %{{.*}}, %{{.*}}, 16 nvvm.cp.async.shared.global %arg0, %arg1, 16 diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir --- a/mlir/test/Target/LLVMIR/nvvmir.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir.mlir @@ -1,5 +1,6 @@ // RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s +// CHECK-LABEL: @nvvm_special_regs llvm.func @nvvm_special_regs() -> i32 { // CHECK: %1 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x() %1 = nvvm.read.ptx.sreg.tid.x : i32 @@ -32,12 +33,21 @@ llvm.return %1 : i32 } +// CHECK-LABEL: @nvvm_rcp +llvm.func @nvvm_rcp(%0: f32) -> f32 { + // CHECK: call float @llvm.nvvm.rcp.approx.ftz.f + %1 = nvvm.rcp.approx.ftz.f %0 : f32 + llvm.return %1 : f32 +} + +// CHECK-LABEL: @llvm_nvvm_barrier0 llvm.func @llvm_nvvm_barrier0() { // CHECK: call void @llvm.nvvm.barrier0() nvvm.barrier0 llvm.return } +// CHECK-LABEL: @nvvm_shfl llvm.func @nvvm_shfl( %0 : i32, %1 : i32, %2 : i32, %3 : i32, %4 : f32) -> i32 { @@ -60,6 +70,7 @@ llvm.return %6 : i32 } +// CHECK-LABEL: @nvvm_shfl_pred llvm.func @nvvm_shfl_pred( %0 : i32, %1 : i32, %2 : i32, %3 : i32, %4 : f32) -> !llvm.struct<(i32, i1)> { @@ -82,6 +93,7 @@ llvm.return %6 : !llvm.struct<(i32, i1)> } +// CHECK-LABEL: @nvvm_vote llvm.func @nvvm_vote(%0 : i32, %1 : i1) -> i32 { // CHECK: call i32 @llvm.nvvm.vote.ballot.sync(i32 %{{.*}}, i1 %{{.*}}) %3 = nvvm.vote.ballot.sync %0, %1 : i32 @@ -99,6 +111,7 @@ llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> } +// CHECK-LABEL: @nvvm_mma_m16n8k16_f16_f16 llvm.func @nvvm_mma_m16n8k16_f16_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>, %a2 : vector<2xf16>, %a3 : vector<2xf16>, %b0 : vector<2xf16>, %b1 : vector<2xf16>, @@ -111,6 +124,7 @@ } // f32 return type, f16 accumulate type +// CHECK-LABEL: @nvvm_mma_m16n8k16_f32_f16 llvm.func @nvvm_mma_m16n8k16_f32_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>, %a2 : vector<2xf16>, %a3 : vector<2xf16>, %b0 : vector<2xf16>, %b1 : vector<2xf16>, @@ -123,6 +137,7 @@ } // f16 return type, f32 accumulate type +// CHECK-LABEL: @nvvm_mma_m16n8k16_f16_f32 llvm.func @nvvm_mma_m16n8k16_f16_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>, %a2 : vector<2xf16>, %a3 : vector<2xf16>, %b0 : vector<2xf16>, %b1 : vector<2xf16>, @@ -135,6 +150,7 @@ } // f32 return type, f32 accumulate type +// CHECK-LABEL: @nvvm_mma_m16n8k16_f32_f32 llvm.func @nvvm_mma_m16n8k16_f32_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>, %a2 : vector<2xf16>, %a3 : vector<2xf16>, %b0 : vector<2xf16>, %b1 : vector<2xf16>, @@ -146,7 +162,8 @@ llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)> } -llvm.func @nvvm_mma_m16n8k16_s8_s8(%a0 : i32, %a1 : i32, +// CHECK-LABEL: @nvvm_mma_m16n8k16_s8_s8 +llvm.func @nvvm_mma_m16n8k16_s8_s8(%a0 : i32, %a1 : i32, %b0 : i32, %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) -> !llvm.struct<(i32, i32, i32, i32)> { // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.mma.m16n8k16.row.col.s8 @@ -158,7 +175,8 @@ llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)> } -llvm.func @nvvm_mma_m16n8k16_s8_u8(%a0 : i32, %a1 : i32, +// CHECK-LABEL: @nvvm_mma_m16n8k16_s8_u8 +llvm.func @nvvm_mma_m16n8k16_s8_u8(%a0 : i32, %a1 : i32, %b0 : i32, %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) -> !llvm.struct<(i32, i32, i32, i32)> { // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.mma.m16n8k16.row.col.satfinite.s8.u8 @@ -170,7 +188,8 @@ llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)> } -llvm.func @nvvm_mma_m16n8k128_b1_b1(%a0 : i32, %a1 : i32, +// CHECK-LABEL: @nvvm_mma_m16n8k128_b1_b1 +llvm.func @nvvm_mma_m16n8k128_b1_b1(%a0 : i32, %a1 : i32, %b0 : i32, %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) -> !llvm.struct<(i32,i32,i32,i32)> { // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.mma.xor.popc.m16n8k128.row.col.b1 @@ -181,6 +200,7 @@ llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)> } +// CHECK-LABEL: @nvvm_mma_m16n8k32_s4_s4 llvm.func @nvvm_mma_m16n8k32_s4_s4(%a0 : i32, %a1 : i32, %b0 : i32, %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) -> !llvm.struct<(i32,i32,i32,i32)> { @@ -193,6 +213,7 @@ llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)> } +// CHECK-LABEL: @nvvm_mma_m8n8k4_f64_f64 llvm.func @nvvm_mma_m8n8k4_f64_f64(%a0 : f64, %b0 : f64, %c0 : f64, %c1 : f64) -> !llvm.struct<(f64, f64)> { @@ -203,6 +224,7 @@ llvm.return %0 : !llvm.struct<(f64, f64)> } +// CHECK-LABEL: @nvvm_mma_m16n8k4_tf32_f32 llvm.func @nvvm_mma_m16n8k4_tf32_f32(%a0 : i32, %a1 : i32, %b0 : i32, %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) -> !llvm.struct<(f32, f32, f32, f32)> { @@ -228,6 +250,7 @@ // The test below checks the correct mapping of the nvvm.wmma.*.store.* op to the correct intrinsic // in the LLVM NVPTX backend. +// CHECK-LABEL: @gpu_wmma_store_op llvm.func @gpu_wmma_store_op(%arg0: !llvm.ptr, %arg1: i32, %arg2: vector<2 x f16>, %arg3: vector<2 x f16>, %arg4: vector<2 xf16>, %arg5: vector<2 x f16>) { @@ -240,6 +263,7 @@ // The test below checks the correct mapping of the nvvm.wmma.*.mma.* op to the correct intrinsic // in the LLVM NVPTX backend. +// CHECK-LABEL: @gpu_wmma_mma_op llvm.func @gpu_wmma_mma_op(%arg0: vector<2 x f16>, %arg1: vector<2 x f16>, %arg2: vector<2 x f16>, %arg3: vector<2 x f16>, %arg4: vector<2 x f16>, %arg5: vector<2 x f16>, @@ -261,6 +285,7 @@ llvm.return } +// CHECK-LABEL: @nvvm_wmma_load_tf32 llvm.func @nvvm_wmma_load_tf32(%arg0: !llvm.ptr, %arg1 : i32) { // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k8.load.a.row.stride.tf32.p0i32(i32* %{{.*}}, i32 %{{.*}}) %0 = nvvm.wmma.load %arg0, %arg1 @@ -269,6 +294,7 @@ llvm.return } +// CHECK-LABEL: @nvvm_wmma_mma llvm.func @nvvm_wmma_mma(%0 : i32, %1 : i32, %2 : i32, %3 : i32, %4 : i32, %5 : i32, %6 : i32, %7 : i32, %8 : f32, %9 : f32, %10 : f32, %11 : f32, %12 : f32, %13 : f32, %14 : f32, %15 : f32) { @@ -280,6 +306,7 @@ llvm.return } +// CHECK-LABEL: @cp_async llvm.func @cp_async(%arg0: !llvm.ptr, %arg1: !llvm.ptr) { // CHECK: call void @llvm.nvvm.cp.async.ca.shared.global.4(i8 addrspace(3)* %{{.*}}, i8 addrspace(1)* %{{.*}}) nvvm.cp.async.shared.global %arg0, %arg1, 4 @@ -296,7 +323,7 @@ llvm.return } -// CHECK-LABEL: @ld_matrix( +// CHECK-LABEL: @ld_matrix llvm.func @ld_matrix(%arg0: !llvm.ptr) { // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.b16.p3i32(i32 addrspace(3)* %{{.*}}) %l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout} : (!llvm.ptr) -> i32