Index: mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp =================================================================== --- mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp +++ mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp @@ -348,20 +348,74 @@ return success(); } -/// Returns the list of fields of `structType` that are written to by a store -/// operation writing `storeSize` bytes at `storeOffset` within the struct. -/// `storeOffset` is required to cleanly point to an immediate field within -/// the struct. -/// If the write operation were to write to any padding, write beyond the -/// struct, partially write to a field, or contains currently unsupported -/// types, failure is returned. -static FailureOr> -getWrittenToFields(const DataLayout &dataLayout, LLVMStructType structType, +namespace { +/// Class abstracting over both array and struct types, turning each into ranges +/// of their sub-types. +class DestructurableTypeRange + : public llvm::indexed_accessor_range { + + using Base = llvm::indexed_accessor_range< + DestructurableTypeRange, DestructurableTypeInterface, Type, Type *, Type>; + +public: + using Base::Base; + + /// Constructs a DestructurableTypeRange from either a LLVMStructType or + /// LLVMArrayType. + explicit DestructurableTypeRange(DestructurableTypeInterface base) + : Base(base, 0, [&]() -> ptrdiff_t { + return TypeSwitch(base) + .Case([](LLVMStructType structType) { + return structType.getBody().size(); + }) + .Case([](LLVMArrayType arrayType) { + return arrayType.getNumElements(); + }) + .Default([](auto) -> ptrdiff_t { + llvm_unreachable( + "Only LLVMStructType or LLVMArrayType supported"); + }); + }()) {} + + /// Returns true if this is a range over a packed struct. + bool isPacked() const { + if (auto structType = dyn_cast(getBase())) + return structType.isPacked(); + return false; + } + +private: + static Type dereference(DestructurableTypeInterface base, ptrdiff_t index) { + // i32 chosen because the implementations of ArrayType and StructType + // specifically expect it to be 32 bit. They will fail otherwise. + Type result = base.getTypeAtIndex( + IntegerAttr::get(IntegerType::get(base.getContext(), 32), index)); + assert(result && "Should always succeed"); + return result; + } + + friend Base; +}; +} // namespace + +/// Returns the list of elements of `destructurableType` that are written to by +/// a store operation writing `storeSize` bytes at `storeOffset`. +/// `storeOffset` is required to cleanly point to an immediate element within +/// the type. If the write operation were to write to any padding, write beyond +/// the aggregate or only partially writes to an element, failure is returned. +static FailureOr +getWrittenToFields(const DataLayout &dataLayout, + DestructurableTypeInterface destructurableType, unsigned storeSize, unsigned storeOffset) { - ArrayRef body = structType.getBody(); + DestructurableTypeRange destructurableTypeRange(destructurableType); + unsigned currentOffset = 0; - body = body.drop_until([&](Type type) { - if (!structType.isPacked()) { + for (; !destructurableTypeRange.empty(); + destructurableTypeRange = destructurableTypeRange.drop_front()) { + Type type = destructurableTypeRange.front(); + if (!destructurableTypeRange.isPacked()) { unsigned alignment = dataLayout.getTypeABIAlignment(type); currentOffset = llvm::alignTo(currentOffset, alignment); } @@ -370,40 +424,43 @@ // 0 or stems from a type-consistent GEP indexing into just a single // aggregate. if (currentOffset == storeOffset) - return true; + break; assert(currentOffset < storeOffset && "storeOffset should cleanly point into an immediate field"); currentOffset += dataLayout.getTypeSize(type); - return false; - }); + } size_t exclusiveEnd = 0; - for (; exclusiveEnd < body.size() && storeSize > 0; exclusiveEnd++) { - if (!structType.isPacked()) { - unsigned alignment = dataLayout.getTypeABIAlignment(body[exclusiveEnd]); + for (; exclusiveEnd < destructurableTypeRange.size() && storeSize > 0; + exclusiveEnd++) { + if (!destructurableTypeRange.isPacked()) { + unsigned alignment = + dataLayout.getTypeABIAlignment(destructurableTypeRange[exclusiveEnd]); // No padding allowed inbetween fields at this point in time. if (!llvm::isAligned(llvm::Align(alignment), currentOffset)) return failure(); } - unsigned fieldSize = dataLayout.getTypeSize(body[exclusiveEnd]); + unsigned fieldSize = + dataLayout.getTypeSize(destructurableTypeRange[exclusiveEnd]); if (fieldSize > storeSize) { // Partial writes into an aggregate are okay since subsequent pattern // applications can further split these up into writes into the // sub-elements. - auto subStruct = dyn_cast(body[exclusiveEnd]); - if (!subStruct) + auto subAggregate = dyn_cast( + destructurableTypeRange[exclusiveEnd]); + if (!subAggregate) return failure(); - // Avoid splitting redundantly by making sure the store into the struct - // can actually be split. - if (failed(getWrittenToFields(dataLayout, subStruct, storeSize, + // Avoid splitting redundantly by making sure the store into the + // aggregate can actually be split. + if (failed(getWrittenToFields(dataLayout, subAggregate, storeSize, /*storeOffset=*/0))) return failure(); - return body.take_front(exclusiveEnd + 1); + return destructurableTypeRange.take_front(exclusiveEnd + 1); } currentOffset += fieldSize; storeSize -= fieldSize; @@ -413,7 +470,7 @@ // as a whole. Abort. if (storeSize > 0) return failure(); - return body.take_front(exclusiveEnd); + return destructurableTypeRange.take_front(exclusiveEnd); } /// Splits a store of the vector `value` into `address` at `storeOffset` into @@ -449,7 +506,7 @@ RewriterBase &rewriter, Value address, Value value, unsigned storeSize, unsigned storeOffset, - ArrayRef writtenToFields) { + DestructurableTypeRange writtenToFields) { unsigned currentOffset = storeOffset; for (Type type : writtenToFields) { unsigned fieldSize = dataLayout.getTypeSize(type); @@ -527,26 +584,24 @@ } } - auto structType = typeHint.dyn_cast(); - if (!structType) { - // TODO: Handle array types in the future. + auto destructurableType = typeHint.dyn_cast(); + if (!destructurableType) return failure(); - } - FailureOr> writtenToFields = - getWrittenToFields(dataLayout, structType, storeSize, offset); - if (failed(writtenToFields)) + FailureOr writtenToElements = + getWrittenToFields(dataLayout, destructurableType, storeSize, offset); + if (failed(writtenToElements)) return failure(); - if (writtenToFields->size() <= 1) { + if (writtenToElements->size() <= 1) { // Other patterns should take care of this case, we are only interested in - // splitting field stores. + // splitting element stores. return failure(); } if (isa(sourceType)) { splitIntegerStore(dataLayout, store.getLoc(), rewriter, address, - store.getValue(), storeSize, offset, *writtenToFields); + store.getValue(), storeSize, offset, *writtenToElements); rewriter.eraseOp(store); return success(); } Index: mlir/test/Dialect/LLVMIR/type-consistency.mlir =================================================================== --- mlir/test/Dialect/LLVMIR/type-consistency.mlir +++ mlir/test/Dialect/LLVMIR/type-consistency.mlir @@ -624,3 +624,28 @@ llvm.return } + +// ----- + +// CHECK-LABEL: llvm.func @coalesced_store_ints_array +// CHECK-SAME: %[[ARG:.*]]: i64 +llvm.func @coalesced_store_ints_array(%arg: i64) { + // CHECK: %[[CST0:.*]] = llvm.mlir.constant(0 : i64) : i64 + // CHECK: %[[CST32:.*]] = llvm.mlir.constant(32 : i64) : i64 + + %0 = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.array<2 x i32> + %1 = llvm.alloca %0 x !llvm.array<2 x i32> : (i32) -> !llvm.ptr + + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<2 x i32> + // CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST0]] + // CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i32 + // CHECK: llvm.store %[[TRUNC]], %[[GEP]] + // CHECK: %[[SHR:.*]] = llvm.lshr %[[ARG]], %[[CST32]] : i64 + // CHECK: %[[TRUNC:.*]] = llvm.trunc %[[SHR]] : i64 to i32 + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<2 x i32> + // CHECK: llvm.store %[[TRUNC]], %[[GEP]] + llvm.store %arg, %1 : i64, !llvm.ptr + // CHECK-NOT: llvm.store %[[ARG]], %[[ALLOCA]] + llvm.return +}