Index: mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h =================================================================== --- mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h +++ mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h @@ -53,6 +53,18 @@ PatternRewriter &rewriter) const override; }; +/// Splits stores of integers which write into multiple adjacent stores +/// of a pointer. The integer is then split and stores are generated for +/// every field being stored in a type-consistent manner. +/// This is currently done on a best-effort basis. +class SplitIntegerStores : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(StoreOp store, + PatternRewriter &rewrite) const override; +}; + } // namespace LLVM } // namespace mlir Index: mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp =================================================================== --- mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp +++ mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp @@ -168,11 +168,20 @@ /// Returns the amount of bytes the provided GEP elements will offset the /// pointer by. Returns nullopt if the offset could not be computed. -static std::optional gepToByteOffset(DataLayout &layout, Type base, - ArrayRef indices) { - uint64_t offset = indices[0] * layout.getTypeSize(base); +static std::optional gepToByteOffset(DataLayout &layout, GEPOp gep) { - Type currentType = base; + SmallVector indices; + // Ensures all indices are static and fetches them. + for (auto index : gep.getIndices()) { + IntegerAttr indexInt = llvm::dyn_cast_if_present(index); + if (!indexInt) + return std::nullopt; + indices.push_back(indexInt.getInt()); + } + + uint64_t offset = indices[0] * layout.getTypeSize(gep.getSourceElementType()); + + Type currentType = gep.getSourceElementType(); for (uint32_t index : llvm::drop_begin(indices)) { bool shouldCancel = TypeSwitch(currentType) @@ -302,34 +311,34 @@ return success(); } -LogicalResult -CanonicalizeAlignedGep::matchAndRewrite(GEPOp gep, - PatternRewriter &rewriter) const { +static LogicalResult gepIsCanonical(GEPOp gep, Type *consistentType = nullptr) { // GEP of typed pointers are not supported. if (!gep.getElemType()) - return failure(); + return success(); std::optional maybeBaseType = gep.getElemType(); if (!maybeBaseType) - return failure(); + return success(); Type baseType = *maybeBaseType; Type typeHint = isElementTypeInconsistent(gep.getBase(), baseType); if (!typeHint) - return failure(); + return success(); + if (consistentType) + *consistentType = typeHint; + return failure(); +} - SmallVector indices; - // Ensures all indices are static and fetches them. - for (auto index : gep.getIndices()) { - IntegerAttr indexInt = llvm::dyn_cast_if_present(index); - if (!indexInt) - return failure(); - indices.push_back(indexInt.getInt()); - } +LogicalResult +CanonicalizeAlignedGep::matchAndRewrite(GEPOp gep, + PatternRewriter &rewriter) const { + Type typeHint; + if (succeeded(gepIsCanonical(gep, &typeHint))) + // GEP is already canonical, nothing to do here. + return failure(); DataLayout layout = DataLayout::closest(gep); - std::optional desiredOffset = - gepToByteOffset(layout, gep.getSourceElementType(), indices); + std::optional desiredOffset = gepToByteOffset(layout, gep); if (!desiredOffset) return failure(); @@ -345,6 +354,145 @@ return success(); } +static FailureOr +endOfStoredToPrimitives(const DataLayout &dataLayout, const Type *begin, + const Type *end, int storeSize, unsigned currentOffset, + bool packed) { + const Type *exclusiveEnd = begin; + for (; exclusiveEnd != end && storeSize > 0; exclusiveEnd++) { + // Not yet recursively handling aggregates, only primitives. + if (!isa(*exclusiveEnd)) + return failure(); + + if (!packed) { + unsigned alignment = dataLayout.getTypeABIAlignment(*exclusiveEnd); + // No padding allowed inbetween fields at this point in time. + if (!llvm::isAligned(llvm::Align(alignment), currentOffset)) + return failure(); + } + + unsigned int fieldSize = dataLayout.getTypeSize(*exclusiveEnd); + currentOffset += fieldSize; + storeSize -= fieldSize; + } + + // If the storeSize is not 0 at this point we are either partially writing + // into a field or writing past the aggregate as a whole. Abort. + if (storeSize != 0) + return failure(); + return exclusiveEnd; +} + +LogicalResult +SplitIntegerStores::matchAndRewrite(StoreOp store, + PatternRewriter &rewriter) const { + IntegerType sourceType = dyn_cast(store.getValue().getType()); + if (!sourceType) + // We currently only support integer sources. + return failure(); + + Type typeHint = isElementTypeInconsistent(store.getAddr(), sourceType); + if (!typeHint) + // Nothing to do, since it is already consistent. + return failure(); + + auto dataLayout = DataLayout::closest(store); + + unsigned offset = 0; + Value address = store.getAddr(); + if (auto gepOp = address.getDefiningOp()) { + // Currently only handle canonical GEPs with exactly two indices, + // indexing a single aggregate deep. + // Recursing into sub-structs is left as a future exercise. + // If the GEP is not canonical we have to fail, otherwise we would not + // create type-consistent IR. + if (gepOp.getIndices().size() != 2 || failed(gepIsCanonical(gepOp))) + return failure(); + + // A GEP might point somewhere into the middle of an aggregate with the + // store storing into multiple adjacent elements. Destructure into + // the base address with an offset. + std::optional byteOffset = gepToByteOffset(dataLayout, gepOp); + if (!byteOffset) + return failure(); + + offset = *byteOffset; + typeHint = gepOp.getSourceElementType(); + address = gepOp.getBase(); + } + + auto structType = typeHint.dyn_cast(); + if (!structType) { + // TODO: Handle array types in the future. + return failure(); + } + + ArrayRef body = structType.getBody(); + // Currently we only handle stores that nicely write into adjacent fields of + // primitives. + // TODO: write into sub-aggregates by recursively splitting + unsigned currentOffset = 0; + const Type *field = llvm::find_if(body, [&](Type t) { + if (!structType.isPacked()) { + unsigned alignment = dataLayout.getTypeABIAlignment(t); + currentOffset = llvm::alignTo(currentOffset, alignment); + } + + // currentOffset is guaranteed to be equal to offset since offset is either + // 0 or stems from a type-consistent GEP indexing into just a single + // aggregate. + if (currentOffset == offset) + return true; + + currentOffset += dataLayout.getTypeSize(t); + return false; + }); + + FailureOr endField = + endOfStoredToPrimitives(dataLayout, field, body.end(), + /*storeSize=*/dataLayout.getTypeSize(sourceType), + currentOffset, structType.isPacked()); + if (failed(endField)) + return failure(); + + for (Type t : llvm::make_range(field, *endField)) { + unsigned fieldSize = dataLayout.getTypeSize(t); + + // Extract the data out of the integer by first shifting right and then + // truncating it. + auto pos = rewriter.create( + store.getLoc(), + rewriter.getIntegerAttr(sourceType, (currentOffset - offset) * 8)); + + auto lshr = rewriter.create(store.getLoc(), store.getValue(), pos); + + IntegerType fieldSizedInteger = rewriter.getIntegerType(fieldSize * 8); + Value valueToStore = + rewriter.create(store.getLoc(), fieldSizedInteger, lshr); + if (fieldSizedInteger != t) + // Bitcast to the right type. 'fieldSizedInteger' was explicitly created + // to be of the same size as 't' and must currently be a primitive as + // well. + valueToStore = + rewriter.create(store.getLoc(), t, valueToStore); + + // We create an 'i8' indexed GEP here as that is the easiest (offset is + // already known). Other patterns turn this into a type-consistent GEP. + auto gep = rewriter.create(store.getLoc(), address.getType(), + rewriter.getI8Type(), address, + ArrayRef{currentOffset}); + rewriter.create(store.getLoc(), valueToStore, gep); + + // No need to care about padding here since we already checked previously + // that no padding exists in this range. + currentOffset += fieldSize; + } + + rewriter.eraseOp(store); + + return success(); +} + //===----------------------------------------------------------------------===// // Type consistency pass //===----------------------------------------------------------------------===// @@ -358,6 +506,7 @@ rewritePatterns.add>( &getContext()); rewritePatterns.add(&getContext()); + rewritePatterns.add(&getContext()); FrozenRewritePatternSet frozen(std::move(rewritePatterns)); if (failed(applyPatternsAndFoldGreedily(getOperation(), frozen))) Index: mlir/test/Dialect/LLVMIR/type-consistency.mlir =================================================================== --- mlir/test/Dialect/LLVMIR/type-consistency.mlir +++ mlir/test/Dialect/LLVMIR/type-consistency.mlir @@ -148,3 +148,126 @@ llvm.store %arg, %7 : i32, !llvm.ptr llvm.return } + +// ----- + +// CHECK-LABEL: llvm.func @coalesced_store_ints +// CHECK-SAME: %[[ARG:.*]]: i64 +llvm.func @coalesced_store_ints(%arg: i64) { + // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i64) : i64 + // CHECK: %[[THIRTY_TWO:.*]] = llvm.mlir.constant(32 : i64) : i64 + + %0 = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i32, i32)> + %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32)> : (i32) -> !llvm.ptr + + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32)> + // CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[ZERO]] + // CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i32 + // CHECK: llvm.store %[[TRUNC]], %[[GEP]] + // CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[THIRTY_TWO]] : i64 + // CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i32 + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32)> + // CHECK: llvm.store %[[TRUNC]], %[[GEP]] + llvm.store %arg, %1 : i64, !llvm.ptr + // CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]] + llvm.return +} + +// ----- + +// CHECK-LABEL: llvm.func @coalesced_store_ints_offset +// CHECK-SAME: %[[ARG:.*]]: i64 +llvm.func @coalesced_store_ints_offset(%arg: i64) { + // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i64) : i64 + // CHECK: %[[THIRTY_TWO:.*]] = llvm.mlir.constant(32 : i64) : i64 + %0 = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i64, i32, i32)> + %1 = llvm.alloca %0 x !llvm.struct<"foo", (i64, i32, i32)> : (i32) -> !llvm.ptr + %3 = llvm.getelementptr %1[0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i64, i32, i32)> + + // CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[ZERO]] + // CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i32 + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i64, i32, i32)> + // CHECK: llvm.store %[[TRUNC]], %[[GEP]] + // CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[THIRTY_TWO]] : i64 + // CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i32 + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i64, i32, i32)> + // CHECK: llvm.store %[[TRUNC]], %[[GEP]] + llvm.store %arg, %3 : i64, !llvm.ptr + // CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]] + llvm.return +} + +// ----- + +// CHECK-LABEL: llvm.func @coalesced_store_floats +// CHECK-SAME: %[[ARG:.*]]: i64 +llvm.func @coalesced_store_floats(%arg: i64) { + // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i64) : i64 + // CHECK: %[[THIRTY_TWO:.*]] = llvm.mlir.constant(32 : i64) : i64 + %0 = llvm.mlir.constant(1 : i32) : i32 + + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (f32, f32)> + %1 = llvm.alloca %0 x !llvm.struct<"foo", (f32, f32)> : (i32) -> !llvm.ptr + + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (f32, f32)> + // CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[ZERO]] + // CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i32 + // CHECK: %[[BIT_CAST:.*]] = llvm.bitcast %[[TRUNC]] : i32 to f32 + // CHECK: llvm.store %[[BIT_CAST]], %[[GEP]] + // CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[THIRTY_TWO]] : i64 + // CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i32 + // CHECK: %[[BIT_CAST:.*]] = llvm.bitcast %[[TRUNC]] : i32 to f32 + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (f32, f32)> + // CHECK: llvm.store %[[BIT_CAST]], %[[GEP]] + llvm.store %arg, %1 : i64, !llvm.ptr + // CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]] + llvm.return +} + +// ----- + +// Padding test purposefully not modified. + +// CHECK-LABEL: llvm.func @coalesced_store_padding_inbetween +// CHECK-SAME: %[[ARG:.*]]: i64 +llvm.func @coalesced_store_padding_inbetween(%arg: i64) { + %0 = llvm.mlir.constant(1 : i32) : i32 + + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i16, i32)> + %1 = llvm.alloca %0 x !llvm.struct<"foo", (i16, i32)> : (i32) -> !llvm.ptr + // CHECK: llvm.store %[[ARG]], %[[ALLOCA]] + llvm.store %arg, %1 : i64, !llvm.ptr + llvm.return +} + +// ----- + +// Padding test purposefully not modified. + +// CHECK-LABEL: llvm.func @coalesced_store_padding_end +// CHECK-SAME: %[[ARG:.*]]: i64 +llvm.func @coalesced_store_padding_end(%arg: i64) { + %0 = llvm.mlir.constant(1 : i32) : i32 + + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i32, i16)> + %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i16)> : (i32) -> !llvm.ptr + // CHECK: llvm.store %[[ARG]], %[[ALLOCA]] + llvm.store %arg, %1 : i64, !llvm.ptr + llvm.return +} + +// ----- + +// CHECK-LABEL: llvm.func @coalesced_store_past_end +// CHECK-SAME: %[[ARG:.*]]: i64 +llvm.func @coalesced_store_past_end(%arg: i64) { + %0 = llvm.mlir.constant(1 : i32) : i32 + + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i32)> + %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32)> : (i32) -> !llvm.ptr + // CHECK: llvm.store %[[ARG]], %[[ALLOCA]] + llvm.store %arg, %1 : i64, !llvm.ptr + llvm.return +}