diff --git a/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h b/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h --- a/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h +++ b/mlir/include/mlir/Conversion/GPUCommon/GPUCommonPass.h @@ -50,13 +50,15 @@ /// This pass does not generate code to call GPU runtime APIs directly but /// instead uses a small wrapper library that exports a stable and conveniently /// typed ABI on top of GPU runtimes such as CUDA or ROCm (HIP). -std::unique_ptr> createGpuToLLVMConversionPass(); +std::unique_ptr> +createGpuToLLVMConversionPass(bool kernelBarePtrCallConv = false); /// Collect a set of patterns to convert from the GPU dialect to LLVM and /// populate converter for gpu types. void populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, - StringRef gpuBinaryAnnotation = {}); + StringRef gpuBinaryAnnotation = {}, + bool kernelBarePtrCallConv = false); } // namespace mlir diff --git a/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h b/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h --- a/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h +++ b/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h @@ -41,6 +41,7 @@ createLowerGpuOpsToROCDLOpsPass( const std::string &chipset = "gfx900", unsigned indexBitwidth = kDeriveIndexBitwidthFromDataLayout, + bool useBarePtrCallConv = false, gpu::amd::Runtime runtime = gpu::amd::Runtime::Unknown); } // namespace mlir diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h --- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h @@ -80,6 +80,13 @@ const LowerToLLVMOptions &getOptions() const { return options; } + /// Set the lowering options to `newOptions`. Note: using this after some + /// some conversions have been performed can lead to inconsistencies in the + /// IR. + void dangerousSetOptions(LowerToLLVMOptions newOptions) { + options = std::move(newOptions); + } + /// Promote the LLVM representation of all operands including promoting MemRef /// descriptors to stack and use pointers to struct to avoid the complexity /// of the platform-specific C/C++ ABI lowering related to struct argument @@ -126,7 +133,7 @@ const DataLayout &layout); /// Check if a memref type can be converted to a bare pointer. - bool canConvertToBarePtr(BaseMemRefType type); + static bool canConvertToBarePtr(BaseMemRefType type); protected: /// Pointer to the LLVM dialect. diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -373,6 +373,10 @@ Option<"indexBitwidth", "index-bitwidth", "unsigned", /*default=kDeriveIndexBitwidthFromDataLayout*/"0", "Bitwidth of the index type, 0 to use size of machine word">, + Option<"useBarePtrCallConv", "use-bare-ptr-memref-call-conv", "bool", + /*default=*/"false", + "Replace memref arguments in GPU functions with bare pointers." + "All memrefs must have static shape">, Option<"runtime", "runtime", "::mlir::gpu::amd::Runtime", "::mlir::gpu::amd::Runtime::Unknown", "Runtime code will be run on (default is Unknown, can also use HIP or OpenCl)", diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp @@ -9,6 +9,7 @@ #include "GPUOpsLowering.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/Builders.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/Support/FormatVariadic.h" using namespace mlir; @@ -137,6 +138,34 @@ &signatureConversion))) return failure(); + // If bare memref pointers are being used, remap them back to memref + // descriptors This must be done after signature conversion to get rid of the + // unrealized casts. + if (getTypeConverter()->getOptions().useBarePtrCallConv) { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&llvmFuncOp.getBody().front()); + for (const auto &en : llvm::enumerate(gpuFuncOp.getArgumentTypes())) { + auto memrefTy = en.value().dyn_cast(); + if (!memrefTy) + continue; + assert(memrefTy.hasStaticShape() && + "Bare pointer convertion used with dynamically-shaped memrefs"); + // Use a placeholder when replacing uses of the memref argument to prevent + // circular replacements. + auto remapping = signatureConversion.getInputMapping(en.index()); + assert(remapping && remapping->size == 1 && + "Type converter should produce 1-to-1 mapping for bare memrefs"); + BlockArgument newArg = + llvmFuncOp.getBody().getArgument(remapping->inputNo); + auto placeholder = rewriter.create( + loc, getTypeConverter()->convertType(memrefTy)); + rewriter.replaceUsesOfBlockArgument(newArg, placeholder); + Value desc = MemRefDescriptor::fromStaticShape( + rewriter, loc, *getTypeConverter(), memrefTy, newArg); + rewriter.replaceOp(placeholder, {desc}); + } + } + rewriter.eraseOp(gpuFuncOp); return success(); } diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp --- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp @@ -49,6 +49,12 @@ public: GpuToLLVMConversionPass() = default; + GpuToLLVMConversionPass(bool kernelBarePtrCallConv) + : GpuToLLVMConversionPass() { + if (this->kernelBarePtrCallConv.getNumOccurrences() == 0) + this->kernelBarePtrCallConv = kernelBarePtrCallConv; + } + GpuToLLVMConversionPass(const GpuToLLVMConversionPass &other) : GpuToLLVMConversionPassBase(other) {} @@ -60,6 +66,11 @@ *this, "gpu-binary-annotation", llvm::cl::desc("Annotation attribute string for GPU binary"), llvm::cl::init(gpu::getDefaultGpuBinaryAnnotation())}; + Option kernelBarePtrCallConv{ + *this, "use-bare-pointers-for-kernels", + llvm::cl::desc("Use bare pointers to pass memref arguments to kernels. " + "The kernel must use the same setting for this option."), + llvm::cl::init(false)}; }; struct FunctionCallBuilder { @@ -290,9 +301,11 @@ : public ConvertOpToGpuRuntimeCallPattern { public: ConvertLaunchFuncOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter, - StringRef gpuBinaryAnnotation) + StringRef gpuBinaryAnnotation, + bool kernelBarePtrCallConv) : ConvertOpToGpuRuntimeCallPattern(typeConverter), - gpuBinaryAnnotation(gpuBinaryAnnotation) {} + gpuBinaryAnnotation(gpuBinaryAnnotation), + kernelBarePtrCallConv(kernelBarePtrCallConv) {} private: Value generateParamsArray(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor, @@ -305,6 +318,7 @@ ConversionPatternRewriter &rewriter) const override; llvm::SmallString<32> gpuBinaryAnnotation; + bool kernelBarePtrCallConv; }; class EraseGpuModuleOpPattern : public OpRewritePattern { @@ -377,7 +391,8 @@ populateFuncToLLVMConversionPatterns(converter, patterns); populateAsyncStructuralTypeConversionsAndLegality(converter, patterns, target); - populateGpuToLLVMConversionPatterns(converter, patterns, gpuBinaryAnnotation); + populateGpuToLLVMConversionPatterns(converter, patterns, gpuBinaryAnnotation, + kernelBarePtrCallConv); if (failed( applyPartialConversion(getOperation(), target, std::move(patterns)))) @@ -635,9 +650,24 @@ gpu::LaunchFuncOp launchOp, OpAdaptor adaptor, OpBuilder &builder) const { auto loc = launchOp.getLoc(); auto numKernelOperands = launchOp.getNumKernelOperands(); - auto arguments = getTypeConverter()->promoteOperands( - loc, launchOp.getOperands().take_back(numKernelOperands), - adaptor.getOperands().take_back(numKernelOperands), builder); + SmallVector arguments; + if (kernelBarePtrCallConv) { + // Hack the bare pointer value on just for the argument promotion + LLVMTypeConverter *converter = getTypeConverter(); + LowerToLLVMOptions options = converter->getOptions(); + LowerToLLVMOptions overrideToMatchKernelOpts = options; + overrideToMatchKernelOpts.useBarePtrCallConv = true; + converter->dangerousSetOptions(overrideToMatchKernelOpts); + arguments = converter->promoteOperands( + loc, launchOp.getOperands().take_back(numKernelOperands), + adaptor.getOperands().take_back(numKernelOperands), builder); + converter->dangerousSetOptions(options); + } else { + arguments = getTypeConverter()->promoteOperands( + loc, launchOp.getOperands().take_back(numKernelOperands), + adaptor.getOperands().take_back(numKernelOperands), builder); + } + auto numArguments = arguments.size(); SmallVector argumentTypes; argumentTypes.reserve(numArguments); @@ -870,13 +900,14 @@ } std::unique_ptr> -mlir::createGpuToLLVMConversionPass() { - return std::make_unique(); +mlir::createGpuToLLVMConversionPass(bool kernelBarePtrCallConv) { + return std::make_unique(kernelBarePtrCallConv); } void mlir::populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, - StringRef gpuBinaryAnnotation) { + StringRef gpuBinaryAnnotation, + bool kernelBarePtrCallConv) { converter.addConversion( [context = &converter.getContext()](gpu::AsyncTokenType type) -> Type { return LLVM::LLVMPointerType::get(IntegerType::get(context, 8)); @@ -890,7 +921,7 @@ ConvertWaitAsyncOpToGpuRuntimeCallPattern, ConvertWaitOpToGpuRuntimeCallPattern, ConvertAsyncYieldToGpuRuntimeCallPattern>(converter); - patterns.add(converter, - gpuBinaryAnnotation); + patterns.add( + converter, gpuBinaryAnnotation, kernelBarePtrCallConv); patterns.add(&converter.getContext()); } diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp --- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp +++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp @@ -41,6 +41,16 @@ using namespace mlir; +/// Returns true if the given `gpu.func` can be safely called using the bare +/// pointer calling convention. +static bool canBeCalledWithBarePointers(gpu::GPUFuncOp func) { + bool canBeBare = true; + for (Type type : func.getArgumentTypes()) + if (auto memrefTy = type.dyn_cast()) + canBeBare &= LLVMTypeConverter::canConvertToBarePtr(memrefTy); + return canBeBare; +} + namespace { /// Import the GPU Ops to ROCDL Patterns. @@ -55,10 +65,16 @@ : public ConvertGpuOpsToROCDLOpsBase { LowerGpuOpsToROCDLOpsPass() = default; LowerGpuOpsToROCDLOpsPass(const std::string &chipset, unsigned indexBitwidth, + bool useBarePtrCallConv, gpu::amd::Runtime runtime) { - this->chipset = chipset; - this->indexBitwidth = indexBitwidth; - this->runtime = runtime; + if (this->chipset.getNumOccurrences() == 0) + this->chipset = chipset; + if (this->indexBitwidth.getNumOccurrences() == 0) + this->indexBitwidth = indexBitwidth; + if (this->useBarePtrCallConv.getNumOccurrences() == 0) + this->useBarePtrCallConv = useBarePtrCallConv; + if (this->runtime.getNumOccurrences() == 0) + this->runtime = runtime; } void runOnOperation() override { @@ -82,6 +98,23 @@ ctx, DataLayout(cast(m.getOperation()))); if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) options.overrideIndexBitwidth(indexBitwidth); + + if (useBarePtrCallConv) { + options.useBarePtrCallConv = true; + WalkResult canUseBarePointers = + m.walk([](gpu::GPUFuncOp func) -> WalkResult { + if (canBeCalledWithBarePointers(func)) + return WalkResult::advance(); + return WalkResult::interrupt(); + }); + if (canUseBarePointers.wasInterrupted()) { + emitError(UnknownLoc::get(ctx), + "bare pointer calling convention requires all memrefs to " + "have static shape and use the identity map"); + return signalPassFailure(); + } + } + LLVMTypeConverter converter(ctx, options); RewritePatternSet patterns(ctx); @@ -189,7 +222,8 @@ std::unique_ptr> mlir::createLowerGpuOpsToROCDLOpsPass(const std::string &chipset, unsigned indexBitwidth, + bool useBarePtrCallConv, gpu::amd::Runtime runtime) { - return std::make_unique(chipset, indexBitwidth, - runtime); + return std::make_unique( + chipset, indexBitwidth, useBarePtrCallConv, runtime); } diff --git a/mlir/test/Conversion/GPUToROCDL/memref.mlir b/mlir/test/Conversion/GPUToROCDL/memref.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/GPUToROCDL/memref.mlir @@ -0,0 +1,15 @@ +// RUN: mlir-opt %s -convert-gpu-to-rocdl -split-input-file | FileCheck %s +// RUN: mlir-opt %s \ +// RUN: -convert-gpu-to-rocdl=use-bare-ptr-memref-call-conv=true \ +// RUN: -split-input-file \ +// RUN: | FileCheck %s --check-prefix=BARE + +gpu.module @memref_conversions { + // CHECK: llvm.func @kern + // CHECK-SAME: (%{{.*}}: !llvm.ptr, %{{.*}}: !llvm.ptr, %{{.*}}: i64, %{{.*}}: i64, %{{.*}}: i64) + // BARE: llvm.func @kern + // BARE-SAME: (%{{.*}}: !llvm.ptr) + gpu.func @kern(%arg0: memref<8xf32>) kernel { + gpu.return + } +} diff --git a/mlir/test/Integration/GPU/ROCM/vecadd.mlir b/mlir/test/Integration/GPU/ROCM/vecadd.mlir --- a/mlir/test/Integration/GPU/ROCM/vecadd.mlir +++ b/mlir/test/Integration/GPU/ROCM/vecadd.mlir @@ -1,24 +1,24 @@ // RUN: mlir-opt %s \ // RUN: -convert-scf-to-cf \ // RUN: -gpu-kernel-outlining \ -// RUN: -pass-pipeline='gpu.module(strip-debuginfo,convert-gpu-to-rocdl,gpu-to-hsaco{chip=%chip})' \ -// RUN: -gpu-to-llvm \ +// RUN: -pass-pipeline='gpu.module(strip-debuginfo,convert-gpu-to-rocdl{use-bare-ptr-memref-call-conv=true},gpu-to-hsaco{chip=%chip})' \ +// RUN: -gpu-to-llvm=use-bare-pointers-for-kernels=true \ // RUN: | mlir-cpu-runner \ // RUN: --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext \ // RUN: --shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext \ // RUN: --entry-point-result=void \ // RUN: | FileCheck %s -func.func @vecadd(%arg0 : memref, %arg1 : memref, %arg2 : memref) { +func.func @vecadd(%arg0 : memref<5xf32>, %arg1 : memref<5xf32>, %arg2 : memref<5xf32>) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index - %block_dim = memref.dim %arg0, %c0 : memref + %block_dim = arith.constant 5 : index gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1) threads(%tx, %ty, %tz) in (%block_x = %block_dim, %block_y = %c1, %block_z = %c1) { - %a = memref.load %arg0[%tx] : memref - %b = memref.load %arg1[%tx] : memref + %a = memref.load %arg0[%tx] : memref<5xf32> + %b = memref.load %arg1[%tx] : memref<5xf32> %c = arith.addf %a, %b : f32 - memref.store %c, %arg2[%tx] : memref + memref.store %c, %arg2[%tx] : memref<5xf32> gpu.terminator } return @@ -49,8 +49,11 @@ %9 = call @mgpuMemGetDeviceMemRef1dFloat(%3) : (memref) -> (memref) %10 = call @mgpuMemGetDeviceMemRef1dFloat(%4) : (memref) -> (memref) %11 = call @mgpuMemGetDeviceMemRef1dFloat(%5) : (memref) -> (memref) + %12 = memref.cast %9 : memref to memref<5xf32> + %13 = memref.cast %10 : memref to memref<5xf32> + %14 = memref.cast %11 : memref to memref<5xf32> - call @vecadd(%9, %10, %11) : (memref, memref, memref) -> () + call @vecadd(%12, %13, %14) : (memref<5xf32>, memref<5xf32>, memref<5xf32>) -> () call @printMemrefF32(%8) : (memref<*xf32>) -> () return }