diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td @@ -167,6 +167,30 @@ ]; } +def GetResultPtrElementType : OpInterface<"GetResultPtrElementType"> { + let description = [{ + An interface for operations that yield an LLVMPointer. Allows the + operation to provide the type of the element an LLVMPointer points to, + if known. This is only a hint as to how to interpret a given pointer, + translating how the current operation understands it. + }]; + + let cppNamespace = "::mlir::LLVM"; + + let methods = [ + InterfaceMethod< + /*desc=*/ [{Returns the the element type hint of the result + LLVMPointer, if known. Returns nullptr if the + requested result is not an LLVMPointer or if the + element type is unknown.}], + /*returnType=*/ "Type", + /*methodName=*/ "getResultPtrElementType", + /*args=*/ (ins) + > + ]; +} + + //===----------------------------------------------------------------------===// // LLVM dialect type interfaces. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -178,7 +178,8 @@ // Memory-related operations. def LLVM_AllocaOp : LLVM_Op<"alloca", [DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods]>, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]>, LLVM_MemOpPatterns { let arguments = (ins AnyInteger:$arraySize, OptionalAttr:$alignment, @@ -239,7 +240,8 @@ def LLVM_GEPOp : LLVM_Op<"getelementptr", [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { let arguments = (ins LLVM_ScalarOrVectorOf:$base, Variadic>:$dynamicIndices, DenseI32ArrayAttr:$rawConstantIndices, diff --git a/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.h b/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.h @@ -12,6 +12,7 @@ #include "mlir/Dialect/LLVMIR/Transforms/LegalizeForExport.h" #include "mlir/Dialect/LLVMIR/Transforms/OptimizeForNVVM.h" #include "mlir/Dialect/LLVMIR/Transforms/RequestCWrappers.h" +#include "mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h" #include "mlir/Pass/Pass.h" namespace mlir { diff --git a/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td b/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td @@ -30,6 +30,17 @@ let constructor = "::mlir::LLVM::createRequestCWrappersPass()"; } +def LLVMTypeConsistency + : Pass<"llvm-type-consistency", "::mlir::LLVM::LLVMFuncOp"> { + let summary = "Rewrites to improve type consistency"; + let description = [{ + Set of rewrites to improve the coherency of types within an LLVM dialect + program. This will adjust operations operating on pointers so they interpret + their associated pointee type as consistently as possible. + }]; + let constructor = "::mlir::LLVM::createTypeConsistencyPass()"; +} + def NVVMOptimizeForTarget : Pass<"llvm-optimize-for-nvvm-target"> { let summary = "Optimize NVVM IR"; let constructor = "::mlir::NVVM::createOptimizeForTargetPass()"; diff --git a/mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h b/mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h @@ -0,0 +1,59 @@ +//===- TypeConsistency.h - Rewrites to improve type consistency -----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Set of rewrites to improve the coherency of types within an LLVM dialect +// program. This will adjust operations around a given pointer so they interpret +// its pointee type as consistently as possible. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_LLVMIR_TRANSFORMS_TYPECONSISTENCY_H +#define MLIR_DIALECT_LLVMIR_TRANSFORMS_TYPECONSISTENCY_H + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace LLVM { + +#define GEN_PASS_DECL_LLVMTYPECONSISTENCY +#include "mlir/Dialect/LLVMIR/Transforms/Passes.h.inc" + +/// Creates a pass that adjusts operations operating on pointers so they +/// interpret pointee types as consistently as possible. +std::unique_ptr createTypeConsistencyPass(); + +/// Transforms uses of pointers to a whole struct to uses of pointers to the +/// first element of a struct. This is achieved by inserting a GEP to the first +/// element when possible. +template +class AddFieldGetterToStructDirectUse : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(User user, + PatternRewriter &rewriter) const override; +}; + +/// Canonicalizes GEPs of which the base type and the pointer's type hint do not +/// match. This is done by replacing the original GEP into a GEP with the type +/// hint as a base type when an element of the hinted type aligns with the +/// original GEP. +class CanonicalizeAlignedGep : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GEPOp gep, + PatternRewriter &rewriter) const override; +}; + +} // namespace LLVM +} // namespace mlir + +#endif // MLIR_DIALECT_LLVMIR_TRANSFORMS_TYPECONSISTENCY_H diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -281,6 +281,12 @@ return success(); } +Type AllocaOp::getResultPtrElementType() { + // This will become trivial once non-opaque pointers are gone. + return getElemType().has_value() ? *getElemType() + : getResult().getType().getElementType(); +} + //===----------------------------------------------------------------------===// // LLVM::BrOp //===----------------------------------------------------------------------===// @@ -771,6 +777,42 @@ .getElementType(); } +Type GEPOp::getResultPtrElementType() { + // Ensures all indices are static and fetches them. + SmallVector indices; + for (auto index : getIndices()) { + IntegerAttr indexInt = llvm::dyn_cast_if_present(index); + if (!indexInt) + return nullptr; + indices.push_back(indexInt); + } + + // Set the initial type currently being used for indexing. This will be + // updated as the indices get walked over. + Type selectedType = getSourceElementType(); + + // Follow the indexed elements in the gep. + for (IntegerAttr index : llvm::drop_begin(indices)) { + // Ensure the structure of the type being indexed can be reasoned about. + // This includes rejecting any potential typed pointer. + auto destructurable = + llvm::dyn_cast(selectedType); + if (!destructurable) + return nullptr; + + // Follow the type at the index the gep is accessing, making it the new type + // used for indexing. + Type field = destructurable.getTypeAtIndex(index); + if (!field) + return nullptr; + selectedType = field; + } + + // When there are no more indices, the type currently being used for indexing + // is the type of the value pointed at by the returned indexed pointer. + return selectedType; +} + //===----------------------------------------------------------------------===// // LoadOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp @@ -28,18 +28,11 @@ // Interfaces for AllocaOp //===----------------------------------------------------------------------===// -static Type getAllocaElementType(LLVM::AllocaOp alloca) { - // This will become trivial once non-opaque pointers are gone. - return alloca.getElemType().has_value() - ? *alloca.getElemType() - : alloca.getResult().getType().getElementType(); -} - llvm::SmallVector LLVM::AllocaOp::getPromotableSlots() { if (!getOperation()->getBlock()->isEntryBlock()) return {}; - return {MemorySlot{getResult(), getAllocaElementType(*this)}}; + return {MemorySlot{getResult(), getResultPtrElementType()}}; } Value LLVM::AllocaOp::getDefaultValue(const MemorySlot &slot, @@ -68,7 +61,7 @@ if (!mlir::matchPattern(getArraySize(), m_One())) return {}; - Type elemType = getAllocaElementType(*this); + Type elemType = getResultPtrElementType(); auto destructurable = dyn_cast(elemType); if (!destructurable) return {}; @@ -292,52 +285,10 @@ return DeletionKind::Delete; } -/// Returns the type the resulting pointer of the GEP points to. If such a type -/// is not clear, returns null type. -static Type computeReachedGEPType(LLVM::GEPOp gep) { - if (gep.getIndices().empty()) - return {}; - - // Ensures all indices are static and fetches them. - SmallVector indices; - for (auto index : gep.getIndices()) { - IntegerAttr indexInt = llvm::dyn_cast_if_present(index); - if (!indexInt) - return {}; - indices.push_back(indexInt); - } - - // Check the pointer indexing only targets the first element. - if (indices[0].getInt() != 0) - return {}; - - // Set the initial type currently being used for indexing. This will be - // updated as the indices get walked over. - std::optional maybeSelectedType = gep.getElemType(); - if (!maybeSelectedType) - return {}; - Type selectedType = *maybeSelectedType; - - // Follow the indexed elements in the gep. - for (IntegerAttr index : llvm::drop_begin(indices)) { - // Ensure the structure of the type being indexed can be reasoned about. - // This includes rejecting any potential typed pointer. - auto destructurable = - llvm::dyn_cast(selectedType); - if (!destructurable) - return {}; - - // Follow the type at the index the gep is accessing, making it the new type - // used for indexing. - Type field = destructurable.getTypeAtIndex(index); - if (!field) - return {}; - selectedType = field; - } - - // When there are no more indices, the type currently being used for indexing - // is the type of the value pointed at by the returned indexed pointer. - return selectedType; +static bool isFirstIndexZero(LLVM::GEPOp gep) { + IntegerAttr index = + llvm::dyn_cast_if_present(gep.getIndices()[0]); + return index && index.getInt() == 0; } LogicalResult LLVM::GEPOp::ensureOnlySafeAccesses( @@ -346,7 +297,9 @@ return success(); if (slot.elemType != getElemType()) return failure(); - Type reachedType = computeReachedGEPType(*this); + if (!isFirstIndexZero(*this)) + return failure(); + Type reachedType = getResultPtrElementType(); if (!reachedType) return failure(); mustBeSafelyUsed.emplace_back({getResult(), reachedType}); @@ -367,7 +320,9 @@ if (getBase() != slot.ptr || slot.elemType != getElemType()) return false; - Type reachedType = computeReachedGEPType(*this); + if (!isFirstIndexZero(*this)) + return false; + Type reachedType = getResultPtrElementType(); if (!reachedType || getIndices().size() < 2) return false; auto firstLevelIndex = cast(getIndices()[1]); diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt @@ -3,6 +3,7 @@ LegalizeForExport.cpp OptimizeForNVVM.cpp RequestCWrappers.cpp + TypeConsistency.cpp DEPENDS MLIRLLVMPassIncGen diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp @@ -0,0 +1,371 @@ +//===- TypeConsistency.cpp - Rewrites to improve type consistency ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/TypeSwitch.h" + +namespace mlir { +namespace LLVM { +#define GEN_PASS_DEF_LLVMTYPECONSISTENCY +#include "mlir/Dialect/LLVMIR/Transforms/Passes.h.inc" +} // namespace LLVM +} // namespace mlir + +using namespace mlir; +using namespace LLVM; + +//===----------------------------------------------------------------------===// +// Utils +//===----------------------------------------------------------------------===// + +/// Checks that a pointer value has a pointee type hint consistent with the +/// expected type. Returns the type it actually hints to if it differs, or +/// nullptr if the type is consistent or impossible to analyze. +static Type isElementTypeInconsistent(Value addr, Type expectedType) { + auto defOp = dyn_cast_or_null(addr.getDefiningOp()); + if (!defOp) + return nullptr; + + Type elemType = defOp.getResultPtrElementType(); + if (!elemType) + return nullptr; + + if (elemType == expectedType) + return nullptr; + + return elemType; +} + +/// Checks that two types are the same or can be bitcast into one another. +static bool areCastCompatible(DataLayout &layout, Type lhs, Type rhs) { + return lhs == rhs || (!isa(lhs) && + !isa(rhs) && + layout.getTypeSize(lhs) == layout.getTypeSize(rhs)); +} + +//===----------------------------------------------------------------------===// +// AddFieldGetterToStructDirectUse +//===----------------------------------------------------------------------===// + +/// Gets the type of the first subelement of `type` if `type` is destructurable, +/// nullptr otherwise. +static Type getFirstSubelementType(Type type) { + auto destructurable = dyn_cast(type); + if (!destructurable) + return nullptr; + + Type subelementType = destructurable.getTypeAtIndex( + IntegerAttr::get(IntegerType::get(type.getContext(), 32), 0)); + if (subelementType) + return subelementType; + + return nullptr; +} + +/// Extracts a pointer to the first field of an `elemType` from the address +/// pointer of the provided MemOp, and rewires the MemOp so it uses that pointer +/// instead. +template +static void insertFieldIndirection(MemOp op, PatternRewriter &rewriter, + Type elemType) { + PatternRewriter::InsertionGuard guard(rewriter); + + rewriter.setInsertionPointAfterValue(op.getAddr()); + SmallVector firstTypeIndices{0, 0}; + + Value properPtr = rewriter.create( + op->getLoc(), LLVM::LLVMPointerType::get(op.getContext()), elemType, + op.getAddr(), firstTypeIndices); + + rewriter.updateRootInPlace(op, + [&]() { op.getAddrMutable().assign(properPtr); }); +} + +template <> +LogicalResult AddFieldGetterToStructDirectUse::matchAndRewrite( + LoadOp load, PatternRewriter &rewriter) const { + PatternRewriter::InsertionGuard guard(rewriter); + + // Load from typed pointers are not supported. + if (!load.getAddr().getType().isOpaque()) + return failure(); + + Type inconsistentElementType = + isElementTypeInconsistent(load.getAddr(), load.getType()); + if (!inconsistentElementType) + return failure(); + Type firstType = getFirstSubelementType(inconsistentElementType); + if (!firstType) + return failure(); + DataLayout layout = DataLayout::closest(load); + if (!areCastCompatible(layout, firstType, load.getResult().getType())) + return failure(); + + insertFieldIndirection(load, rewriter, inconsistentElementType); + + // If the load does not use the first type but a type that can be casted from + // it, add a bitcast and change the load type. + if (firstType != load.getResult().getType()) { + rewriter.setInsertionPointAfterValue(load.getResult()); + BitcastOp bitcast = rewriter.create( + load->getLoc(), load.getResult().getType(), load.getResult()); + rewriter.updateRootInPlace(load, + [&]() { load.getResult().setType(firstType); }); + rewriter.replaceAllUsesExcept(load.getResult(), bitcast.getResult(), + bitcast); + } + + return success(); +} + +template <> +LogicalResult AddFieldGetterToStructDirectUse::matchAndRewrite( + StoreOp store, PatternRewriter &rewriter) const { + PatternRewriter::InsertionGuard guard(rewriter); + + // Store to typed pointers are not supported. + if (!store.getAddr().getType().isOpaque()) + return failure(); + + Type inconsistentElementType = + isElementTypeInconsistent(store.getAddr(), store.getValue().getType()); + if (!inconsistentElementType) + return failure(); + Type firstType = getFirstSubelementType(inconsistentElementType); + if (!firstType) + return failure(); + + DataLayout layout = DataLayout::closest(store); + // Check that the first field has the right type or can at least be bitcast + // to the right type. + if (!areCastCompatible(layout, firstType, store.getValue().getType())) + return failure(); + + insertFieldIndirection(store, rewriter, inconsistentElementType); + + Value replaceValue = store.getValue(); + if (firstType != store.getValue().getType()) { + rewriter.setInsertionPointAfterValue(store.getValue()); + replaceValue = rewriter.create(store->getLoc(), firstType, + store.getValue()); + } + + rewriter.updateRootInPlace( + store, [&]() { store.getValueMutable().assign(replaceValue); }); + + return success(); +} + +//===----------------------------------------------------------------------===// +// CanonicalizeAlignedGep +//===----------------------------------------------------------------------===// + +/// Returns the amount of bytes the provided GEP elements will offset the +/// pointer by. Returns nullopt if the offset could not be computed. +static std::optional gepToByteOffset(DataLayout &layout, Type base, + ArrayRef indices) { + uint64_t offset = indices[0] * layout.getTypeSize(base); + + Type currentType = base; + for (uint32_t index : llvm::drop_begin(indices)) { + bool shouldCancel = + TypeSwitch(currentType) + .Case([&](LLVMArrayType arrayType) { + if (arrayType.getNumElements() <= index) + return true; + offset += index * layout.getTypeSize(arrayType.getElementType()); + currentType = arrayType.getElementType(); + return false; + }) + .Case([&](LLVMStructType structType) { + ArrayRef body = structType.getBody(); + if (body.size() <= index) + return true; + for (uint32_t i = 0; i < index; i++) { + if (!structType.isPacked()) + offset = llvm::alignTo(offset, + layout.getTypeABIAlignment(body[i])); + offset += layout.getTypeSize(body[i]); + } + currentType = body[index]; + return false; + }) + .Default([](Type) { return true; }); + + if (shouldCancel) + return std::nullopt; + } + + return offset; +} + +/// Fills in `equivalentIndicesOut` with GEP indices that would be equivalent to +/// offsetting a pointer by `offset` bytes, assuming the GEP has `base` as base +/// type. +static LogicalResult +findIndicesForOffset(DataLayout &layout, Type base, uint64_t offset, + SmallVectorImpl &equivalentIndicesOut) { + + uint64_t baseSize = layout.getTypeSize(base); + uint64_t rootIndex = offset / baseSize; + if (rootIndex > std::numeric_limits::max()) + return failure(); + equivalentIndicesOut.push_back(rootIndex); + + uint64_t distanceToStart = rootIndex * baseSize; + +#ifndef NDEBUG + auto isWithinCurrentType = [&](Type currentType) { + return offset < distanceToStart + layout.getTypeSize(currentType); + }; +#endif + + Type currentType = base; + while (distanceToStart < offset) { + // While an index that does not perfectly align with offset has not been + // reached... + + assert(isWithinCurrentType(currentType)); + + bool shouldCancel = + TypeSwitch(currentType) + .Case([&](LLVMArrayType arrayType) { + // Find which element of the array contains the offset. + uint64_t elemSize = + layout.getTypeSize(arrayType.getElementType()); + uint64_t index = (offset - distanceToStart) / elemSize; + equivalentIndicesOut.push_back(index); + distanceToStart += index * elemSize; + + // Then, try to find where in the element the offset is. If the + // offset is exactly the beginning of the element, the loop is + // complete. + currentType = arrayType.getElementType(); + + // Only continue if the element in question can be indexed using + // an i32. + return index > std::numeric_limits::max(); + }) + .Case([&](LLVMStructType structType) { + ArrayRef body = structType.getBody(); + uint32_t index = 0; + + // Walk over the elements of the struct to find in which of them + // the offset is. + for (Type elem : body) { + uint64_t elemSize = layout.getTypeSize(elem); + if (!structType.isPacked()) { + distanceToStart = llvm::alignTo( + distanceToStart, layout.getTypeABIAlignment(elem)); + // If the offset is in padding, cancel the rewrite. + if (offset < distanceToStart) + return true; + } + + if (offset < distanceToStart + elemSize) { + // The offset is within this element, stop iterating the + // struct and look within the current element. + equivalentIndicesOut.push_back(index); + currentType = elem; + return false; + } + + // The offset is not within this element, continue walking over + // the struct. + distanceToStart += elemSize; + index++; + } + + // The offset was supposed to be within this struct but is not. + // This can happen if the offset points into final padding. + // Anyway, nothing can be done. + return true; + }) + .Default([](Type) { + // If the offset is within a type that cannot be split, no indices + // will yield this offset. This can happen if the offset is not + // perfectly aligned with a leaf type. + // TODO: support vectors. + return true; + }); + + if (shouldCancel) + return failure(); + } + + return success(); +} + +LogicalResult +CanonicalizeAlignedGep::matchAndRewrite(GEPOp gep, + PatternRewriter &rewriter) const { + // GEP of typed pointers are not supported. + if (!gep.getElemType()) + return failure(); + + std::optional maybeBaseType = gep.getElemType(); + if (!maybeBaseType) + return failure(); + Type baseType = *maybeBaseType; + + Type typeHint = isElementTypeInconsistent(gep.getBase(), baseType); + if (!typeHint) + return failure(); + + SmallVector indices; + // Ensures all indices are static and fetches them. + for (auto index : gep.getIndices()) { + IntegerAttr indexInt = llvm::dyn_cast_if_present(index); + if (!indexInt) + return failure(); + indices.push_back(indexInt.getInt()); + } + + DataLayout layout = DataLayout::closest(gep); + std::optional desiredOffset = + gepToByteOffset(layout, gep.getSourceElementType(), indices); + if (!desiredOffset) + return failure(); + + SmallVector newIndices; + if (failed( + findIndicesForOffset(layout, typeHint, *desiredOffset, newIndices))) + return failure(); + + rewriter.replaceOpWithNewOp( + gep, LLVM::LLVMPointerType::get(getContext()), typeHint, gep.getBase(), + newIndices, gep.getInbounds()); + + return success(); +} + +//===----------------------------------------------------------------------===// +// Type consistency pass +//===----------------------------------------------------------------------===// + +namespace { +struct LLVMTypeConsistencyPass + : public LLVM::impl::LLVMTypeConsistencyBase { + void runOnOperation() override { + RewritePatternSet rewritePatterns(&getContext()); + rewritePatterns.add>(&getContext()); + rewritePatterns.add>( + &getContext()); + rewritePatterns.add(&getContext()); + FrozenRewritePatternSet frozen(std::move(rewritePatterns)); + + if (failed(applyPatternsAndFoldGreedily(getOperation(), frozen))) + signalPassFailure(); + } +}; +} // namespace + +std::unique_ptr LLVM::createTypeConsistencyPass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/LLVMIR/type-consistency.mlir b/mlir/test/Dialect/LLVMIR/type-consistency.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/LLVMIR/type-consistency.mlir @@ -0,0 +1,150 @@ +// RUN: mlir-opt %s --pass-pipeline="builtin.module(llvm.func(llvm-type-consistency))" --split-input-file | FileCheck %s + +// CHECK-LABEL: llvm.func @same_address +llvm.func @same_address(%arg: i32) { + %0 = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i32, i32, i32)> + %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32, i32)> : (i32) -> !llvm.ptr + // CHECK: = llvm.getelementptr %[[ALLOCA]][0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32, i32)> + %7 = llvm.getelementptr %1[8] : (!llvm.ptr) -> !llvm.ptr, i8 + llvm.store %arg, %7 : i32, !llvm.ptr + llvm.return +} + +// ----- + +// CHECK-LABEL: llvm.func @same_address_keep_inbounds +llvm.func @same_address_keep_inbounds(%arg: i32) { + %0 = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i32, i32, i32)> + %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32, i32)> : (i32) -> !llvm.ptr + // CHECK: = llvm.getelementptr inbounds %[[ALLOCA]][0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32, i32)> + %7 = llvm.getelementptr inbounds %1[8] : (!llvm.ptr) -> !llvm.ptr, i8 + llvm.store %arg, %7 : i32, !llvm.ptr + llvm.return +} + +// ----- + +// CHECK-LABEL: llvm.func @struct_store_instead_of_first_field +llvm.func @struct_store_instead_of_first_field(%arg: i32) { + %0 = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i32, i32, i32)> + %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32, i32)> : (i32) -> !llvm.ptr + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32, i32)> + // CHECK: llvm.store %{{.*}}, %[[GEP]] : i32 + llvm.store %arg, %1 : i32, !llvm.ptr + llvm.return +} + +// ----- + +// CHECK-LABEL: llvm.func @struct_store_instead_of_first_field_same_size +// CHECK-SAME: (%[[ARG:.*]]: f32) +llvm.func @struct_store_instead_of_first_field_same_size(%arg: f32) { + %0 = llvm.mlir.constant(1 : i32) : i32 + // CHECK-DAG: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i32, i32, i32)> + %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32, i32)> : (i32) -> !llvm.ptr + // CHECK-DAG: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32, i32)> + // CHECK-DAG: %[[BITCAST:.*]] = llvm.bitcast %[[ARG]] : f32 to i32 + // CHECK: llvm.store %[[BITCAST]], %[[GEP]] : i32 + llvm.store %arg, %1 : f32, !llvm.ptr + llvm.return +} + +// ----- + +// CHECK-LABEL: llvm.func @struct_load_instead_of_first_field +llvm.func @struct_load_instead_of_first_field() -> i32 { + %0 = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i32, i32, i32)> + %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32, i32)> : (i32) -> !llvm.ptr + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32, i32)> + // CHECK: %[[RES:.*]] = llvm.load %[[GEP]] : !llvm.ptr -> i32 + %2 = llvm.load %1 : !llvm.ptr -> i32 + // CHECK: llvm.return %[[RES]] : i32 + llvm.return %2 : i32 +} + +// ----- + +// CHECK-LABEL: llvm.func @struct_load_instead_of_first_field_same_size +llvm.func @struct_load_instead_of_first_field_same_size() -> f32 { + %0 = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i32, i32, i32)> + %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32, i32)> : (i32) -> !llvm.ptr + // CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, i32, i32)> + // CHECK: %[[LOADED:.*]] = llvm.load %[[GEP]] : !llvm.ptr -> i32 + // CHECK: %[[RES:.*]] = llvm.bitcast %[[LOADED]] : i32 to f32 + %2 = llvm.load %1 : !llvm.ptr -> f32 + // CHECK: llvm.return %[[RES]] : f32 + llvm.return %2 : f32 +} + +// ----- + +// CHECK-LABEL: llvm.func @index_in_final_padding +llvm.func @index_in_final_padding(%arg: i32) { + %0 = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i32, i8)> + %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i8)> : (i32) -> !llvm.ptr + // CHECK: = llvm.getelementptr %[[ALLOCA]][7] : (!llvm.ptr) -> !llvm.ptr, i8 + %7 = llvm.getelementptr %1[7] : (!llvm.ptr) -> !llvm.ptr, i8 + llvm.store %arg, %7 : i32, !llvm.ptr + llvm.return +} + +// ----- + +// CHECK-LABEL: llvm.func @index_out_of_bounds +llvm.func @index_out_of_bounds(%arg: i32) { + %0 = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i32, i32)> + %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, i32)> : (i32) -> !llvm.ptr + // CHECK: = llvm.getelementptr %[[ALLOCA]][9] : (!llvm.ptr) -> !llvm.ptr, i8 + %7 = llvm.getelementptr %1[9] : (!llvm.ptr) -> !llvm.ptr, i8 + llvm.store %arg, %7 : i32, !llvm.ptr + llvm.return +} + +// ----- + +// CHECK-LABEL: llvm.func @index_in_padding +llvm.func @index_in_padding(%arg: i16) { + %0 = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i16, i32)> + %1 = llvm.alloca %0 x !llvm.struct<"foo", (i16, i32)> : (i32) -> !llvm.ptr + // CHECK: = llvm.getelementptr %[[ALLOCA]][2] : (!llvm.ptr) -> !llvm.ptr, i8 + %7 = llvm.getelementptr %1[2] : (!llvm.ptr) -> !llvm.ptr, i8 + llvm.store %arg, %7 : i16, !llvm.ptr + llvm.return +} + +// ----- + +// CHECK-LABEL: llvm.func @index_not_in_padding_because_packed +llvm.func @index_not_in_padding_because_packed(%arg: i16) { + %0 = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", packed (i16, i32)> + %1 = llvm.alloca %0 x !llvm.struct<"foo", packed (i16, i32)> : (i32) -> !llvm.ptr + // CHECK: = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", packed (i16, i32)> + %7 = llvm.getelementptr %1[2] : (!llvm.ptr) -> !llvm.ptr, i8 + llvm.store %arg, %7 : i16, !llvm.ptr + llvm.return +} + +// ----- + +// CHECK-LABEL: llvm.func @index_to_struct +// CHECK-SAME: (%[[ARG:.*]]: i32) +llvm.func @index_to_struct(%arg: i32) { + %0 = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %{{.*}} x !llvm.struct<"foo", (i32, struct<"bar", (i32, i32)>)> + %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, struct<"bar", (i32, i32)>)> : (i32) -> !llvm.ptr + // CHECK: %[[GEP0:.*]] = llvm.getelementptr %[[ALLOCA]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, struct<"bar", (i32, i32)>)> + // CHECK: %[[GEP1:.*]] = llvm.getelementptr %[[GEP0]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"bar", (i32, i32)> + %7 = llvm.getelementptr %1[4] : (!llvm.ptr) -> !llvm.ptr, i8 + // CHECK: llvm.store %[[ARG]], %[[GEP1]] + llvm.store %arg, %7 : i32, !llvm.ptr + llvm.return +}