Changeset View
Standalone View
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
Show First 20 Lines • Show All 859 Lines • ▼ Show 20 Lines | for (UnrankedMemRefDescriptor desc : values) { | ||||
// Total allocation size. | // Total allocation size. | ||||
Value allocationSize = builder.create<LLVM::AddOp>( | Value allocationSize = builder.create<LLVM::AddOp>( | ||||
loc, indexType, doublePointerSize, rankIndexSize); | loc, indexType, doublePointerSize, rankIndexSize); | ||||
sizes.push_back(allocationSize); | sizes.push_back(allocationSize); | ||||
} | } | ||||
} | } | ||||
Value UnrankedMemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc, | |||||
Value memRefDescPtr, | |||||
LLVM::LLVMType elemPtrPtrType) { | |||||
Value elementPtrPtr = | |||||
builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr); | |||||
return builder.create<LLVM::LoadOp>(loc, elementPtrPtr); | |||||
} | |||||
void UnrankedMemRefDescriptor::setAllocatedPtr(OpBuilder &builder, Location loc, | |||||
Value memRefDescPtr, | |||||
LLVM::LLVMType elemPtrPtrType, | |||||
Value allocatedPtr) { | |||||
Value elementPtrPtr = | |||||
builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr); | |||||
builder.create<LLVM::StoreOp>(loc, allocatedPtr, elementPtrPtr); | |||||
} | |||||
Value UnrankedMemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc, | |||||
LLVMTypeConverter &typeConverter, | |||||
Value memRefDescPtr, | |||||
LLVM::LLVMType elemPtrPtrType) { | |||||
Value elementPtrPtr = | |||||
builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr); | |||||
Value one = | |||||
createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 1); | |||||
Value alignedGep = builder.create<LLVM::GEPOp>( | |||||
loc, elemPtrPtrType, elementPtrPtr, ValueRange({one})); | |||||
return builder.create<LLVM::LoadOp>(loc, alignedGep); | |||||
} | |||||
void UnrankedMemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc, | |||||
LLVMTypeConverter &typeConverter, | |||||
Value memRefDescPtr, | |||||
LLVM::LLVMType elemPtrPtrType, | |||||
Value alignedPtr) { | |||||
Value elementPtrPtr = | |||||
builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr); | |||||
Value one = | |||||
createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 1); | |||||
Value alignedGep = builder.create<LLVM::GEPOp>( | |||||
loc, elemPtrPtrType, elementPtrPtr, ValueRange({one})); | |||||
builder.create<LLVM::StoreOp>(loc, alignedPtr, alignedGep); | |||||
} | |||||
Value UnrankedMemRefDescriptor::offset(OpBuilder &builder, Location loc, | |||||
LLVMTypeConverter &typeConverter, | |||||
Value memRefDescPtr, | |||||
LLVM::LLVMType elemPtrPtrType) { | |||||
Value elementPtrPtr = | |||||
builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr); | |||||
Value two = | |||||
createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 2); | |||||
Value offsetGep = builder.create<LLVM::GEPOp>( | |||||
loc, elemPtrPtrType, elementPtrPtr, ValueRange({two})); | |||||
offsetGep = builder.create<LLVM::BitcastOp>( | |||||
loc, typeConverter.getIndexType().getPointerTo(), offsetGep); | |||||
return builder.create<LLVM::LoadOp>(loc, offsetGep); | |||||
} | |||||
void UnrankedMemRefDescriptor::setOffset(OpBuilder &builder, Location loc, | |||||
LLVMTypeConverter &typeConverter, | |||||
Value memRefDescPtr, | |||||
LLVM::LLVMType elemPtrPtrType, | |||||
Value offset) { | |||||
Value elementPtrPtr = | |||||
builder.create<LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr); | |||||
Value two = | |||||
createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 2); | |||||
Value offsetGep = builder.create<LLVM::GEPOp>( | |||||
loc, elemPtrPtrType, elementPtrPtr, ValueRange({two})); | |||||
offsetGep = builder.create<LLVM::BitcastOp>( | |||||
loc, typeConverter.getIndexType().getPointerTo(), offsetGep); | |||||
builder.create<LLVM::StoreOp>(loc, offset, offsetGep); | |||||
} | |||||
Value UnrankedMemRefDescriptor::sizeBasePtr(OpBuilder &builder, Location loc, | |||||
LLVMTypeConverter &typeConverter, | |||||
Value memRefDescPtr, | |||||
LLVM::LLVMType elemPtrPtrType) { | |||||
LLVM::LLVMType elemPtrTy = elemPtrPtrType.getPointerElementTy(); | |||||
LLVM::LLVMType indexTy = typeConverter.getIndexType(); | |||||
LLVM::LLVMType structPtrTy = | |||||
LLVM::LLVMType::getStructTy(elemPtrTy, elemPtrTy, indexTy, indexTy) | |||||
.getPointerTo(); | |||||
Value structPtr = | |||||
builder.create<LLVM::BitcastOp>(loc, structPtrTy, memRefDescPtr); | |||||
LLVM::LLVMType int32_type = | |||||
unwrap(typeConverter.convertType(builder.getI32Type())); | |||||
Value zero = | |||||
createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 0); | |||||
Value three = builder.create<LLVM::ConstantOp>(loc, int32_type, | |||||
builder.getI32IntegerAttr(3)); | |||||
return builder.create<LLVM::GEPOp>(loc, indexTy.getPointerTo(), structPtr, | |||||
ValueRange({zero, three})); | |||||
} | |||||
Value UnrankedMemRefDescriptor::size(OpBuilder &builder, Location loc, | |||||
LLVMTypeConverter typeConverter, | |||||
Value sizeBasePtr, Value index) { | |||||
LLVM::LLVMType indexPtrTy = typeConverter.getIndexType().getPointerTo(); | |||||
Value sizeStoreGep = builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizeBasePtr, | |||||
ValueRange({index})); | |||||
return builder.create<LLVM::LoadOp>(loc, sizeStoreGep); | |||||
} | |||||
void UnrankedMemRefDescriptor::setSize(OpBuilder &builder, Location loc, | |||||
LLVMTypeConverter typeConverter, | |||||
Value sizeBasePtr, Value index, | |||||
Value size) { | |||||
LLVM::LLVMType indexPtrTy = typeConverter.getIndexType().getPointerTo(); | |||||
Value sizeStoreGep = builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizeBasePtr, | |||||
ValueRange({index})); | |||||
builder.create<LLVM::StoreOp>(loc, size, sizeStoreGep); | |||||
} | |||||
Value UnrankedMemRefDescriptor::strideBasePtr(OpBuilder &builder, Location loc, | |||||
LLVMTypeConverter &typeConverter, | |||||
Value sizeBasePtr, Value rank) { | |||||
LLVM::LLVMType indexPtrTy = typeConverter.getIndexType().getPointerTo(); | |||||
return builder.create<LLVM::GEPOp>(loc, indexPtrTy, sizeBasePtr, | |||||
ValueRange({rank})); | |||||
} | |||||
Value UnrankedMemRefDescriptor::stride(OpBuilder &builder, Location loc, | |||||
LLVMTypeConverter typeConverter, | |||||
Value strideBasePtr, Value index, | |||||
Value stride) { | |||||
LLVM::LLVMType indexPtrTy = typeConverter.getIndexType().getPointerTo(); | |||||
Value strideStoreGep = builder.create<LLVM::GEPOp>( | |||||
loc, indexPtrTy, strideBasePtr, ValueRange({index})); | |||||
return builder.create<LLVM::LoadOp>(loc, strideStoreGep); | |||||
} | |||||
void UnrankedMemRefDescriptor::setStride(OpBuilder &builder, Location loc, | |||||
LLVMTypeConverter typeConverter, | |||||
Value strideBasePtr, Value index, | |||||
Value stride) { | |||||
LLVM::LLVMType indexPtrTy = typeConverter.getIndexType().getPointerTo(); | |||||
Value strideStoreGep = builder.create<LLVM::GEPOp>( | |||||
loc, indexPtrTy, strideBasePtr, ValueRange({index})); | |||||
builder.create<LLVM::StoreOp>(loc, stride, strideStoreGep); | |||||
} | |||||
LLVM::LLVMDialect &ConvertToLLVMPattern::getDialect() const { | LLVM::LLVMDialect &ConvertToLLVMPattern::getDialect() const { | ||||
return *typeConverter.getDialect(); | return *typeConverter.getDialect(); | ||||
} | } | ||||
LLVM::LLVMType ConvertToLLVMPattern::getIndexType() const { | LLVM::LLVMType ConvertToLLVMPattern::getIndexType() const { | ||||
return typeConverter.getIndexType(); | return typeConverter.getIndexType(); | ||||
} | } | ||||
▲ Show 20 Lines • Show All 1,536 Lines • ▼ Show 20 Lines | if (srcType.isa<MemRefType>() && dstType.isa<UnrankedMemRefType>()) { | ||||
auto loadOp = rewriter.create<LLVM::LoadOp>(loc, castPtr); | auto loadOp = rewriter.create<LLVM::LoadOp>(loc, castPtr); | ||||
rewriter.replaceOp(op, loadOp.getResult()); | rewriter.replaceOp(op, loadOp.getResult()); | ||||
} else { | } else { | ||||
llvm_unreachable("Unsupported unranked memref to unranked memref cast"); | llvm_unreachable("Unsupported unranked memref to unranked memref cast"); | ||||
} | } | ||||
} | } | ||||
}; | }; | ||||
/// Extracts allocated, aligned pointers and offset from a ranked or unranked | |||||
ftynse: Nit: could we have a doc for this function? | |||||
/// memref type. In unranked case, the fields are extracted from the underlying | |||||
/// ranked descriptor. | |||||
static void extractPointersAndOffset(Location loc, | |||||
ConversionPatternRewriter &rewriter, | |||||
LLVMTypeConverter &typeConverter, | |||||
Value originalOperand, | |||||
Value convertedOperand, | |||||
Value *allocatedPtr, Value *alignedPtr, | |||||
Value *offset = nullptr) { | |||||
Type operandType = originalOperand.getType(); | |||||
if (operandType.isa<MemRefType>()) { | |||||
MemRefDescriptor desc(convertedOperand); | |||||
*allocatedPtr = desc.allocatedPtr(rewriter, loc); | |||||
*alignedPtr = desc.alignedPtr(rewriter, loc); | |||||
if (offset != nullptr) | |||||
*offset = desc.offset(rewriter, loc); | |||||
return; | |||||
} | |||||
unsigned memorySpace = | |||||
operandType.cast<UnrankedMemRefType>().getMemorySpace(); | |||||
Type elementType = operandType.cast<UnrankedMemRefType>().getElementType(); | |||||
LLVM::LLVMType llvmElementType = | |||||
unwrap(typeConverter.convertType(elementType)); | |||||
LLVM::LLVMType elementPtrPtrType = | |||||
llvmElementType.getPointerTo(memorySpace).getPointerTo(); | |||||
// Extract pointer to the underlying ranked memref descriptor and cast it to | |||||
// ElemType**. | |||||
UnrankedMemRefDescriptor unrankedDesc(convertedOperand); | |||||
Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc); | |||||
*allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr( | |||||
rewriter, loc, underlyingDescPtr, elementPtrPtrType); | |||||
*alignedPtr = UnrankedMemRefDescriptor::alignedPtr( | |||||
rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType); | |||||
Nit: would it be reasonable to have these things as additional accessors on UnrankedMemRefDescriptor? ftynse: Nit: would it be reasonable to have these things as additional accessors on… | |||||
if (offset != nullptr) { | |||||
*offset = UnrankedMemRefDescriptor::offset( | |||||
rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType); | |||||
} | |||||
} | |||||
struct MemRefReinterpretCastOpLowering | struct MemRefReinterpretCastOpLowering | ||||
: public ConvertOpToLLVMPattern<MemRefReinterpretCastOp> { | : public ConvertOpToLLVMPattern<MemRefReinterpretCastOp> { | ||||
using ConvertOpToLLVMPattern<MemRefReinterpretCastOp>::ConvertOpToLLVMPattern; | using ConvertOpToLLVMPattern<MemRefReinterpretCastOp>::ConvertOpToLLVMPattern; | ||||
LogicalResult | LogicalResult | ||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands, | matchAndRewrite(Operation *op, ArrayRef<Value> operands, | ||||
ConversionPatternRewriter &rewriter) const override { | ConversionPatternRewriter &rewriter) const override { | ||||
auto castOp = cast<MemRefReinterpretCastOp>(op); | auto castOp = cast<MemRefReinterpretCastOp>(op); | ||||
MemRefReinterpretCastOp::Adaptor adaptor(operands, op->getAttrDictionary()); | MemRefReinterpretCastOp::Adaptor adaptor(operands, op->getAttrDictionary()); | ||||
I wonder if we cannot run into issues because of alignment properties on some weird architectures. The descriptor structure is not packed, so it is subject to alignment rules between elements as defined by LLVM's data layout. I don't have an example offhand, so it may be always fine, so I'd appreciate an argument why it is a safe thing to do in a comment. Otherwise, LLVMTypeConverter contains the llvm::DataLayout that we are targeting, which can be used to get proper offsets of elements in bytes and do all indexing arithmetic after bitcasting to i8*. ftynse: I wonder if we cannot run into issues because of alignment properties on some weird… | |||||
I am not sure how many architectures, not to mention weird ones, are using unranked code generation. I would leave it like that for now and use llvm::DataLayout later if needed. pifon2a: I am not sure how many architectures, not to mention weird ones, are using unranked code… | |||||
Let's keep a TODO comment then. If we ever run into a problem, will be easier to debug. ftynse: Let's keep a TODO comment then. If we ever run into a problem, will be easier to debug. | |||||
Type srcType = castOp.source().getType(); | Type srcType = castOp.source().getType(); | ||||
Value descriptor; | Value descriptor; | ||||
if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp, | if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp, | ||||
adaptor, &descriptor))) | adaptor, &descriptor))) | ||||
return failure(); | return failure(); | ||||
rewriter.replaceOp(op, {descriptor}); | rewriter.replaceOp(op, {descriptor}); | ||||
return success(); | return success(); | ||||
Show All 13 Lines | if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy()) | ||||
return failure(); | return failure(); | ||||
// Create descriptor. | // Create descriptor. | ||||
Location loc = castOp.getLoc(); | Location loc = castOp.getLoc(); | ||||
auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); | auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); | ||||
// Set allocated and aligned pointers. | // Set allocated and aligned pointers. | ||||
Value allocatedPtr, alignedPtr; | Value allocatedPtr, alignedPtr; | ||||
extractPointers(loc, rewriter, castOp.source(), adaptor.source(), | extractPointersAndOffset(loc, rewriter, typeConverter, castOp.source(), | ||||
&allocatedPtr, &alignedPtr); | adaptor.source(), &allocatedPtr, &alignedPtr); | ||||
desc.setAllocatedPtr(rewriter, loc, allocatedPtr); | desc.setAllocatedPtr(rewriter, loc, allocatedPtr); | ||||
desc.setAlignedPtr(rewriter, loc, alignedPtr); | desc.setAlignedPtr(rewriter, loc, alignedPtr); | ||||
// Set offset. | // Set offset. | ||||
if (castOp.isDynamicOffset(0)) | if (castOp.isDynamicOffset(0)) | ||||
desc.setOffset(rewriter, loc, adaptor.offsets()[0]); | desc.setOffset(rewriter, loc, adaptor.offsets()[0]); | ||||
else | else | ||||
desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0)); | desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0)); | ||||
Show All 10 Lines | for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) { | ||||
if (castOp.isDynamicStride(i)) | if (castOp.isDynamicStride(i)) | ||||
desc.setStride(rewriter, loc, i, adaptor.strides()[dynStrideId++]); | desc.setStride(rewriter, loc, i, adaptor.strides()[dynStrideId++]); | ||||
else | else | ||||
desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i)); | desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i)); | ||||
} | } | ||||
*descriptor = desc; | *descriptor = desc; | ||||
return success(); | return success(); | ||||
} | } | ||||
}; | |||||
void extractPointers(Location loc, ConversionPatternRewriter &rewriter, | struct MemRefReshapeOpLowering | ||||
Value originalOperand, Value convertedOperand, | : public ConvertOpToLLVMPattern<MemRefReshapeOp> { | ||||
Value *allocatedPtr, Value *alignedPtr) const { | using ConvertOpToLLVMPattern<MemRefReshapeOp>::ConvertOpToLLVMPattern; | ||||
Type operandType = originalOperand.getType(); | |||||
if (operandType.isa<MemRefType>()) { | LogicalResult | ||||
MemRefDescriptor desc(convertedOperand); | matchAndRewrite(Operation *op, ArrayRef<Value> operands, | ||||
*allocatedPtr = desc.allocatedPtr(rewriter, loc); | ConversionPatternRewriter &rewriter) const override { | ||||
*alignedPtr = desc.alignedPtr(rewriter, loc); | auto reshapeOp = cast<MemRefReshapeOp>(op); | ||||
return; | |||||
MemRefReshapeOp::Adaptor adaptor(operands, op->getAttrDictionary()); | |||||
Nit: LLVM::LLVMType was fine here ftynse: Nit: LLVM::LLVMType was fine here | |||||
Type srcType = reshapeOp.source().getType(); | |||||
Please fix ftynse: Please fix | |||||
Value descriptor; | |||||
if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp, | |||||
adaptor, &descriptor))) | |||||
return failure(); | |||||
rewriter.replaceOp(op, {descriptor}); | |||||
return success(); | |||||
} | } | ||||
unsigned memorySpace = | private: | ||||
operandType.cast<UnrankedMemRefType>().getMemorySpace(); | LogicalResult | ||||
Type elementType = operandType.cast<UnrankedMemRefType>().getElementType(); | convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter, | ||||
Type srcType, MemRefReshapeOp reshapeOp, | |||||
MemRefReshapeOp::Adaptor adaptor, | |||||
Value *descriptor) const { | |||||
// Conversion for statically-known shape args is performed via | |||||
// `memref_reinterpret_cast`. | |||||
auto shapeMemRefType = reshapeOp.shape().getType().cast<MemRefType>(); | |||||
if (shapeMemRefType.hasStaticShape()) | |||||
return failure(); | |||||
// The shape is a rank-1 tensor with unknown length. | |||||
Location loc = reshapeOp.getLoc(); | |||||
MemRefDescriptor shapeDesc(adaptor.shape()); | |||||
Value resultRank = shapeDesc.size(rewriter, loc, 0); | |||||
// Extract address space and element type. | |||||
Can't we use createIndexConstant instead here? Or at least not hardcode i32, I think there was some option controlling the bit size of the address arithmetic ftynse: Can't we use createIndexConstant instead here? Or at least not hardcode i32, I think there was… | |||||
auto targetType = | |||||
reshapeOp.getResult().getType().cast<UnrankedMemRefType>(); | |||||
unsigned addressSpace = targetType.getMemorySpace(); | |||||
Type elementType = targetType.getElementType(); | |||||
// Create the unranked memref descriptor that holds the ranked one. The | |||||
// inner descriptor is allocated on stack. | |||||
auto targetDesc = UnrankedMemRefDescriptor::undef( | |||||
rewriter, loc, unwrap(typeConverter.convertType(targetType))); | |||||
Putting these addressing tricks (which are cool, I must admit!) into the UnrankedMemRefDescriptor sounds even more appealing to me. Since we seem to always need the triple allocated/aligned/store, we can have a function for those. ftynse: Putting these addressing tricks (which are cool, I must admit!) into the… | |||||
I added setters/getters to UnrankedMemRefDescriptor. At first I intended to have a separate PR that adds them, but it looks like the current PR is the best way to actually test them. pifon2a: I added setters/getters to `UnrankedMemRefDescriptor`. At first I intended to have a separate… | |||||
targetDesc.setRank(rewriter, loc, resultRank); | |||||
SmallVector<Value, 4> sizes; | |||||
UnrankedMemRefDescriptor::computeSizes(rewriter, loc, typeConverter, | |||||
targetDesc, sizes); | |||||
Value underlyingDescPtr = rewriter.create<LLVM::AllocaOp>( | |||||
loc, getVoidPtrType(), sizes.front(), llvm::None); | |||||
targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr); | |||||
// Extract pointers and offset from the source memref. | |||||
Value allocatedPtr, alignedPtr, offset; | |||||
extractPointersAndOffset(loc, rewriter, typeConverter, reshapeOp.source(), | |||||
adaptor.source(), &allocatedPtr, &alignedPtr, | |||||
&offset); | |||||
// Set pointers and offset. | |||||
LLVM::LLVMType llvmElementType = | LLVM::LLVMType llvmElementType = | ||||
typeConverter.convertType(elementType).cast<LLVM::LLVMType>(); | unwrap(typeConverter.convertType(elementType)); | ||||
LLVM::LLVMType elementPtrPtrType = | LLVM::LLVMType elementPtrPtrType = | ||||
llvmElementType.getPointerTo(memorySpace).getPointerTo(); | llvmElementType.getPointerTo(addressSpace).getPointerTo(); | ||||
UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr, | |||||
elementPtrPtrType, allocatedPtr); | |||||
UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, typeConverter, | |||||
addArgument is invalid in conversion patterns, similarly to other in-place updates. rewriter.createBlock creates block with arguments, but you'd need to clone the remaining operations in there. Maybe splitBlock can be extended to also add arguments to the newly created block. (I suspect the addArgument will most likely just work even if we rollback the change, but it may run into some bad use-def loop) ftynse: `addArgument` is invalid in conversion patterns, similarly to other in-place updates. `rewriter. | |||||
extending splitBlock looks much harder than just using createBlock and cloning the remaining ops. I ll do that tomorrow. pifon2a: extending `splitBlock` looks much harder than just using `createBlock` and cloning the… | |||||
underlyingDescPtr, | |||||
elementPtrPtrType, alignedPtr); | |||||
UnrankedMemRefDescriptor::setOffset(rewriter, loc, typeConverter, | |||||
underlyingDescPtr, elementPtrPtrType, | |||||
offset); | |||||
// Use the offset pointer as base for further addressing. Copy over the new | |||||
// shape and compute strides. For this, we create a loop from rank-1 to 0. | |||||
Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr( | |||||
rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType); | |||||
Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr( | |||||
rewriter, loc, typeConverter, targetSizesBase, resultRank); | |||||
Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc); | |||||
Value oneIndex = createIndexConstant(rewriter, loc, 1); | |||||
Value resultRankMinusOne = | |||||
rewriter.create<LLVM::SubOp>(loc, resultRank, oneIndex); | |||||
// Extract pointer to the underlying ranked memref descriptor and cast it to | Block *initBlock = rewriter.getInsertionBlock(); | ||||
// ElemType**. | LLVM::LLVMType indexType = typeConverter.getIndexType(); | ||||
UnrankedMemRefDescriptor unrankedDesc(convertedOperand); | Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint()); | ||||
Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc); | |||||
Value elementPtrPtr = rewriter.create<LLVM::BitcastOp>( | Block *condBlock = rewriter.createBlock(initBlock->getParent(), {}, | ||||
loc, elementPtrPtrType, underlyingDescPtr); | {indexType, indexType}); | ||||
(Beyond the scope): I wonder if we could later refactor this and the SCF-to-std lowering in a createLoop(function_ref bodyBuilder, function_ref conditionBuilder) that produces std control flow; the pattern infra will then lower the std ops into LLVM automatically. ftynse: (Beyond the scope): I wonder if we could later refactor this and the SCF-to-std lowering in a… | |||||
Yes, that would be much more readable. pifon2a: Yes, that would be much more readable. | |||||
// Iterate over the remaining ops in initBlock and move them to condBlock. | |||||
BlockAndValueMapping map; | |||||
for (auto it = remainingOpsIt, e = initBlock->end(); it != e; ++it) { | |||||
rewriter.clone(*it, map); | |||||
rewriter.eraseOp(&*it); | |||||
} | |||||
LLVM::LLVMType int32Type = | rewriter.setInsertionPointToEnd(initBlock); | ||||
typeConverter.convertType(rewriter.getI32Type()).cast<LLVM::LLVMType>(); | rewriter.create<LLVM::BrOp>(loc, ValueRange({resultRankMinusOne, oneIndex}), | ||||
condBlock); | |||||
rewriter.setInsertionPointToStart(condBlock); | |||||
Value indexArg = condBlock->getArgument(0); | |||||
Value strideArg = condBlock->getArgument(1); | |||||
Value zeroIndex = createIndexConstant(rewriter, loc, 0); | |||||
Value pred = rewriter.create<LLVM::ICmpOp>( | |||||
loc, LLVM::LLVMType::getInt1Ty(rewriter.getContext()), | |||||
LLVM::ICmpPredicate::sge, indexArg, zeroIndex); | |||||
Block *bodyBlock = | |||||
rewriter.splitBlock(condBlock, rewriter.getInsertionPoint()); | |||||
rewriter.setInsertionPointToStart(bodyBlock); | |||||
// Copy size from shape to descriptor. | |||||
LLVM::LLVMType llvmIndexPtrType = indexType.getPointerTo(); | |||||
Value sizeLoadGep = rewriter.create<LLVM::GEPOp>( | |||||
loc, llvmIndexPtrType, shapeOperandPtr, ValueRange{indexArg}); | |||||
Value size = rewriter.create<LLVM::LoadOp>(loc, sizeLoadGep); | |||||
UnrankedMemRefDescriptor::setSize(rewriter, loc, typeConverter, | |||||
targetSizesBase, indexArg, size); | |||||
// Write stride value and compute next one. | |||||
UnrankedMemRefDescriptor::setStride(rewriter, loc, typeConverter, | |||||
targetStridesBase, indexArg, strideArg); | |||||
Value nextStride = rewriter.create<LLVM::MulOp>(loc, strideArg, size); | |||||
// Decrement loop counter and branch back. | |||||
Value decrement = rewriter.create<LLVM::SubOp>(loc, indexArg, oneIndex); | |||||
rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, nextStride}), | |||||
condBlock); | |||||
Block *remainder = | |||||
rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint()); | |||||
// Hook up the cond exit to the remainder. | |||||
rewriter.setInsertionPointToEnd(condBlock); | |||||
rewriter.create<LLVM::CondBrOp>(loc, pred, bodyBlock, llvm::None, remainder, | |||||
llvm::None); | |||||
// Extract and set allocated pointer. | // Reset position to beginning of new remainder block. | ||||
*allocatedPtr = rewriter.create<LLVM::LoadOp>(loc, elementPtrPtr); | rewriter.setInsertionPointToStart(remainder); | ||||
// Extract and set aligned pointer. | *descriptor = targetDesc; | ||||
Value one = rewriter.create<LLVM::ConstantOp>( | return success(); | ||||
loc, int32Type, rewriter.getI32IntegerAttr(1)); | |||||
Value alignedGep = rewriter.create<LLVM::GEPOp>( | |||||
loc, elementPtrPtrType, elementPtrPtr, ValueRange({one})); | |||||
*alignedPtr = rewriter.create<LLVM::LoadOp>(loc, alignedGep); | |||||
} | } | ||||
}; | }; | ||||
struct DialectCastOpLowering | struct DialectCastOpLowering | ||||
: public ConvertOpToLLVMPattern<LLVM::DialectCastOp> { | : public ConvertOpToLLVMPattern<LLVM::DialectCastOp> { | ||||
using ConvertOpToLLVMPattern<LLVM::DialectCastOp>::ConvertOpToLLVMPattern; | using ConvertOpToLLVMPattern<LLVM::DialectCastOp>::ConvertOpToLLVMPattern; | ||||
LogicalResult | LogicalResult | ||||
▲ Show 20 Lines • Show All 1,104 Lines • ▼ Show 20 Lines | void mlir::populateStdToLLVMMemoryConversionPatterns( | ||||
// clang-format off | // clang-format off | ||||
patterns.insert< | patterns.insert< | ||||
AssumeAlignmentOpLowering, | AssumeAlignmentOpLowering, | ||||
DeallocOpLowering, | DeallocOpLowering, | ||||
DimOpLowering, | DimOpLowering, | ||||
LoadOpLowering, | LoadOpLowering, | ||||
MemRefCastOpLowering, | MemRefCastOpLowering, | ||||
MemRefReinterpretCastOpLowering, | MemRefReinterpretCastOpLowering, | ||||
MemRefReshapeOpLowering, | |||||
RankOpLowering, | RankOpLowering, | ||||
StoreOpLowering, | StoreOpLowering, | ||||
SubViewOpLowering, | SubViewOpLowering, | ||||
TransposeOpLowering, | TransposeOpLowering, | ||||
ViewOpLowering>(converter); | ViewOpLowering>(converter); | ||||
// clang-format on | // clang-format on | ||||
if (converter.getOptions().useAlignedAlloc) | if (converter.getOptions().useAlignedAlloc) | ||||
patterns.insert<AlignedAllocOpLowering>(converter); | patterns.insert<AlignedAllocOpLowering>(converter); | ||||
▲ Show 20 Lines • Show All 161 Lines • Show Last 20 Lines |
Nit: could we have a doc for this function?