Please use GitHub pull requests for new patches. Avoid migrating existing patches. Phabricator shutdown timeline
Changeset View
Changeset View
Standalone View
Standalone View
mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp
Show All 17 Lines | |||||
// with SymbolTable trait instead of ModuleOp and make similar change here. This | // with SymbolTable trait instead of ModuleOp and make similar change here. This | ||||
// allows call sites to use getParentWithTrait<OpTrait::SymbolTable> instead | // allows call sites to use getParentWithTrait<OpTrait::SymbolTable> instead | ||||
// of getParentOfType<ModuleOp> to pass down the operation. | // of getParentOfType<ModuleOp> to pass down the operation. | ||||
LLVM::LLVMFuncOp getNotalignedAllocFn(LLVMTypeConverter *typeConverter, | LLVM::LLVMFuncOp getNotalignedAllocFn(LLVMTypeConverter *typeConverter, | ||||
ModuleOp module, Type indexType) { | ModuleOp module, Type indexType) { | ||||
bool useGenericFn = typeConverter->getOptions().useGenericFunctions; | bool useGenericFn = typeConverter->getOptions().useGenericFunctions; | ||||
if (useGenericFn) | if (useGenericFn) | ||||
return LLVM::lookupOrCreateGenericAllocFn(module, indexType); | return LLVM::lookupOrCreateGenericAllocFn( | ||||
module, indexType, typeConverter->useOpaquePointers()); | |||||
return LLVM::lookupOrCreateMallocFn(module, indexType); | return LLVM::lookupOrCreateMallocFn(module, indexType, | ||||
typeConverter->useOpaquePointers()); | |||||
} | } | ||||
LLVM::LLVMFuncOp getAlignedAllocFn(LLVMTypeConverter *typeConverter, | LLVM::LLVMFuncOp getAlignedAllocFn(LLVMTypeConverter *typeConverter, | ||||
ModuleOp module, Type indexType) { | ModuleOp module, Type indexType) { | ||||
bool useGenericFn = typeConverter->getOptions().useGenericFunctions; | bool useGenericFn = typeConverter->getOptions().useGenericFunctions; | ||||
if (useGenericFn) | if (useGenericFn) | ||||
return LLVM::lookupOrCreateGenericAlignedAllocFn(module, indexType); | return LLVM::lookupOrCreateGenericAlignedAllocFn( | ||||
module, indexType, typeConverter->useOpaquePointers()); | |||||
return LLVM::lookupOrCreateAlignedAllocFn(module, indexType); | return LLVM::lookupOrCreateAlignedAllocFn(module, indexType, | ||||
typeConverter->useOpaquePointers()); | |||||
} | } | ||||
} // end namespace | } // end namespace | ||||
Value AllocationOpLLVMLowering::createAligned( | Value AllocationOpLLVMLowering::createAligned( | ||||
ConversionPatternRewriter &rewriter, Location loc, Value input, | ConversionPatternRewriter &rewriter, Location loc, Value input, | ||||
Value alignment) { | Value alignment) { | ||||
Value one = createIndexAttrConstant(rewriter, loc, alignment.getType(), 1); | Value one = createIndexAttrConstant(rewriter, loc, alignment.getType(), 1); | ||||
Value bump = rewriter.create<LLVM::SubOp>(loc, alignment, one); | Value bump = rewriter.create<LLVM::SubOp>(loc, alignment, one); | ||||
Value bumped = rewriter.create<LLVM::AddOp>(loc, input, bump); | Value bumped = rewriter.create<LLVM::AddOp>(loc, input, bump); | ||||
Value mod = rewriter.create<LLVM::URemOp>(loc, bumped, alignment); | Value mod = rewriter.create<LLVM::URemOp>(loc, bumped, alignment); | ||||
return rewriter.create<LLVM::SubOp>(loc, bumped, mod); | return rewriter.create<LLVM::SubOp>(loc, bumped, mod); | ||||
} | } | ||||
static Value castAllocFuncResult(ConversionPatternRewriter &rewriter, | |||||
Location loc, Value allocatedPtr, | |||||
MemRefType memRefType, Type elementPtrType, | |||||
LLVMTypeConverter &typeConverter) { | |||||
auto allocatedPtrTy = allocatedPtr.getType().cast<LLVM::LLVMPointerType>(); | |||||
if (allocatedPtrTy.getAddressSpace() != memRefType.getMemorySpaceAsInt()) | |||||
allocatedPtr = rewriter.create<LLVM::AddrSpaceCastOp>( | |||||
loc, | |||||
typeConverter.getPointer(allocatedPtrTy.getElementType(), | |||||
memRefType.getMemorySpaceAsInt()), | |||||
allocatedPtr); | |||||
if (!typeConverter.useOpaquePointers()) | |||||
allocatedPtr = | |||||
rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, allocatedPtr); | |||||
return allocatedPtr; | |||||
} | |||||
std::tuple<Value, Value> AllocationOpLLVMLowering::allocateBufferManuallyAlign( | std::tuple<Value, Value> AllocationOpLLVMLowering::allocateBufferManuallyAlign( | ||||
ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes, | ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes, | ||||
Operation *op, Value alignment) const { | Operation *op, Value alignment) const { | ||||
if (alignment) { | if (alignment) { | ||||
// Adjust the allocation size to consider alignment. | // Adjust the allocation size to consider alignment. | ||||
sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment); | sizeBytes = rewriter.create<LLVM::AddOp>(loc, sizeBytes, alignment); | ||||
} | } | ||||
MemRefType memRefType = getMemRefResultType(op); | MemRefType memRefType = getMemRefResultType(op); | ||||
// Allocate the underlying buffer. | // Allocate the underlying buffer. | ||||
Type elementPtrType = this->getElementPtrType(memRefType); | Type elementPtrType = this->getElementPtrType(memRefType); | ||||
LLVM::LLVMFuncOp allocFuncOp = getNotalignedAllocFn( | LLVM::LLVMFuncOp allocFuncOp = getNotalignedAllocFn( | ||||
getTypeConverter(), op->getParentOfType<ModuleOp>(), getIndexType()); | getTypeConverter(), op->getParentOfType<ModuleOp>(), getIndexType()); | ||||
auto results = rewriter.create<LLVM::CallOp>(loc, allocFuncOp, sizeBytes); | auto results = rewriter.create<LLVM::CallOp>(loc, allocFuncOp, sizeBytes); | ||||
Value allocatedPtr = rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, | |||||
results.getResult()); | Value allocatedPtr = | ||||
castAllocFuncResult(rewriter, loc, results.getResult(), memRefType, | |||||
elementPtrType, *getTypeConverter()); | |||||
Value alignedPtr = allocatedPtr; | Value alignedPtr = allocatedPtr; | ||||
if (alignment) { | if (alignment) { | ||||
// Compute the aligned pointer. | // Compute the aligned pointer. | ||||
Value allocatedInt = | Value allocatedInt = | ||||
rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr); | rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), allocatedPtr); | ||||
Value alignmentInt = createAligned(rewriter, loc, allocatedInt, alignment); | Value alignmentInt = createAligned(rewriter, loc, allocatedInt, alignment); | ||||
alignedPtr = | alignedPtr = | ||||
▲ Show 20 Lines • Show All 44 Lines • ▼ Show 20 Lines | Value AllocationOpLLVMLowering::allocateBufferAutoAlign( | ||||
if (!isMemRefSizeMultipleOf(memRefType, alignment, op, defaultLayout)) | if (!isMemRefSizeMultipleOf(memRefType, alignment, op, defaultLayout)) | ||||
sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment); | sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment); | ||||
Type elementPtrType = this->getElementPtrType(memRefType); | Type elementPtrType = this->getElementPtrType(memRefType); | ||||
LLVM::LLVMFuncOp allocFuncOp = getAlignedAllocFn( | LLVM::LLVMFuncOp allocFuncOp = getAlignedAllocFn( | ||||
getTypeConverter(), op->getParentOfType<ModuleOp>(), getIndexType()); | getTypeConverter(), op->getParentOfType<ModuleOp>(), getIndexType()); | ||||
auto results = rewriter.create<LLVM::CallOp>( | auto results = rewriter.create<LLVM::CallOp>( | ||||
loc, allocFuncOp, ValueRange({allocAlignment, sizeBytes})); | loc, allocFuncOp, ValueRange({allocAlignment, sizeBytes})); | ||||
Value allocatedPtr = rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, | |||||
results.getResult()); | |||||
return allocatedPtr; | return castAllocFuncResult(rewriter, loc, results.getResult(), memRefType, | ||||
elementPtrType, *getTypeConverter()); | |||||
} | } | ||||
LogicalResult AllocLikeOpLLVMLowering::matchAndRewrite( | LogicalResult AllocLikeOpLLVMLowering::matchAndRewrite( | ||||
Operation *op, ArrayRef<Value> operands, | Operation *op, ArrayRef<Value> operands, | ||||
ConversionPatternRewriter &rewriter) const { | ConversionPatternRewriter &rewriter) const { | ||||
MemRefType memRefType = getMemRefResultType(op); | MemRefType memRefType = getMemRefResultType(op); | ||||
if (!isConvertibleAndHasIdentityMaps(memRefType)) | if (!isConvertibleAndHasIdentityMaps(memRefType)) | ||||
return rewriter.notifyMatchFailure(op, "incompatible memref type"); | return rewriter.notifyMatchFailure(op, "incompatible memref type"); | ||||
Show All 23 Lines |