diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h --- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h @@ -115,6 +115,15 @@ Value allocatedPtr, Value alignedPtr, ArrayRef sizes, ArrayRef strides, ConversionPatternRewriter &rewriter) const; + + /// Copies the memory descriptor for any operands that were unranked + /// descriptors originally to heap-allocated memory (if toDynamic is true) or + /// to stack-allocated memory (otherwise). Also frees the previously used + /// memory if it was not on the stack. + LogicalResult copyUnrankedDescriptors(OpBuilder &builder, Location loc, + TypeRange origTypes, + SmallVectorImpl &operands, + bool toDynamic) const; }; /// Utility class for operation conversions targeting the LLVM dialect that diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp --- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/IR/AffineMap.h" @@ -224,6 +225,83 @@ return memRefDescriptor; } +LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors( + OpBuilder &builder, Location loc, TypeRange origTypes, + SmallVectorImpl &operands, bool toDynamic) const { + assert(origTypes.size() == operands.size() && + "expected as may original types as operands"); + + // Find operands of unranked memref type and store them. + SmallVector unrankedMemrefs; + for (unsigned i = 0, e = operands.size(); i < e; ++i) + if (origTypes[i].isa()) + unrankedMemrefs.emplace_back(operands[i]); + + if (unrankedMemrefs.empty()) + return success(); + + // Compute allocation sizes. + SmallVector sizes; + UnrankedMemRefDescriptor::computeSizes(builder, loc, *getTypeConverter(), + unrankedMemrefs, sizes); + + // Get frequently used types. + MLIRContext *context = builder.getContext(); + Type voidPtrType = LLVM::LLVMPointerType::get(IntegerType::get(context, 8)); + auto i1Type = IntegerType::get(context, 1); + Type indexType = getTypeConverter()->getIndexType(); + + // Find the malloc and free, or declare them if necessary. + auto module = builder.getInsertionPoint()->getParentOfType(); + LLVM::LLVMFuncOp freeFunc, mallocFunc; + if (toDynamic) + mallocFunc = LLVM::lookupOrCreateMallocFn(module, indexType); + if (!toDynamic) + freeFunc = LLVM::lookupOrCreateFreeFn(module); + + // Initialize shared constants. + Value zero = + builder.create(loc, i1Type, builder.getBoolAttr(false)); + + unsigned unrankedMemrefPos = 0; + for (unsigned i = 0, e = operands.size(); i < e; ++i) { + Type type = origTypes[i]; + if (!type.isa()) + continue; + Value allocationSize = sizes[unrankedMemrefPos++]; + UnrankedMemRefDescriptor desc(operands[i]); + + // Allocate memory, copy, and free the source if necessary. + Value memory = + toDynamic + ? builder.create(loc, mallocFunc, allocationSize) + .getResult(0) + : builder.create(loc, voidPtrType, allocationSize, + /*alignment=*/0); + Value source = desc.memRefDescPtr(builder, loc); + builder.create(loc, memory, source, allocationSize, zero); + if (!toDynamic) + builder.create(loc, freeFunc, source); + + // Create a new descriptor. The same descriptor can be returned multiple + // times, attempting to modify its pointer can lead to memory leaks + // (allocated twice and overwritten) or double frees (the caller does not + // know if the descriptor points to the same memory). + Type descriptorType = getTypeConverter()->convertType(type); + if (!descriptorType) + return failure(); + auto updatedDesc = + UnrankedMemRefDescriptor::undef(builder, loc, descriptorType); + Value rank = desc.rank(builder, loc); + updatedDesc.setRank(builder, loc, rank); + updatedDesc.setMemRefDescPtr(builder, loc, memory); + + operands[i] = updatedDesc; + } + + return success(); +} + //===----------------------------------------------------------------------===// // Detail methods //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -719,89 +719,6 @@ } }; -/// Copies the shaped descriptor part to (if `toDynamic` is set) or from -/// (otherwise) the dynamically allocated memory for any operands that were -/// unranked descriptors originally. -static LogicalResult copyUnrankedDescriptors(OpBuilder &builder, Location loc, - LLVMTypeConverter &typeConverter, - TypeRange origTypes, - SmallVectorImpl &operands, - bool toDynamic) { - assert(origTypes.size() == operands.size() && - "expected as may original types as operands"); - - // Find operands of unranked memref type and store them. - SmallVector unrankedMemrefs; - for (unsigned i = 0, e = operands.size(); i < e; ++i) - if (origTypes[i].isa()) - unrankedMemrefs.emplace_back(operands[i]); - - if (unrankedMemrefs.empty()) - return success(); - - // Compute allocation sizes. - SmallVector sizes; - UnrankedMemRefDescriptor::computeSizes(builder, loc, typeConverter, - unrankedMemrefs, sizes); - - // Get frequently used types. - MLIRContext *context = builder.getContext(); - Type voidPtrType = LLVM::LLVMPointerType::get(IntegerType::get(context, 8)); - auto i1Type = IntegerType::get(context, 1); - Type indexType = typeConverter.getIndexType(); - - // Find the malloc and free, or declare them if necessary. - auto module = builder.getInsertionPoint()->getParentOfType(); - LLVM::LLVMFuncOp freeFunc, mallocFunc; - if (toDynamic) - mallocFunc = LLVM::lookupOrCreateMallocFn(module, indexType); - if (!toDynamic) - freeFunc = LLVM::lookupOrCreateFreeFn(module); - - // Initialize shared constants. - Value zero = - builder.create(loc, i1Type, builder.getBoolAttr(false)); - - unsigned unrankedMemrefPos = 0; - for (unsigned i = 0, e = operands.size(); i < e; ++i) { - Type type = origTypes[i]; - if (!type.isa()) - continue; - Value allocationSize = sizes[unrankedMemrefPos++]; - UnrankedMemRefDescriptor desc(operands[i]); - - // Allocate memory, copy, and free the source if necessary. - Value memory = - toDynamic - ? builder.create(loc, mallocFunc, allocationSize) - .getResult(0) - : builder.create(loc, voidPtrType, allocationSize, - /*alignment=*/0); - - Value source = desc.memRefDescPtr(builder, loc); - builder.create(loc, memory, source, allocationSize, zero); - if (!toDynamic) - builder.create(loc, freeFunc, source); - - // Create a new descriptor. The same descriptor can be returned multiple - // times, attempting to modify its pointer can lead to memory leaks - // (allocated twice and overwritten) or double frees (the caller does not - // know if the descriptor points to the same memory). - Type descriptorType = typeConverter.convertType(type); - if (!descriptorType) - return failure(); - auto updatedDesc = - UnrankedMemRefDescriptor::undef(builder, loc, descriptorType); - Value rank = desc.rank(builder, loc); - updatedDesc.setRank(builder, loc, rank); - updatedDesc.setMemRefDescPtr(builder, loc, memory); - - operands[i] = updatedDesc; - } - - return success(); -} - // A CallOp automatically promotes MemRefType to a sequence of alloca/store and // passes the pointer to the MemRef across function boundaries. template @@ -857,10 +774,9 @@ "The number of arguments and types doesn't match"); this->getTypeConverter()->promoteBarePtrsToDescriptors( rewriter, callOp.getLoc(), resultTypes, results); - } else if (failed(copyUnrankedDescriptors(rewriter, callOp.getLoc(), - *this->getTypeConverter(), - resultTypes, results, - /*toDynamic=*/false))) { + } else if (failed(this->copyUnrankedDescriptors(rewriter, callOp.getLoc(), + resultTypes, results, + /*toDynamic=*/false))) { return failure(); } @@ -1955,8 +1871,7 @@ } } else { updatedOperands = llvm::to_vector<4>(operands); - (void)copyUnrankedDescriptors(rewriter, loc, *getTypeConverter(), - op.getOperands().getTypes(), + (void)copyUnrankedDescriptors(rewriter, loc, op.getOperands().getTypes(), updatedOperands, /*toDynamic=*/true); }