Index: mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h =================================================================== --- mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h +++ mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h @@ -27,6 +27,7 @@ #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Support/ThreadLocalCache.h" +#include "llvm/ADT/PointerEmbeddedInt.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" @@ -72,6 +73,36 @@ #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/LLVMIR/LLVMOpsAttrDefs.h.inc" +namespace mlir { +namespace LLVM { +template +class GEPIndicesAdaptor; + +/// Bit-width of a 'GEPConstantIndex' within GEPArg. +constexpr int kGEPConstantBitWidth = 29; +/// Wrapper around a int32_t for use in a PointerUnion. +using GEPConstantIndex = + llvm::PointerEmbeddedInt; + +/// Class used for building a 'llvm.getelementptr'. A single instance represents +/// a sum type that is either a 'Value' or a constant 'GEPConstantIndex' index. +/// The former represents a dynamic index in a GEP operation, while the later is +/// a constant index as is required for indices into struct types. +class GEPArg : public PointerUnion { + using BaseT = PointerUnion; + +public: + /// Constructs a GEPArg with a constant index. + /*implicit*/ GEPArg(int32_t integer) : BaseT(integer) {} + + /// Constructs a GEPArg with a dynamic index. + /*implicit*/ GEPArg(Value value) : BaseT(value) {} + + using BaseT::operator=; +}; +} // namespace LLVM +} // namespace mlir + ///// Ops ///// #define GET_OP_CLASSES #include "mlir/Dialect/LLVMIR/LLVMOps.h.inc" @@ -82,6 +113,114 @@ namespace mlir { namespace LLVM { + +/// Class used for convenient access and iteration over GEP indices. +/// This class is templated to support not only retrieving the dynamic operands +/// of a GEP operation, but also as an adaptor during folding or conversion to +/// LLVM IR. +/// +/// GEP indices may either be constant indices or dynamic indices. The +/// 'rawConstantIndices' is specially encoded by GEPOp and contains either the +/// constant index or the information that an index is a dynamic index. +/// +/// When an access to such an index is made it is done through the +/// 'DynamicRange' of this class. This way it can be used as getter in GEPOp via +/// 'GEPIndicesAdaptor' or during folding via +/// 'GEPIndicesAdaptor>'. +template +class GEPIndicesAdaptor { +public: + /// Return type of 'operator[]' and the iterators 'operator*'. It is depended + /// upon the value type of 'DynamicRange'. If 'DynamicRange' contains + /// Attributes or subclasses thereof, then value_type is 'Attribute'. In + /// all other cases it is a pointer union between the value type of + /// 'DynamicRange' and IntegerAttr. + using value_type = std::conditional_t< + std::is_base_of>::value, + Attribute, + PointerUnion>>; + + /// Constructs a GEPIndicesAdaptor with the raw constant indices of a GEPOp + /// and the range that is indexed into for retrieving dynamic indices. + GEPIndicesAdaptor(DenseI32ArrayAttr rawConstantIndices, DynamicRange values) + : rawConstantIndices(rawConstantIndices), values(std::move(values)) {} + + /// Returns the GEP index at the given position. Note that this operation has + /// a linear complexity in regards to the accessed position. To iterate over + /// all indices, use the iterators. + /// + /// This operation is invalid if the index is out of bounds. + value_type operator[](size_t index) const { + assert(index < size() && "index out of bounds"); + return *std::next(begin(), index); + } + + /// Returns whether the GEP index at the given position is a dynamic index. + bool isDynamicIndex(size_t index) const { + return rawConstantIndices[index] == GEPOp::kDynamicIndex; + } + + /// Returns the amount of indices of the GEPOp. + size_t size() const { return rawConstantIndices.size(); } + + /// Returns true if this GEPOp does not have any indices. + bool empty() const { return rawConstantIndices.empty(); } + + class iterator + : public llvm::iterator_facade_base { + public: + iterator(const GEPIndicesAdaptor *base, + ArrayRef::iterator rawConstantIter, + llvm::detail::IterOfRange valuesIter) + : base(base), rawConstantIter(rawConstantIter), valuesIter(valuesIter) { + } + + value_type operator*() const { + if (*rawConstantIter == GEPOp::kDynamicIndex) + return *valuesIter; + + return IntegerAttr::get( + ElementsAttr::getElementType(base->rawConstantIndices), + *rawConstantIter); + } + + iterator &operator++() { + if (*rawConstantIter == GEPOp::kDynamicIndex) + valuesIter++; + rawConstantIter++; + return *this; + } + + bool operator==(const iterator &rhs) const { + return base == rhs.base && rawConstantIter == rhs.rawConstantIter && + valuesIter == rhs.valuesIter; + } + + private: + const GEPIndicesAdaptor *base; + ArrayRef::const_iterator rawConstantIter; + llvm::detail::IterOfRange valuesIter; + }; + + /// Returns the begin iterator, iterating over all GEP indices. + iterator begin() const { + return iterator(this, rawConstantIndices.asArrayRef().begin(), + values.begin()); + } + + /// Returns the end iterator, iterating over all GEP indices. + iterator end() const { + return iterator(this, rawConstantIndices.asArrayRef().end(), values.end()); + } + +private: + DenseI32ArrayAttr rawConstantIndices; + DynamicRange values; +}; + /// Create an LLVM global containing the string "value" at the module containing /// surrounding the insertion point of builder. Obtain the address of that /// global and use it to compute the address of the first character in the Index: mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td =================================================================== --- mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -423,51 +423,75 @@ def LLVM_GEPOp : LLVM_Op<"getelementptr", [NoSideEffect]> { let arguments = (ins LLVM_ScalarOrVectorOf:$base, - Variadic>:$indices, - I32ElementsAttr:$structIndices, + Variadic>:$dynamicIndices, + DenseI32ArrayAttr:$rawConstantIndices, OptionalAttr:$elem_type); let results = (outs LLVM_ScalarOrVectorOf:$res); let skipDefaultBuilders = 1; + + let description = [{ + This operation mirrors LLVM IRs 'getelementptr' operation that is used to + perform pointer arithmetic. + + Like in LLVM IR, it is possible to use both constants as well as SSA values + as indices. In the case of indexing within a structure, it is required to + either use constant indices directly, or supply a constant SSA value. + + Examples: + + ```mlir + // GEP with an SSA value offset + %0 = llvm.getelementptr %1[%2] : (!llvm.ptr, i64) -> !llvm.ptr + + // GEP with a constant offset + %0 = llvm.getelementptr %1[3] : (!llvm.ptr) -> !llvm.ptr + + // GEP with constant offsets into a structure + %0 = llvm.getelementptr %1[0, 1] + : (!llvm.ptr) -> !llvm.ptr + ``` + }]; + let builders = [ - OpBuilder<(ins "Type":$resultType, "Value":$basePtr, "ValueRange":$indices, - CArg<"ArrayRef", "{}">:$attributes)>, - OpBuilder<(ins "Type":$resultType, "Value":$basePtr, "ValueRange":$indices, - "ArrayRef":$structIndices, - CArg<"ArrayRef", "{}">:$attributes)>, OpBuilder<(ins "Type":$resultType, "Type":$basePtrType, "Value":$basePtr, "ValueRange":$indices, CArg<"ArrayRef", "{}">:$attributes)>, - OpBuilder<(ins "Type":$resultType, "Type":$basePtrType, "Value":$basePtr, - "ValueRange":$indices, + OpBuilder<(ins "Type":$resultType, "Value":$basePtr, "ValueRange":$indices, + CArg<"ArrayRef", "{}">:$attributes)>, + OpBuilder<(ins "Type":$resultType, "Value":$basePtr, "ArrayRef":$indices, CArg<"ArrayRef", "{}">:$attributes)>, OpBuilder<(ins "Type":$resultType, "Type":$basePtrType, "Value":$basePtr, - "ValueRange":$indices, "ArrayRef":$structIndices, - CArg<"ArrayRef", "{}">:$attributes)> + "ArrayRef":$indices, + CArg<"ArrayRef", "{}">:$attributes)>, ]; let llvmBuilder = [{ SmallVector indices; - indices.reserve($structIndices.size()); - unsigned operandIdx = 0; - for (int32_t structIndex : $structIndices.getValues()) { - if (structIndex == GEPOp::kDynamicIndex) - indices.push_back($indices[operandIdx++]); + indices.reserve($rawConstantIndices.size()); + GEPIndicesAdaptor + gepIndices(op.getRawConstantIndicesAttr(), $dynamicIndices); + for (PointerUnion valueOrAttr : gepIndices) { + if (llvm::Value* value = valueOrAttr.dyn_cast()) + indices.push_back(value); else - indices.push_back(builder.getInt32(structIndex)); + indices.push_back( + builder.getInt32(valueOrAttr.get().getInt())); } Type baseElementType = op.getSourceElementType(); llvm::Type *elementType = moduleTranslation.convertType(baseElementType); $res = builder.CreateGEP(elementType, $base, indices); }]; let assemblyFormat = [{ - $base `[` custom($indices, $structIndices) `]` attr-dict + $base `[` custom($dynamicIndices, $rawConstantIndices) `]` attr-dict `:` functional-type(operands, results) (`,` $elem_type^)? }]; let extraClassDeclaration = [{ - constexpr static int kDynamicIndex = std::numeric_limits::min(); + constexpr static int32_t kDynamicIndex = std::numeric_limits::min(); /// Returns the type pointed to by the pointer argument of this GEP. Type getSourceElementType(); + + GEPIndicesAdaptor getIndices(); }]; let hasFolder = 1; let hasVerifier = 1; Index: mlir/include/mlir/IR/BuiltinAttributes.h =================================================================== --- mlir/include/mlir/IR/BuiltinAttributes.h +++ mlir/include/mlir/IR/BuiltinAttributes.h @@ -755,7 +755,10 @@ /// Implicit conversion to ArrayRef. operator ArrayRef() const; - ArrayRef asArrayRef() { return ArrayRef{*this}; } + ArrayRef asArrayRef() const { return ArrayRef{*this}; } + + /// Random access to elements. + T operator[](std::size_t index) const { return asArrayRef()[index]; } /// Builder from ArrayRef. static DenseArrayAttr get(MLIRContext *context, ArrayRef content); @@ -1017,6 +1020,14 @@ } }; +template <> +struct PointerLikeTypeTraits + : public PointerLikeTypeTraits { + static inline mlir::IntegerAttr getFromVoidPointer(void *p) { + return mlir::IntegerAttr::getFromOpaquePointer(p); + } +}; + template <> struct PointerLikeTypeTraits : public PointerLikeTypeTraits { Index: mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp =================================================================== --- mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -418,7 +418,154 @@ // Code for LLVM::GEPOp. //===----------------------------------------------------------------------===// -constexpr int GEPOp::kDynamicIndex; +constexpr int32_t GEPOp::kDynamicIndex; + +GEPIndicesAdaptor GEPOp::getIndices() { + return GEPIndicesAdaptor(getRawConstantIndicesAttr(), + getDynamicIndices()); +} + +/// Returns the elemental type of any LLVM-compatible vector type or self. +static Type extractVectorElementType(Type type) { + if (auto vectorType = type.dyn_cast()) + return vectorType.getElementType(); + if (auto scalableVectorType = type.dyn_cast()) + return scalableVectorType.getElementType(); + if (auto fixedVectorType = type.dyn_cast()) + return fixedVectorType.getElementType(); + return type; +} + +void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType, + Value basePtr, ArrayRef indices, + ArrayRef attributes) { + auto ptrType = + extractVectorElementType(basePtr.getType()).cast(); + assert(!ptrType.isOpaque() && + "expected non-opaque pointer, provide elementType explicitly when " + "opaque pointers are used"); + build(builder, result, resultType, ptrType.getElementType(), basePtr, indices, + attributes); +} + +static void destructIndices(Type currType, ArrayRef indices, + SmallVectorImpl &rawConstantIndices, + SmallVectorImpl &dynamicIndices) { + for (const GEPArg &iter : indices) { + // If the thing we are currently indexing into is a struct we must turn + // any integer constants into constant indices. If this is not possible + // we don't do anything here. The verifier will catch it and emit a proper + // error. All other canonicalization is done in the fold method. + bool requiresConst = !rawConstantIndices.empty() && + currType.isa_and_nonnull(); + if (Value val = iter.dyn_cast()) { + APInt intC; + if (requiresConst && matchPattern(val, m_ConstantInt(&intC)) && + intC.isSignedIntN(kGEPConstantBitWidth)) { + rawConstantIndices.push_back(intC.getSExtValue()); + } else { + rawConstantIndices.push_back(GEPOp::kDynamicIndex); + dynamicIndices.push_back(val); + } + } else { + rawConstantIndices.push_back(iter.get()); + } + + // Skip for very first iteration of this loop. First index does not index + // within the aggregates, but is just a pointer offset. + if (rawConstantIndices.size() == 1 || !currType) + continue; + + currType = + TypeSwitch(currType) + .Case([](auto containerType) { + return containerType.getElementType(); + }) + .Case([&](LLVMStructType structType) -> Type { + int64_t memberIndex = rawConstantIndices.back(); + if (memberIndex >= 0 && static_cast(memberIndex) < + structType.getBody().size()) + return structType.getBody()[memberIndex]; + return nullptr; + }) + .Default(Type(nullptr)); + } +} + +void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType, + Type elementType, Value basePtr, ArrayRef indices, + ArrayRef attributes) { + SmallVector rawConstantIndices; + SmallVector dynamicIndices; + destructIndices(elementType, indices, rawConstantIndices, dynamicIndices); + + result.addTypes(resultType); + result.addAttributes(attributes); + result.addAttribute(getRawConstantIndicesAttrName(result.name), + builder.getDenseI32ArrayAttr(rawConstantIndices)); + if (extractVectorElementType(basePtr.getType()) + .cast() + .isOpaque()) + result.addAttribute(kElemTypeAttrName, TypeAttr::get(elementType)); + result.addOperands(basePtr); + result.addOperands(dynamicIndices); +} + +void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType, + Value basePtr, ValueRange indices, + ArrayRef attributes) { + build(builder, result, resultType, basePtr, SmallVector(indices), + attributes); +} + +void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType, + Type elementType, Value basePtr, ValueRange indices, + ArrayRef attributes) { + build(builder, result, resultType, elementType, basePtr, + SmallVector(indices), attributes); +} + +static ParseResult +parseGEPIndices(OpAsmParser &parser, + SmallVectorImpl &indices, + DenseI32ArrayAttr &rawConstantIndices) { + SmallVector constantIndices; + + auto idxParser = [&]() -> ParseResult { + int32_t constantIndex; + OptionalParseResult parsedInteger = + parser.parseOptionalInteger(constantIndex); + if (parsedInteger.hasValue()) { + if (failed(parsedInteger.getValue())) + return failure(); + constantIndices.push_back(constantIndex); + return success(); + } + + constantIndices.push_back(LLVM::GEPOp::kDynamicIndex); + return parser.parseOperand(indices.emplace_back()); + }; + if (parser.parseCommaSeparatedList(idxParser)) + return failure(); + + rawConstantIndices = + DenseI32ArrayAttr::get(parser.getContext(), constantIndices); + return success(); +} + +static void printGEPIndices(OpAsmPrinter &printer, LLVM::GEPOp gepOp, + OperandRange indices, + DenseI32ArrayAttr rawConstantIndices) { + llvm::interleaveComma( + GEPIndicesAdaptor(rawConstantIndices, indices), printer, + [&](PointerUnion cst) { + if (Value val = cst.dyn_cast()) + printer.printOperand(val); + else + printer << cst.get().getInt(); + }); +} namespace { /// Base class for llvm::Error related to GEP index. @@ -467,69 +614,33 @@ char GEPStaticIndexError::ID = 0; /// For the given `structIndices` and `indices`, check if they're complied -/// with `baseGEPType`, especially check against LLVMStructTypes nested within, -/// and refine/promote struct index from `indices` to `updatedStructIndices` -/// if the latter argument is not null. -static llvm::Error -recordStructIndices(Type baseGEPType, unsigned indexPos, - ArrayRef structIndices, ValueRange indices, - SmallVectorImpl *updatedStructIndices, - SmallVectorImpl *remainingIndices) { - if (indexPos >= structIndices.size()) +/// with `baseGEPType`, especially check against LLVMStructTypes nested within. +static llvm::Error verifyStructIndices(Type baseGEPType, unsigned indexPos, + GEPIndicesAdaptor indices) { + if (indexPos >= indices.size()) // Stop searching return llvm::Error::success(); - int32_t gepIndex = structIndices[indexPos]; - bool isStaticIndex = gepIndex != GEPOp::kDynamicIndex; - - unsigned dynamicIndexPos = indexPos; - if (!isStaticIndex) - dynamicIndexPos = llvm::count(structIndices.take_front(indexPos + 1), - LLVM::GEPOp::kDynamicIndex) - - 1; - return llvm::TypeSwitch(baseGEPType) .Case([&](LLVMStructType structType) -> llvm::Error { - // We don't always want to refine the index (e.g. when performing - // verification), so we only refine when updatedStructIndices is not - // null. - if (!isStaticIndex && updatedStructIndices) { - // Try to refine. - APInt staticIndexValue; - isStaticIndex = matchPattern(indices[dynamicIndexPos], - m_ConstantInt(&staticIndexValue)); - if (isStaticIndex) { - assert(staticIndexValue.getBitWidth() <= 64 && - llvm::isInt<32>(staticIndexValue.getLimitedValue()) && - "struct index can't fit within int32_t"); - gepIndex = static_cast(staticIndexValue.getSExtValue()); - } - } - if (!isStaticIndex) + if (!indices[indexPos].is()) return llvm::make_error(indexPos); + int32_t gepIndex = indices[indexPos].get().getInt(); ArrayRef elementTypes = structType.getBody(); if (gepIndex < 0 || static_cast(gepIndex) >= elementTypes.size()) return llvm::make_error(indexPos); - if (updatedStructIndices) - (*updatedStructIndices)[indexPos] = gepIndex; - - // Instead of recusively going into every children types, we only + // Instead of recursively going into every children types, we only // dive into the one indexed by gepIndex. - return recordStructIndices(elementTypes[gepIndex], indexPos + 1, - structIndices, indices, updatedStructIndices, - remainingIndices); + return verifyStructIndices(elementTypes[gepIndex], indexPos + 1, + indices); }) .Case([&](auto containerType) -> llvm::Error { - // Currently we don't refine non-struct index even if it's static. - if (remainingIndices) - remainingIndices->push_back(indices[dynamicIndexPos]); - return recordStructIndices(containerType.getElementType(), indexPos + 1, - structIndices, indices, updatedStructIndices, - remainingIndices); + return verifyStructIndices(containerType.getElementType(), indexPos + 1, + indices); }) .Default( [](auto otherType) -> llvm::Error { return llvm::Error::success(); }); @@ -537,122 +648,9 @@ /// Driver function around `recordStructIndices`. Note that we always check /// from the second GEP index since the first one is always dynamic. -static llvm::Error -findStructIndices(Type baseGEPType, ArrayRef structIndices, - ValueRange indices, - SmallVectorImpl *updatedStructIndices = nullptr, - SmallVectorImpl *remainingIndices = nullptr) { - if (remainingIndices) - // The first GEP index is always dynamic. - remainingIndices->push_back(indices[0]); - return recordStructIndices(baseGEPType, /*indexPos=*/1, structIndices, - indices, updatedStructIndices, remainingIndices); -} - -void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType, - Value basePtr, ValueRange operands, - ArrayRef attributes) { - build(builder, result, resultType, basePtr, operands, - SmallVector(operands.size(), kDynamicIndex), attributes); -} - -/// Returns the elemental type of any LLVM-compatible vector type or self. -static Type extractVectorElementType(Type type) { - if (auto vectorType = type.dyn_cast()) - return vectorType.getElementType(); - if (auto scalableVectorType = type.dyn_cast()) - return scalableVectorType.getElementType(); - if (auto fixedVectorType = type.dyn_cast()) - return fixedVectorType.getElementType(); - return type; -} - -void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType, - Type elementType, Value basePtr, ValueRange indices, - ArrayRef attributes) { - build(builder, result, resultType, elementType, basePtr, indices, - SmallVector(indices.size(), kDynamicIndex), attributes); -} - -void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType, - Value basePtr, ValueRange indices, - ArrayRef structIndices, - ArrayRef attributes) { - auto ptrType = - extractVectorElementType(basePtr.getType()).cast(); - assert(!ptrType.isOpaque() && - "expected non-opaque pointer, provide elementType explicitly when " - "opaque pointers are used"); - build(builder, result, resultType, ptrType.getElementType(), basePtr, indices, - structIndices, attributes); -} - -void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType, - Type elementType, Value basePtr, ValueRange indices, - ArrayRef structIndices, - ArrayRef attributes) { - SmallVector remainingIndices; - SmallVector updatedStructIndices(structIndices.begin(), - structIndices.end()); - if (llvm::Error err = - findStructIndices(elementType, structIndices, indices, - &updatedStructIndices, &remainingIndices)) - llvm::report_fatal_error(StringRef(llvm::toString(std::move(err)))); - - assert(remainingIndices.size() == static_cast(llvm::count( - updatedStructIndices, kDynamicIndex)) && - "expected as many index operands as dynamic index attr elements"); - - result.addTypes(resultType); - result.addAttributes(attributes); - result.addAttribute("structIndices", - builder.getI32TensorAttr(updatedStructIndices)); - if (extractVectorElementType(basePtr.getType()) - .cast() - .isOpaque()) - result.addAttribute(kElemTypeAttrName, TypeAttr::get(elementType)); - result.addOperands(basePtr); - result.addOperands(remainingIndices); -} - -static ParseResult -parseGEPIndices(OpAsmParser &parser, - SmallVectorImpl &indices, - DenseIntElementsAttr &structIndices) { - SmallVector constantIndices; - - auto idxParser = [&]() -> ParseResult { - int32_t constantIndex; - OptionalParseResult parsedInteger = - parser.parseOptionalInteger(constantIndex); - if (parsedInteger.hasValue()) { - if (failed(parsedInteger.getValue())) - return failure(); - constantIndices.push_back(constantIndex); - return success(); - } - - constantIndices.push_back(LLVM::GEPOp::kDynamicIndex); - return parser.parseOperand(indices.emplace_back()); - }; - if (parser.parseCommaSeparatedList(idxParser)) - return failure(); - - structIndices = parser.getBuilder().getI32TensorAttr(constantIndices); - return success(); -} - -static void printGEPIndices(OpAsmPrinter &printer, LLVM::GEPOp gepOp, - OperandRange indices, - DenseIntElementsAttr structIndices) { - unsigned operandIdx = 0; - llvm::interleaveComma(structIndices.getValues(), printer, - [&](int32_t cst) { - if (cst == LLVM::GEPOp::kDynamicIndex) - printer.printOperand(indices[operandIdx++]); - else - printer << cst; - }); +static llvm::Error verifyStructIndices(Type baseGEPType, + GEPIndicesAdaptor indices) { + return verifyStructIndices(baseGEPType, /*indexPos=*/1, indices); } LogicalResult LLVM::GEPOp::verify() { @@ -662,14 +660,14 @@ getElemType()))) return failure(); - auto structIndexRange = getStructIndices().getValues(); - // structIndexRange is a kind of iterator, which cannot be converted - // to ArrayRef directly. - SmallVector structIndices(structIndexRange.size()); - for (unsigned i : llvm::seq(0, structIndexRange.size())) - structIndices[i] = structIndexRange[i]; - if (llvm::Error err = findStructIndices(getSourceElementType(), structIndices, - getIndices())) + if (static_cast( + llvm::count(getRawConstantIndices(), kDynamicIndex)) != + getDynamicIndices().size()) + return emitOpError("expected as many dynamic indices as specified in '") + << getRawConstantIndicesAttrName().getValue() << "'"; + + if (llvm::Error err = + verifyStructIndices(getSourceElementType(), getIndices())) return emitOpError() << llvm::toString(std::move(err)); return success(); @@ -2697,10 +2695,49 @@ //===----------------------------------------------------------------------===// OpFoldResult LLVM::GEPOp::fold(ArrayRef operands) { + GEPIndicesAdaptor> indices(getRawConstantIndicesAttr(), + operands.drop_front()); + // gep %x:T, 0 -> %x - if (getBase().getType() == getType() && getIndices().size() == 1 && - getStructIndices().size() == 1 && matchPattern(getIndices()[0], m_Zero())) - return getBase(); + if (getBase().getType() == getType() && indices.size() == 1) + if (auto integer = indices[0].dyn_cast_or_null()) + if (integer.getValue().isZero()) + return getBase(); + + // canonicalize any dynamic indices of constant value to constant indices. + bool changed = false; + SmallVector gepArgs; + for (auto &iter : llvm::enumerate(indices)) { + auto integer = iter.value().dyn_cast_or_null(); + // constant indices can only be int32_t, so if integer does not fit we + // are forced to keep it dynamic, despite being a constant. + if (!indices.isDynamicIndex(iter.index()) || !integer || + !integer.getValue().isSignedIntN(kGEPConstantBitWidth)) { + + PointerUnion existing = getIndices()[iter.index()]; + if (Value val = existing.dyn_cast()) + gepArgs.emplace_back(val); + else + gepArgs.emplace_back(existing.get().getInt()); + + continue; + } + + changed = true; + gepArgs.emplace_back(integer.getInt()); + } + if (changed) { + SmallVector rawConstantIndices; + SmallVector dynamicIndices; + destructIndices(getSourceElementType(), gepArgs, rawConstantIndices, + dynamicIndices); + + getDynamicIndicesMutable().assign(dynamicIndices); + setRawConstantIndicesAttr( + DenseI32ArrayAttr::get(getContext(), rawConstantIndices)); + return Value{*this}; + } + return {}; } Index: mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp =================================================================== --- mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp +++ mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp @@ -1072,24 +1072,23 @@ Value basePtr = processValue(gep->getOperand(0)); Type sourceElementType = processType(gep->getSourceElementType()); - SmallVector indices; - for (llvm::Value *operand : llvm::drop_begin(gep->operand_values())) { - indices.push_back(processValue(operand)); - if (!indices.back()) - return failure(); - } // Treat every indices as dynamic since GEPOp::build will refine those // indices into static attributes later. One small downside of this // approach is that many unused `llvm.mlir.constant` would be emitted // at first place. - SmallVector structIndices(indices.size(), - LLVM::GEPOp::kDynamicIndex); + SmallVector indices; + for (llvm::Value *operand : llvm::drop_begin(gep->operand_values())) { + Value val = processValue(operand); + if (!val) + return failure(); + indices.push_back(val); + } Type type = processType(inst->getType()); if (!type) return failure(); - instMap[inst] = b.create(loc, type, sourceElementType, basePtr, - indices, structIndices); + instMap[inst] = + b.create(loc, type, sourceElementType, basePtr, indices); return success(); } case llvm::Instruction::InsertValue: { Index: mlir/test/Dialect/LLVMIR/canonicalize.mlir =================================================================== --- mlir/test/Dialect/LLVMIR/canonicalize.mlir +++ mlir/test/Dialect/LLVMIR/canonicalize.mlir @@ -102,8 +102,7 @@ // CHECK-LABEL: fold_gep_neg // CHECK-SAME: %[[a0:arg[0-9]+]] -// CHECK-NEXT: %[[C:.*]] = arith.constant 0 -// CHECK-NEXT: %[[RES:.*]] = llvm.getelementptr %[[a0]][%[[C]], 1] +// CHECK-NEXT: %[[RES:.*]] = llvm.getelementptr %[[a0]][0, 1] // CHECK-NEXT: llvm.return %[[RES]] llvm.func @fold_gep_neg(%x : !llvm.ptr) -> !llvm.ptr { %c0 = arith.constant 0 : i32 @@ -111,6 +110,17 @@ llvm.return %0 : !llvm.ptr } +// CHECK-LABEL: fold_gep_canon +// CHECK-SAME: %[[a0:arg[0-9]+]] +// CHECK-NEXT: %[[RES:.*]] = llvm.getelementptr %[[a0]][2] +// CHECK-NEXT: llvm.return %[[RES]] +llvm.func @fold_gep_canon(%x : !llvm.ptr) -> !llvm.ptr { + %c2 = arith.constant 2 : i32 + %c = llvm.getelementptr %x[%c2] : (!llvm.ptr, i32) -> !llvm.ptr + llvm.return %c : !llvm.ptr +} + + // ----- // Check that LLVM constants participate in cross-dialect constant folding. The Index: mlir/test/Dialect/LLVMIR/dynamic-gep-index.mlir =================================================================== --- mlir/test/Dialect/LLVMIR/dynamic-gep-index.mlir +++ mlir/test/Dialect/LLVMIR/dynamic-gep-index.mlir @@ -6,7 +6,7 @@ // CHECK: %[[C0:.+]] = llvm.mlir.constant(0 : i32) %0 = llvm.mlir.constant(0 : i32) : i32 // CHECK: llvm.getelementptr %[[ARG0]][%[[C0]], 1, %[[ARG1]]] - %1 = "llvm.getelementptr"(%arg0, %0, %arg1) {structIndices = dense<[-2147483648, 1, -2147483648]> : tensor<3xi32>} : (!llvm.ptr, array<4 x i32>)>>, i32, i32) -> !llvm.ptr + %1 = "llvm.getelementptr"(%arg0, %0, %arg1) {rawConstantIndices = [:i32 -2147483648, 1, -2147483648]} : (!llvm.ptr, array<4 x i32>)>>, i32, i32) -> !llvm.ptr llvm.return } } Index: mlir/test/Dialect/LLVMIR/invalid.mlir =================================================================== --- mlir/test/Dialect/LLVMIR/invalid.mlir +++ mlir/test/Dialect/LLVMIR/invalid.mlir @@ -146,6 +146,13 @@ // ----- +func.func @gep_too_few_dynamic(%base : !llvm.ptr) { + // expected-error@+1 {{expected as many dynamic indices as specified in 'rawConstantIndices'}} + %1 = "llvm.getelementptr"(%base) {rawConstantIndices = [:i32 -2147483648]} : (!llvm.ptr) -> !llvm.ptr +} + +// ----- + func.func @load_non_llvm_type(%foo : memref) { // expected-error@+1 {{expected LLVM pointer type}} llvm.load %foo : memref