Please use GitHub pull requests for new patches. Phabricator shutdown timeline
Changeset View
Changeset View
Standalone View
Standalone View
mlir/lib/Conversion/LLVMCommon/Pattern.cpp
Show First 20 Lines • Show All 42 Lines • ▼ Show 20 Lines | return IntegerType::get(&getTypeConverter()->getContext(), | ||||
getTypeConverter()->getPointerBitwidth(addressSpace)); | getTypeConverter()->getPointerBitwidth(addressSpace)); | ||||
} | } | ||||
Type ConvertToLLVMPattern::getVoidType() const { | Type ConvertToLLVMPattern::getVoidType() const { | ||||
return LLVM::LLVMVoidType::get(&getTypeConverter()->getContext()); | return LLVM::LLVMVoidType::get(&getTypeConverter()->getContext()); | ||||
} | } | ||||
Type ConvertToLLVMPattern::getVoidPtrType() const { | Type ConvertToLLVMPattern::getVoidPtrType() const { | ||||
return LLVM::LLVMPointerType::get( | return getTypeConverter()->getPointerType( | ||||
IntegerType::get(&getTypeConverter()->getContext(), 8)); | IntegerType::get(&getTypeConverter()->getContext(), 8)); | ||||
gysit: nit: Would `getTypeConverter()->getPointer(IntegerType::get(&getTypeConverter()->getContext()… | |||||
} | } | ||||
Value ConvertToLLVMPattern::createIndexAttrConstant(OpBuilder &builder, | Value ConvertToLLVMPattern::createIndexAttrConstant(OpBuilder &builder, | ||||
Location loc, | Location loc, | ||||
Type resultType, | Type resultType, | ||||
int64_t value) { | int64_t value) { | ||||
return builder.create<LLVM::ConstantOp>(loc, resultType, | return builder.create<LLVM::ConstantOp>(loc, resultType, | ||||
builder.getIndexAttr(value)); | builder.getIndexAttr(value)); | ||||
Show All 27 Lines | if (strides[i] != 1) { // Skip if stride is 1. | ||||
: createIndexConstant(rewriter, loc, strides[i]); | : createIndexConstant(rewriter, loc, strides[i]); | ||||
increment = rewriter.create<LLVM::MulOp>(loc, increment, stride); | increment = rewriter.create<LLVM::MulOp>(loc, increment, stride); | ||||
} | } | ||||
index = | index = | ||||
index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment; | index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment; | ||||
} | } | ||||
Type elementPtrType = memRefDescriptor.getElementPtrType(); | Type elementPtrType = memRefDescriptor.getElementPtrType(); | ||||
return index ? rewriter.create<LLVM::GEPOp>(loc, elementPtrType, base, index) | return index ? rewriter.create<LLVM::GEPOp>( | ||||
loc, elementPtrType, | |||||
getTypeConverter()->convertType(type.getElementType()), | |||||
base, index) | |||||
: base; | : base; | ||||
} | } | ||||
// Check if the MemRefType `type` is supported by the lowering. We currently | // Check if the MemRefType `type` is supported by the lowering. We currently | ||||
// only support memrefs with identity maps. | // only support memrefs with identity maps. | ||||
bool ConvertToLLVMPattern::isConvertibleAndHasIdentityMaps( | bool ConvertToLLVMPattern::isConvertibleAndHasIdentityMaps( | ||||
MemRefType type) const { | MemRefType type) const { | ||||
if (!typeConverter->convertType(type.getElementType())) | if (!typeConverter->convertType(type.getElementType())) | ||||
return false; | return false; | ||||
return type.getLayout().isIdentity(); | return type.getLayout().isIdentity(); | ||||
} | } | ||||
Type ConvertToLLVMPattern::getElementPtrType(MemRefType type) const { | Type ConvertToLLVMPattern::getElementPtrType(MemRefType type) const { | ||||
auto elementType = type.getElementType(); | auto elementType = type.getElementType(); | ||||
auto structElementType = typeConverter->convertType(elementType); | auto structElementType = typeConverter->convertType(elementType); | ||||
return LLVM::LLVMPointerType::get(structElementType, | return getTypeConverter()->getPointerType(structElementType, | ||||
type.getMemorySpaceAsInt()); | type.getMemorySpaceAsInt()); | ||||
nit: Could this also be a use case for getPointer()? gysit: nit: Could this also be a use case for `getPointer()`? | |||||
} | } | ||||
void ConvertToLLVMPattern::getMemRefDescriptorSizes( | void ConvertToLLVMPattern::getMemRefDescriptorSizes( | ||||
Location loc, MemRefType memRefType, ValueRange dynamicSizes, | Location loc, MemRefType memRefType, ValueRange dynamicSizes, | ||||
ConversionPatternRewriter &rewriter, SmallVectorImpl<Value> &sizes, | ConversionPatternRewriter &rewriter, SmallVectorImpl<Value> &sizes, | ||||
SmallVectorImpl<Value> &strides, Value &sizeBytes) const { | SmallVectorImpl<Value> &strides, Value &sizeBytes) const { | ||||
assert(isConvertibleAndHasIdentityMaps(memRefType) && | assert(isConvertibleAndHasIdentityMaps(memRefType) && | ||||
"layout maps must have been normalized away"); | "layout maps must have been normalized away"); | ||||
Show All 30 Lines | for (auto i = memRefType.getRank(); i-- > 0;) { | ||||
else if (stride == ShapedType::kDynamic) | else if (stride == ShapedType::kDynamic) | ||||
runningStride = | runningStride = | ||||
rewriter.create<LLVM::MulOp>(loc, runningStride, sizes[i]); | rewriter.create<LLVM::MulOp>(loc, runningStride, sizes[i]); | ||||
else | else | ||||
runningStride = createIndexConstant(rewriter, loc, stride); | runningStride = createIndexConstant(rewriter, loc, stride); | ||||
} | } | ||||
// Buffer size in bytes. | // Buffer size in bytes. | ||||
Type elementPtrType = getElementPtrType(memRefType); | Type elementType = typeConverter->convertType(memRefType.getElementType()); | ||||
Type elementPtrType = getTypeConverter()->getPointerType(elementType); | |||||
Value nullPtr = rewriter.create<LLVM::NullOp>(loc, elementPtrType); | Value nullPtr = rewriter.create<LLVM::NullOp>(loc, elementPtrType); | ||||
Value gepPtr = | Value gepPtr = rewriter.create<LLVM::GEPOp>(loc, elementPtrType, elementType, | ||||
rewriter.create<LLVM::GEPOp>(loc, elementPtrType, nullPtr, runningStride); | nullPtr, runningStride); | ||||
sizeBytes = rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr); | sizeBytes = rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr); | ||||
} | } | ||||
Value ConvertToLLVMPattern::getSizeInBytes( | Value ConvertToLLVMPattern::getSizeInBytes( | ||||
Location loc, Type type, ConversionPatternRewriter &rewriter) const { | Location loc, Type type, ConversionPatternRewriter &rewriter) const { | ||||
// Compute the size of an individual element. This emits the MLIR equivalent | // Compute the size of an individual element. This emits the MLIR equivalent | ||||
// of the following sizeof(...) implementation in LLVM IR: | // of the following sizeof(...) implementation in LLVM IR: | ||||
// %0 = getelementptr %elementType* null, %indexType 1 | // %0 = getelementptr %elementType* null, %indexType 1 | ||||
// %1 = ptrtoint %elementType* %0 to %indexType | // %1 = ptrtoint %elementType* %0 to %indexType | ||||
// which is a common pattern of getting the size of a type in bytes. | // which is a common pattern of getting the size of a type in bytes. | ||||
auto convertedPtrType = | Type llvmType = typeConverter->convertType(type); | ||||
LLVM::LLVMPointerType::get(typeConverter->convertType(type)); | auto convertedPtrType = getTypeConverter()->getPointerType(llvmType); | ||||
auto nullPtr = rewriter.create<LLVM::NullOp>(loc, convertedPtrType); | auto nullPtr = rewriter.create<LLVM::NullOp>(loc, convertedPtrType); | ||||
auto gep = rewriter.create<LLVM::GEPOp>(loc, convertedPtrType, nullPtr, | auto gep = rewriter.create<LLVM::GEPOp>(loc, convertedPtrType, llvmType, | ||||
ArrayRef<LLVM::GEPArg>{1}); | nullPtr, ArrayRef<LLVM::GEPArg>{1}); | ||||
return rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gep); | return rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gep); | ||||
} | } | ||||
Value ConvertToLLVMPattern::getNumElements( | Value ConvertToLLVMPattern::getNumElements( | ||||
Location loc, ArrayRef<Value> shape, | Location loc, ArrayRef<Value> shape, | ||||
ConversionPatternRewriter &rewriter) const { | ConversionPatternRewriter &rewriter) const { | ||||
// Compute the total number of memref elements. | // Compute the total number of memref elements. | ||||
Value numElements = | Value numElements = | ||||
▲ Show 20 Lines • Show All 49 Lines • ▼ Show 20 Lines | LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors( | ||||
// Compute allocation sizes. | // Compute allocation sizes. | ||||
SmallVector<Value, 4> sizes; | SmallVector<Value, 4> sizes; | ||||
UnrankedMemRefDescriptor::computeSizes(builder, loc, *getTypeConverter(), | UnrankedMemRefDescriptor::computeSizes(builder, loc, *getTypeConverter(), | ||||
unrankedMemrefs, sizes); | unrankedMemrefs, sizes); | ||||
// Get frequently used types. | // Get frequently used types. | ||||
MLIRContext *context = builder.getContext(); | MLIRContext *context = builder.getContext(); | ||||
Type voidPtrType = LLVM::LLVMPointerType::get(IntegerType::get(context, 8)); | |||||
auto i1Type = IntegerType::get(context, 1); | auto i1Type = IntegerType::get(context, 1); | ||||
Type indexType = getTypeConverter()->getIndexType(); | Type indexType = getTypeConverter()->getIndexType(); | ||||
// Find the malloc and free, or declare them if necessary. | // Find the malloc and free, or declare them if necessary. | ||||
auto module = builder.getInsertionPoint()->getParentOfType<ModuleOp>(); | auto module = builder.getInsertionPoint()->getParentOfType<ModuleOp>(); | ||||
LLVM::LLVMFuncOp freeFunc, mallocFunc; | LLVM::LLVMFuncOp freeFunc, mallocFunc; | ||||
if (toDynamic) | if (toDynamic) | ||||
mallocFunc = LLVM::lookupOrCreateMallocFn(module, indexType); | mallocFunc = LLVM::lookupOrCreateMallocFn( | ||||
module, indexType, getTypeConverter()->useOpaquePointers()); | |||||
if (!toDynamic) | if (!toDynamic) | ||||
freeFunc = LLVM::lookupOrCreateFreeFn(module); | freeFunc = LLVM::lookupOrCreateFreeFn( | ||||
module, getTypeConverter()->useOpaquePointers()); | |||||
// Initialize shared constants. | // Initialize shared constants. | ||||
Value zero = | Value zero = | ||||
builder.create<LLVM::ConstantOp>(loc, i1Type, builder.getBoolAttr(false)); | builder.create<LLVM::ConstantOp>(loc, i1Type, builder.getBoolAttr(false)); | ||||
unsigned unrankedMemrefPos = 0; | unsigned unrankedMemrefPos = 0; | ||||
for (unsigned i = 0, e = operands.size(); i < e; ++i) { | for (unsigned i = 0, e = operands.size(); i < e; ++i) { | ||||
Type type = origTypes[i]; | Type type = origTypes[i]; | ||||
if (!type.isa<UnrankedMemRefType>()) | if (!type.isa<UnrankedMemRefType>()) | ||||
continue; | continue; | ||||
Value allocationSize = sizes[unrankedMemrefPos++]; | Value allocationSize = sizes[unrankedMemrefPos++]; | ||||
UnrankedMemRefDescriptor desc(operands[i]); | UnrankedMemRefDescriptor desc(operands[i]); | ||||
// Allocate memory, copy, and free the source if necessary. | // Allocate memory, copy, and free the source if necessary. | ||||
Value memory = | Value memory = | ||||
toDynamic | toDynamic | ||||
? builder.create<LLVM::CallOp>(loc, mallocFunc, allocationSize) | ? builder.create<LLVM::CallOp>(loc, mallocFunc, allocationSize) | ||||
.getResult() | .getResult() | ||||
: builder.create<LLVM::AllocaOp>(loc, voidPtrType, allocationSize, | : builder.create<LLVM::AllocaOp>(loc, getVoidPtrType(), | ||||
getVoidType(), allocationSize, | |||||
/*alignment=*/0); | /*alignment=*/0); | ||||
Value source = desc.memRefDescPtr(builder, loc); | Value source = desc.memRefDescPtr(builder, loc); | ||||
builder.create<LLVM::MemcpyOp>(loc, memory, source, allocationSize, zero); | builder.create<LLVM::MemcpyOp>(loc, memory, source, allocationSize, zero); | ||||
if (!toDynamic) | if (!toDynamic) | ||||
builder.create<LLVM::CallOp>(loc, freeFunc, source); | builder.create<LLVM::CallOp>(loc, freeFunc, source); | ||||
// Create a new descriptor. The same descriptor can be returned multiple | // Create a new descriptor. The same descriptor can be returned multiple | ||||
// times, attempting to modify its pointer can lead to memory leaks | // times, attempting to modify its pointer can lead to memory leaks | ||||
▲ Show 20 Lines • Show All 59 Lines • Show Last 20 Lines |
nit: Would getTypeConverter()->getPointer(IntegerType::get(&getTypeConverter()->getContext(), 8)) work here as well?