Index: mlir/include/mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h =================================================================== --- mlir/include/mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h +++ mlir/include/mlir/Conversion/NVGPUToNVVM/NVGPUToNVVM.h @@ -16,14 +16,11 @@ class RewritePatternSet; class Pass; -#define GEN_PASS_DECL_CONVERTNVGPUTONVVM +#define GEN_PASS_DECL_CONVERTNVGPUTONVVMPASS #include "mlir/Conversion/Passes.h.inc" void populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns); - -std::unique_ptr createConvertNVGPUToNVVMPass(); - } // namespace mlir #endif // MLIR_CONVERSION_NVGPUTONVVM_NVGPUTONVVMPASS_H_ Index: mlir/include/mlir/Conversion/Passes.td =================================================================== --- mlir/include/mlir/Conversion/Passes.td +++ mlir/include/mlir/Conversion/Passes.td @@ -683,15 +683,20 @@ // NVGPUToNVVM //===----------------------------------------------------------------------===// -def ConvertNVGPUToNVVM : Pass<"convert-nvgpu-to-nvvm"> { +def ConvertNVGPUToNVVMPass : Pass<"convert-nvgpu-to-nvvm"> { let summary = "Convert NVGPU dialect to NVVM dialect"; let description = [{ This pass converts supported NVGPU ops to NVVM dialect intrinsics. }]; - let constructor = "mlir::createConvertNVGPUToNVVMPass()"; + let dependentDialects = [ "NVVM::NVVMDialect", ]; + let options = [ + Option<"useOpaquePointers", "use-opaque-pointers", "bool", + /*default=*/"false", "Generate LLVM IR using opaque pointers " + "instead of typed pointers"> + ]; } Index: mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp =================================================================== --- mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -18,7 +18,7 @@ #include "mlir/Pass/Pass.h" namespace mlir { -#define GEN_PASS_DEF_CONVERTNVGPUTONVVM +#define GEN_PASS_DEF_CONVERTNVGPUTONVVMPASS #include "mlir/Conversion/Passes.h.inc" } // namespace mlir @@ -338,12 +338,14 @@ }; struct ConvertNVGPUToNVVMPass - : public impl::ConvertNVGPUToNVVMBase { - ConvertNVGPUToNVVMPass() = default; + : public impl::ConvertNVGPUToNVVMPassBase { + using Base::Base; void runOnOperation() override { + LowerToLLVMOptions options(&getContext()); + options.useOpaquePointers = useOpaquePointers; RewritePatternSet patterns(&getContext()); - LLVMTypeConverter converter(&getContext()); + LLVMTypeConverter converter(&getContext(), options); /// device-side async tokens cannot be materialized in nvvm. We just convert /// them to a dummy i32 type in order to easily drop them during conversion. converter.addConversion([&](nvgpu::DeviceAsyncTokenType type) -> Type { @@ -577,8 +579,10 @@ if (failed(dstAddressSpace)) return rewriter.notifyMatchFailure( loc, "destination memref address space not convertible to integer"); - auto dstPointerType = LLVM::LLVMPointerType::get(i8Ty, *dstAddressSpace); - dstPtr = rewriter.create(loc, dstPointerType, dstPtr); + auto dstPointerType = + getTypeConverter()->getPointerType(i8Ty, *dstAddressSpace); + if (!getTypeConverter()->useOpaquePointers()) + dstPtr = rewriter.create(loc, dstPointerType, dstPtr); auto srcMemrefType = op.getSrc().getType().cast(); FailureOr srcAddressSpace = @@ -589,10 +593,12 @@ Value scrPtr = getStridedElementPtr(loc, srcMemrefType, adaptor.getSrc(), adaptor.getSrcIndices(), rewriter); - auto srcPointerType = LLVM::LLVMPointerType::get(i8Ty, *srcAddressSpace); - scrPtr = rewriter.create(loc, srcPointerType, scrPtr); + auto srcPointerType = + getTypeConverter()->getPointerType(i8Ty, *srcAddressSpace); + if (!getTypeConverter()->useOpaquePointers()) + scrPtr = rewriter.create(loc, srcPointerType, scrPtr); // Intrinsics takes a global pointer so we need an address space cast. - auto srcPointerGlobalType = LLVM::LLVMPointerType::get( + auto srcPointerGlobalType = getTypeConverter()->getPointerType( i8Ty, NVVM::NVVMMemorySpace::kGlobalMemorySpace); scrPtr = rewriter.create(loc, srcPointerGlobalType, scrPtr); @@ -675,7 +681,3 @@ NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering, NVGPUMmaSparseSyncLowering>(converter); } - -std::unique_ptr mlir::createConvertNVGPUToNVVMPass() { - return std::make_unique(); -} Index: mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir =================================================================== --- mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir +++ mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt --convert-nvgpu-to-nvvm --split-input-file %s | FileCheck %s +// RUN: mlir-opt --convert-nvgpu-to-nvvm='use-opaque-pointers=1' --split-input-file %s | FileCheck %s // CHECK-LABEL: @m16n8k16_fp16 func.func @m16n8k16_fp16(%arg0: vector<4x2xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> { @@ -244,23 +244,21 @@ func.func @async_cp( %src: memref<128x128xf32>, %dst: memref<3x16x128xf32, 3>, %i : index) { // CHECK: %[[IDX1:.*]] = builtin.unrealized_conversion_cast %[[IDX]] : index to i64 - // CHECK-DAG: %[[BASEDST:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> + // CHECK-DAG: %[[BASEDST:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<3 x i64>, array<3 x i64>)> // CHECK-DAG: %[[S0:.*]] = llvm.mlir.constant(2048 : index) : i64 // CHECK-DAG: %[[LI:.*]] = llvm.mul %[[IDX1]], %[[S0]] : i64 // CHECK-DAG: %[[S1:.*]] = llvm.mlir.constant(128 : index) : i64 // CHECK-DAG: %[[FI0:.*]] = llvm.mul %[[IDX1]], %[[S1]] : i64 // CHECK-DAG: %[[FI1:.*]] = llvm.add %[[LI]], %[[FI0]] : i64 // CHECK-DAG: %[[FI2:.*]] = llvm.add %[[FI1]], %[[IDX1]] : i64 - // CHECK-DAG: %[[ADDRESSDST:.*]] = llvm.getelementptr %[[BASEDST]][%[[FI2]]] : (!llvm.ptr, i64) -> !llvm.ptr - // CHECK-DAG: %[[CAST0:.*]] = llvm.bitcast %[[ADDRESSDST]] : !llvm.ptr to !llvm.ptr - // CHECK-DAG: %[[BASESRC:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK-DAG: %[[ADDRESSDST:.*]] = llvm.getelementptr %[[BASEDST]][%[[FI2]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3> + // CHECK-DAG: %[[BASESRC:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-DAG: %[[S3:.*]] = llvm.mlir.constant(128 : index) : i64 // CHECK-DAG: %[[FI3:.*]] = llvm.mul %[[IDX1]], %[[S3]] : i64 // CHECK-DAG: %[[FI4:.*]] = llvm.add %[[FI3]], %[[IDX1]] : i64 - // CHECK-DAG: %[[ADDRESSSRC:.*]] = llvm.getelementptr %[[BASESRC]][%[[FI4]]] : (!llvm.ptr, i64) -> !llvm.ptr - // CHECK-DAG: %[[CAST1:.*]] = llvm.bitcast %[[ADDRESSSRC]] : !llvm.ptr to !llvm.ptr - // CHECK-DAG: %[[CAST2:.*]] = llvm.addrspacecast %[[CAST1]] : !llvm.ptr to !llvm.ptr - // CHECK-DAG: nvvm.cp.async.shared.global %[[CAST0]], %[[CAST2]], 16 + // CHECK-DAG: %[[ADDRESSSRC:.*]] = llvm.getelementptr %[[BASESRC]][%[[FI4]]] : (!llvm.ptr, i64) -> !llvm.ptr + // CHECK-DAG: %[[CAST2:.*]] = llvm.addrspacecast %[[ADDRESSSRC]] : !llvm.ptr to !llvm.ptr<1> + // CHECK-DAG: nvvm.cp.async.shared.global %[[ADDRESSDST]], %[[CAST2]], 16 %0 = nvgpu.device_async_copy %src[%i, %i], %dst[%i, %i, %i], 4 : memref<128x128xf32> to memref<3x16x128xf32, 3> // CHECK: nvvm.cp.async.commit.group %1 = nvgpu.device_async_create_group %0 @@ -279,20 +277,18 @@ func.func @async_cp_i4( %src: memref<128x64xi4>, %dst: memref<128x128xi4, 3>, %i : index) -> !nvgpu.device.async.token { // CHECK: %[[IDX1:.*]] = builtin.unrealized_conversion_cast %[[IDX]] : index to i64 - // CHECK-DAG: %[[BASEDST:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK-DAG: %[[BASEDST:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> // CHECK-DAG: %[[S0:.*]] = llvm.mlir.constant(128 : index) : i64 // CHECK-DAG: %[[LI:.*]] = llvm.mul %[[IDX1]], %[[S0]] : i64 // CHECK-DAG: %[[FI1:.*]] = llvm.add %[[LI]], %[[IDX1]] : i64 - // CHECK-DAG: %[[ADDRESSDST:.*]] = llvm.getelementptr %[[BASEDST]][%[[FI1]]] : (!llvm.ptr, i64) -> !llvm.ptr - // CHECK-DAG: %[[CAST0:.*]] = llvm.bitcast %[[ADDRESSDST]] : !llvm.ptr to !llvm.ptr - // CHECK-DAG: %[[BASESRC:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK-DAG: %[[ADDRESSDST:.*]] = llvm.getelementptr %[[BASEDST]][%[[FI1]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3> + // CHECK-DAG: %[[BASESRC:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-DAG: %[[S2:.*]] = llvm.mlir.constant(64 : index) : i64 // CHECK-DAG: %[[FI2:.*]] = llvm.mul %[[IDX1]], %[[S2]] : i64 // CHECK-DAG: %[[FI3:.*]] = llvm.add %[[FI2]], %[[IDX1]] : i64 - // CHECK-DAG: %[[ADDRESSSRC:.*]] = llvm.getelementptr %[[BASESRC]][%[[FI3]]] : (!llvm.ptr, i64) -> !llvm.ptr - // CHECK-DAG: %[[CAST1:.*]] = llvm.bitcast %[[ADDRESSSRC]] : !llvm.ptr to !llvm.ptr - // CHECK-DAG: %[[CAST2:.*]] = llvm.addrspacecast %[[CAST1]] : !llvm.ptr to !llvm.ptr - // CHECK-DAG: nvvm.cp.async.shared.global %[[CAST0]], %[[CAST2]], 16 + // CHECK-DAG: %[[ADDRESSSRC:.*]] = llvm.getelementptr %[[BASESRC]][%[[FI3]]] : (!llvm.ptr, i64) -> !llvm.ptr + // CHECK-DAG: %[[CAST2:.*]] = llvm.addrspacecast %[[ADDRESSSRC]] : !llvm.ptr to !llvm.ptr<1> + // CHECK-DAG: nvvm.cp.async.shared.global %[[ADDRESSDST]], %[[CAST2]], 16 %0 = nvgpu.device_async_copy %src[%i, %i], %dst[%i, %i], 32 : memref<128x64xi4> to memref<128x128xi4, 3> return %0 : !nvgpu.device.async.token } @@ -304,7 +300,7 @@ func.func @async_cp_zfill( %src: memref<128x128xf32>, %dst: memref<3x16x128xf32, 3>, %i : index, %srcElements : index) { - // CHECK-DAG: lvm.inline_asm has_side_effects asm_dialect = att "cp.async.cg.shared.global [$0], [$1], $2, $3;\0A", "r,l,n,r" %[[DSTPTR:.*]], %[[SRCPTR:.*]], %[[DSTBYTES:.*]], %[[SRCBYTES:.*]] : (!llvm.ptr, !llvm.ptr, i32, i32) -> !llvm.void + // CHECK-DAG: lvm.inline_asm has_side_effects asm_dialect = att "cp.async.cg.shared.global [$0], [$1], $2, $3;\0A", "r,l,n,r" %[[DSTPTR:.*]], %[[SRCPTR:.*]], %[[DSTBYTES:.*]], %[[SRCBYTES:.*]] : (!llvm.ptr<3>, !llvm.ptr<1>, i32, i32) -> !llvm.void %0 = nvgpu.device_async_copy %src[%i, %i], %dst[%i, %i, %i], 4, %srcElements {bypassL1}: memref<128x128xf32> to memref<3x16x128xf32, 3> // CHECK: nvvm.cp.async.commit.group %1 = nvgpu.device_async_create_group %0 Index: mlir/test/Conversion/NVGPUToNVVM/typed-pointers.mlir =================================================================== --- /dev/null +++ mlir/test/Conversion/NVGPUToNVVM/typed-pointers.mlir @@ -0,0 +1,59 @@ +// RUN: mlir-opt --convert-nvgpu-to-nvvm='use-opaque-pointers=0' --split-input-file %s | FileCheck %s + +// CHECK-LABEL: @async_cp( +// CHECK-SAME: %[[IDX:[a-zA-Z0-9_]+]]: index) +func.func @async_cp( + %src: memref<128x128xf32>, %dst: memref<3x16x128xf32, 3>, %i : index) { + // CHECK: %[[IDX1:.*]] = builtin.unrealized_conversion_cast %[[IDX]] : index to i64 + // CHECK-DAG: %[[BASEDST:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> + // CHECK-DAG: %[[S0:.*]] = llvm.mlir.constant(2048 : index) : i64 + // CHECK-DAG: %[[LI:.*]] = llvm.mul %[[IDX1]], %[[S0]] : i64 + // CHECK-DAG: %[[S1:.*]] = llvm.mlir.constant(128 : index) : i64 + // CHECK-DAG: %[[FI0:.*]] = llvm.mul %[[IDX1]], %[[S1]] : i64 + // CHECK-DAG: %[[FI1:.*]] = llvm.add %[[LI]], %[[FI0]] : i64 + // CHECK-DAG: %[[FI2:.*]] = llvm.add %[[FI1]], %[[IDX1]] : i64 + // CHECK-DAG: %[[ADDRESSDST:.*]] = llvm.getelementptr %[[BASEDST]][%[[FI2]]] : (!llvm.ptr, i64) -> !llvm.ptr + // CHECK-DAG: %[[CAST0:.*]] = llvm.bitcast %[[ADDRESSDST]] : !llvm.ptr to !llvm.ptr + // CHECK-DAG: %[[BASESRC:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK-DAG: %[[S3:.*]] = llvm.mlir.constant(128 : index) : i64 + // CHECK-DAG: %[[FI3:.*]] = llvm.mul %[[IDX1]], %[[S3]] : i64 + // CHECK-DAG: %[[FI4:.*]] = llvm.add %[[FI3]], %[[IDX1]] : i64 + // CHECK-DAG: %[[ADDRESSSRC:.*]] = llvm.getelementptr %[[BASESRC]][%[[FI4]]] : (!llvm.ptr, i64) -> !llvm.ptr + // CHECK-DAG: %[[CAST1:.*]] = llvm.bitcast %[[ADDRESSSRC]] : !llvm.ptr to !llvm.ptr + // CHECK-DAG: %[[CAST2:.*]] = llvm.addrspacecast %[[CAST1]] : !llvm.ptr to !llvm.ptr + // CHECK-DAG: nvvm.cp.async.shared.global %[[CAST0]], %[[CAST2]], 16 + %0 = nvgpu.device_async_copy %src[%i, %i], %dst[%i, %i, %i], 4 : memref<128x128xf32> to memref<3x16x128xf32, 3> + // CHECK: nvvm.cp.async.commit.group + %1 = nvgpu.device_async_create_group %0 + // CHECK: nvvm.cp.async.wait.group 1 + nvgpu.device_async_wait %1 { numGroups = 1 : i32 } + + // CHECK: nvvm.cp.async.shared.global %{{.*}}, %{{.*}}, 16 {bypass_l1} + %2 = nvgpu.device_async_copy %src[%i, %i], %dst[%i, %i, %i], 4 {bypassL1}: memref<128x128xf32> to memref<3x16x128xf32, 3> + return +} + +// ----- + +// CHECK-LABEL: @async_cp_i4( +// CHECK-SAME: %[[IDX:[a-zA-Z0-9_]+]]: index) +func.func @async_cp_i4( + %src: memref<128x64xi4>, %dst: memref<128x128xi4, 3>, %i : index) -> !nvgpu.device.async.token { + // CHECK: %[[IDX1:.*]] = builtin.unrealized_conversion_cast %[[IDX]] : index to i64 + // CHECK-DAG: %[[BASEDST:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK-DAG: %[[S0:.*]] = llvm.mlir.constant(128 : index) : i64 + // CHECK-DAG: %[[LI:.*]] = llvm.mul %[[IDX1]], %[[S0]] : i64 + // CHECK-DAG: %[[FI1:.*]] = llvm.add %[[LI]], %[[IDX1]] : i64 + // CHECK-DAG: %[[ADDRESSDST:.*]] = llvm.getelementptr %[[BASEDST]][%[[FI1]]] : (!llvm.ptr, i64) -> !llvm.ptr + // CHECK-DAG: %[[CAST0:.*]] = llvm.bitcast %[[ADDRESSDST]] : !llvm.ptr to !llvm.ptr + // CHECK-DAG: %[[BASESRC:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK-DAG: %[[S2:.*]] = llvm.mlir.constant(64 : index) : i64 + // CHECK-DAG: %[[FI2:.*]] = llvm.mul %[[IDX1]], %[[S2]] : i64 + // CHECK-DAG: %[[FI3:.*]] = llvm.add %[[FI2]], %[[IDX1]] : i64 + // CHECK-DAG: %[[ADDRESSSRC:.*]] = llvm.getelementptr %[[BASESRC]][%[[FI3]]] : (!llvm.ptr, i64) -> !llvm.ptr + // CHECK-DAG: %[[CAST1:.*]] = llvm.bitcast %[[ADDRESSSRC]] : !llvm.ptr to !llvm.ptr + // CHECK-DAG: %[[CAST2:.*]] = llvm.addrspacecast %[[CAST1]] : !llvm.ptr to !llvm.ptr + // CHECK-DAG: nvvm.cp.async.shared.global %[[CAST0]], %[[CAST2]], 16 + %0 = nvgpu.device_async_copy %src[%i, %i], %dst[%i, %i], 32 : memref<128x64xi4> to memref<128x128xi4, 3> + return %0 : !nvgpu.device.async.token +}