Index: mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h =================================================================== --- mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h +++ mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h @@ -79,6 +79,17 @@ PatternRewriter &rewriter) const override; }; +/// Splits GEPs with more than two indices into multiple GEPs with exactly +/// two indices. The created GEPs are then guaranteed to index into only +/// one aggregate at a time. +class SplitGEP : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GEPOp gepOp, + PatternRewriter &rewriter) const override; +}; + } // namespace LLVM } // namespace mlir Index: mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp =================================================================== --- mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -778,34 +778,27 @@ } Type GEPOp::getResultPtrElementType() { - // Ensures all indices are static and fetches them. - SmallVector indices; - for (auto index : getIndices()) { - IntegerAttr indexInt = llvm::dyn_cast_if_present(index); - if (!indexInt) - return nullptr; - indices.push_back(indexInt); - } - // Set the initial type currently being used for indexing. This will be // updated as the indices get walked over. Type selectedType = getSourceElementType(); // Follow the indexed elements in the gep. - for (IntegerAttr index : llvm::drop_begin(indices)) { - // Ensure the structure of the type being indexed can be reasoned about. - // This includes rejecting any potential typed pointer. - auto destructurable = - llvm::dyn_cast(selectedType); - if (!destructurable) - return nullptr; - - // Follow the type at the index the gep is accessing, making it the new type - // used for indexing. - Type field = destructurable.getTypeAtIndex(index); - if (!field) - return nullptr; - selectedType = field; + auto indices = getIndices(); + for (GEPIndicesAdaptor::value_type index : + llvm::drop_begin(indices)) { + // GEPs can only index into aggregates which can be structs or arrays. + + // The resulting type if indexing into an array type is always the element + // type, regardless of index. + if (auto arrayType = dyn_cast(selectedType)) { + selectedType = arrayType.getElementType(); + continue; + } + + // The GEP verifier ensures that any index into structs are static and + // that they refer to a field within the struct. + selectedType = cast(selectedType) + .getTypeAtIndex(cast(index)); } // When there are no more indices, the type currently being used for indexing Index: mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp =================================================================== --- mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp +++ mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp @@ -349,7 +349,9 @@ Type reachedType = getResultPtrElementType(); if (!reachedType || getIndices().size() < 2) return false; - auto firstLevelIndex = cast(getIndices()[1]); + auto firstLevelIndex = dyn_cast(getIndices()[1]); + if (!firstLevelIndex) + return false; assert(slot.elementPtrs.contains(firstLevelIndex)); if (!llvm::isa(slot.elementPtrs.at(firstLevelIndex))) return false; Index: mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp =================================================================== --- mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp +++ mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp @@ -565,6 +565,46 @@ return success(); } +LogicalResult SplitGEP::matchAndRewrite(GEPOp gepOp, + PatternRewriter &rewriter) const { + FailureOr typeHint = getRequiredConsistentGEPType(gepOp); + if (succeeded(typeHint) || gepOp.getIndices().size() <= 2) { + // GEP is not canonical or a single aggregate deep, nothing to do here. + return failure(); + } + + auto indexToGEPArg = + [](GEPIndicesAdaptor::value_type index) -> GEPArg { + if (auto integerAttr = dyn_cast(index)) + return integerAttr.getValue().getSExtValue(); + return cast(index); + }; + + GEPIndicesAdaptor indices = gepOp.getIndices(); + + auto splitIter = std::next(indices.begin(), 2); + + // Split of the first GEP using the first two indices. + auto subGepOp = rewriter.create( + gepOp.getLoc(), gepOp.getType(), gepOp.getSourceElementType(), + gepOp.getBase(), + llvm::map_to_vector(llvm::make_range(indices.begin(), splitIter), + indexToGEPArg), + gepOp.getInbounds()); + + // The second GEP indexes on the result pointer element type of the previous + // with all the remaining indices and a zero upfront. If this GEP has more + // than two indices remaining it'll be further split in subsequent pattern + // applications. + SmallVector newIndices = {0}; + llvm::transform(llvm::make_range(splitIter, indices.end()), + std::back_inserter(newIndices), indexToGEPArg); + rewriter.replaceOpWithNewOp(gepOp, gepOp.getType(), + subGepOp.getResultPtrElementType(), + subGepOp, newIndices, gepOp.getInbounds()); + return success(); +} + //===----------------------------------------------------------------------===// // Type consistency pass //===----------------------------------------------------------------------===// @@ -580,6 +620,7 @@ rewritePatterns.add(&getContext()); rewritePatterns.add(&getContext(), maxVectorSplitSize); rewritePatterns.add(&getContext()); + rewritePatterns.add(&getContext()); FrozenRewritePatternSet frozen(std::move(rewritePatterns)); if (failed(applyPatternsAndFoldGreedily(getOperation(), frozen))) Index: mlir/test/Dialect/LLVMIR/type-consistency.mlir =================================================================== --- mlir/test/Dialect/LLVMIR/type-consistency.mlir +++ mlir/test/Dialect/LLVMIR/type-consistency.mlir @@ -433,3 +433,63 @@ // CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]] llvm.return } + +// ----- + +// CHECK-LABEL: llvm.func @gep_split +// CHECK-SAME: %[[ARG:.*]]: i64 +llvm.func @gep_split(%arg: i64) { + %0 = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.array<2 x struct<"foo", (i64)>> + %1 = llvm.alloca %0 x !llvm.array<2 x struct<"foo", (i64)>> : (i32) -> !llvm.ptr + %3 = llvm.getelementptr %1[0, 1, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<2 x struct<"foo", (i64)>> + // CHECK: %[[TOP_GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<2 x struct<"foo", (i64)>> + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[TOP_GEP]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i64)> + // CHECK: llvm.store %[[ARG]], %[[GEP]] + llvm.store %arg, %3 : i64, !llvm.ptr + // CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]] + llvm.return +} + +// ----- + +// CHECK-LABEL: llvm.func @coalesced_store_ints_subaggregate +// CHECK-SAME: %[[ARG:.*]]: i64 +llvm.func @coalesced_store_ints_subaggregate(%arg: i64) { + // CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64 + // CHECK: %[[CST32:.*]] = llvm.mlir.constant(32 : i64) : i64 + %0 = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i64, struct<(i32, i32)>)> + %1 = llvm.alloca %0 x !llvm.struct<"foo", (i64, struct<(i32, i32)>)> : (i32) -> !llvm.ptr + %3 = llvm.getelementptr %1[0, 1, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i64, struct<(i32, i32)>)> + + // CHECK: %[[TOP_GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i64, struct<(i32, i32)>)> + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[TOP_GEP]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(i32, i32)> + // CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST0]] + // CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i32 + // CHECK: llvm.store %[[TRUNC]], %[[GEP]] + // CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST32]] : i64 + // CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i32 + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[TOP_GEP]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(i32, i32)> + // CHECK: llvm.store %[[TRUNC]], %[[GEP]] + llvm.store %arg, %3 : i64, !llvm.ptr + // CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]] + llvm.return +} + +// ----- + +// CHECK-LABEL: llvm.func @gep_result_ptr_type_dynamic +// CHECK-SAME: %[[ARG:.*]]: i64 +llvm.func @gep_result_ptr_type_dynamic(%arg: i64) { + %0 = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.array<2 x struct<"foo", (i64)>> + %1 = llvm.alloca %0 x !llvm.array<2 x struct<"foo", (i64)>> : (i32) -> !llvm.ptr + %3 = llvm.getelementptr %1[0, %arg, 0] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.array<2 x struct<"foo", (i64)>> + // CHECK: %[[TOP_GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, %[[ARG]]] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.array<2 x struct<"foo", (i64)>> + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[TOP_GEP]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i64)> + // CHECK: llvm.store %[[ARG]], %[[GEP]] + llvm.store %arg, %3 : i64, !llvm.ptr + // CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]] + llvm.return +}