diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp --- a/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp +++ b/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp @@ -357,7 +357,7 @@ /// types, failure is returned. static FailureOr> getWrittenToFields(const DataLayout &dataLayout, LLVMStructType structType, - int storeSize, unsigned storeOffset) { + unsigned storeSize, unsigned storeOffset) { ArrayRef body = structType.getBody(); unsigned currentOffset = 0; body = body.drop_until([&](Type type) { @@ -381,10 +381,6 @@ size_t exclusiveEnd = 0; for (; exclusiveEnd < body.size() && storeSize > 0; exclusiveEnd++) { - // Not yet recursively handling aggregates, only primitives. - if (!isa(body[exclusiveEnd])) - return failure(); - if (!structType.isPacked()) { unsigned alignment = dataLayout.getTypeABIAlignment(body[exclusiveEnd]); // No padding allowed inbetween fields at this point in time. @@ -393,13 +389,29 @@ } unsigned fieldSize = dataLayout.getTypeSize(body[exclusiveEnd]); + if (fieldSize > storeSize) { + // Partial writes into an aggregate are okay since subsequent pattern + // applications can further split these up into writes into the + // sub-elements. + auto subStruct = dyn_cast(body[exclusiveEnd]); + if (!subStruct) + return failure(); + + // Avoid splitting redundantly by making sure the store into the struct + // can actually be split. + if (failed(getWrittenToFields(dataLayout, subStruct, storeSize, + /*storeOffset=*/0))) + return failure(); + + return body.take_front(exclusiveEnd + 1); + } currentOffset += fieldSize; storeSize -= fieldSize; } - // If the storeSize is not 0 at this point we are either partially writing - // into a field or writing past the aggregate as a whole. Abort. - if (storeSize != 0) + // If the storeSize is not 0 at this point we are writing past the aggregate + // as a whole. Abort. + if (storeSize > 0) return failure(); return body.take_front(exclusiveEnd); } @@ -435,7 +447,8 @@ /// type-consistent. static void splitIntegerStore(const DataLayout &dataLayout, Location loc, RewriterBase &rewriter, Value address, - Value value, unsigned storeOffset, + Value value, unsigned storeSize, + unsigned storeOffset, ArrayRef writtenToFields) { unsigned currentOffset = storeOffset; for (Type type : writtenToFields) { @@ -449,7 +462,12 @@ auto shrOp = rewriter.create(loc, value, pos); - IntegerType fieldIntType = rewriter.getIntegerType(fieldSize * 8); + // If we are doing a partial write into a direct field the remaining + // `storeSize` will be less than the size of the field. We have to truncate + // to the `storeSize` to avoid creating a store that wasn't in the original + // code. + IntegerType fieldIntType = + rewriter.getIntegerType(std::min(fieldSize, storeSize) * 8); Value valueToStore = rewriter.create(loc, fieldIntType, shrOp); // We create an `i8` indexed GEP here as that is the easiest (offset is @@ -462,6 +480,7 @@ // No need to care about padding here since we already checked previously // that no padding exists in this range. currentOffset += fieldSize; + storeSize -= fieldSize; } } @@ -481,28 +500,31 @@ auto dataLayout = DataLayout::closest(store); + unsigned storeSize = dataLayout.getTypeSize(sourceType); unsigned offset = 0; Value address = store.getAddr(); if (auto gepOp = address.getDefiningOp()) { // Currently only handle canonical GEPs with exactly two indices, // indexing a single aggregate deep. - // Recursing into sub-structs is left as a future exercise. // If the GEP is not canonical we have to fail, otherwise we would not // create type-consistent IR. if (gepOp.getIndices().size() != 2 || succeeded(getRequiredConsistentGEPType(gepOp))) return failure(); - // A GEP might point somewhere into the middle of an aggregate with the - // store storing into multiple adjacent elements. Destructure into - // the base address with an offset. - std::optional byteOffset = gepToByteOffset(dataLayout, gepOp); - if (!byteOffset) - return failure(); + // If the size of the element indexed by the GEP is smaller than the store + // size, it is pointing into the middle of an aggregate with the store + // storing into multiple adjacent elements. Destructure into the base + // address of the aggregate with a store offset. + if (storeSize > dataLayout.getTypeSize(gepOp.getResultPtrElementType())) { + std::optional byteOffset = gepToByteOffset(dataLayout, gepOp); + if (!byteOffset) + return failure(); - offset = *byteOffset; - typeHint = gepOp.getSourceElementType(); - address = gepOp.getBase(); + offset = *byteOffset; + typeHint = gepOp.getSourceElementType(); + address = gepOp.getBase(); + } } auto structType = typeHint.dyn_cast(); @@ -512,9 +534,7 @@ } FailureOr> writtenToFields = - getWrittenToFields(dataLayout, structType, - /*storeSize=*/dataLayout.getTypeSize(sourceType), - /*storeOffset=*/offset); + getWrittenToFields(dataLayout, structType, storeSize, offset); if (failed(writtenToFields)) return failure(); @@ -526,7 +546,7 @@ if (isa(sourceType)) { splitIntegerStore(dataLayout, store.getLoc(), rewriter, address, - store.getValue(), offset, *writtenToFields); + store.getValue(), storeSize, offset, *writtenToFields); rewriter.eraseOp(store); return success(); } diff --git a/mlir/test/Dialect/LLVMIR/type-consistency.mlir b/mlir/test/Dialect/LLVMIR/type-consistency.mlir --- a/mlir/test/Dialect/LLVMIR/type-consistency.mlir +++ b/mlir/test/Dialect/LLVMIR/type-consistency.mlir @@ -493,3 +493,134 @@ // CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]] llvm.return } + +// ----- + +// CHECK-LABEL: llvm.func @overlapping_int_aggregate_store +// CHECK-SAME: %[[ARG:.*]]: i64 +llvm.func @overlapping_int_aggregate_store(%arg: i64) { + // CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64 + // CHECK: %[[CST16:.*]] = llvm.mlir.constant(16 : i64) : i64 + + %0 = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i16, struct<(i16, i16, i16)>)> + %1 = llvm.alloca %0 x !llvm.struct<"foo", (i16, struct<(i16, i16, i16)>)> : (i32) -> !llvm.ptr + + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i16, struct<(i16, i16, i16)>)> + // CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST0]] + // CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i16 + // CHECK: llvm.store %[[TRUNC]], %[[GEP]] + + // CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST16]] : i64 + // CHECK: [[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i48 + // CHECK: %[[TOP_GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i16, struct<(i16, i16, i16)>)> + + // Normal integer splitting of [[TRUNC]] follows: + + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[TOP_GEP]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(i16, i16, i16)> + // CHECK: llvm.store %{{.*}}, %[[GEP]] + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[TOP_GEP]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(i16, i16, i16)> + // CHECK: llvm.store %{{.*}}, %[[GEP]] + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[TOP_GEP]][0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(i16, i16, i16)> + // CHECK: llvm.store %{{.*}}, %[[GEP]] + + llvm.store %arg, %1 : i64, !llvm.ptr + // CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]] + llvm.return +} + +// ----- + +// CHECK-LABEL: llvm.func @overlapping_vector_aggregate_store +// CHECK-SAME: %[[ARG:.*]]: vector<4xi16> +llvm.func @overlapping_vector_aggregate_store(%arg: vector<4 x i16>) { + // CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[CST1:.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[CST2:.*]] = llvm.mlir.constant(2 : i32) : i32 + // CHECK: %[[CST3:.*]] = llvm.mlir.constant(3 : i32) : i32 + + %0 = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i16, struct<(i16, i16, i16)>)> + %1 = llvm.alloca %0 x !llvm.struct<"foo", (i16, struct<(i16, i16, i16)>)> : (i32) -> !llvm.ptr + + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i16, struct<(i16, i16, i16)>)> + // CHECK: %[[EXTRACT:.*]] = llvm.extractelement %[[ARG]][%[[CST0]] : i32] + // CHECK: llvm.store %[[EXTRACT]], %[[GEP]] + + // CHECK: %[[EXTRACT:.*]] = llvm.extractelement %[[ARG]][%[[CST1]] : i32] + // CHECK: %[[GEP0:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i16, struct<(i16, i16, i16)>)> + // CHECK: %[[GEP1:.*]] = llvm.getelementptr %[[GEP0]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(i16, i16, i16)> + // CHECK: llvm.store %[[EXTRACT]], %[[GEP1]] + + // CHECK: %[[EXTRACT:.*]] = llvm.extractelement %[[ARG]][%[[CST2]] : i32] + // CHECK: %[[GEP0:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i16, struct<(i16, i16, i16)>)> + // CHECK: %[[GEP1:.*]] = llvm.getelementptr %[[GEP0]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(i16, i16, i16)> + // CHECK: llvm.store %[[EXTRACT]], %[[GEP1]] + + // CHECK: %[[EXTRACT:.*]] = llvm.extractelement %[[ARG]][%[[CST3]] : i32] + // CHECK: %[[GEP0:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i16, struct<(i16, i16, i16)>)> + // CHECK: %[[GEP1:.*]] = llvm.getelementptr %[[GEP0]][0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(i16, i16, i16)> + // CHECK: llvm.store %[[EXTRACT]], %[[GEP1]] + + llvm.store %arg, %1 : vector<4 x i16>, !llvm.ptr + // CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]] + llvm.return +} + +// ----- + +// CHECK-LABEL: llvm.func @partially_overlapping_aggregate_store +// CHECK-SAME: %[[ARG:.*]]: i64 +llvm.func @partially_overlapping_aggregate_store(%arg: i64) { + // CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64 + // CHECK: %[[CST16:.*]] = llvm.mlir.constant(16 : i64) : i64 + + %0 = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i16, struct<(i16, i16, i16, i16)>)> + %1 = llvm.alloca %0 x !llvm.struct<"foo", (i16, struct<(i16, i16, i16, i16)>)> : (i32) -> !llvm.ptr + + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i16, struct<(i16, i16, i16, i16)>)> + // CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST0]] + // CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i16 + // CHECK: llvm.store %[[TRUNC]], %[[GEP]] + + // CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST16]] : i64 + // CHECK: [[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i48 + // CHECK: %[[TOP_GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i16, struct<(i16, i16, i16, i16)>)> + + // Normal integer splitting of [[TRUNC]] follows: + + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[TOP_GEP]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(i16, i16, i16, i16)> + // CHECK: llvm.store %{{.*}}, %[[GEP]] + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[TOP_GEP]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(i16, i16, i16, i16)> + // CHECK: llvm.store %{{.*}}, %[[GEP]] + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[TOP_GEP]][0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(i16, i16, i16, i16)> + // CHECK: llvm.store %{{.*}}, %[[GEP]] + + // It is important that there are no more stores at this point. + // Specifically a store into the fourth field of %[[TOP_GEP]] would + // incorrectly change the semantics of the code. + // CHECK-NOT: llvm.store %{{.*}}, %{{.*}} + + llvm.store %arg, %1 : i64, !llvm.ptr + + llvm.return +} + +// ----- + +// Here a split is undesirable since the store does a partial store into the field. + +// CHECK-LABEL: llvm.func @undesirable_overlapping_aggregate_store +// CHECK-SAME: %[[ARG:.*]]: i64 +llvm.func @undesirable_overlapping_aggregate_store(%arg: i64) { + %0 = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i32, i32, struct<(i64, i16, i16, i16)>)> + %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32, struct<(i64, i16, i16, i16)>)> : (i32) -> !llvm.ptr + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32, struct<(i64, i16, i16, i16)>)> + %2 = llvm.getelementptr %1[0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32, struct<(i64, i16, i16, i16)>)> + // CHECK: llvm.store %[[ARG]], %[[GEP]] + llvm.store %arg, %2 : i64, !llvm.ptr + + llvm.return +}