diff --git a/mlir/docs/TargetLLVMIR.md b/mlir/docs/TargetLLVMIR.md --- a/mlir/docs/TargetLLVMIR.md +++ b/mlir/docs/TargetLLVMIR.md @@ -553,6 +553,17 @@ The "bare pointer" calling convention does not support unranked memrefs as their shape cannot be known at compile time. +### Generic alloction and deallocation functions + +When converting the Memref dialect, allocations and deallocations are converted +into calls to `malloc` (`aligned_alloc` if aligned allocations are requested) +and `free`. However, it is possible to convert them to more generic functions +which can be implemented by a runtime library, thus allowing custom allocation +strategies or runtime profiling. When the conversion pass is instructed to +perform such operation, the names of the calles are `_mlir_alloc`, +`_mlir_aligned_alloc` and `_mlir_free`. Their signatures are the same of +`malloc`, `aligned_alloc` and `free`. + ### C-compatible wrapper emission In practical cases, it may be desirable to have externally-facing functions with diff --git a/mlir/include/mlir/Conversion/LLVMCommon/LoweringOptions.h b/mlir/include/mlir/Conversion/LLVMCommon/LoweringOptions.h --- a/mlir/include/mlir/Conversion/LLVMCommon/LoweringOptions.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/LoweringOptions.h @@ -48,6 +48,8 @@ AllocLowering allocLowering = AllocLowering::Malloc; + bool useGenericFunctions = false; + /// The data layout of the module to produce. This must be consistent with the /// data layout used in the upper levels of the lowering pipeline. // TODO: this should be replaced by MLIR data layout when one exists. diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -522,6 +522,11 @@ Option<"indexBitwidth", "index-bitwidth", "unsigned", /*default=kDeriveIndexBitwidthFromDataLayout*/"0", "Bitwidth of the index type, 0 to use size of machine word">, + Option<"useGenericFunctions", "use-generic-functions", + "bool", + /*default=*/"false", + "Use generic allocation and deallocation functions instead of the " + "classic 'malloc', 'aligned_alloc' and 'free' functions"> ]; } diff --git a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h --- a/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h +++ b/mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h @@ -45,6 +45,11 @@ LLVM::LLVMFuncOp lookupOrCreateAlignedAllocFn(ModuleOp moduleOp, Type indexType); LLVM::LLVMFuncOp lookupOrCreateFreeFn(ModuleOp moduleOp); +LLVM::LLVMFuncOp lookupOrCreateGenericAllocFn(ModuleOp moduleOp, + Type indexType); +LLVM::LLVMFuncOp lookupOrCreateGenericAlignedAllocFn(ModuleOp moduleOp, + Type indexType); +LLVM::LLVMFuncOp lookupOrCreateGenericFreeFn(ModuleOp moduleOp); LLVM::LLVMFuncOp lookupOrCreateMemRefCopyFn(ModuleOp moduleOp, Type indexType, Type unrankedDescriptorType); diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -35,6 +35,15 @@ : AllocLikeOpLLVMLowering(memref::AllocOp::getOperationName(), converter) {} + LLVM::LLVMFuncOp getAllocFn(ModuleOp module) const { + bool useGenericFn = getTypeConverter()->getOptions().useGenericFunctions; + + if (useGenericFn) + return LLVM::lookupOrCreateGenericAllocFn(module, getIndexType()); + + return LLVM::lookupOrCreateMallocFn(module, getIndexType()); + } + std::tuple allocateBuffer(ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes, Operation *op) const override { @@ -61,8 +70,7 @@ // Allocate the underlying buffer and store a pointer to it in the MemRef // descriptor. Type elementPtrType = this->getElementPtrType(memRefType); - auto allocFuncOp = LLVM::lookupOrCreateMallocFn( - allocOp->getParentOfType(), getIndexType()); + auto allocFuncOp = getAllocFn(allocOp->getParentOfType()); auto results = createLLVMCall(rewriter, loc, allocFuncOp, {sizeBytes}, getVoidPtrType()); Value allocatedPtr = @@ -135,6 +143,15 @@ llvm::PowerOf2Ceil(eltSizeBytes)); } + LLVM::LLVMFuncOp getAllocFn(ModuleOp module) const { + bool useGenericFn = getTypeConverter()->getOptions().useGenericFunctions; + + if (useGenericFn) + return LLVM::lookupOrCreateGenericAlignedAllocFn(module, getIndexType()); + + return LLVM::lookupOrCreateAlignedAllocFn(module, getIndexType()); + } + std::tuple allocateBuffer(ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes, Operation *op) const override { @@ -150,8 +167,7 @@ sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment); Type elementPtrType = this->getElementPtrType(memRefType); - auto allocFuncOp = LLVM::lookupOrCreateAlignedAllocFn( - allocOp->getParentOfType(), getIndexType()); + auto allocFuncOp = getAllocFn(allocOp->getParentOfType()); auto results = createLLVMCall(rewriter, loc, allocFuncOp, {allocAlignment, sizeBytes}, getVoidPtrType()); @@ -300,11 +316,20 @@ explicit DeallocOpLowering(LLVMTypeConverter &converter) : ConvertOpToLLVMPattern(converter) {} + LLVM::LLVMFuncOp getFreeFn(ModuleOp module) const { + bool useGenericFn = getTypeConverter()->getOptions().useGenericFunctions; + + if (useGenericFn) + return LLVM::lookupOrCreateGenericFreeFn(module); + + return LLVM::lookupOrCreateFreeFn(module); + } + LogicalResult matchAndRewrite(memref::DeallocOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Insert the `free` declaration if it is not already present. - auto freeFunc = LLVM::lookupOrCreateFreeFn(op->getParentOfType()); + auto freeFunc = getFreeFn(op->getParentOfType()); MemRefDescriptor memref(adaptor.getMemref()); Value casted = rewriter.create( op.getLoc(), getVoidPtrType(), @@ -2047,6 +2072,9 @@ options.allocLowering = (useAlignedAlloc ? LowerToLLVMOptions::AllocLowering::AlignedAlloc : LowerToLLVMOptions::AllocLowering::Malloc); + + options.useGenericFunctions = useGenericFunctions; + if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) options.overrideIndexBitwidth(indexBitwidth); diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp @@ -35,6 +35,9 @@ static constexpr llvm::StringRef kMalloc = "malloc"; static constexpr llvm::StringRef kAlignedAlloc = "aligned_alloc"; static constexpr llvm::StringRef kFree = "free"; +static constexpr llvm::StringRef kGenericAlloc = "_mlir_alloc"; +static constexpr llvm::StringRef kGenericAlignedAlloc = "_mlir_aligned_alloc"; +static constexpr llvm::StringRef kGenericFree = "_mlir_free"; static constexpr llvm::StringRef kMemRefCopy = "memrefCopy"; /// Generic print function lookupOrCreate helper. @@ -115,6 +118,28 @@ LLVM::LLVMVoidType::get(moduleOp->getContext())); } +LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericAllocFn(ModuleOp moduleOp, + Type indexType) { + return LLVM::lookupOrCreateFn( + moduleOp, kGenericAlloc, indexType, + LLVM::LLVMPointerType::get(IntegerType::get(moduleOp->getContext(), 8))); +} + +LLVM::LLVMFuncOp +mlir::LLVM::lookupOrCreateGenericAlignedAllocFn(ModuleOp moduleOp, + Type indexType) { + return LLVM::lookupOrCreateFn( + moduleOp, kGenericAlignedAlloc, {indexType, indexType}, + LLVM::LLVMPointerType::get(IntegerType::get(moduleOp->getContext(), 8))); +} + +LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericFreeFn(ModuleOp moduleOp) { + return LLVM::lookupOrCreateFn( + moduleOp, kGenericFree, + LLVM::LLVMPointerType::get(IntegerType::get(moduleOp->getContext(), 8)), + LLVM::LLVMVoidType::get(moduleOp->getContext())); +} + LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateMemRefCopyFn(ModuleOp moduleOp, Type indexType, Type unrankedDescriptorType) { diff --git a/mlir/test/Conversion/MemRefToLLVM/generic-functions.mlir b/mlir/test/Conversion/MemRefToLLVM/generic-functions.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/MemRefToLLVM/generic-functions.mlir @@ -0,0 +1,23 @@ +// RUN: mlir-opt -pass-pipeline="convert-memref-to-llvm{use-generic-functions=1}" -split-input-file %s \ +// RUN: | FileCheck %s --check-prefix="CHECK-NOTALIGNED" + +// RUN: mlir-opt -pass-pipeline="convert-memref-to-llvm{use-generic-functions=1 use-aligned-alloc=1}" -split-input-file %s \ +// RUN: | FileCheck %s --check-prefix="CHECK-ALIGNED" + +// CHECK-LABEL: func @alloc() +func.func @zero_d_alloc() -> memref { +// CHECK-NOTALIGNED: llvm.call @_mlir_alloc(%{{.*}}) : (i64) -> !llvm.ptr +// CHECK-ALIGNED: llvm.call @_mlir_aligned_alloc(%{{.*}}, %{{.*}}) : (i64, i64) -> !llvm.ptr + %0 = memref.alloc() : memref + return %0 : memref +} + +// ----- + +// CHECK-LABEL: func @dealloc() +func.func @dealloc(%arg0: memref) { +// CHECK-NOTALIGNED: llvm.call @_mlir_free(%{{.*}}) : (!llvm.ptr) -> () +// CHECK-ALIGNED: llvm.call @_mlir_free(%{{.*}}) : (!llvm.ptr) -> () + memref.dealloc %arg0 : memref + return +}