Please use GitHub pull requests for new patches. Avoid migrating existing patches. Phabricator shutdown timeline
Changeset View
Changeset View
Standalone View
Standalone View
mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
Show First 20 Lines • Show All 126 Lines • ▼ Show 20 Lines | |||||
/// Builds IR extracting the pos-th size from the descriptor. | /// Builds IR extracting the pos-th size from the descriptor. | ||||
Value MemRefDescriptor::size(OpBuilder &builder, Location loc, unsigned pos) { | Value MemRefDescriptor::size(OpBuilder &builder, Location loc, unsigned pos) { | ||||
return builder.create<LLVM::ExtractValueOp>( | return builder.create<LLVM::ExtractValueOp>( | ||||
loc, value, ArrayRef<int64_t>({kSizePosInMemRefDescriptor, pos})); | loc, value, ArrayRef<int64_t>({kSizePosInMemRefDescriptor, pos})); | ||||
} | } | ||||
Value MemRefDescriptor::size(OpBuilder &builder, Location loc, Value pos, | Value MemRefDescriptor::size(OpBuilder &builder, Location loc, Value pos, | ||||
int64_t rank) { | int64_t rank) { | ||||
auto indexPtrTy = LLVM::LLVMPointerType::get(indexType); | |||||
auto arrayTy = LLVM::LLVMArrayType::get(indexType, rank); | auto arrayTy = LLVM::LLVMArrayType::get(indexType, rank); | ||||
auto arrayPtrTy = LLVM::LLVMPointerType::get(arrayTy); | |||||
LLVM::LLVMPointerType indexPtrTy; | |||||
LLVM::LLVMPointerType arrayPtrTy; | |||||
if (useOpaquePointers()) { | |||||
arrayPtrTy = indexPtrTy = LLVM::LLVMPointerType::get(builder.getContext()); | |||||
} else { | |||||
indexPtrTy = LLVM::LLVMPointerType::get(indexType); | |||||
arrayPtrTy = LLVM::LLVMPointerType::get(arrayTy); | |||||
} | |||||
// Copy size values to stack-allocated memory. | // Copy size values to stack-allocated memory. | ||||
auto one = createIndexAttrConstant(builder, loc, indexType, 1); | auto one = createIndexAttrConstant(builder, loc, indexType, 1); | ||||
auto sizes = builder.create<LLVM::ExtractValueOp>( | auto sizes = builder.create<LLVM::ExtractValueOp>( | ||||
loc, value, llvm::ArrayRef<int64_t>({kSizePosInMemRefDescriptor})); | loc, value, llvm::ArrayRef<int64_t>({kSizePosInMemRefDescriptor})); | ||||
auto sizesPtr = | auto sizesPtr = builder.create<LLVM::AllocaOp>(loc, arrayPtrTy, arrayTy, one, | ||||
builder.create<LLVM::AllocaOp>(loc, arrayPtrTy, one, /*alignment=*/0); | /*alignment=*/0); | ||||
builder.create<LLVM::StoreOp>(loc, sizes, sizesPtr); | builder.create<LLVM::StoreOp>(loc, sizes, sizesPtr); | ||||
// Load an return size value of interest. | // Load an return size value of interest. | ||||
auto resultPtr = builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizesPtr, | auto resultPtr = builder.create<LLVM::GEPOp>( | ||||
ArrayRef<LLVM::GEPArg>{0, pos}); | loc, indexPtrTy, arrayTy, sizesPtr, ArrayRef<LLVM::GEPArg>{0, pos}); | ||||
return builder.create<LLVM::LoadOp>(loc, resultPtr); | return builder.create<LLVM::LoadOp>(loc, indexType, resultPtr); | ||||
} | } | ||||
/// Builds IR inserting the pos-th size into the descriptor | /// Builds IR inserting the pos-th size into the descriptor | ||||
void MemRefDescriptor::setSize(OpBuilder &builder, Location loc, unsigned pos, | void MemRefDescriptor::setSize(OpBuilder &builder, Location loc, unsigned pos, | ||||
Value size) { | Value size) { | ||||
value = builder.create<LLVM::InsertValueOp>( | value = builder.create<LLVM::InsertValueOp>( | ||||
loc, value, size, ArrayRef<int64_t>({kSizePosInMemRefDescriptor, pos})); | loc, value, size, ArrayRef<int64_t>({kSizePosInMemRefDescriptor, pos})); | ||||
} | } | ||||
▲ Show 20 Lines • Show All 78 Lines • ▼ Show 20 Lines | |||||
/// Returns the number of non-aggregate values that would be produced by | /// Returns the number of non-aggregate values that would be produced by | ||||
/// `unpack`. | /// `unpack`. | ||||
unsigned MemRefDescriptor::getNumUnpackedValues(MemRefType type) { | unsigned MemRefDescriptor::getNumUnpackedValues(MemRefType type) { | ||||
// Two pointers, offset, <rank> sizes, <rank> shapes. | // Two pointers, offset, <rank> sizes, <rank> shapes. | ||||
return 3 + 2 * type.getRank(); | return 3 + 2 * type.getRank(); | ||||
} | } | ||||
bool MemRefDescriptor::useOpaquePointers() { | |||||
return getElementPtrType().isOpaque(); | |||||
} | |||||
//===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||
// MemRefDescriptorView implementation. | // MemRefDescriptorView implementation. | ||||
//===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||
MemRefDescriptorView::MemRefDescriptorView(ValueRange range) | MemRefDescriptorView::MemRefDescriptorView(ValueRange range) | ||||
: rank((range.size() - kSizePosInMemRefDescriptor) / 2), elements(range) {} | : rank((range.size() - kSizePosInMemRefDescriptor) / 2), elements(range) {} | ||||
Value MemRefDescriptorView::allocatedPtr() { | Value MemRefDescriptorView::allocatedPtr() { | ||||
▲ Show 20 Lines • Show All 114 Lines • ▼ Show 20 Lines | for (UnrankedMemRefDescriptor desc : values) { | ||||
// Total allocation size. | // Total allocation size. | ||||
Value allocationSize = builder.create<LLVM::AddOp>( | Value allocationSize = builder.create<LLVM::AddOp>( | ||||
loc, indexType, doublePointerSize, rankIndexSize); | loc, indexType, doublePointerSize, rankIndexSize); | ||||
sizes.push_back(allocationSize); | sizes.push_back(allocationSize); | ||||
} | } | ||||
} | } | ||||
Value UnrankedMemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc, | Value UnrankedMemRefDescriptor::allocatedPtr( | ||||
Value memRefDescPtr, | OpBuilder &builder, Location loc, Value memRefDescPtr, | ||||
Type elemPtrPtrType) { | LLVM::LLVMPointerType elemPtrType) { | ||||
Value elementPtrPtr; | |||||
if (elemPtrType.isOpaque()) | |||||
elementPtrPtr = memRefDescPtr; | |||||
else | |||||
elementPtrPtr = builder.create<LLVM::BitcastOp>( | |||||
loc, LLVM::LLVMPointerType::get(elemPtrType), memRefDescPtr); | |||||
return builder.create<LLVM::LoadOp>(loc, elemPtrType, elementPtrPtr); | |||||
} | |||||
void UnrankedMemRefDescriptor::setAllocatedPtr( | |||||
OpBuilder &builder, Location loc, Value memRefDescPtr, | |||||
LLVM::LLVMPointerType elemPtrType, Value allocatedPtr) { | |||||
Value elementPtrPtr; | |||||
if (elemPtrType.isOpaque()) | |||||
elementPtrPtr = memRefDescPtr; | |||||
else | |||||
elementPtrPtr = builder.create<LLVM::BitcastOp>( | |||||
loc, LLVM::LLVMPointerType::get(elemPtrType), memRefDescPtr); | |||||
Value elementPtrPtr = | builder.create<LLVM::StoreOp>(loc, allocatedPtr, elementPtrPtr); | ||||
builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr); | |||||
return builder.create<LLVM::LoadOp>(loc, elementPtrPtr); | |||||
} | } | ||||
void UnrankedMemRefDescriptor::setAllocatedPtr(OpBuilder &builder, Location loc, | static std::pair<Value, Type> | ||||
Value memRefDescPtr, | castToElemPtrPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, | ||||
Type elemPtrPtrType, | LLVM::LLVMPointerType elemPtrType) { | ||||
Value allocatedPtr) { | Value elementPtrPtr; | ||||
Value elementPtrPtr = | Type elemPtrPtrType; | ||||
if (elemPtrType.isOpaque()) { | |||||
elementPtrPtr = memRefDescPtr; | |||||
elemPtrPtrType = LLVM::LLVMPointerType::get(builder.getContext()); | |||||
} else { | |||||
elemPtrPtrType = LLVM::LLVMPointerType::get(elemPtrType); | |||||
elementPtrPtr = | |||||
builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr); | builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr); | ||||
builder.create<LLVM::StoreOp>(loc, allocatedPtr, elementPtrPtr); | } | ||||
return {elementPtrPtr, elemPtrPtrType}; | |||||
} | } | ||||
Value UnrankedMemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc, | Value UnrankedMemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc, | ||||
LLVMTypeConverter &typeConverter, | LLVMTypeConverter &typeConverter, | ||||
Value memRefDescPtr, | Value memRefDescPtr, | ||||
Type elemPtrPtrType) { | LLVM::LLVMPointerType elemPtrType) { | ||||
Value elementPtrPtr = | auto [elementPtrPtr, elemPtrPtrType] = | ||||
builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr); | castToElemPtrPtr(builder, loc, memRefDescPtr, elemPtrType); | ||||
Value alignedGep = builder.create<LLVM::GEPOp>( | Value alignedGep = | ||||
loc, elemPtrPtrType, elementPtrPtr, ArrayRef<LLVM::GEPArg>{1}); | builder.create<LLVM::GEPOp>(loc, elemPtrPtrType, elemPtrType, | ||||
return builder.create<LLVM::LoadOp>(loc, alignedGep); | elementPtrPtr, ArrayRef<LLVM::GEPArg>{1}); | ||||
return builder.create<LLVM::LoadOp>(loc, elemPtrType, alignedGep); | |||||
} | } | ||||
void UnrankedMemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc, | void UnrankedMemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc, | ||||
LLVMTypeConverter &typeConverter, | LLVMTypeConverter &typeConverter, | ||||
Value memRefDescPtr, | Value memRefDescPtr, | ||||
Type elemPtrPtrType, | LLVM::LLVMPointerType elemPtrType, | ||||
Value alignedPtr) { | Value alignedPtr) { | ||||
Value elementPtrPtr = | auto [elementPtrPtr, elemPtrPtrType] = | ||||
builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr); | castToElemPtrPtr(builder, loc, memRefDescPtr, elemPtrType); | ||||
Value alignedGep = builder.create<LLVM::GEPOp>( | Value alignedGep = | ||||
loc, elemPtrPtrType, elementPtrPtr, ArrayRef<LLVM::GEPArg>{1}); | builder.create<LLVM::GEPOp>(loc, elemPtrPtrType, elemPtrType, | ||||
elementPtrPtr, ArrayRef<LLVM::GEPArg>{1}); | |||||
builder.create<LLVM::StoreOp>(loc, alignedPtr, alignedGep); | builder.create<LLVM::StoreOp>(loc, alignedPtr, alignedGep); | ||||
} | } | ||||
Value UnrankedMemRefDescriptor::offset(OpBuilder &builder, Location loc, | Value UnrankedMemRefDescriptor::offset(OpBuilder &builder, Location loc, | ||||
LLVMTypeConverter &typeConverter, | LLVMTypeConverter &typeConverter, | ||||
Value memRefDescPtr, | Value memRefDescPtr, | ||||
Type elemPtrPtrType) { | LLVM::LLVMPointerType elemPtrType) { | ||||
Value elementPtrPtr = | auto [elementPtrPtr, elemPtrPtrType] = | ||||
builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr); | castToElemPtrPtr(builder, loc, memRefDescPtr, elemPtrType); | ||||
Value offsetGep = | |||||
builder.create<LLVM::GEPOp>(loc, elemPtrPtrType, elemPtrType, | |||||
elementPtrPtr, ArrayRef<LLVM::GEPArg>{2}); | |||||
Value offsetGep = builder.create<LLVM::GEPOp>( | if (!elemPtrType.isOpaque()) { | ||||
loc, elemPtrPtrType, elementPtrPtr, ArrayRef<LLVM::GEPArg>{2}); | |||||
offsetGep = builder.create<LLVM::BitcastOp>( | offsetGep = builder.create<LLVM::BitcastOp>( | ||||
loc, LLVM::LLVMPointerType::get(typeConverter.getIndexType()), offsetGep); | loc, LLVM::LLVMPointerType::get(typeConverter.getIndexType()), | ||||
return builder.create<LLVM::LoadOp>(loc, offsetGep); | offsetGep); | ||||
} | |||||
return builder.create<LLVM::LoadOp>(loc, typeConverter.getIndexType(), | |||||
offsetGep); | |||||
} | } | ||||
void UnrankedMemRefDescriptor::setOffset(OpBuilder &builder, Location loc, | void UnrankedMemRefDescriptor::setOffset(OpBuilder &builder, Location loc, | ||||
LLVMTypeConverter &typeConverter, | LLVMTypeConverter &typeConverter, | ||||
Value memRefDescPtr, | Value memRefDescPtr, | ||||
Type elemPtrPtrType, Value offset) { | LLVM::LLVMPointerType elemPtrType, | ||||
Value elementPtrPtr = | Value offset) { | ||||
builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr); | auto [elementPtrPtr, elemPtrPtrType] = | ||||
castToElemPtrPtr(builder, loc, memRefDescPtr, elemPtrType); | |||||
Value offsetGep = builder.create<LLVM::GEPOp>( | Value offsetGep = | ||||
loc, elemPtrPtrType, elementPtrPtr, ArrayRef<LLVM::GEPArg>{2}); | builder.create<LLVM::GEPOp>(loc, elemPtrPtrType, elemPtrType, | ||||
elementPtrPtr, ArrayRef<LLVM::GEPArg>{2}); | |||||
if (!elemPtrType.isOpaque()) { | |||||
offsetGep = builder.create<LLVM::BitcastOp>( | offsetGep = builder.create<LLVM::BitcastOp>( | ||||
loc, LLVM::LLVMPointerType::get(typeConverter.getIndexType()), offsetGep); | loc, LLVM::LLVMPointerType::get(typeConverter.getIndexType()), | ||||
offsetGep); | |||||
} | |||||
builder.create<LLVM::StoreOp>(loc, offset, offsetGep); | builder.create<LLVM::StoreOp>(loc, offset, offsetGep); | ||||
} | } | ||||
Value UnrankedMemRefDescriptor::sizeBasePtr( | Value UnrankedMemRefDescriptor::sizeBasePtr(OpBuilder &builder, Location loc, | ||||
OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, | LLVMTypeConverter &typeConverter, | ||||
Value memRefDescPtr, LLVM::LLVMPointerType elemPtrPtrType) { | Value memRefDescPtr, | ||||
Type elemPtrTy = elemPtrPtrType.getElementType(); | LLVM::LLVMPointerType elemPtrType) { | ||||
Type indexTy = typeConverter.getIndexType(); | Type indexTy = typeConverter.getIndexType(); | ||||
Type structPtrTy = | Type structTy = LLVM::LLVMStructType::getLiteral( | ||||
LLVM::LLVMPointerType::get(LLVM::LLVMStructType::getLiteral( | indexTy.getContext(), {elemPtrType, elemPtrType, indexTy, indexTy}); | ||||
indexTy.getContext(), {elemPtrTy, elemPtrTy, indexTy, indexTy})); | Value structPtr; | ||||
Value structPtr = | if (elemPtrType.isOpaque()) { | ||||
structPtr = memRefDescPtr; | |||||
} else { | |||||
Type structPtrTy = LLVM::LLVMPointerType::get(structTy); | |||||
structPtr = | |||||
builder.create<LLVM::BitcastOp>(loc, structPtrTy, memRefDescPtr); | builder.create<LLVM::BitcastOp>(loc, structPtrTy, memRefDescPtr); | ||||
} | |||||
return builder.create<LLVM::GEPOp>(loc, LLVM::LLVMPointerType::get(indexTy), | auto resultType = elemPtrType.isOpaque() | ||||
structPtr, ArrayRef<LLVM::GEPArg>{0, 3}); | ? LLVM::LLVMPointerType::get(indexTy.getContext()) | ||||
: LLVM::LLVMPointerType::get(indexTy); | |||||
return builder.create<LLVM::GEPOp>(loc, resultType, structTy, structPtr, | |||||
ArrayRef<LLVM::GEPArg>{0, 3}); | |||||
} | } | ||||
Value UnrankedMemRefDescriptor::size(OpBuilder &builder, Location loc, | Value UnrankedMemRefDescriptor::size(OpBuilder &builder, Location loc, | ||||
LLVMTypeConverter &typeConverter, | LLVMTypeConverter &typeConverter, | ||||
Value sizeBasePtr, Value index) { | Value sizeBasePtr, Value index) { | ||||
Type indexPtrTy = LLVM::LLVMPointerType::get(typeConverter.getIndexType()); | |||||
Type indexTy = typeConverter.getIndexType(); | |||||
Type indexPtrTy = typeConverter.getPointer(indexTy); | |||||
Value sizeStoreGep = | Value sizeStoreGep = | ||||
builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizeBasePtr, index); | builder.create<LLVM::GEPOp>(loc, indexPtrTy, indexTy, sizeBasePtr, index); | ||||
return builder.create<LLVM::LoadOp>(loc, sizeStoreGep); | return builder.create<LLVM::LoadOp>(loc, indexTy, sizeStoreGep); | ||||
} | } | ||||
void UnrankedMemRefDescriptor::setSize(OpBuilder &builder, Location loc, | void UnrankedMemRefDescriptor::setSize(OpBuilder &builder, Location loc, | ||||
LLVMTypeConverter &typeConverter, | LLVMTypeConverter &typeConverter, | ||||
Value sizeBasePtr, Value index, | Value sizeBasePtr, Value index, | ||||
Value size) { | Value size) { | ||||
Type indexPtrTy = LLVM::LLVMPointerType::get(typeConverter.getIndexType()); | Type indexTy = typeConverter.getIndexType(); | ||||
Type indexPtrTy = typeConverter.getPointer(indexTy); | |||||
Value sizeStoreGep = | Value sizeStoreGep = | ||||
builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizeBasePtr, index); | builder.create<LLVM::GEPOp>(loc, indexPtrTy, indexTy, sizeBasePtr, index); | ||||
builder.create<LLVM::StoreOp>(loc, size, sizeStoreGep); | builder.create<LLVM::StoreOp>(loc, size, sizeStoreGep); | ||||
} | } | ||||
Value UnrankedMemRefDescriptor::strideBasePtr(OpBuilder &builder, Location loc, | Value UnrankedMemRefDescriptor::strideBasePtr(OpBuilder &builder, Location loc, | ||||
LLVMTypeConverter &typeConverter, | LLVMTypeConverter &typeConverter, | ||||
Value sizeBasePtr, Value rank) { | Value sizeBasePtr, Value rank) { | ||||
Type indexPtrTy = LLVM::LLVMPointerType::get(typeConverter.getIndexType()); | Type indexTy = typeConverter.getIndexType(); | ||||
return builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizeBasePtr, rank); | Type indexPtrTy = typeConverter.getPointer(indexTy); | ||||
return builder.create<LLVM::GEPOp>(loc, indexPtrTy, indexTy, sizeBasePtr, | |||||
rank); | |||||
} | } | ||||
Value UnrankedMemRefDescriptor::stride(OpBuilder &builder, Location loc, | Value UnrankedMemRefDescriptor::stride(OpBuilder &builder, Location loc, | ||||
LLVMTypeConverter &typeConverter, | LLVMTypeConverter &typeConverter, | ||||
Value strideBasePtr, Value index, | Value strideBasePtr, Value index, | ||||
Value stride) { | Value stride) { | ||||
Type indexPtrTy = LLVM::LLVMPointerType::get(typeConverter.getIndexType()); | Type indexTy = typeConverter.getIndexType(); | ||||
Value strideStoreGep = | Type indexPtrTy = typeConverter.getPointer(indexTy); | ||||
builder.create<LLVM::GEPOp>(loc, indexPtrTy, strideBasePtr, index); | |||||
return builder.create<LLVM::LoadOp>(loc, strideStoreGep); | Value strideStoreGep = builder.create<LLVM::GEPOp>(loc, indexPtrTy, indexTy, | ||||
strideBasePtr, index); | |||||
return builder.create<LLVM::LoadOp>(loc, indexTy, strideStoreGep); | |||||
} | } | ||||
void UnrankedMemRefDescriptor::setStride(OpBuilder &builder, Location loc, | void UnrankedMemRefDescriptor::setStride(OpBuilder &builder, Location loc, | ||||
LLVMTypeConverter &typeConverter, | LLVMTypeConverter &typeConverter, | ||||
Value strideBasePtr, Value index, | Value strideBasePtr, Value index, | ||||
Value stride) { | Value stride) { | ||||
Type indexPtrTy = LLVM::LLVMPointerType::get(typeConverter.getIndexType()); | Type indexTy = typeConverter.getIndexType(); | ||||
Value strideStoreGep = | Type indexPtrTy = typeConverter.getPointer(indexTy); | ||||
builder.create<LLVM::GEPOp>(loc, indexPtrTy, strideBasePtr, index); | |||||
Value strideStoreGep = builder.create<LLVM::GEPOp>(loc, indexPtrTy, indexTy, | |||||
strideBasePtr, index); | |||||
builder.create<LLVM::StoreOp>(loc, stride, strideStoreGep); | builder.create<LLVM::StoreOp>(loc, stride, strideStoreGep); | ||||
} | } |