Index: mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h =================================================================== --- mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h +++ mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h @@ -68,6 +68,17 @@ PatternRewriter &rewrite) const override; }; +/// Splits GEPs with more than two indices into multiple GEPs with exactly +/// two indices. The created GEPs are then guaranteed to index into only +/// one aggregate at a time. +class SplitGEP : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GEPOp gep, + PatternRewriter &rewriter) const override; +}; + } // namespace LLVM } // namespace mlir Index: mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp =================================================================== --- mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -778,34 +778,27 @@ } Type GEPOp::getResultPtrElementType() { - // Ensures all indices are static and fetches them. - SmallVector indices; - for (auto index : getIndices()) { - IntegerAttr indexInt = llvm::dyn_cast_if_present(index); - if (!indexInt) - return nullptr; - indices.push_back(indexInt); - } - // Set the initial type currently being used for indexing. This will be // updated as the indices get walked over. Type selectedType = getSourceElementType(); // Follow the indexed elements in the gep. - for (IntegerAttr index : llvm::drop_begin(indices)) { - // Ensure the structure of the type being indexed can be reasoned about. - // This includes rejecting any potential typed pointer. - auto destructurable = - llvm::dyn_cast(selectedType); - if (!destructurable) - return nullptr; - - // Follow the type at the index the gep is accessing, making it the new type - // used for indexing. - Type field = destructurable.getTypeAtIndex(index); - if (!field) - return nullptr; - selectedType = field; + auto indices = getIndices(); + for (GEPIndicesAdaptor::value_type index : + llvm::drop_begin(indices)) { + // GEPs can only index into aggregates which are only arrays and structs. + + // The resulting type if indexing into an array type is always the element + // type, regardless of index. + if (auto arrayType = dyn_cast(selectedType)) { + selectedType = arrayType.getElementType(); + continue; + } + + // The GEP verifier ensures that any index into structs are static and + // that they refer to a field within the struct. + int64_t fieldIndex = cast(index).getValue().getZExtValue(); + selectedType = cast(selectedType).getBody()[fieldIndex]; } // When there are no more indices, the type currently being used for indexing Index: mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp =================================================================== --- mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp +++ mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp @@ -349,7 +349,9 @@ Type reachedType = getResultPtrElementType(); if (!reachedType || getIndices().size() < 2) return false; - auto firstLevelIndex = cast(getIndices()[1]); + auto firstLevelIndex = dyn_cast(getIndices()[1]); + if (!firstLevelIndex) + return false; assert(slot.elementPtrs.contains(firstLevelIndex)); if (!llvm::isa(slot.elementPtrs.at(firstLevelIndex))) return false; Index: mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp =================================================================== --- mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp +++ mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp @@ -558,6 +558,45 @@ return success(); } +LogicalResult SplitGEP::matchAndRewrite(GEPOp gep, + PatternRewriter &rewriter) const { + FailureOr typeHint = getRequiredConsistentGEPType(gep); + if (succeeded(typeHint) || gep.getIndices().size() <= 2) { + // GEP is not canonical or a single aggregate deep, nothing to do here. + return failure(); + } + + auto indexToGEPArg = + [](GEPIndicesAdaptor::value_type index) -> GEPArg { + if (auto integerAttr = dyn_cast(index)) + return integerAttr.getValue().getSExtValue(); + return cast(index); + }; + + GEPIndicesAdaptor indices = gep.getIndices(); + + auto splitIter = std::next(indices.begin(), 2); + + // Split of the first GEP using the first two indices. + auto subGepOp = rewriter.create( + gep.getLoc(), gep.getType(), gep.getSourceElementType(), gep.getBase(), + llvm::map_to_vector(llvm::make_range(indices.begin(), splitIter), + indexToGEPArg), + gep.getInbounds()); + + // The second GEP indexes on the result pointer element type of the previous + // with all the remaining indices and a zero upfront. If this GEP has more + // than two indices remaining it'll be further split in subsequent pattern + // applications. + auto gepArgs = llvm::map_to_vector(llvm::make_range(splitIter, indices.end()), + indexToGEPArg); + gepArgs.insert(gepArgs.begin(), 0); + rewriter.replaceOpWithNewOp(gep, gep.getType(), + subGepOp.getResultPtrElementType(), + subGepOp, gepArgs, gep.getInbounds()); + return success(); +} + //===----------------------------------------------------------------------===// // Type consistency pass //===----------------------------------------------------------------------===// @@ -572,6 +611,7 @@ &getContext()); rewritePatterns.add(&getContext()); rewritePatterns.add(&getContext(), maxVectorSplitSize); + rewritePatterns.add(&getContext()); FrozenRewritePatternSet frozen(std::move(rewritePatterns)); if (failed(applyPatternsAndFoldGreedily(getOperation(), frozen)))