diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -374,7 +374,7 @@ scrPtr); int64_t numElements = adaptor.numElements().getZExtValue(); int64_t sizeInBytes = - (dstMemrefType.getElementTypeBitWidth() / 8) * numElements; + (dstMemrefType.getElementTypeBitWidth() * numElements) / 8; // bypass L1 is only supported for byte sizes of 16, we drop the hint // otherwise. UnitAttr bypassL1 = sizeInBytes == 16 ? adaptor.bypassL1Attr() : UnitAttr(); diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir --- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir +++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir @@ -219,3 +219,28 @@ return } +// ----- + +// CHECK-LABEL: @async_cp_i4( +// CHECK-SAME: %[[IDX:[a-zA-Z0-9_]+]]: index) +func.func @async_cp_i4( + %src: memref<128x64xi4>, %dst: memref<128x128xi4, 3>, %i : index) -> !nvgpu.device.async.token { + // CHECK: %[[IDX1:.*]] = builtin.unrealized_conversion_cast %[[IDX]] : index to i64 + // CHECK-DAG: %[[BASEDST:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK-DAG: %[[S0:.*]] = llvm.mlir.constant(128 : index) : i64 + // CHECK-DAG: %[[LI:.*]] = llvm.mul %[[IDX1]], %[[S0]] : i64 + // CHECK-DAG: %[[FI1:.*]] = llvm.add %[[LI]], %[[IDX1]] : i64 + // CHECK-DAG: %[[ADDRESSDST:.*]] = llvm.getelementptr %[[BASEDST]][%[[FI1]]] : (!llvm.ptr, i64) -> !llvm.ptr + // CHECK-DAG: %[[CAST0:.*]] = llvm.bitcast %[[ADDRESSDST]] : !llvm.ptr to !llvm.ptr + // CHECK-DAG: %[[BASESRC:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK-DAG: %[[S2:.*]] = llvm.mlir.constant(64 : index) : i64 + // CHECK-DAG: %[[FI2:.*]] = llvm.mul %[[IDX1]], %[[S2]] : i64 + // CHECK-DAG: %[[FI3:.*]] = llvm.add %[[FI2]], %[[IDX1]] : i64 + // CHECK-DAG: %[[ADDRESSSRC:.*]] = llvm.getelementptr %[[BASESRC]][%[[FI3]]] : (!llvm.ptr, i64) -> !llvm.ptr + // CHECK-DAG: %[[CAST1:.*]] = llvm.bitcast %[[ADDRESSSRC]] : !llvm.ptr to !llvm.ptr + // CHECK-DAG: %[[CAST2:.*]] = llvm.addrspacecast %[[CAST1]] : !llvm.ptr to !llvm.ptr + // CHECK-DAG: nvvm.cp.async.shared.global %[[CAST0]], %[[CAST2]], 16 + %0 = nvgpu.device_async_copy %src[%i, %i], %dst[%i, %i], 32 : memref<128x64xi4> to memref<128x128xi4, 3> + return %0 : !nvgpu.device.async.token +} +