diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -164,12 +164,20 @@ /// Convert a memref type into an LLVM type that captures the relevant data. Type convertMemRefType(MemRefType type); - /// Convert a memref type into a list of non-aggregate LLVM IR types that - /// contain all the relevant data. In particular, the list will contain: + /// Convert a memref type into a list of LLVM IR types that will form the + /// memref descriptor. If `unpackAggregates` is true the `sizes` and `strides` + /// arrays in the descriptors are unpacked to individual index-typed elements, + /// else they are are kept as rank-sized arrays of index type. In particular, + /// the list will contain: /// - two pointers to the memref element type, followed by - /// - an integer offset, followed by - /// - one integer size per dimension of the memref, followed by - /// - one integer stride per dimension of the memref. + /// - an index-typed offset, followed by + /// - (if unpackAggregates = true) + /// - one index-typed size per dimension of the memref, followed by + /// - one index-typed stride per dimension of the memref. + /// - (if unpackArrregates = false) + /// - one rank-sized array of index-type for the size of each dimension + /// - one rank-sized array of index-type for the stride of each dimension + /// /// For example, memref is converted to the following list: /// - `!llvm<"float*">` (allocated pointer), /// - `!llvm<"float*">` (aligned pointer), @@ -177,17 +185,19 @@ /// - `!llvm.i64`, `!llvm.i64` (sizes), /// - `!llvm.i64`, `!llvm.i64` (strides). /// These types can be recomposed to a memref descriptor struct. - SmallVector convertMemRefSignature(MemRefType type); + SmallVector + getMemRefDescriptorFields(MemRefType type, bool unpackAggregates); /// Convert an unranked memref type into a list of non-aggregate LLVM IR types - /// that contain all the relevant data. In particular, this list contains: + /// that will form the unranked memref descriptor. In particular, this list + /// contains: /// - an integer rank, followed by /// - a pointer to the memref descriptor struct. /// For example, memref<*xf32> is converted to the following list: /// !llvm.i64 (rank) /// !llvm<"i8*"> (type-erased pointer). /// These types can be recomposed to a unranked memref descriptor struct. - SmallVector convertUnrankedMemRefSignature(); + SmallVector getUnrankedMemRefDescriptorFields(); // Convert an unranked memref type to an LLVM type that captures the // runtime rank and a pointer to the static ranked memref desc diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -61,14 +61,17 @@ Type type, SmallVectorImpl &result) { if (auto memref = type.dyn_cast()) { - auto converted = converter.convertMemRefSignature(memref); + // In signatures, Memref descriptors are expanded into lists of + // non-aggregate values. + auto converted = + converter.getMemRefDescriptorFields(memref, /*unpackAggregates=*/true); if (converted.empty()) return failure(); result.append(converted.begin(), converted.end()); return success(); } if (type.isa()) { - auto converted = converter.convertUnrankedMemRefSignature(); + auto converted = converter.getUnrankedMemRefDescriptorFields(); if (converted.empty()) return failure(); result.append(converted.begin(), converted.end()); @@ -216,32 +219,6 @@ return converted.getPointerTo(); } -/// In signatures, MemRef descriptors are expanded into lists of non-aggregate -/// values. -SmallVector -LLVMTypeConverter::convertMemRefSignature(MemRefType type) { - SmallVector results; - assert(isStrided(type) && - "Non-strided layout maps must have been normalized away"); - - LLVM::LLVMType elementType = unwrap(convertType(type.getElementType())); - if (!elementType) - return {}; - auto indexTy = getIndexType(); - - results.insert(results.begin(), 2, - elementType.getPointerTo(type.getMemorySpace())); - results.push_back(indexTy); - auto rank = type.getRank(); - results.insert(results.end(), 2 * rank, indexTy); - return results; -} - -/// In signatures, unranked MemRef descriptors are expanded into a pair "rank, -/// pointer to descriptor". -SmallVector LLVMTypeConverter::convertUnrankedMemRefSignature() { - return {getIndexType(), LLVM::LLVMType::getInt8PtrTy(&getContext())}; -} // Function types are converted to LLVM Function types by recursively converting // argument and result types. If MLIR Function has zero results, the LLVM @@ -305,69 +282,92 @@ return LLVM::LLVMType::getFunctionTy(resultType, inputs, false); } -// Convert a MemRef to an LLVM type. The result is a MemRef descriptor which -// contains: -// 1. the pointer to the data buffer, followed by -// 2. a lowered `index`-type integer containing the distance between the -// beginning of the buffer and the first element to be accessed through the -// view, followed by -// 3. an array containing as many `index`-type integers as the rank of the -// MemRef: the array represents the size, in number of elements, of the memref -// along the given dimension. For constant MemRef dimensions, the -// corresponding size entry is a constant whose runtime value must match the -// static value, followed by -// 4. a second array containing as many `index`-type integers as the rank of -// the MemRef: the second array represents the "stride" (in tensor abstraction -// sense), i.e. the number of consecutive elements of the underlying buffer. -// TODO: add assertions for the static cases. -// -// template -// struct { -// Elem *allocatedPtr; -// Elem *alignedPtr; -// int64_t offset; -// int64_t sizes[Rank]; // omitted when rank == 0 -// int64_t strides[Rank]; // omitted when rank == 0 -// }; +/// Convert a memref type into a list of LLVM IR types that will form the +/// memref descriptor. The result contains the following types: +/// 1. The pointer to the allocated data buffer, followed by +/// 2. The pointer to the aligned data buffer, followed by +/// 3. A lowered `index`-type integer containing the distance between the +/// beginning of the buffer and the first element to be accessed through the +/// view, followed by +/// 4. An array containing as many `index`-type integers as the rank of the +/// MemRef: the array represents the size, in number of elements, of the memref +/// along the given dimension. For constant MemRef dimensions, the +/// corresponding size entry is a constant whose runtime value must match the +/// static value, followed by +/// 5. A second array containing as many `index`-type integers as the rank of +/// the MemRef: the second array represents the "stride" (in tensor abstraction +/// sense), i.e. the number of consecutive elements of the underlying buffer. +/// TODO: add assertions for the static cases. +/// +/// If `unpackAggregates` is set to true, the arrays described in (4) and (5) +/// are expanded into individual index-type elements. +/// +/// template +/// struct { +/// Elem *allocatedPtr; +/// Elem *alignedPtr; +/// Index offset; +/// Index sizes[Rank]; // omitted when rank == 0 +/// Index strides[Rank]; // omitted when rank == 0 +/// }; static constexpr unsigned kAllocatedPtrPosInMemRefDescriptor = 0; static constexpr unsigned kAlignedPtrPosInMemRefDescriptor = 1; static constexpr unsigned kOffsetPosInMemRefDescriptor = 2; static constexpr unsigned kSizePosInMemRefDescriptor = 3; static constexpr unsigned kStridePosInMemRefDescriptor = 4; -Type LLVMTypeConverter::convertMemRefType(MemRefType type) { - int64_t offset; - SmallVector strides; - bool strideSuccess = succeeded(getStridesAndOffset(type, strides, offset)); - assert(strideSuccess && + +SmallVector +LLVMTypeConverter::getMemRefDescriptorFields(MemRefType type, + bool unpackAggregates) { + assert(isStrided(type) && "Non-strided layout maps must have been normalized away"); - (void)strideSuccess; + LLVM::LLVMType elementType = unwrap(convertType(type.getElementType())); if (!elementType) return {}; auto ptrTy = elementType.getPointerTo(type.getMemorySpace()); auto indexTy = getIndexType(); + + SmallVector results = {ptrTy, ptrTy, indexTy}; auto rank = type.getRank(); - if (rank > 0) { - auto arrayTy = LLVM::LLVMType::getArrayTy(indexTy, type.getRank()); - return LLVM::LLVMType::getStructTy(ptrTy, ptrTy, indexTy, arrayTy, arrayTy); - } - return LLVM::LLVMType::getStructTy(ptrTy, ptrTy, indexTy); -} + if (rank == 0) + return results; -// Converts UnrankedMemRefType to LLVMType. The result is a descriptor which -// contains: -// 1. int64_t rank, the dynamic rank of this MemRef -// 2. void* ptr, pointer to the static ranked MemRef descriptor. This will be -// stack allocated (alloca) copy of a MemRef descriptor that got casted to -// be unranked. + if (unpackAggregates) + results.insert(results.end(), 2 * rank, indexTy); + else + results.insert(results.end(), 2, LLVM::LLVMType::getArrayTy(indexTy, rank)); + return results; +} +/// Converts MemRefType to LLVMType. A MemRefType is converted to a struct that +/// packs the descriptor fields as defined by `getMemRefDescriptorFields`. +Type LLVMTypeConverter::convertMemRefType(MemRefType type) { + // When converting a MemRefType to a struct with descriptor fields, do not + // unpacked the `sizes` and `strides` arrays. + SmallVector types = + getMemRefDescriptorFields(type, /*unpackAggregates=*/false); + return LLVM::LLVMType::getStructTy(&getContext(), types); +} + +/// Convert an unranked memref type into a list of non-aggregate LLVM IR types +/// that will form the unranked memref descriptor. In particular, the fields +/// for an unranked memref descriptor are: +/// 1. index-typed rank, the dynamic rank of this MemRef +/// 2. void* ptr, pointer to the static ranked MemRef descriptor. This will be +/// stack allocated (alloca) copy of a MemRef descriptor that got casted to +/// be unranked. static constexpr unsigned kRankInUnrankedMemRefDescriptor = 0; static constexpr unsigned kPtrInUnrankedMemRefDescriptor = 1; +SmallVector +LLVMTypeConverter::getUnrankedMemRefDescriptorFields() { + return {getIndexType(), LLVM::LLVMType::getInt8PtrTy(&getContext())}; +} + Type LLVMTypeConverter::convertUnrankedMemRefType(UnrankedMemRefType type) { - auto rankTy = getIndexType(); - auto ptrTy = LLVM::LLVMType::getInt8PtrTy(&getContext()); - return LLVM::LLVMType::getStructTy(rankTy, ptrTy); + return LLVM::LLVMType::getStructTy(&getContext(), + getUnrankedMemRefDescriptorFields()); } /// Convert a memref type to a bare pointer to the memref element type.