Index: mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td =================================================================== --- mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -334,16 +334,6 @@ /// Returns the type pointed to by the pointer argument of this GEP. Type getSourceElementType(); - - /// Populates `indices` with positions of GEP indices that correspond to - /// LLVMStructTypes potentially nested in the given `sourceElementType`, - /// which is the type pointed to by the pointer argument of a GEP. If - /// `structSizes` is provided, it is populated with sizes of the indexed - /// structs for bounds verification purposes. - static void findKnownStructIndices( - Type sourceElementType, SmallVectorImpl &indices, - SmallVectorImpl *structSizes = nullptr); - }]; let hasFolder = 1; let hasVerifier = 1; Index: mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp =================================================================== --- mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -29,6 +29,7 @@ #include "llvm/IR/Attributes.h" #include "llvm/IR/Function.h" #include "llvm/IR/Type.h" +#include "llvm/Support/Error.h" #include "llvm/Support/Mutex.h" #include "llvm/Support/SourceMgr.h" @@ -398,55 +399,135 @@ constexpr int GEPOp::kDynamicIndex; -/// Populates `indices` with positions of GEP indices that would correspond to -/// LLVMStructTypes potentially nested in the given type. The type currently -/// visited gets `currentIndex` and LLVM container types are visited -/// recursively. The recursion is bounded and takes care of recursive types by -/// means of the `visited` set. -static void recordStructIndices(Type type, unsigned currentIndex, - SmallVectorImpl &indices, - SmallVectorImpl *structSizes, - SmallPtrSet &visited) { - if (visited.contains(type)) - return; +namespace { +class IndexError : public llvm::ErrorInfo { +protected: + unsigned indexPos; + +public: + static char ID; + + std::error_code convertToErrorCode() const override { + return llvm::inconvertibleErrorCode(); + } + + explicit IndexError(unsigned pos) : indexPos(pos) {} +}; + +struct IndexOutOfBoundError + : public llvm::ErrorInfo { + static char ID; + + using ErrorInfo::ErrorInfo; + + void log(llvm::raw_ostream &OS) const override { + OS << "index " << indexPos << " indexing a struct is out of bounds"; + } +}; - visited.insert(type); +struct StaticIndexError : public llvm::ErrorInfo { + static char ID; - llvm::TypeSwitch(type) - .Case([&](LLVMStructType structType) { - indices.push_back(currentIndex); - if (structSizes) - structSizes->push_back(structType.getBody().size()); - for (Type elementType : structType.getBody()) - recordStructIndices(elementType, currentIndex + 1, indices, - structSizes, visited); + using ErrorInfo::ErrorInfo; + + void log(llvm::raw_ostream &OS) const override { + OS << "expected index " << indexPos << " indexing a struct " + << "to be constant"; + } +}; +} // end anonymous namespace + +char IndexError::ID = 0; +char IndexOutOfBoundError::ID = 0; +char StaticIndexError::ID = 0; + +/// For the given `structIndices` and `indices`, check if they're complied +/// with `baseGEPType`, especially check against LLVMStructTypes nested within, +/// and refine/promote struct index from `indices` to `updatedStructIndices` +/// if the latter argument is not null. +static llvm::Error +recordStructIndices(Type baseGEPType, unsigned indexPos, + ArrayRef structIndices, ValueRange indices, + SmallVectorImpl *updatedStructIndices, + SmallVectorImpl *remainingIndices) { + if (indexPos >= structIndices.size()) + // Stop searching + return llvm::Error::success(); + + int32_t gepIndex = structIndices[indexPos]; + bool isStaticIndex = gepIndex != GEPOp::kDynamicIndex; + + unsigned dynamicIndexPos = indexPos; + if (!isStaticIndex) + dynamicIndexPos = llvm::count(structIndices.take_front(indexPos + 1), + LLVM::GEPOp::kDynamicIndex) - 1; + + return llvm::TypeSwitch(baseGEPType) + .Case([&](LLVMStructType structType) -> llvm::Error { + // We don't always want to refine the index (e.g. when performing + // verification), so we only refine when updatedStructIndices is not + // null. + if (!isStaticIndex && updatedStructIndices) { + // Try to refine. + APInt staticIndexValue; + isStaticIndex = matchPattern(indices[dynamicIndexPos], + m_ConstantInt(&staticIndexValue)); + if (isStaticIndex) { + assert(staticIndexValue.getBitWidth() <= 64 && + llvm::isInt<32>(staticIndexValue.getLimitedValue()) && + "struct index can't fit within int32_t"); + gepIndex = static_cast(staticIndexValue.getSExtValue()); + } + } + if (!isStaticIndex) + return llvm::make_error(indexPos); + + ArrayRef elementTypes = structType.getBody(); + if (gepIndex < 0 || + static_cast(gepIndex) >= elementTypes.size()) + return llvm::make_error(indexPos); + + if (updatedStructIndices) + (*updatedStructIndices)[indexPos] = gepIndex; + + // Instead of recusively going into every children types, we only + // dive into the one indexed by gepIndex. + return recordStructIndices(elementTypes[gepIndex], indexPos + 1, + structIndices, indices, updatedStructIndices, + remainingIndices); }) .Case([&](auto containerType) { - recordStructIndices(containerType.getElementType(), currentIndex + 1, - indices, structSizes, visited); - }); + LLVMArrayType>([&](auto containerType) -> llvm::Error { + // Currently we don't refine non-struct index even if it's static. + if (remainingIndices) + remainingIndices->push_back(indices[dynamicIndexPos]); + return recordStructIndices(containerType.getElementType(), indexPos + 1, + structIndices, indices, updatedStructIndices, + remainingIndices); + }) + .Default( + [](auto otherType) -> llvm::Error { return llvm::Error::success(); }); } -/// Populates `indices` with positions of GEP indices that correspond to -/// LLVMStructTypes potentially nested in the given `baseGEPType`, which must -/// be either an LLVMPointer type or a vector thereof. If `structSizes` is -/// provided, it is populated with sizes of the indexed structs for bounds -/// verification purposes. -void GEPOp::findKnownStructIndices(Type sourceElementType, - SmallVectorImpl &indices, - SmallVectorImpl *structSizes) { - SmallPtrSet visited; - recordStructIndices(sourceElementType, /*currentIndex=*/1, indices, - structSizes, visited); +/// Driver function around `recordStructIndices`. Note that we always check +/// from the second GEP index since the first one is always dynamic. +static llvm::Error +findStructIndices(Type baseGEPType, ArrayRef structIndices, + ValueRange indices, + SmallVectorImpl *updatedStructIndices = nullptr, + SmallVectorImpl *remainingIndices = nullptr) { + if (remainingIndices) + // The first GEP index is always dynamic. + remainingIndices->push_back(indices[0]); + return recordStructIndices(baseGEPType, /*indexPos=*/1, structIndices, + indices, updatedStructIndices, remainingIndices); } void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType, Value basePtr, ValueRange operands, ArrayRef attributes) { build(builder, result, resultType, basePtr, operands, - SmallVector(operands.size(), LLVM::GEPOp::kDynamicIndex), - attributes); + SmallVector(operands.size(), kDynamicIndex), attributes); } /// Returns the elemental type of any LLVM-compatible vector type or self. @@ -480,44 +561,9 @@ SmallVector remainingIndices; SmallVector updatedStructIndices(structIndices.begin(), structIndices.end()); - SmallVector structRelatedPositions; - findKnownStructIndices(elementType, structRelatedPositions); - - SmallVector operandsToErase; - for (unsigned pos : structRelatedPositions) { - // GEP may not be indexing as deep as some structs are located. - if (pos >= structIndices.size()) - continue; - - // If the index is already static, it's fine. - if (structIndices[pos] != kDynamicIndex) - continue; - - // Find the corresponding operand. - unsigned operandPos = - std::count(structIndices.begin(), std::next(structIndices.begin(), pos), - kDynamicIndex); - - // Extract the constant value from the operand and put it into the attribute - // instead. - APInt staticIndexValue; - bool matched = - matchPattern(indices[operandPos], m_ConstantInt(&staticIndexValue)); - (void)matched; - assert(matched && "index into a struct must be a constant"); - assert(staticIndexValue.sge(APInt::getSignedMinValue(/*numBits=*/32)) && - "struct index underflows 32-bit integer"); - assert(staticIndexValue.sle(APInt::getSignedMaxValue(/*numBits=*/32)) && - "struct index overflows 32-bit integer"); - auto staticIndex = static_cast(staticIndexValue.getSExtValue()); - updatedStructIndices[pos] = staticIndex; - operandsToErase.push_back(operandPos); - } - - for (unsigned i = 0, e = indices.size(); i < e; ++i) { - if (!llvm::is_contained(operandsToErase, i)) - remainingIndices.push_back(indices[i]); - } + if (auto err = findStructIndices(elementType, structIndices, indices, + &updatedStructIndices, &remainingIndices)) + llvm::report_fatal_error(StringRef(llvm::toString(std::move(err)))); assert(remainingIndices.size() == static_cast(llvm::count( updatedStructIndices, kDynamicIndex)) && @@ -582,24 +628,16 @@ getElemType()))) return failure(); - SmallVector indices; - SmallVector structSizes; - findKnownStructIndices(getSourceElementType(), indices, &structSizes); - DenseIntElementsAttr structIndices = getStructIndices(); - for (unsigned i : llvm::seq(0, indices.size())) { - unsigned index = indices[i]; - // GEP may not be indexing as deep as some structs nested in the type. - if (index >= structIndices.getNumElements()) - continue; - - int32_t staticIndex = structIndices.getValues()[index]; - if (staticIndex == LLVM::GEPOp::kDynamicIndex) - return emitOpError() << "expected index " << index - << " indexing a struct to be constant"; - if (staticIndex < 0 || static_cast(staticIndex) >= structSizes[i]) - return emitOpError() << "index " << index - << " indexing a struct is out of bounds"; - } + auto structIndexRange = getStructIndices().getValues(); + // structIndexRange is a kind of iterator, which cannot be converted + // to ArrayRef directly. + SmallVector structIndices(structIndexRange.size()); + for (unsigned i : llvm::seq(0, structIndexRange.size())) + structIndices[i] = structIndexRange[i]; + if (auto err = findStructIndices(getSourceElementType(), structIndices, + getIndices())) + return emitOpError() << llvm::toString(std::move(err)); + return success(); } Index: mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp =================================================================== --- mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp +++ mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp @@ -1002,33 +1002,26 @@ // FIXME: Support inbounds GEPs. llvm::GetElementPtrInst *gep = cast(inst); Value basePtr = processValue(gep->getOperand(0)); - SmallVector staticIndices; - SmallVector dynamicIndices; Type sourceElementType = processType(gep->getSourceElementType()); - SmallVector staticIndexPositions; - GEPOp::findKnownStructIndices(sourceElementType, staticIndexPositions); - - for (const auto &en : - llvm::enumerate(llvm::drop_begin(gep->operand_values()))) { - llvm::Value *operand = en.value(); - if (llvm::find(staticIndexPositions, en.index()) == - staticIndexPositions.end()) { - staticIndices.push_back(GEPOp::kDynamicIndex); - dynamicIndices.push_back(processValue(operand)); - if (!dynamicIndices.back()) - return failure(); - } else { - auto *constantInt = cast(operand); - staticIndices.push_back( - static_cast(constantInt->getValue().getZExtValue())); - } + + SmallVector indices; + for (llvm::Value *operand : llvm::drop_begin(gep->operand_values())) { + indices.push_back(processValue(operand)); + if (!indices.back()) + return failure(); } + // Treat every indices as dynamic since GEPOp::build will refine those + // indices into static attributes later. One small downside of this + // approach is that many unused `llvm.mlir.constant` would be emitted + // at first place. + SmallVector structIndices(indices.size(), + LLVM::GEPOp::kDynamicIndex); Type type = processType(inst->getType()); if (!type) return failure(); instMap[inst] = b.create(loc, type, sourceElementType, basePtr, - dynamicIndices, staticIndices); + indices, structIndices); return success(); } case llvm::Instruction::InsertValue: { Index: mlir/test/Dialect/LLVMIR/dynamic-gep-index.mlir =================================================================== --- /dev/null +++ mlir/test/Dialect/LLVMIR/dynamic-gep-index.mlir @@ -0,0 +1,12 @@ +// RUN: mlir-opt %s | FileCheck %s + +module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry : vector<2xi32>>, #dlti.dl_entry : vector<2xi32>>, #dlti.dl_entry : vector<2xi32>>, #dlti.dl_entry : vector<2xi32>>, #dlti.dl_entry : vector<2xi32>>, #dlti.dl_entry : vector<2xi32>>, #dlti.dl_entry : vector<2xi32>>, #dlti.dl_entry : vector<2xi32>>>} { + // CHECK: llvm.func @foo(%[[ARG0:.+]]: !llvm.ptr>, %[[ARG1:.+]]: i32) + llvm.func @foo(%arg0: !llvm.ptr, array<4 x i32>)>>, %arg1: i32) { + // CHECK: %[[C0:.+]] = llvm.mlir.constant(0 : i32) + %0 = llvm.mlir.constant(0 : i32) : i32 + // CHECK: llvm.getelementptr %[[ARG0]][%[[C0]], 1, %[[ARG1]]] + %1 = "llvm.getelementptr"(%arg0, %0, %arg1) {structIndices = dense<[-2147483648, 1, -2147483648]> : tensor<3xi32>} : (!llvm.ptr, array<4 x i32>)>>, i32, i32) -> !llvm.ptr + llvm.return + } +} Index: mlir/test/Target/LLVMIR/Import/dynamic-gep-index.ll =================================================================== --- /dev/null +++ mlir/test/Target/LLVMIR/Import/dynamic-gep-index.ll @@ -0,0 +1,12 @@ +; RUN: mlir-translate --import-llvm %s | FileCheck %s + +%sub_struct = type { i32, i8 } +%my_struct = type { %sub_struct, [4 x i32] } + +; CHECK: llvm.func @foo(%[[ARG0:.+]]: !llvm.ptr>, %[[ARG1:.+]]: i32) +define void @foo(%my_struct* %arg, i32 %idx) { + ; CHECK: %[[C0:.+]] = llvm.mlir.constant(0 : i32) + ; CHECK: llvm.getelementptr %[[ARG0]][%[[C0]], 1, %[[ARG1]]] + %1 = getelementptr %my_struct, %my_struct* %arg, i32 0, i32 1, i32 %idx + ret void +}