diff --git a/mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h b/mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h --- a/mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h +++ b/mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h @@ -53,6 +53,18 @@ PatternRewriter &rewriter) const override; }; +/// Splits stores of integers which write into multiple adjacent stores +/// of a pointer. The integer is then split and stores are generated for +/// every field being stored in a type-consistent manner. +/// This is currently done on a best-effort basis. +class SplitIntegerStores : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(StoreOp store, + PatternRewriter &rewrite) const override; +}; + } // namespace LLVM } // namespace mlir diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp --- a/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp +++ b/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp @@ -168,11 +168,20 @@ /// Returns the amount of bytes the provided GEP elements will offset the /// pointer by. Returns nullopt if the offset could not be computed. -static std::optional gepToByteOffset(DataLayout &layout, Type base, - ArrayRef indices) { - uint64_t offset = indices[0] * layout.getTypeSize(base); +static std::optional gepToByteOffset(DataLayout &layout, GEPOp gep) { - Type currentType = base; + SmallVector indices; + // Ensures all indices are static and fetches them. + for (auto index : gep.getIndices()) { + IntegerAttr indexInt = llvm::dyn_cast_if_present(index); + if (!indexInt) + return std::nullopt; + indices.push_back(indexInt.getInt()); + } + + uint64_t offset = indices[0] * layout.getTypeSize(gep.getSourceElementType()); + + Type currentType = gep.getSourceElementType(); for (uint32_t index : llvm::drop_begin(indices)) { bool shouldCancel = TypeSwitch(currentType) @@ -302,9 +311,9 @@ return success(); } -LogicalResult -CanonicalizeAlignedGep::matchAndRewrite(GEPOp gep, - PatternRewriter &rewriter) const { +/// Returns the consistent type for the GEP if the GEP is not type-consistent. +/// Returns failure if the GEP is already consistent. +static FailureOr getRequiredConsistentGEPType(GEPOp gep) { // GEP of typed pointers are not supported. if (!gep.getElemType()) return failure(); @@ -317,34 +326,185 @@ Type typeHint = isElementTypeInconsistent(gep.getBase(), baseType); if (!typeHint) return failure(); + return typeHint; +} - SmallVector indices; - // Ensures all indices are static and fetches them. - for (auto index : gep.getIndices()) { - IntegerAttr indexInt = llvm::dyn_cast_if_present(index); - if (!indexInt) - return failure(); - indices.push_back(indexInt.getInt()); +LogicalResult +CanonicalizeAlignedGep::matchAndRewrite(GEPOp gep, + PatternRewriter &rewriter) const { + FailureOr typeHint = getRequiredConsistentGEPType(gep); + if (failed(typeHint)) { + // GEP is already canonical, nothing to do here. + return failure(); } DataLayout layout = DataLayout::closest(gep); - std::optional desiredOffset = - gepToByteOffset(layout, gep.getSourceElementType(), indices); + std::optional desiredOffset = gepToByteOffset(layout, gep); if (!desiredOffset) return failure(); SmallVector newIndices; if (failed( - findIndicesForOffset(layout, typeHint, *desiredOffset, newIndices))) + findIndicesForOffset(layout, *typeHint, *desiredOffset, newIndices))) return failure(); rewriter.replaceOpWithNewOp( - gep, LLVM::LLVMPointerType::get(getContext()), typeHint, gep.getBase(), + gep, LLVM::LLVMPointerType::get(getContext()), *typeHint, gep.getBase(), newIndices, gep.getInbounds()); return success(); } +/// Returns the list of fields of `structType` that are written to by a store +/// operation writing `storeSize` bytes at `storeOffset` within the struct. +/// `storeOffset` is required to cleanly point to an immediate field within +/// the struct. +/// If the write operation were to write to any padding, write beyond the +/// struct, partially write to a field, or contains currently unsupported +/// types, failure is returned. +static FailureOr> +getWrittenToFields(const DataLayout &dataLayout, LLVMStructType structType, + int storeSize, unsigned storeOffset) { + ArrayRef body = structType.getBody(); + unsigned currentOffset = 0; + body = body.drop_until([&](Type type) { + if (!structType.isPacked()) { + unsigned alignment = dataLayout.getTypeABIAlignment(type); + currentOffset = llvm::alignTo(currentOffset, alignment); + } + + // currentOffset is guaranteed to be equal to offset since offset is either + // 0 or stems from a type-consistent GEP indexing into just a single + // aggregate. + if (currentOffset == storeOffset) + return true; + + assert(currentOffset < storeOffset && + "storeOffset should cleanly point into an immediate field"); + + currentOffset += dataLayout.getTypeSize(type); + return false; + }); + + size_t exclusiveEnd = 0; + for (; exclusiveEnd < body.size() && storeSize > 0; exclusiveEnd++) { + // Not yet recursively handling aggregates, only primitives. + if (!isa(body[exclusiveEnd])) + return failure(); + + if (!structType.isPacked()) { + unsigned alignment = dataLayout.getTypeABIAlignment(body[exclusiveEnd]); + // No padding allowed inbetween fields at this point in time. + if (!llvm::isAligned(llvm::Align(alignment), currentOffset)) + return failure(); + } + + unsigned fieldSize = dataLayout.getTypeSize(body[exclusiveEnd]); + currentOffset += fieldSize; + storeSize -= fieldSize; + } + + // If the storeSize is not 0 at this point we are either partially writing + // into a field or writing past the aggregate as a whole. Abort. + if (storeSize != 0) + return failure(); + return body.take_front(exclusiveEnd); +} + +LogicalResult +SplitIntegerStores::matchAndRewrite(StoreOp store, + PatternRewriter &rewriter) const { + IntegerType sourceType = dyn_cast(store.getValue().getType()); + if (!sourceType) { + // We currently only support integer sources. + return failure(); + } + + Type typeHint = isElementTypeInconsistent(store.getAddr(), sourceType); + if (!typeHint) { + // Nothing to do, since it is already consistent. + return failure(); + } + + auto dataLayout = DataLayout::closest(store); + + unsigned offset = 0; + Value address = store.getAddr(); + if (auto gepOp = address.getDefiningOp()) { + // Currently only handle canonical GEPs with exactly two indices, + // indexing a single aggregate deep. + // Recursing into sub-structs is left as a future exercise. + // If the GEP is not canonical we have to fail, otherwise we would not + // create type-consistent IR. + if (gepOp.getIndices().size() != 2 || + succeeded(getRequiredConsistentGEPType(gepOp))) + return failure(); + + // A GEP might point somewhere into the middle of an aggregate with the + // store storing into multiple adjacent elements. Destructure into + // the base address with an offset. + std::optional byteOffset = gepToByteOffset(dataLayout, gepOp); + if (!byteOffset) + return failure(); + + offset = *byteOffset; + typeHint = gepOp.getSourceElementType(); + address = gepOp.getBase(); + } + + auto structType = typeHint.dyn_cast(); + if (!structType) { + // TODO: Handle array types in the future. + return failure(); + } + + FailureOr> writtenToFields = + getWrittenToFields(dataLayout, structType, + /*storeSize=*/dataLayout.getTypeSize(sourceType), + /*storeOffset=*/offset); + if (failed(writtenToFields)) + return failure(); + + unsigned currentOffset = offset; + for (Type type : *writtenToFields) { + unsigned fieldSize = dataLayout.getTypeSize(type); + + // Extract the data out of the integer by first shifting right and then + // truncating it. + auto pos = rewriter.create( + store.getLoc(), + rewriter.getIntegerAttr(sourceType, (currentOffset - offset) * 8)); + + auto shrOp = rewriter.create(store.getLoc(), store.getValue(), pos); + + IntegerType fieldIntType = rewriter.getIntegerType(fieldSize * 8); + Value valueToStore = + rewriter.create(store.getLoc(), fieldIntType, shrOp); + if (fieldIntType != type) { + // Bitcast to the right type. `fieldIntType` was explicitly created + // to be of the same size as `type` and must currently be a primitive as + // well. + valueToStore = + rewriter.create(store.getLoc(), type, valueToStore); + } + + // We create an `i8` indexed GEP here as that is the easiest (offset is + // already known). Other patterns turn this into a type-consistent GEP. + auto gepOp = rewriter.create(store.getLoc(), address.getType(), + rewriter.getI8Type(), address, + ArrayRef{currentOffset}); + rewriter.create(store.getLoc(), valueToStore, gepOp); + + // No need to care about padding here since we already checked previously + // that no padding exists in this range. + currentOffset += fieldSize; + } + + rewriter.eraseOp(store); + + return success(); +} + //===----------------------------------------------------------------------===// // Type consistency pass //===----------------------------------------------------------------------===// @@ -358,6 +518,7 @@ rewritePatterns.add>( &getContext()); rewritePatterns.add(&getContext()); + rewritePatterns.add(&getContext()); FrozenRewritePatternSet frozen(std::move(rewritePatterns)); if (failed(applyPatternsAndFoldGreedily(getOperation(), frozen))) diff --git a/mlir/test/Dialect/LLVMIR/type-consistency.mlir b/mlir/test/Dialect/LLVMIR/type-consistency.mlir --- a/mlir/test/Dialect/LLVMIR/type-consistency.mlir +++ b/mlir/test/Dialect/LLVMIR/type-consistency.mlir @@ -148,3 +148,155 @@ llvm.store %arg, %7 : i32, !llvm.ptr llvm.return } + +// ----- + +// CHECK-LABEL: llvm.func @coalesced_store_ints +// CHECK-SAME: %[[ARG:.*]]: i64 +llvm.func @coalesced_store_ints(%arg: i64) { + // CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64 + // CHECK: %[[CST32:.*]] = llvm.mlir.constant(32 : i64) : i64 + + %0 = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i32, i32)> + %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32)> : (i32) -> !llvm.ptr + + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32)> + // CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST0]] + // CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i32 + // CHECK: llvm.store %[[TRUNC]], %[[GEP]] + // CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST32]] : i64 + // CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i32 + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32)> + // CHECK: llvm.store %[[TRUNC]], %[[GEP]] + llvm.store %arg, %1 : i64, !llvm.ptr + // CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]] + llvm.return +} + +// ----- + +// CHECK-LABEL: llvm.func @coalesced_store_ints_offset +// CHECK-SAME: %[[ARG:.*]]: i64 +llvm.func @coalesced_store_ints_offset(%arg: i64) { + // CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64 + // CHECK: %[[CST32:.*]] = llvm.mlir.constant(32 : i64) : i64 + %0 = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i64, i32, i32)> + %1 = llvm.alloca %0 x !llvm.struct<"foo", (i64, i32, i32)> : (i32) -> !llvm.ptr + %3 = llvm.getelementptr %1[0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i64, i32, i32)> + + // CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST0]] + // CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i32 + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i64, i32, i32)> + // CHECK: llvm.store %[[TRUNC]], %[[GEP]] + // CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST32]] : i64 + // CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i32 + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i64, i32, i32)> + // CHECK: llvm.store %[[TRUNC]], %[[GEP]] + llvm.store %arg, %3 : i64, !llvm.ptr + // CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]] + llvm.return +} + +// ----- + +// CHECK-LABEL: llvm.func @coalesced_store_floats +// CHECK-SAME: %[[ARG:.*]]: i64 +llvm.func @coalesced_store_floats(%arg: i64) { + // CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64 + // CHECK: %[[CST32:.*]] = llvm.mlir.constant(32 : i64) : i64 + %0 = llvm.mlir.constant(1 : i32) : i32 + + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (f32, f32)> + %1 = llvm.alloca %0 x !llvm.struct<"foo", (f32, f32)> : (i32) -> !llvm.ptr + + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (f32, f32)> + // CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST0]] + // CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i32 + // CHECK: %[[BIT_CAST:.*]] = llvm.bitcast %[[TRUNC]] : i32 to f32 + // CHECK: llvm.store %[[BIT_CAST]], %[[GEP]] + // CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST32]] : i64 + // CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i32 + // CHECK: %[[BIT_CAST:.*]] = llvm.bitcast %[[TRUNC]] : i32 to f32 + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (f32, f32)> + // CHECK: llvm.store %[[BIT_CAST]], %[[GEP]] + llvm.store %arg, %1 : i64, !llvm.ptr + // CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]] + llvm.return +} + +// ----- + +// Padding test purposefully not modified. + +// CHECK-LABEL: llvm.func @coalesced_store_padding_inbetween +// CHECK-SAME: %[[ARG:.*]]: i64 +llvm.func @coalesced_store_padding_inbetween(%arg: i64) { + %0 = llvm.mlir.constant(1 : i32) : i32 + + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i16, i32)> + %1 = llvm.alloca %0 x !llvm.struct<"foo", (i16, i32)> : (i32) -> !llvm.ptr + // CHECK: llvm.store %[[ARG]], %[[ALLOCA]] + llvm.store %arg, %1 : i64, !llvm.ptr + llvm.return +} + +// ----- + +// Padding test purposefully not modified. + +// CHECK-LABEL: llvm.func @coalesced_store_padding_end +// CHECK-SAME: %[[ARG:.*]]: i64 +llvm.func @coalesced_store_padding_end(%arg: i64) { + %0 = llvm.mlir.constant(1 : i32) : i32 + + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i32, i16)> + %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i16)> : (i32) -> !llvm.ptr + // CHECK: llvm.store %[[ARG]], %[[ALLOCA]] + llvm.store %arg, %1 : i64, !llvm.ptr + llvm.return +} + +// ----- + +// CHECK-LABEL: llvm.func @coalesced_store_past_end +// CHECK-SAME: %[[ARG:.*]]: i64 +llvm.func @coalesced_store_past_end(%arg: i64) { + %0 = llvm.mlir.constant(1 : i32) : i32 + + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i32)> + %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32)> : (i32) -> !llvm.ptr + // CHECK: llvm.store %[[ARG]], %[[ALLOCA]] + llvm.store %arg, %1 : i64, !llvm.ptr + llvm.return +} + +// ----- + +// CHECK-LABEL: llvm.func @coalesced_store_packed_struct +// CHECK-SAME: %[[ARG:.*]]: i64 +llvm.func @coalesced_store_packed_struct(%arg: i64) { + %0 = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64 + // CHECK: %[[CST16:.*]] = llvm.mlir.constant(16 : i64) : i64 + // CHECK: %[[CST48:.*]] = llvm.mlir.constant(48 : i64) : i64 + + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", packed (i16, i32, i16)> + %1 = llvm.alloca %0 x !llvm.struct<"foo", packed (i16, i32, i16)> : (i32) -> !llvm.ptr + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", packed (i16, i32, i16)> + // CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST0]] + // CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i16 + // CHECK: llvm.store %[[TRUNC]], %[[GEP]] + // CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST16]] + // CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i32 + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", packed (i16, i32, i16)> + // CHECK: llvm.store %[[TRUNC]], %[[GEP]] + // CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST48]] + // CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i16 + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", packed (i16, i32, i16)> + // CHECK: llvm.store %[[TRUNC]], %[[GEP]] + llvm.store %arg, %1 : i64, !llvm.ptr + // CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]] + llvm.return +}