diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td @@ -167,6 +167,31 @@ ]; } +def GetResultPtrElementType : OpInterface<"GetResultPtrElementType"> { + let description = [{ + An interface for operations that yield an LLVMPointer. Allows the + operation to provide the type of the element an LLVMPointer opaque + pointer points to, if known. This is only a hint as to how to + interpret a given pointer, translating how the current operation + understands it. + }]; + + let cppNamespace = "::mlir::LLVM"; + + let methods = [ + InterfaceMethod< + /*desc=*/ [{Returns the the element type hint of the n-th + LLVMPointer result, if known. Returns nullptr if the + requested result is not an LLVMPointer or if the + element type is unknown.}], + /*returnType=*/ "Type", + /*methodName=*/ "getResultPtrElementType", + /*args=*/ (ins "size_t":$nth) + > + ]; +} + + //===----------------------------------------------------------------------===// // LLVM dialect type interfaces. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -174,7 +174,8 @@ // Memory-related operations. def LLVM_AllocaOp : LLVM_Op<"alloca", [DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods]>, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]>, LLVM_MemOpPatterns { let arguments = (ins AnyInteger:$arraySize, OptionalAttr:$alignment, @@ -235,7 +236,8 @@ def LLVM_GEPOp : LLVM_Op<"getelementptr", [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { let arguments = (ins LLVM_ScalarOrVectorOf:$base, Variadic>:$dynamicIndices, DenseI32ArrayAttr:$rawConstantIndices, diff --git a/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.h b/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.h @@ -12,6 +12,7 @@ #include "mlir/Dialect/LLVMIR/Transforms/LegalizeForExport.h" #include "mlir/Dialect/LLVMIR/Transforms/OptimizeForNVVM.h" #include "mlir/Dialect/LLVMIR/Transforms/RequestCWrappers.h" +#include "mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h" #include "mlir/Pass/Pass.h" namespace mlir { diff --git a/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td b/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td @@ -30,6 +30,17 @@ let constructor = "::mlir::LLVM::createRequestCWrappersPass()"; } +def LLVMTypeConsistency + : Pass<"llvm-type-consistency", "::mlir::LLVM::LLVMFuncOp"> { + let summary = "Rewrites to improve type consistency"; + let description = [{ + Set of rewrites to improve the coherency of types within an LLVM dialect + program. This will adjust operations operating on pointers so they interpret + their associated pointee type as consistently as possible. + }]; + let constructor = "::mlir::LLVM::createTypeConsistencyPass()"; +} + def NVVMOptimizeForTarget : Pass<"llvm-optimize-for-nvvm-target"> { let summary = "Optimize NVVM IR"; let constructor = "::mlir::NVVM::createOptimizeForTargetPass()"; diff --git a/mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h b/mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h @@ -0,0 +1,59 @@ +//===- TypeConsistency.h - Rewrites to improve type consistency -----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Set of rewrites to improve the coherency of types within an LLVM dialect +// program. This will adjust operations around a given pointer so they interpret +// its pointee type as consistently as possible. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_LLVMIR_TRANSFORMS_TYPECONSISTENCY_H +#define MLIR_DIALECT_LLVMIR_TRANSFORMS_TYPECONSISTENCY_H + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace LLVM { + +#define GEN_PASS_DECL_LLVMTYPECONSISTENCY +#include "mlir/Dialect/LLVMIR/Transforms/Passes.h.inc" + +/// Creates a pass that adjusts operations operating on pointers so they +/// interpret pointee types as consistently as possible. +std::unique_ptr createTypeConsistencyPass(); + +/// Transforms uses of pointers to a whole struct to uses of pointers to the +/// first element of a struct. This is achieved by inserting when possible a GEP +/// to the first element. +template +class AddFieldGetterToStructDirectUse : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(User user, + PatternRewriter &rewriter) const override; +}; + +/// Canonicalizes GEPs of which the base type and the pointer's type hint do not +/// match. This is done by replacing the original GEP into a GEP with the type +/// hint as a base type when an element of the hinted type aligns with the +/// original GEP. +class CanonicalizeAlignedGep : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GEPOp gep, + PatternRewriter &rewriter) const override; +}; + +} // namespace LLVM +} // namespace mlir + +#endif // MLIR_DIALECT_LLVMIR_TRANSFORMS_TYPECONSISTENCY_H diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -281,6 +281,14 @@ return success(); } +Type AllocaOp::getResultPtrElementType(size_t nth) { + if (nth != 0) + return nullptr; + // This will become trivial once non-opaque pointers are gone. + return getElemType().has_value() ? *getElemType() + : getResult().getType().getElementType(); +} + //===----------------------------------------------------------------------===// // LLVM::BrOp //===----------------------------------------------------------------------===// @@ -763,6 +771,45 @@ .getElementType(); } +Type GEPOp::getResultPtrElementType(size_t nth) { + if (nth != 0) + return nullptr; + + // Ensures all indices are static and fetches them. + SmallVector indices; + for (auto index : getIndices()) { + IntegerAttr indexInt = llvm::dyn_cast_if_present(index); + if (!indexInt) + return nullptr; + indices.push_back(indexInt); + } + + // Set the initial type currently being used for indexing. This will be + // updated as the indices get walked over. + Type selectedType = getSourceElementType(); + + // Follow the indexed elements in the gep. + for (IntegerAttr index : llvm::drop_begin(indices)) { + // Ensure the structure of the type being indexed can be reasoned about. + // This includes rejecting any potential typed pointer. + auto destructurable = + llvm::dyn_cast(selectedType); + if (!destructurable) + return nullptr; + + // Follow the type at the index the gep is accessing, making it the new type + // used for indexing. + Type field = destructurable.getTypeAtIndex(index); + if (!field) + return nullptr; + selectedType = field; + } + + // When there are no more indices, the type currently being used for indexing + // is the type of the value pointed at by the returned indexed pointer. + return selectedType; +} + //===----------------------------------------------------------------------===// // LoadOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp @@ -28,18 +28,11 @@ // Interfaces for AllocaOp //===----------------------------------------------------------------------===// -static Type getAllocaElementType(LLVM::AllocaOp alloca) { - // This will become trivial once non-opaque pointers are gone. - return alloca.getElemType().has_value() - ? *alloca.getElemType() - : alloca.getResult().getType().getElementType(); -} - llvm::SmallVector LLVM::AllocaOp::getPromotableSlots() { if (!getOperation()->getBlock()->isEntryBlock()) return {}; - return {MemorySlot{getResult(), getAllocaElementType(*this)}}; + return {MemorySlot{getResult(), getResultPtrElementType(0)}}; } Value LLVM::AllocaOp::getDefaultValue(const MemorySlot &slot, @@ -68,7 +61,7 @@ if (!mlir::matchPattern(getArraySize(), m_One())) return {}; - Type elemType = getAllocaElementType(*this); + Type elemType = getResultPtrElementType(0); auto destructurable = dyn_cast(elemType); if (!destructurable) return {}; @@ -292,52 +285,10 @@ return DeletionKind::Delete; } -/// Returns the type the resulting pointer of the GEP points to. If such a type -/// is not clear, returns null type. -static Type computeReachedGEPType(LLVM::GEPOp gep) { - if (gep.getIndices().empty()) - return {}; - - // Ensures all indices are static and fetches them. - SmallVector indices; - for (auto index : gep.getIndices()) { - IntegerAttr indexInt = llvm::dyn_cast_if_present(index); - if (!indexInt) - return {}; - indices.push_back(indexInt); - } - - // Check the pointer indexing only targets the first element. - if (indices[0].getInt() != 0) - return {}; - - // Set the initial type currently being used for indexing. This will be - // updated as the indices get walked over. - std::optional maybeSelectedType = gep.getElemType(); - if (!maybeSelectedType) - return {}; - Type selectedType = *maybeSelectedType; - - // Follow the indexed elements in the gep. - for (IntegerAttr index : llvm::drop_begin(indices)) { - // Ensure the structure of the type being indexed can be reasoned about. - // This includes rejecting any potential typed pointer. - auto destructurable = - llvm::dyn_cast(selectedType); - if (!destructurable) - return {}; - - // Follow the type at the index the gep is accessing, making it the new type - // used for indexing. - Type field = destructurable.getTypeAtIndex(index); - if (!field) - return {}; - selectedType = field; - } - - // When there are no more indices, the type currently being used for indexing - // is the type of the value pointed at by the returned indexed pointer. - return selectedType; +static bool isFirstIndexZero(LLVM::GEPOp gep) { + IntegerAttr index = + llvm::dyn_cast_if_present(gep.getIndices()[0]); + return index && index.getInt() == 0; } LogicalResult LLVM::GEPOp::ensureOnlySafeAccesses( @@ -346,7 +297,9 @@ return success(); if (slot.elemType != getElemType()) return failure(); - Type reachedType = computeReachedGEPType(*this); + if (!isFirstIndexZero(*this)) + return failure(); + Type reachedType = getResultPtrElementType(0); if (!reachedType) return failure(); mustBeSafelyUsed.emplace_back({getResult(), reachedType}); @@ -367,7 +320,9 @@ if (getBase() != slot.ptr || slot.elemType != getElemType()) return false; - Type reachedType = computeReachedGEPType(*this); + if (!isFirstIndexZero(*this)) + return false; + Type reachedType = getResultPtrElementType(0); if (!reachedType || getIndices().size() < 2) return false; auto firstLevelIndex = cast(getIndices()[1]); diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt @@ -3,6 +3,7 @@ LegalizeForExport.cpp OptimizeForNVVM.cpp RequestCWrappers.cpp + TypeConsistency.cpp DEPENDS MLIRLLVMPassIncGen diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp @@ -0,0 +1,330 @@ +//===- TypeConsistency.cpp - Rewrites to improve type consistency ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMInterfaces.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/DataLayoutInterfaces.h" +#include "mlir/Interfaces/MemorySlotInterfaces.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/MathExtras.h" +#include +#include +#include + +namespace mlir { +namespace LLVM { +#define GEN_PASS_DEF_LLVMTYPECONSISTENCY +#include "mlir/Dialect/LLVMIR/Transforms/Passes.h.inc" +} // namespace LLVM +} // namespace mlir + +using namespace mlir; +using namespace LLVM; + +//===----------------------------------------------------------------------===// +// Utils +//===----------------------------------------------------------------------===// + +/// Checks that a pointer value has a pointee type hint consistent with the +/// expected type. Returns the type it actually hints to if it differs, or +/// nullptr if the type is consistent or impossible to analyze. +static Type isElementTypeInconsistent(Value addr, Type expectedType) { + auto defOp = dyn_cast_or_null(addr.getDefiningOp()); + if (!defOp) + return nullptr; + + Type elemType = defOp.getResultPtrElementType(0); + if (!elemType) + return nullptr; + + if (elemType == expectedType) + return nullptr; + + return elemType; +} + +//===----------------------------------------------------------------------===// +// AddFieldGetterToStructDirectUse +//===----------------------------------------------------------------------===// + +/// Gets the type of the first subelement of `type` if `type` is destructurable, +/// nullptr otherwise. +static Type getFirstSubelementType(Type type) { + auto destructurable = dyn_cast(type); + if (!destructurable) + return nullptr; + + Type indexType = destructurable.getTypeAtIndex( + IntegerAttr::get(IntegerType::get(type.getContext(), 32), 0)); + if (indexType) + return indexType; + + indexType = destructurable.getTypeAtIndex( + IntegerAttr::get(IntegerType::get(type.getContext(), 64), 0)); + if (indexType) + return indexType; + + return nullptr; +} + +/// Attempts to extract a pointer to the first field of an `elemType` from the +/// address pointer of the provided MemOp, and rewire the MemOp so it uses that +/// pointer instead. +template +static LogicalResult +attemptFieldIndirection(MemOp op, PatternRewriter &rewriter, Type elemType) { + PatternRewriter::InsertionGuard guard(rewriter); + + rewriter.setInsertionPointAfterValue(op.getAddr()); + SmallVector firstTypeIndices{0, 0}; + + // TODO: Simplify once typed pointers are removed from LLVM dialect. + Value properPtr = + op.getAddr().getType().isOpaque() + ? rewriter.create( + op->getLoc(), LLVM::LLVMPointerType::get(op.getContext()), + elemType, op.getAddr(), firstTypeIndices) + : rewriter.create(op->getLoc(), + LLVM::LLVMPointerType::get(op.getContext()), + op.getAddr(), firstTypeIndices); + + rewriter.updateRootInPlace(op, [&]() { + MutableOperandRange addr = op.getAddrMutable(); + addr.clear(); + addr.append(properPtr); + }); + + return success(); +} + +template <> +LogicalResult AddFieldGetterToStructDirectUse::matchAndRewrite( + LoadOp load, PatternRewriter &rewriter) const { + Type inconsistentElementType = + isElementTypeInconsistent(load.getAddr(), load.getType()); + if (!inconsistentElementType) + return failure(); + if (!getFirstSubelementType(inconsistentElementType)) + return failure(); + return attemptFieldIndirection(load, rewriter, + inconsistentElementType); +} + +template <> +LogicalResult AddFieldGetterToStructDirectUse::matchAndRewrite( + StoreOp store, PatternRewriter &rewriter) const { + PatternRewriter::InsertionGuard guard(rewriter); + + Type inconsistentElementType = + isElementTypeInconsistent(store.getAddr(), store.getValue().getType()); + if (!inconsistentElementType) + return failure(); + Type firstType = getFirstSubelementType(inconsistentElementType); + if (!firstType) + return failure(); + + DataLayout layout = DataLayout::closest(store); + // Check that the first field has the right type or can at least be bitcast + // to the right type. + if (firstType != store.getValue().getType() && + (isa(firstType) || + layout.getTypeSize(firstType) != + layout.getTypeSize(store.getValue().getType()))) + return failure(); + + if (failed(attemptFieldIndirection(store, rewriter, + inconsistentElementType))) + return failure(); + + Value replaceValue; + if (firstType == store.getValue().getType()) { + replaceValue = store.getValue(); + } else { + rewriter.setInsertionPointAfterValue(store.getValue()); + replaceValue = rewriter.create(store->getLoc(), firstType, + store.getValue()); + } + + rewriter.updateRootInPlace(store, [&]() { + MutableOperandRange value = store.getValueMutable(); + value.clear(); + value.append(replaceValue); + }); + + return success(); +} + +//===----------------------------------------------------------------------===// +// CanonicalizeAlignedGep +//===----------------------------------------------------------------------===// + +/// Returns the amount of bytes the provided GEP elements will offset the +/// pointer by. Returns nullopt if the offset could not be computed. +static std::optional gepToByteOffset(DataLayout &layout, Type base, + ArrayRef indices) { + size_t offset = indices[0] * layout.getTypeSize(base); + + Type currentType = base; + for (uint32_t index : llvm::drop_begin(indices)) { + bool shouldCancel = + TypeSwitch(currentType) + .Case([&](LLVMArrayType arrayType) { + if (arrayType.getNumElements() <= index) + return true; + offset += index * layout.getTypeSize(arrayType.getElementType()); + currentType = arrayType.getElementType(); + return false; + }) + .Case([&](LLVMStructType structType) { + ArrayRef body = structType.getBody(); + if (body.size() <= index) + return true; + for (uint32_t i = 0; i < index; i++) { + if (!structType.isPacked()) + offset = llvm::alignTo(offset, + layout.getTypeABIAlignment(body[i])); + offset += layout.getTypeSize(body[i]); + } + currentType = body[index]; + return false; + }) + .Default([](Type) { return true; }); + + if (shouldCancel) + return std::nullopt; + } + + return offset; +} + +/// Fills in `equivalentIndicesOut` with GEP indices that would be equivalent to +/// offsetting a pointer by `offset` bytes, assuming the GEP has `base` as base +/// type. +static LogicalResult +findIndicesForOffset(DataLayout &layout, Type base, uint64_t offset, + SmallVectorImpl &equivalentIndicesOut) { + uint64_t baseSize = layout.getTypeSize(base); + uint64_t rootIndex = offset / baseSize; + offset %= baseSize; + if (rootIndex > std::numeric_limits::max()) + return failure(); + equivalentIndicesOut.push_back(rootIndex); + + Type currentType = base; + while (offset > 0) { + bool shouldCancel = + TypeSwitch(currentType) + .Case([&](LLVMArrayType arrayType) { + uint64_t elemSize = + layout.getTypeSize(arrayType.getElementType()); + uint64_t index = offset / elemSize; + offset %= elemSize; + equivalentIndicesOut.push_back(index); + currentType = arrayType.getElementType(); + return index > std::numeric_limits::max(); + }) + .Case([&](LLVMStructType structType) { + ArrayRef body = structType.getBody(); + uint32_t index = 0; + for (Type elem : body) { + if (!structType.isPacked()) + offset = + llvm::alignDown(offset, layout.getTypeABIAlignment(elem)); + if (offset < layout.getTypeSize(elem)) { + equivalentIndicesOut.push_back(index); + currentType = elem; + return false; + } + offset -= layout.getTypeSize(elem); + index++; + } + return true; + }) + .Default([](Type) { return true; }); + + if (shouldCancel) + return failure(); + } + + return success(); +} + +LogicalResult +CanonicalizeAlignedGep::matchAndRewrite(GEPOp gep, + PatternRewriter &rewriter) const { + std::optional maybeBaseType = gep.getElemType(); + if (!maybeBaseType) + return failure(); + Type baseType = *maybeBaseType; + + Type typeHint = isElementTypeInconsistent(gep.getBase(), baseType); + if (!typeHint) + return failure(); + + SmallVector indices; + // Ensures all indices are static and fetches them. + for (auto index : gep.getIndices()) { + IntegerAttr indexInt = llvm::dyn_cast_if_present(index); + if (!indexInt) + return failure(); + indices.push_back(indexInt.getInt()); + } + + DataLayout layout = DataLayout::closest(gep); + std::optional desiredOffset = + gepToByteOffset(layout, gep.getSourceElementType(), indices); + if (!desiredOffset) + return failure(); + + SmallVector newIndices; + if (failed( + findIndicesForOffset(layout, typeHint, *desiredOffset, newIndices))) + return failure(); + + // TODO: Simplify when opaque pointers are gone. + if (gep.getElemType()) + rewriter.replaceOpWithNewOp( + gep, LLVM::LLVMPointerType::get(getContext()), typeHint, gep.getBase(), + newIndices, gep.getInbounds()); + else + rewriter.replaceOpWithNewOp( + gep, LLVM::LLVMPointerType::get(getContext()), gep.getBase(), + newIndices, gep.getInbounds()); + + return success(); +} + +//===----------------------------------------------------------------------===// +// Type consistency pass +//===----------------------------------------------------------------------===// + +namespace { +struct LLVMTypeConsistencyPass + : public LLVM::impl::LLVMTypeConsistencyBase { + void runOnOperation() override { + RewritePatternSet rewritePatterns(&getContext()); + rewritePatterns.add>(&getContext()); + rewritePatterns.add>( + &getContext()); + rewritePatterns.add(&getContext()); + FrozenRewritePatternSet frozen(std::move(rewritePatterns)); + + if (failed(applyPatternsAndFoldGreedily(getOperation(), frozen))) + signalPassFailure(); + } +}; +} // namespace + +std::unique_ptr LLVM::createTypeConsistencyPass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/LLVMIR/type-consistency.mlir b/mlir/test/Dialect/LLVMIR/type-consistency.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/LLVMIR/type-consistency.mlir @@ -0,0 +1,53 @@ +// RUN: mlir-opt %s --pass-pipeline="builtin.module(llvm.func(llvm-type-consistency))" --split-input-file | FileCheck %s + +// CHECK-LABEL: llvm.func @same_address +llvm.func @same_address(%arg: i32) { + %0 = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i32, i32, i32)> + %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32, i32)> : (i32) -> !llvm.ptr + // CHECK: = llvm.getelementptr %[[ALLOCA]][0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32, i32)> + %7 = llvm.getelementptr %1[8] : (!llvm.ptr) -> !llvm.ptr, i8 + llvm.store %arg, %7 : i32, !llvm.ptr + llvm.return +} + +// ----- + +// CHECK-LABEL: llvm.func @same_address_keep_inbounds +llvm.func @same_address_keep_inbounds(%arg: i32) { + %0 = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i32, i32, i32)> + %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32, i32)> : (i32) -> !llvm.ptr + // CHECK: = llvm.getelementptr inbounds %[[ALLOCA]][0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32, i32)> + %7 = llvm.getelementptr inbounds %1[8] : (!llvm.ptr) -> !llvm.ptr, i8 + llvm.store %arg, %7 : i32, !llvm.ptr + llvm.return +} + +// ----- + +// CHECK-LABEL: llvm.func @struct_use_instead_of_first_field +llvm.func @struct_use_instead_of_first_field(%arg: i32) { + %0 = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i32, i32, i32)> + %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32, i32)> : (i32) -> !llvm.ptr + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32, i32)> + // CHECK: llvm.store %{{.*}}, %[[GEP]] : i32 + llvm.store %arg, %1 : i32, !llvm.ptr + llvm.return +} + +// ----- + +// CHECK-LABEL: llvm.func @struct_use_instead_of_first_field_same_size +// CHECK-SAME: (%[[ARG:.*]]: f32) +llvm.func @struct_use_instead_of_first_field_same_size(%arg: f32) { + %0 = llvm.mlir.constant(1 : i32) : i32 + // CHECK-DAG: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i32, i32, i32)> + %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32, i32)> : (i32) -> !llvm.ptr + // CHECK-DAG: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32, i32)> + // CHECK-DAG: %[[BITCAST:.*]] = llvm.bitcast %[[ARG]] : f32 to i32 + // CHECK: llvm.store %[[BITCAST]], %[[GEP]] : i32 + llvm.store %arg, %1 : f32, !llvm.ptr + llvm.return +}