Index: mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h =================================================================== --- mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h +++ mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h @@ -72,6 +72,100 @@ #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/LLVMIR/LLVMOpsAttrDefs.h.inc" +namespace mlir { +namespace LLVM { +template +class GEPIndicesAdaptor; + +/// Class used for building a 'llvm.getelementptr'. A single instance represents +/// a sum type that is either a 'Value' or a constant 'int32_t' index. The +/// former represents a dynamic index in a GEP operation, while the later is +/// a constant index as is required for indices into struct types. +class GEPArg { +public: + /// Constructs a GEPArg with a constant index. + GEPArg(int32_t integer) : intActive(true), constantIndex(integer) {} + + /// Constructs a GEPArg with a dynamic index. + GEPArg(Value value) : intActive(false), dynamicIndex(value) {} + + /// Assigns an integer to the GEPArg and makes it active. + GEPArg &operator=(int32_t integer) { + intActive = true; + constantIndex = integer; + return *this; + } + + /// Assigns a value to the GEPArg and makes it active. + GEPArg &operator=(Value value) { + intActive = false; + dynamicIndex = value; + return *this; + } + + template + bool is() const = delete; + + /// Returns true if this GEPArg contains a int32_t. + template <> + bool is() const { + return intActive; + } + + /// Returns true if this GEPArg contains a Value. + template <> + bool is() const { + return !intActive; + } + + template + std::conditional_t::value, const int32_t &, Value> + get() const = delete; + + /// Returns the contained int32_t. This operation is invalid if this GEPArg + /// does not contain a int32_t. + template <> + const int32_t &get() const { + assert(is() && "int32_t member not active"); + return constantIndex; + } + + /// Returns the contained Value. This operation is invalid if this GEPArg + /// does not contain a Value. + template <> + Value get() const { + assert(is() && "value type member not active"); + return dynamicIndex; + } + + template + std::conditional_t::value, const int32_t *, Value> + dyn_cast() const = delete; + + /// Returns a pointer to the contained int32_t or nullptr if it does not + /// contain a int32_t. + template <> + const int32_t *dyn_cast() const { + return is() ? &get() : nullptr; + } + + /// Returns the contained Value or a null Value if it does not + /// contain a Value. + template <> + Value dyn_cast() const { + return is() ? get() : nullptr; + } + +private: + /// Poor man's sum type. A boolean value is used to track whether the int32_t + /// or the Value should be deemed as active. + bool intActive; + int32_t constantIndex; + Value dynamicIndex; +}; +} // namespace LLVM +} // namespace mlir + ///// Ops ///// #define GET_OP_CLASSES #include "mlir/Dialect/LLVMIR/LLVMOps.h.inc" @@ -82,6 +176,93 @@ namespace mlir { namespace LLVM { + +/// Class used for convenient random access and iteration over GEP indices. +/// This class is templated to support not only retrieving the dynamic operands +/// of a GEP operation, but also as an adaptor during folding or conversion to +/// LLVM IR. +/// +/// GEP uses a DenseI64ArrayAttr with as many elements as it has indices. If a +/// specific element in it is in the range of a int32_t it is a constant index. +/// If it is larger than the maximum value of a int32_t then it is an index +/// offset by 'GEPOp::kDynamicIndexOffset' into another range containing the +/// dynamic index. +/// +/// This range being accessed is the 'DynamicRange' of this class. This +/// way it can be used as getter in GEPOp via 'GEPIndicesAdapter' +/// or during folding via 'GEPIndicesAdapter>'. +template +class GEPIndicesAdaptor { +public: + /// Return type of 'operator[]' and the iterators 'operator*'. It is depended + /// upon the value type of 'DynamicRange'. If 'DynamicRange' contains + /// Attributes or subclasses thereof, then value_type is 'Attribute'. In + /// all other cases it is a pointer union between the value type of + /// 'DynamicRange' and IntegerAttr. + using value_type = std::conditional_t< + std::is_base_of::value_type>::value, + Attribute, + PointerUnion::value_type>>; + + /// Constructs a GEPIndicesAdaptor with the raw constant indices of a GEPOp + /// and the range that is indexed into for retrieving dynamic indices. + GEPIndicesAdaptor(DenseI64ArrayAttr rawConstantIndices, DynamicRange values) + : rawConstantIndices(rawConstantIndices), values(std::move(values)) {} + + /// Returns the GEP index at the given position. This operation is invalid if + /// the index is out of bounds. + value_type operator[](size_t index) const { + assert(index < size() && "index out of bounds"); + if (isDynamicIndex(index)) + return values[rawConstantIndices[index] - GEPOp::kDynamicIndexOffset]; + + return IntegerAttr::get(ElementsAttr::getElementType(rawConstantIndices), + rawConstantIndices[index]); + } + + /// Returns whether the GEP index at the given position is a dynamic index. + bool isDynamicIndex(size_t index) const { + return rawConstantIndices[index] >= GEPOp::kDynamicIndexOffset; + } + + /// Returns the amount of indices of the GEPOp. + size_t size() const { return rawConstantIndices.size(); } + + /// Returns true if this GEPOp does not have any indices. + bool empty() const { return size() == 0; } + + class iterator : public llvm::indexed_accessor_iterator< + iterator, const GEPIndicesAdaptor *, + value_type, value_type *, value_type> { + using BaseT = + llvm::indexed_accessor_iterator *, + value_type, value_type *, value_type>; + + public: + iterator(const GEPIndicesAdaptor *base, size_t index) + : BaseT(base, index) {} + + value_type operator*() const { + return (*this->getBase())[this->getIndex()]; + } + }; + + /// Returns the begin iterator, iterating over all GEP indices. + iterator begin() const { return iterator(this, 0); } + + /// Returns the end iterator, iterating over all GEP indices. + iterator end() const { return iterator(this, size()); } + +private: + DenseI64ArrayAttr rawConstantIndices; + DynamicRange values; +}; + /// Create an LLVM global containing the string "value" at the module containing /// surrounding the insertion point of builder. Obtain the address of that /// global and use it to compute the address of the first character in the Index: mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td =================================================================== --- mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -423,51 +423,75 @@ def LLVM_GEPOp : LLVM_Op<"getelementptr", [NoSideEffect]> { let arguments = (ins LLVM_ScalarOrVectorOf:$base, - Variadic>:$indices, - I32ElementsAttr:$structIndices, + Variadic>:$dynamicIndices, + DenseI64ArrayAttr:$rawConstantIndices, OptionalAttr:$elem_type); let results = (outs LLVM_ScalarOrVectorOf:$res); let skipDefaultBuilders = 1; + + let description = [{ + This operation mirrors LLVM IRs 'getelementptr' operation that is used to + perform pointer arithmetic. + + Like in LLVM IR, it is possible to use both constants as well as SSA values + as indices. In the case of indexing within a structure, it is required to + either use constant indices directly, or supply a constant SSA value. + + Examples: + + ```mlir + // GEP with an SSA value offset + %0 = llvm.getelementptr %1[%2] : (!llvm.ptr, i64) -> !llvm.ptr + + // GEP with a constant offset + %0 = llvm.getelementptr %1[3] : (!llvm.ptr) -> !llvm.ptr + + // GEP with constant offsets into a structure + %0 = llvm.getelementptr %1[0, 1] + : (!llvm.ptr) -> !llvm.ptr + ``` + }]; + let builders = [ - OpBuilder<(ins "Type":$resultType, "Value":$basePtr, "ValueRange":$indices, - CArg<"ArrayRef", "{}">:$attributes)>, - OpBuilder<(ins "Type":$resultType, "Value":$basePtr, "ValueRange":$indices, - "ArrayRef":$structIndices, - CArg<"ArrayRef", "{}">:$attributes)>, OpBuilder<(ins "Type":$resultType, "Type":$basePtrType, "Value":$basePtr, "ValueRange":$indices, CArg<"ArrayRef", "{}">:$attributes)>, - OpBuilder<(ins "Type":$resultType, "Type":$basePtrType, "Value":$basePtr, - "ValueRange":$indices, + OpBuilder<(ins "Type":$resultType, "Value":$basePtr, "ValueRange":$indices, + CArg<"ArrayRef", "{}">:$attributes)>, + OpBuilder<(ins "Type":$resultType, "Value":$basePtr, "ArrayRef":$indices, CArg<"ArrayRef", "{}">:$attributes)>, OpBuilder<(ins "Type":$resultType, "Type":$basePtrType, "Value":$basePtr, - "ValueRange":$indices, "ArrayRef":$structIndices, - CArg<"ArrayRef", "{}">:$attributes)> + "ArrayRef":$indices, + CArg<"ArrayRef", "{}">:$attributes)>, ]; let llvmBuilder = [{ SmallVector indices; - indices.reserve($structIndices.size()); - unsigned operandIdx = 0; - for (int32_t structIndex : $structIndices.getValues()) { - if (structIndex == GEPOp::kDynamicIndex) - indices.push_back($indices[operandIdx++]); + indices.reserve($rawConstantIndices.size()); + GEPIndicesAdaptor + gepIndices(op.getRawConstantIndicesAttr(), $dynamicIndices); + for (PointerUnion valueOrAttr : gepIndices) { + if (llvm::Value* value = valueOrAttr.dyn_cast()) + indices.push_back(value); else - indices.push_back(builder.getInt32(structIndex)); + indices.push_back( + builder.getInt32(valueOrAttr.get().getInt())); } Type baseElementType = op.getSourceElementType(); llvm::Type *elementType = moduleTranslation.convertType(baseElementType); $res = builder.CreateGEP(elementType, $base, indices); }]; let assemblyFormat = [{ - $base `[` custom($indices, $structIndices) `]` attr-dict + $base `[` custom($dynamicIndices, $rawConstantIndices) `]` attr-dict `:` functional-type(operands, results) (`,` $elem_type^)? }]; let extraClassDeclaration = [{ - constexpr static int kDynamicIndex = std::numeric_limits::min(); + constexpr static int64_t kDynamicIndexOffset = std::numeric_limits::max() + int64_t(1); /// Returns the type pointed to by the pointer argument of this GEP. Type getSourceElementType(); + + GEPIndicesAdaptor getIndices(); }]; let hasFolder = 1; let hasVerifier = 1; Index: mlir/include/mlir/IR/BuiltinAttributes.h =================================================================== --- mlir/include/mlir/IR/BuiltinAttributes.h +++ mlir/include/mlir/IR/BuiltinAttributes.h @@ -755,7 +755,10 @@ /// Implicit conversion to ArrayRef. operator ArrayRef() const; - ArrayRef asArrayRef() { return ArrayRef{*this}; } + ArrayRef asArrayRef() const { return ArrayRef{*this}; } + + /// Random access to elements. + T operator[](std::size_t index) const { return asArrayRef()[index]; } /// Builder from ArrayRef. static DenseArrayAttr get(MLIRContext *context, ArrayRef content); @@ -1017,6 +1020,14 @@ } }; +template <> +struct PointerLikeTypeTraits + : public PointerLikeTypeTraits { + static inline mlir::IntegerAttr getFromVoidPointer(void *p) { + return mlir::IntegerAttr::getFromOpaquePointer(p); + } +}; + template <> struct PointerLikeTypeTraits : public PointerLikeTypeTraits { Index: mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp =================================================================== --- mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -418,7 +418,156 @@ // Code for LLVM::GEPOp. //===----------------------------------------------------------------------===// -constexpr int GEPOp::kDynamicIndex; +constexpr int64_t GEPOp::kDynamicIndexOffset; + +GEPIndicesAdaptor GEPOp::getIndices() { + return GEPIndicesAdaptor(getRawConstantIndicesAttr(), + getDynamicIndices()); +} + +/// Returns the elemental type of any LLVM-compatible vector type or self. +static Type extractVectorElementType(Type type) { + if (auto vectorType = type.dyn_cast()) + return vectorType.getElementType(); + if (auto scalableVectorType = type.dyn_cast()) + return scalableVectorType.getElementType(); + if (auto fixedVectorType = type.dyn_cast()) + return fixedVectorType.getElementType(); + return type; +} + +void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType, + Value basePtr, ArrayRef indices, + ArrayRef attributes) { + auto ptrType = + extractVectorElementType(basePtr.getType()).cast(); + assert(!ptrType.isOpaque() && + "expected non-opaque pointer, provide elementType explicitly when " + "opaque pointers are used"); + build(builder, result, resultType, ptrType.getElementType(), basePtr, indices, + attributes); +} + +static void destructIndices(Type currType, ArrayRef indices, + SmallVectorImpl &rawConstantIndices, + SmallVectorImpl &dynamicIndices) { + for (const GEPArg &iter : indices) { + // If the thing we are currently indexing into is a struct we must turn + // any integer constants into constant indices. If this is not possible + // we don't do anything here. The verifier will catch it and emit a proper + // error. All other canonicalization is done in the fold method. + bool requiresConst = !rawConstantIndices.empty() && + currType.isa_and_nonnull(); + if (Value val = iter.dyn_cast()) { + APInt intC; + if (requiresConst && matchPattern(val, m_ConstantInt(&intC)) && + intC.isSignedIntN(32)) { + rawConstantIndices.push_back(intC.getSExtValue()); + } else { + rawConstantIndices.push_back(dynamicIndices.size() + + GEPOp::kDynamicIndexOffset); + dynamicIndices.push_back(val); + } + } else { + rawConstantIndices.push_back(iter.get()); + } + + // Skip for very first iteration of this loop. First index does not index + // within the aggregates, but is just a pointer offset. + if (rawConstantIndices.size() == 1 || !currType) + continue; + + currType = + TypeSwitch(currType) + .Case([](auto containerType) { + return containerType.getElementType(); + }) + .Case([&](LLVMStructType structType) -> Type { + int64_t memberIndex = rawConstantIndices.back(); + if (memberIndex >= 0 && static_cast(memberIndex) < + structType.getBody().size()) + return structType.getBody()[memberIndex]; + return nullptr; + }) + .Default(Type(nullptr)); + } +} + +void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType, + Type elementType, Value basePtr, ArrayRef indices, + ArrayRef attributes) { + SmallVector rawConstantIndices; + SmallVector dynamicIndices; + destructIndices(elementType, indices, rawConstantIndices, dynamicIndices); + + result.addTypes(resultType); + result.addAttributes(attributes); + result.addAttribute(getRawConstantIndicesAttrName(result.name), + builder.getDenseI64ArrayAttr(rawConstantIndices)); + if (extractVectorElementType(basePtr.getType()) + .cast() + .isOpaque()) + result.addAttribute(kElemTypeAttrName, TypeAttr::get(elementType)); + result.addOperands(basePtr); + result.addOperands(dynamicIndices); +} + +void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType, + Value basePtr, ValueRange indices, + ArrayRef attributes) { + build(builder, result, resultType, basePtr, SmallVector(indices), + attributes); +} + +void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType, + Type elementType, Value basePtr, ValueRange indices, + ArrayRef attributes) { + build(builder, result, resultType, elementType, basePtr, + SmallVector(indices), attributes); +} + +static ParseResult +parseGEPIndices(OpAsmParser &parser, + SmallVectorImpl &indices, + DenseI64ArrayAttr &rawConstantIndices) { + SmallVector constantIndices; + + auto idxParser = [&]() -> ParseResult { + int32_t constantIndex; + OptionalParseResult parsedInteger = + parser.parseOptionalInteger(constantIndex); + if (parsedInteger.hasValue()) { + if (failed(parsedInteger.getValue())) + return failure(); + constantIndices.push_back(constantIndex); + return success(); + } + + constantIndices.push_back(indices.size() + + LLVM::GEPOp::kDynamicIndexOffset); + return parser.parseOperand(indices.emplace_back()); + }; + if (parser.parseCommaSeparatedList(idxParser)) + return failure(); + + rawConstantIndices = + DenseI64ArrayAttr::get(parser.getContext(), constantIndices); + return success(); +} + +static void printGEPIndices(OpAsmPrinter &printer, LLVM::GEPOp gepOp, + OperandRange indices, + DenseI64ArrayAttr rawConstantIndices) { + llvm::interleaveComma( + GEPIndicesAdaptor(rawConstantIndices, indices), printer, + [&](PointerUnion cst) { + if (Value val = cst.dyn_cast()) + printer.printOperand(val); + else + printer << cst.get().getInt(); + }); +} namespace { /// Base class for llvm::Error related to GEP index. @@ -467,69 +616,33 @@ char GEPStaticIndexError::ID = 0; /// For the given `structIndices` and `indices`, check if they're complied -/// with `baseGEPType`, especially check against LLVMStructTypes nested within, -/// and refine/promote struct index from `indices` to `updatedStructIndices` -/// if the latter argument is not null. -static llvm::Error -recordStructIndices(Type baseGEPType, unsigned indexPos, - ArrayRef structIndices, ValueRange indices, - SmallVectorImpl *updatedStructIndices, - SmallVectorImpl *remainingIndices) { - if (indexPos >= structIndices.size()) +/// with `baseGEPType`, especially check against LLVMStructTypes nested within. +static llvm::Error verifyStructIndices(Type baseGEPType, unsigned indexPos, + GEPIndicesAdaptor indices) { + if (indexPos >= indices.size()) // Stop searching return llvm::Error::success(); - int32_t gepIndex = structIndices[indexPos]; - bool isStaticIndex = gepIndex != GEPOp::kDynamicIndex; - - unsigned dynamicIndexPos = indexPos; - if (!isStaticIndex) - dynamicIndexPos = llvm::count(structIndices.take_front(indexPos + 1), - LLVM::GEPOp::kDynamicIndex) - - 1; - return llvm::TypeSwitch(baseGEPType) .Case([&](LLVMStructType structType) -> llvm::Error { - // We don't always want to refine the index (e.g. when performing - // verification), so we only refine when updatedStructIndices is not - // null. - if (!isStaticIndex && updatedStructIndices) { - // Try to refine. - APInt staticIndexValue; - isStaticIndex = matchPattern(indices[dynamicIndexPos], - m_ConstantInt(&staticIndexValue)); - if (isStaticIndex) { - assert(staticIndexValue.getBitWidth() <= 64 && - llvm::isInt<32>(staticIndexValue.getLimitedValue()) && - "struct index can't fit within int32_t"); - gepIndex = static_cast(staticIndexValue.getSExtValue()); - } - } - if (!isStaticIndex) + if (!indices[indexPos].is()) return llvm::make_error(indexPos); + int32_t gepIndex = indices[indexPos].get().getInt(); ArrayRef elementTypes = structType.getBody(); if (gepIndex < 0 || static_cast(gepIndex) >= elementTypes.size()) return llvm::make_error(indexPos); - if (updatedStructIndices) - (*updatedStructIndices)[indexPos] = gepIndex; - - // Instead of recusively going into every children types, we only + // Instead of recursively going into every children types, we only // dive into the one indexed by gepIndex. - return recordStructIndices(elementTypes[gepIndex], indexPos + 1, - structIndices, indices, updatedStructIndices, - remainingIndices); + return verifyStructIndices(elementTypes[gepIndex], indexPos + 1, + indices); }) .Case([&](auto containerType) -> llvm::Error { - // Currently we don't refine non-struct index even if it's static. - if (remainingIndices) - remainingIndices->push_back(indices[dynamicIndexPos]); - return recordStructIndices(containerType.getElementType(), indexPos + 1, - structIndices, indices, updatedStructIndices, - remainingIndices); + return verifyStructIndices(containerType.getElementType(), indexPos + 1, + indices); }) .Default( [](auto otherType) -> llvm::Error { return llvm::Error::success(); }); @@ -537,122 +650,9 @@ /// Driver function around `recordStructIndices`. Note that we always check /// from the second GEP index since the first one is always dynamic. -static llvm::Error -findStructIndices(Type baseGEPType, ArrayRef structIndices, - ValueRange indices, - SmallVectorImpl *updatedStructIndices = nullptr, - SmallVectorImpl *remainingIndices = nullptr) { - if (remainingIndices) - // The first GEP index is always dynamic. - remainingIndices->push_back(indices[0]); - return recordStructIndices(baseGEPType, /*indexPos=*/1, structIndices, - indices, updatedStructIndices, remainingIndices); -} - -void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType, - Value basePtr, ValueRange operands, - ArrayRef attributes) { - build(builder, result, resultType, basePtr, operands, - SmallVector(operands.size(), kDynamicIndex), attributes); -} - -/// Returns the elemental type of any LLVM-compatible vector type or self. -static Type extractVectorElementType(Type type) { - if (auto vectorType = type.dyn_cast()) - return vectorType.getElementType(); - if (auto scalableVectorType = type.dyn_cast()) - return scalableVectorType.getElementType(); - if (auto fixedVectorType = type.dyn_cast()) - return fixedVectorType.getElementType(); - return type; -} - -void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType, - Type elementType, Value basePtr, ValueRange indices, - ArrayRef attributes) { - build(builder, result, resultType, elementType, basePtr, indices, - SmallVector(indices.size(), kDynamicIndex), attributes); -} - -void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType, - Value basePtr, ValueRange indices, - ArrayRef structIndices, - ArrayRef attributes) { - auto ptrType = - extractVectorElementType(basePtr.getType()).cast(); - assert(!ptrType.isOpaque() && - "expected non-opaque pointer, provide elementType explicitly when " - "opaque pointers are used"); - build(builder, result, resultType, ptrType.getElementType(), basePtr, indices, - structIndices, attributes); -} - -void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType, - Type elementType, Value basePtr, ValueRange indices, - ArrayRef structIndices, - ArrayRef attributes) { - SmallVector remainingIndices; - SmallVector updatedStructIndices(structIndices.begin(), - structIndices.end()); - if (llvm::Error err = - findStructIndices(elementType, structIndices, indices, - &updatedStructIndices, &remainingIndices)) - llvm::report_fatal_error(StringRef(llvm::toString(std::move(err)))); - - assert(remainingIndices.size() == static_cast(llvm::count( - updatedStructIndices, kDynamicIndex)) && - "expected as many index operands as dynamic index attr elements"); - - result.addTypes(resultType); - result.addAttributes(attributes); - result.addAttribute("structIndices", - builder.getI32TensorAttr(updatedStructIndices)); - if (extractVectorElementType(basePtr.getType()) - .cast() - .isOpaque()) - result.addAttribute(kElemTypeAttrName, TypeAttr::get(elementType)); - result.addOperands(basePtr); - result.addOperands(remainingIndices); -} - -static ParseResult -parseGEPIndices(OpAsmParser &parser, - SmallVectorImpl &indices, - DenseIntElementsAttr &structIndices) { - SmallVector constantIndices; - - auto idxParser = [&]() -> ParseResult { - int32_t constantIndex; - OptionalParseResult parsedInteger = - parser.parseOptionalInteger(constantIndex); - if (parsedInteger.hasValue()) { - if (failed(parsedInteger.getValue())) - return failure(); - constantIndices.push_back(constantIndex); - return success(); - } - - constantIndices.push_back(LLVM::GEPOp::kDynamicIndex); - return parser.parseOperand(indices.emplace_back()); - }; - if (parser.parseCommaSeparatedList(idxParser)) - return failure(); - - structIndices = parser.getBuilder().getI32TensorAttr(constantIndices); - return success(); -} - -static void printGEPIndices(OpAsmPrinter &printer, LLVM::GEPOp gepOp, - OperandRange indices, - DenseIntElementsAttr structIndices) { - unsigned operandIdx = 0; - llvm::interleaveComma(structIndices.getValues(), printer, - [&](int32_t cst) { - if (cst == LLVM::GEPOp::kDynamicIndex) - printer.printOperand(indices[operandIdx++]); - else - printer << cst; - }); +static llvm::Error verifyStructIndices(Type baseGEPType, + GEPIndicesAdaptor indices) { + return verifyStructIndices(baseGEPType, /*indexPos=*/1, indices); } LogicalResult LLVM::GEPOp::verify() { @@ -662,14 +662,28 @@ getElemType()))) return failure(); - auto structIndexRange = getStructIndices().getValues(); - // structIndexRange is a kind of iterator, which cannot be converted - // to ArrayRef directly. - SmallVector structIndices(structIndexRange.size()); - for (unsigned i : llvm::seq(0, structIndexRange.size())) - structIndices[i] = structIndexRange[i]; - if (llvm::Error err = findStructIndices(getSourceElementType(), structIndices, - getIndices())) + if (static_cast( + llvm::count_if(getRawConstantIndices(), [](int64_t val) { + return val >= kDynamicIndexOffset; + })) != getDynamicIndices().size()) + return emitOpError("expected as many dynamic indices as specified in '") + << getRawConstantIndicesAttrName().getValue() << "'"; + + for (int64_t index : getRawConstantIndices()) { + if (index < kDynamicIndexOffset) + continue; + + if ((index - kDynamicIndexOffset) < + static_cast(getDynamicIndices().size())) + continue; + + return emitOpError("invalid raw constant index pointing at non-existent " + "dynamic index '") + << (index - kDynamicIndexOffset) << "'"; + } + + if (llvm::Error err = + verifyStructIndices(getSourceElementType(), getIndices())) return emitOpError() << llvm::toString(std::move(err)); return success(); @@ -2697,10 +2711,49 @@ //===----------------------------------------------------------------------===// OpFoldResult LLVM::GEPOp::fold(ArrayRef operands) { + GEPIndicesAdaptor> indices(getRawConstantIndicesAttr(), + operands.drop_front()); + // gep %x:T, 0 -> %x - if (getBase().getType() == getType() && getIndices().size() == 1 && - getStructIndices().size() == 1 && matchPattern(getIndices()[0], m_Zero())) - return getBase(); + if (getBase().getType() == getType() && indices.size() == 1) + if (auto integer = indices[0].dyn_cast_or_null()) + if (integer.getValue().isZero()) + return getBase(); + + // canonicalize any dynamic indices of constant value to constant indices. + bool changed = false; + SmallVector gepArgs; + for (auto &iter : llvm::enumerate(indices)) { + auto integer = iter.value().dyn_cast_or_null(); + // constant indices can only be int32_t, so if integer does not fit we + // are forced to keep it dynamic, despite being a constant. + if (!indices.isDynamicIndex(iter.index()) || !integer || + !integer.getValue().isSignedIntN(32)) { + + PointerUnion existing = getIndices()[iter.index()]; + if (Value val = existing.dyn_cast()) + gepArgs.emplace_back(val); + else + gepArgs.emplace_back(existing.get().getInt()); + + continue; + } + + changed = true; + gepArgs.emplace_back(integer.getInt()); + } + if (changed) { + SmallVector rawConstantIndices; + SmallVector dynamicIndices; + destructIndices(getSourceElementType(), gepArgs, rawConstantIndices, + dynamicIndices); + + getDynamicIndicesMutable().assign(dynamicIndices); + setRawConstantIndicesAttr( + DenseI64ArrayAttr::get(getContext(), rawConstantIndices)); + return Value{*this}; + } + return {}; } Index: mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp =================================================================== --- mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp +++ mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp @@ -1072,24 +1072,23 @@ Value basePtr = processValue(gep->getOperand(0)); Type sourceElementType = processType(gep->getSourceElementType()); - SmallVector indices; - for (llvm::Value *operand : llvm::drop_begin(gep->operand_values())) { - indices.push_back(processValue(operand)); - if (!indices.back()) - return failure(); - } // Treat every indices as dynamic since GEPOp::build will refine those // indices into static attributes later. One small downside of this // approach is that many unused `llvm.mlir.constant` would be emitted // at first place. - SmallVector structIndices(indices.size(), - LLVM::GEPOp::kDynamicIndex); + SmallVector indices; + for (llvm::Value *operand : llvm::drop_begin(gep->operand_values())) { + Value val = processValue(operand); + if (!val) + return failure(); + indices.push_back(val); + } Type type = processType(inst->getType()); if (!type) return failure(); - instMap[inst] = b.create(loc, type, sourceElementType, basePtr, - indices, structIndices); + instMap[inst] = + b.create(loc, type, sourceElementType, basePtr, indices); return success(); } case llvm::Instruction::InsertValue: { Index: mlir/test/Dialect/LLVMIR/canonicalize.mlir =================================================================== --- mlir/test/Dialect/LLVMIR/canonicalize.mlir +++ mlir/test/Dialect/LLVMIR/canonicalize.mlir @@ -102,8 +102,7 @@ // CHECK-LABEL: fold_gep_neg // CHECK-SAME: %[[a0:arg[0-9]+]] -// CHECK-NEXT: %[[C:.*]] = arith.constant 0 -// CHECK-NEXT: %[[RES:.*]] = llvm.getelementptr %[[a0]][%[[C]], 1] +// CHECK-NEXT: %[[RES:.*]] = llvm.getelementptr %[[a0]][0, 1] // CHECK-NEXT: llvm.return %[[RES]] llvm.func @fold_gep_neg(%x : !llvm.ptr) -> !llvm.ptr { %c0 = arith.constant 0 : i32 @@ -111,6 +110,17 @@ llvm.return %0 : !llvm.ptr } +// CHECK-LABEL: fold_gep_canon +// CHECK-SAME: %[[a0:arg[0-9]+]] +// CHECK-NEXT: %[[RES:.*]] = llvm.getelementptr %[[a0]][2] +// CHECK-NEXT: llvm.return %[[RES]] +llvm.func @fold_gep_canon(%x : !llvm.ptr) -> !llvm.ptr { + %c2 = arith.constant 2 : i32 + %c = llvm.getelementptr %x[%c2] : (!llvm.ptr, i32) -> !llvm.ptr + llvm.return %c : !llvm.ptr +} + + // ----- // Check that LLVM constants participate in cross-dialect constant folding. The Index: mlir/test/Dialect/LLVMIR/dynamic-gep-index.mlir =================================================================== --- mlir/test/Dialect/LLVMIR/dynamic-gep-index.mlir +++ mlir/test/Dialect/LLVMIR/dynamic-gep-index.mlir @@ -6,7 +6,7 @@ // CHECK: %[[C0:.+]] = llvm.mlir.constant(0 : i32) %0 = llvm.mlir.constant(0 : i32) : i32 // CHECK: llvm.getelementptr %[[ARG0]][%[[C0]], 1, %[[ARG1]]] - %1 = "llvm.getelementptr"(%arg0, %0, %arg1) {structIndices = dense<[-2147483648, 1, -2147483648]> : tensor<3xi32>} : (!llvm.ptr, array<4 x i32>)>>, i32, i32) -> !llvm.ptr + %1 = "llvm.getelementptr"(%arg0, %0, %arg1) {rawConstantIndices = [:i64 2147483648, 1, 2147483649]} : (!llvm.ptr, array<4 x i32>)>>, i32, i32) -> !llvm.ptr llvm.return } } Index: mlir/test/Dialect/LLVMIR/invalid.mlir =================================================================== --- mlir/test/Dialect/LLVMIR/invalid.mlir +++ mlir/test/Dialect/LLVMIR/invalid.mlir @@ -146,6 +146,20 @@ // ----- +func.func @gep_too_few_dynamic(%base : !llvm.ptr) { + // expected-error@+1 {{expected as many dynamic indices as specified in 'rawConstantIndices'}} + %1 = "llvm.getelementptr"(%base) {rawConstantIndices = [:i64 2147483648]} : (!llvm.ptr) -> !llvm.ptr +} + +// ----- + +func.func @gep_invalid_dynamic_index(%pos : i64,%base : !llvm.ptr) { + // expected-error@+1 {{invalid raw constant index pointing at non-existent dynamic index '1'}} + %1 = "llvm.getelementptr"(%base, %pos) {rawConstantIndices = [:i64 2147483649]} : (!llvm.ptr, i64) -> !llvm.ptr +} + +// ----- + func.func @load_non_llvm_type(%foo : memref) { // expected-error@+1 {{expected LLVM pointer type}} llvm.load %foo : memref