diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -173,7 +173,8 @@ // Memory-related operations. def LLVM_AllocaOp : LLVM_Op<"alloca", - [DeclareOpInterfaceMethods]>, + [DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]>, LLVM_MemOpPatterns { let arguments = (ins AnyInteger:$arraySize, OptionalAttr:$alignment, @@ -232,7 +233,9 @@ } def LLVM_GEPOp : LLVM_Op<"getelementptr", [Pure, - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { let arguments = (ins LLVM_ScalarOrVectorOf:$base, Variadic>:$dynamicIndices, DenseI32ArrayAttr:$rawConstantIndices, @@ -316,7 +319,8 @@ } def LLVM_LoadOp : LLVM_MemAccessOpBase<"load", - [DeclareOpInterfaceMethods]> { + [DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { dag args = (ins Arg, "", [MemRead]>:$addr, OptionalAttr:$alignment, UnitAttr:$volatile_, @@ -388,7 +392,8 @@ } def LLVM_StoreOp : LLVM_MemAccessOpBase<"store", - [DeclareOpInterfaceMethods]> { + [DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { dag args = (ins LLVM_LoadableType:$value, Arg,"",[MemWrite]>:$addr, OptionalAttr:$alignment, diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h @@ -16,6 +16,7 @@ #include "mlir/IR/Types.h" #include "mlir/Interfaces/DataLayoutInterfaces.h" +#include "mlir/Interfaces/MemorySlotInterfaces.h" #include namespace llvm { @@ -103,6 +104,7 @@ class LLVMStructType : public Type::TypeBase { public: /// Inherit base constructors. @@ -198,6 +200,12 @@ LogicalResult verifyEntries(DataLayoutEntryListRef entries, Location loc) const; + + /// Destructs the struct into its indexed field types. + Optional> getSubelementIndexMap(); + + /// Returns which type is stored at a given integer index within the struct. + Type getTypeAtIndex(Attribute index); }; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td @@ -12,6 +12,7 @@ include "mlir/Dialect/LLVMIR/LLVMOpBase.td" include "mlir/IR/AttrTypeBase.td" include "mlir/Interfaces/DataLayoutInterfaces.td" +include "mlir/Interfaces/MemorySlotInterfaces.td" /// Base class for all LLVM dialect types. class LLVMType traits = []> @@ -24,7 +25,8 @@ //===----------------------------------------------------------------------===// def LLVMArrayType : LLVMType<"LLVMArray", "array", [ - DeclareTypeInterfaceMethods]> { + DeclareTypeInterfaceMethods, + DeclareTypeInterfaceMethods]> { let summary = "LLVM array type"; let description = [{ The `!llvm.array` type represents a fixed-size array of element types. diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt --- a/mlir/include/mlir/Interfaces/CMakeLists.txt +++ b/mlir/include/mlir/Interfaces/CMakeLists.txt @@ -19,6 +19,8 @@ set(LLVM_TARGET_DEFINITIONS MemorySlotInterfaces.td) mlir_tablegen(MemorySlotOpInterfaces.h.inc -gen-op-interface-decls) mlir_tablegen(MemorySlotOpInterfaces.cpp.inc -gen-op-interface-defs) +mlir_tablegen(MemorySlotTypeInterfaces.h.inc -gen-type-interface-decls) +mlir_tablegen(MemorySlotTypeInterfaces.cpp.inc -gen-type-interface-defs) add_public_tablegen_target(MLIRMemorySlotInterfacesIncGen) add_dependencies(mlir-generic-headers MLIRMemorySlotInterfacesIncGen) diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h --- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h +++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.h @@ -23,6 +23,13 @@ Type elemType; }; +/// Memory slot attached with information about its destructuring procedure. +struct DestructurableMemorySlot : public MemorySlot { + /// Maps an index within the memory slot to the type of the pointer that + /// will be generated to access the element directly. + DenseMap elementPtrs; +}; + /// Returned by operation promotion logic requesting the deletion of an /// operation. enum class DeletionKind { @@ -35,5 +42,6 @@ } // namespace mlir #include "mlir/Interfaces/MemorySlotOpInterfaces.h.inc" +#include "mlir/Interfaces/MemorySlotTypeInterfaces.h.inc" #endif // MLIR_INTERFACES_MEMORYSLOTINTERFACES_H diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td --- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td +++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td @@ -195,4 +195,143 @@ ]; } +def DestructurableAllocationOpInterface + : OpInterface<"DestructurableAllocationOpInterface"> { + let description = [{ + Describes operations allocating memory slots of aggregates that can be + destructured into multiple smaller allocations. + }]; + let cppNamespace = "::mlir"; + + let methods = [ + InterfaceMethod<[{ + Returns the list of slots for which destructuring should be attempted, + specifying in which way the slot should be destructured into subslots. + The subslots are indexed by attributes. This computes the type of the + pointers of each subslots to be generated. The type of the memory slot + must implement `DestructurableTypeInterface`. + }], + "::llvm::SmallVector<::mlir::DestructurableMemorySlot>", + "getDestructurableSlots", + (ins) + >, + InterfaceMethod<[{ + Destructures this slot into multiple subslots. The newly generated slots + may belong to a different allocator. The original slot must still exist + at the end of this call. + + The builder is located at the beginning of the block where the slot + pointer is defined. + }], + "::llvm::DenseMap<::mlir::Attribute, ::mlir::MemorySlot>", + "destructure", + (ins "const ::mlir::DestructurableMemorySlot &":$slot, + "::llvm::SmallPtrSetImpl<::mlir::Attribute> &":$usedIndices, + "::mlir::OpBuilder &":$builder) + >, + InterfaceMethod<[{ + Hook triggered once the destructuring of a slot is complete, meaning the + original slot is no longer being refered to and could be deleted. + This will only be called for slots declared by this operation. + }], + "void", "handleDestructuringComplete", + (ins "const ::mlir::DestructurableMemorySlot &":$slot) + >, + ]; +} + +def SafeMemorySlotAccessOpInterface + : OpInterface<"SafeMemorySlotAccessOpInterface"> { + let description = [{ + Describes operations using memory slots in a type-safe manner. + }]; + let cppNamespace = "::mlir"; + + let methods = [ + InterfaceMethod<[{ + Returns whether all accesses in this operation to the provided slot are + done in a type-safe manner. To be type-safe, the access must only load + the value in this type as the type of the slot, and without assuming any + context around the slot. For example, a type-safe load must not load + outside the bounds of the slot. + + If the type-safety of the accesses depends on the type-safety of the + accesses to further memory slots, the result of this method will be + conditioned to the type-safety of the accesses to the slots added by + this method to `mustBeSafelyUsed`. + }], + "::mlir::LogicalResult", + "ensureOnlySafeAccesses", + (ins "const ::mlir::MemorySlot &":$slot, + "::mlir::SmallVectorImpl<::mlir::MemorySlot> &":$mustBeSafelyUsed) + > + ]; +} + +def DestructurableAccessorOpInterface + : OpInterface<"DestructurableAccessorOpInterface"> { + let description = [{ + Describes operations that can access a sub-element of a destructurable slot. + }]; + let cppNamespace = "::mlir"; + + let methods = [ + InterfaceMethod<[{ + For a given destructurable memory slot, returns whether this operation can + rewire its uses of the slot to use the slots generated after + destructuring. This may involve creating new operations, and usually + amounts to checking if the pointer types match. + + This method must also register the indices it will access within the + `usedIndices` set. If the accessor generates new slots mapping to + subelements, they must be registered in `mustBeSafelyUsed` to ensure + they are used in a locally type-safe manner. + }], + "bool", + "canRewire", + (ins "const ::mlir::DestructurableMemorySlot &":$slot, + "::llvm::SmallPtrSetImpl<::mlir::Attribute> &":$usedIndices, + "::mlir::SmallVectorImpl<::mlir::MemorySlot> &":$mustBeSafelyUsed) + >, + InterfaceMethod<[{ + Rewires the use of a slot to the generated subslots, without deleting + any operation. Returns whether the accessor should be deleted. + }], + "::mlir::DeletionKind", + "rewire", + (ins "const ::mlir::DestructurableMemorySlot &":$slot, + "::llvm::DenseMap<::mlir::Attribute, ::mlir::MemorySlot> &":$subslots) + > + ]; +} + +def DestructurableTypeInterface + : TypeInterface<"DestructurableTypeInterface"> { + let description = [{ + Describes a type that can be broken down into indexable sub-element types. + }]; + let cppNamespace = "::mlir"; + + let methods = [ + InterfaceMethod<[{ + Destructures the type into subelements into a map of index attributes to + types of subelements. Returns nothing if the type cannot be destructured. + }], + "::std::optional<::llvm::DenseMap<::mlir::Attribute, ::mlir::Type>>", + "getSubelementIndexMap", + (ins) + >, + InterfaceMethod<[{ + Indicates which type is held at the provided index, returning a null + Type if no type could be computed. While this can return information + even when the type cannot be completely destructured, it must be coherent + with the types returned by `getSubelementIndexMap` when they exist. + }], + "::mlir::Type", + "getTypeAtIndex", + (ins "::mlir::Attribute":$index) + > + ]; +} + #endif // MLIR_INTERFACES_MEMORYSLOTINTERFACES diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -36,6 +36,7 @@ #define GEN_PASS_DECL_MEM2REG #define GEN_PASS_DECL_PRINTIRPASS #define GEN_PASS_DECL_PRINTOPSTATS +#define GEN_PASS_DECL_SROA #define GEN_PASS_DECL_STRIPDEBUGINFO #define GEN_PASS_DECL_SCCP #define GEN_PASS_DECL_SYMBOLDCE diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -214,6 +214,22 @@ let constructor = "mlir::createSCCPPass()"; } +def SROA : Pass<"sroa"> { + let summary = "Scalar Replacement of Aggregates"; + let description = [{ + Scalar Replacement of Aggregates. Replaces allocations of aggregates into + independant allocations of its elements. + + Allocators must implement `DestructurableAllocationOpInterface` to provide + the list of memory slots for which destructuring should be attempted. + + This pass will only be applied if all accessors of the aggregate implement + the `DestructurableAccessorOpInterface`. If the accessors provide a view + into the struct, users of the view must ensure it is used in a type-safe + manner and within bounds by implementing `TypeSafeOpInterface`. + }]; +} + def StripDebugInfo : Pass<"strip-debuginfo"> { let summary = "Strip debug info from all operations"; let description = [{ diff --git a/mlir/include/mlir/Transforms/SROA.h b/mlir/include/mlir/Transforms/SROA.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Transforms/SROA.h @@ -0,0 +1,55 @@ +//===-- SROA.h - Scalar Replacement Of Aggregates ---------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TRANSFORMS_SROA_H +#define MLIR_TRANSFORMS_SROA_H + +#include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/MemorySlotInterfaces.h" +#include "mlir/Support/LogicalResult.h" + +namespace mlir { + +/// Information computed by destructurable memory slot analysis used to perform +/// actual destructuring of the slot. This struct is only constructed if +/// destructuring is possible, and contains the necessary data to perform it. +struct MemorySlotDestructuringInfo { + SmallPtrSet usedIndices; + DenseMap> userToBlockingUses; + SmallVector accessors; +}; + +/// Computes information for slot destructuring. This will compute whether this +/// slot can be destructured and data to perform the destructuring. Returns +/// nothing if the slot cannot be destructured. +std::optional +computeDestructuringInfo(DestructurableMemorySlot &slot); + +/// Performs the destructuring of a destructible slot given associated +/// destructuring information. The provided slot will be destructured in +/// subslots as specified by its allocator. +void destructureSlot(DestructurableMemorySlot &slot, + DestructurableAllocationOpInterface allocator, + OpBuilder &builder, MemorySlotDestructuringInfo &info); + +/// Pattern applying SROA to the regions of the operations on which it +/// matches. +class SROAPattern : public RewritePattern { +public: + using RewritePattern::RewritePattern; + + SROAPattern(MLIRContext *ctx, PatternBenefit benefit = 1) + : RewritePattern(MatchAnyOpTypeTag(), benefit, ctx) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override; +}; + +} // namespace mlir + +#endif // MLIR_TRANSFORMS_SROA_H diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Matchers.h" +#include "mlir/IR/ValueRange.h" #include "mlir/Interfaces/MemorySlotInterfaces.h" #include "mlir/Support/LogicalResult.h" #include "llvm/ADT/STLExtras.h" @@ -60,6 +61,57 @@ erase(); } +SmallVector LLVM::AllocaOp::getDestructurableSlots() { + if (!mlir::matchPattern(getArraySize(), m_One())) + return {}; + + Type elemType = getAllocaElementType(*this); + auto destructurable = dyn_cast(elemType); + if (!destructurable) + return {}; + + Optional> destructuredType = + destructurable.getSubelementIndexMap(); + if (!destructuredType) + return {}; + + DenseMap allocaTypeMap; + for (Attribute index : llvm::make_first_range(destructuredType.value())) + allocaTypeMap.insert({index, LLVM::LLVMPointerType::get(getContext())}); + + return {DestructurableMemorySlot{{getResult(), elemType}, {allocaTypeMap}}}; +} + +DenseMap +LLVM::AllocaOp::destructure(const DestructurableMemorySlot &slot, + SmallPtrSetImpl &usedIndices, + OpBuilder &builder) { + assert(slot.ptr == getResult()); + Type elemType = + getElemType() ? *getElemType() : getResult().getType().getElementType(); + + builder.setInsertionPointAfter(*this); + + auto destructurableType = cast(elemType); + DenseMap slotMap; + for (Attribute index : usedIndices) { + Type elemType = destructurableType.getTypeAtIndex(index); + assert(elemType && "used index must exist"); + auto subAlloca = builder.create( + getLoc(), LLVM::LLVMPointerType::get(getContext()), elemType, + getArraySize()); + slotMap.try_emplace(index, {subAlloca.getResult(), elemType}); + } + + return slotMap; +} + +void LLVM::AllocaOp::handleDestructuringComplete( + const DestructurableMemorySlot &slot) { + assert(slot.ptr == getResult()); + erase(); +} + //===----------------------------------------------------------------------===// // Interfaces for LoadOp/StoreOp //===----------------------------------------------------------------------===// @@ -125,6 +177,17 @@ return DeletionKind::Delete; } +LogicalResult LLVM::LoadOp::ensureOnlySafeAccesses( + const MemorySlot &slot, SmallVectorImpl &mustBeSafelyUsed) { + return success(getAddr() != slot.ptr || getType() == slot.elemType); +} + +LogicalResult LLVM::StoreOp::ensureOnlySafeAccesses( + const MemorySlot &slot, SmallVectorImpl &mustBeSafelyUsed) { + return success(getAddr() != slot.ptr || + getValue().getType() == slot.elemType); +} + //===----------------------------------------------------------------------===// // Interfaces for discardable OPs //===----------------------------------------------------------------------===// @@ -193,6 +256,10 @@ return DeletionKind::Delete; } +//===----------------------------------------------------------------------===// +// Interfaces for GEPOp +//===----------------------------------------------------------------------===// + static bool hasAllZeroIndices(LLVM::GEPOp gepOp) { return llvm::all_of(gepOp.getIndices(), [](auto index) { auto indexAttr = index.template dyn_cast(); @@ -200,10 +267,6 @@ }); } -//===----------------------------------------------------------------------===// -// Interfaces for GEPOp -//===----------------------------------------------------------------------===// - bool LLVM::GEPOp::canUsesBeRemoved( const SmallPtrSetImpl &blockingUses, SmallVectorImpl &newBlockingUses) { @@ -218,3 +281,166 @@ return DeletionKind::Delete; } +/// Returns the type the resulting pointer of the GEP points to. If such a type +/// is not clear, returns null type. +static Type computeReachedGEPType(LLVM::GEPOp gep) { + if (gep.getIndices().empty()) + return {}; + + // Check the pointer indexing only targets the first element. + auto firstIndex = gep.getIndices()[0]; + IntegerAttr indexInt = firstIndex.dyn_cast(); + if (!indexInt || indexInt.getInt() != 0) + return {}; + + // Set the initial type currently being used for indexing. This will be + // updated as the indices get walked over. + std::optional maybeSelectedType = gep.getElemType(); + if (!maybeSelectedType) + return {}; + Type selectedType = *maybeSelectedType; + + // Follow the indexed elements in the gep. + for (const LLVM::GEPIndicesAdaptor::value_type &index : + llvm::drop_begin(gep.getIndices())) { + // Ensure the index is static and obtain it. + IntegerAttr indexInt = index.dyn_cast(); + if (!indexInt) + return {}; + + // Ensure the structure of the type being indexed can be reasoned about. + // This includes rejecting any potential typed pointer. + auto destructurable = selectedType.dyn_cast(); + if (!destructurable) + return {}; + + // Follow the type at the index the gep is accessing, making it the new type + // used for indexing. + Type field = destructurable.getTypeAtIndex(indexInt); + if (!field) + return {}; + selectedType = field; + } + + // When there are no more indices, the type currently being used for indexing + // is the type of the value pointed at by the returned indexed pointer. + return selectedType; +} + +LogicalResult LLVM::GEPOp::ensureOnlySafeAccesses( + const MemorySlot &slot, SmallVectorImpl &mustBeSafelyUsed) { + if (getBase() != slot.ptr) + return success(); + if (slot.elemType != getElemType()) + return failure(); + Type reachedType = computeReachedGEPType(*this); + if (!reachedType) + return failure(); + mustBeSafelyUsed.emplace_back({getResult(), reachedType}); + return success(); +} + +bool LLVM::GEPOp::canRewire(const DestructurableMemorySlot &slot, + SmallPtrSetImpl &usedIndices, + SmallVectorImpl &mustBeSafelyUsed) { + auto basePtrType = getBase().getType().dyn_cast(); + if (!basePtrType) + return false; + + // Typed pointers are not supported. This should be removed once typed + // pointers are removed from the LLVM dialect. + if (!basePtrType.isOpaque()) + return false; + + if (getBase() != slot.ptr || slot.elemType != getElemType()) + return false; + Type reachedType = computeReachedGEPType(*this); + if (!reachedType || getIndices().size() < 2) + return false; + auto firstLevelIndex = cast(getIndices()[1]); + assert(slot.elementPtrs.contains(firstLevelIndex)); + if (!slot.elementPtrs.at(firstLevelIndex).isa()) + return false; + mustBeSafelyUsed.emplace_back({getResult(), reachedType}); + usedIndices.insert(firstLevelIndex); + return true; +} + +DeletionKind LLVM::GEPOp::rewire(const DestructurableMemorySlot &slot, + DenseMap &subslots) { + IntegerAttr firstLevelIndex = getIndices()[1].dyn_cast(); + const MemorySlot &newSlot = subslots.at(firstLevelIndex); + + ArrayRef remainingIndices = getRawConstantIndices().slice(2); + + // If the GEP would become trivial after this transformation, eliminate it. + // A GEP should only be eliminated if it has no indices (except the first + // pointer index), as simplifying GEPs with all-zero indices would eliminate + // structure information useful for further destruction. + if (remainingIndices.empty()) { + getResult().replaceAllUsesWith(newSlot.ptr); + return DeletionKind::Delete; + } + + // Rewire the indices by popping off the second index. + // Start with a single zero, then add the indices beyond the second. + SmallVector newIndices(1); + newIndices.append(remainingIndices.begin(), remainingIndices.end()); + setRawConstantIndices(newIndices); + + // Rewire the pointed type. + setElemType(newSlot.elemType); + + // Rewire the pointer. + getBaseMutable().assign(newSlot.ptr); + + return DeletionKind::Keep; +} + +//===----------------------------------------------------------------------===// +// Interfaces for destructurable types +//===----------------------------------------------------------------------===// + +std::optional> +LLVM::LLVMStructType::getSubelementIndexMap() { + Type i32 = IntegerType::get(getContext(), 32); + DenseMap destructured; + for (const auto &[index, elemType] : llvm::enumerate(getBody())) + destructured.insert({IntegerAttr::get(i32, index), elemType}); + return destructured; +} + +Type LLVM::LLVMStructType::getTypeAtIndex(Attribute index) { + auto indexAttr = index.dyn_cast(); + if (!indexAttr || !indexAttr.getType().isInteger(32)) + return {}; + int32_t indexInt = indexAttr.getInt(); + ArrayRef body = getBody(); + if (indexInt < 0 || body.size() <= static_cast(indexInt)) + return {}; + return body[indexInt]; +} + +std::optional> +LLVM::LLVMArrayType::getSubelementIndexMap() const { + constexpr size_t maxArraySizeForDestructuring = 16; + if (getNumElements() > maxArraySizeForDestructuring) + return {}; + int32_t numElements = getNumElements(); + + Type i32 = IntegerType::get(getContext(), 32); + DenseMap destructured; + for (int32_t index = 0; index < numElements; ++index) + destructured.insert({IntegerAttr::get(i32, index), getElementType()}); + return destructured; +} + +Type LLVM::LLVMArrayType::getTypeAtIndex(Attribute index) const { + auto indexAttr = index.dyn_cast(); + if (!indexAttr || !indexAttr.getType().isInteger(32)) + return {}; + int32_t indexInt = indexAttr.getInt(); + if (indexInt < 0 || getNumElements() <= static_cast(indexInt)) + return {}; + return getElementType(); +} diff --git a/mlir/lib/Interfaces/MemorySlotInterfaces.cpp b/mlir/lib/Interfaces/MemorySlotInterfaces.cpp --- a/mlir/lib/Interfaces/MemorySlotInterfaces.cpp +++ b/mlir/lib/Interfaces/MemorySlotInterfaces.cpp @@ -9,3 +9,4 @@ #include "mlir/Interfaces/MemorySlotInterfaces.h" #include "mlir/Interfaces/MemorySlotOpInterfaces.cpp.inc" +#include "mlir/Interfaces/MemorySlotTypeInterfaces.cpp.inc" diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt --- a/mlir/lib/Transforms/CMakeLists.txt +++ b/mlir/lib/Transforms/CMakeLists.txt @@ -12,6 +12,7 @@ OpStats.cpp PrintIR.cpp SCCP.cpp + SROA.cpp StripDebugInfo.cpp SymbolDCE.cpp SymbolPrivatize.cpp diff --git a/mlir/lib/Transforms/SROA.cpp b/mlir/lib/Transforms/SROA.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Transforms/SROA.cpp @@ -0,0 +1,210 @@ +//===-- SROA.cpp - Scalar Replacement Of Aggregates -------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/SROA.h" +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/IR/Builders.h" +#include "mlir/Interfaces/MemorySlotInterfaces.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/TypeSwitch.h" + +namespace mlir { +#define GEN_PASS_DEF_SROA +#include "mlir/Transforms/Passes.h.inc" +} // namespace mlir + +using namespace mlir; + +std::optional +mlir::computeDestructuringInfo(DestructurableMemorySlot &slot) { + assert(isa(slot.elemType)); + + MemorySlotDestructuringInfo info; + + SmallVector usedSafelyWorklist; + + auto scheduleAsBlockingUse = [&](OpOperand &use) { + SmallPtrSetImpl &blockingUses = + info.userToBlockingUses.getOrInsertDefault(use.getOwner()); + blockingUses.insert(&use); + }; + + // Initialize the analysis with the immediate users of the slot. + for (OpOperand &use : slot.ptr.getUses()) { + if (auto accessor = + dyn_cast(use.getOwner())) { + if (accessor.canRewire(slot, info.usedIndices, usedSafelyWorklist)) { + info.accessors.push_back(accessor); + continue; + } + } + + // If it cannot be shown that the operation uses the slot safely, maybe it + // can be promoted out of using the slot? + scheduleAsBlockingUse(use); + } + + SmallPtrSet visited; + while (!usedSafelyWorklist.empty()) { + MemorySlot mustBeUsedSafely = usedSafelyWorklist.pop_back_val(); + for (OpOperand &subslotUse : mustBeUsedSafely.ptr.getUses()) { + if (!visited.insert(&subslotUse).second) + continue; + Operation *subslotUser = subslotUse.getOwner(); + + if (auto memOp = dyn_cast(subslotUser)) + if (succeeded(memOp.ensureOnlySafeAccesses(mustBeUsedSafely, + usedSafelyWorklist))) + continue; + + // If it cannot be shown that the operation uses the slot safely, maybe it + // can be promoted out of using the slot? + scheduleAsBlockingUse(subslotUse); + } + } + + SetVector forwardSlice; + mlir::getForwardSlice(slot.ptr, &forwardSlice); + for (Operation *user : forwardSlice) { + // If the next operation has no blocking uses, everything is fine. + if (!info.userToBlockingUses.contains(user)) + continue; + + SmallPtrSet &blockingUses = info.userToBlockingUses[user]; + auto promotable = dyn_cast(user); + + // An operation that has blocking uses must be promoted. If it is not + // promotable, destructuring must fail. + if (!promotable) + return {}; + + SmallVector newBlockingUses; + // If the operation decides it cannot deal with removing the blocking uses, + // destructuring must fail. + if (!promotable.canUsesBeRemoved(blockingUses, newBlockingUses)) + return {}; + + // Then, register any new blocking uses for coming operations. + for (OpOperand *blockingUse : newBlockingUses) { + assert(llvm::is_contained(user->getResults(), blockingUse->get())); + + SmallPtrSetImpl &newUserBlockingUseSet = + info.userToBlockingUses.getOrInsertDefault(blockingUse->getOwner()); + newUserBlockingUseSet.insert(blockingUse); + } + } + + return info; +} + +void mlir::destructureSlot(DestructurableMemorySlot &slot, + DestructurableAllocationOpInterface allocator, + OpBuilder &builder, + MemorySlotDestructuringInfo &info) { + OpBuilder::InsertionGuard guard(builder); + + builder.setInsertionPointToStart(slot.ptr.getParentBlock()); + DenseMap subslots = + allocator.destructure(slot, info.usedIndices, builder); + + SetVector usersToRewire; + for (Operation *user : llvm::make_first_range(info.userToBlockingUses)) + usersToRewire.insert(user); + for (DestructurableAccessorOpInterface accessor : info.accessors) + usersToRewire.insert(accessor); + usersToRewire = mlir::topologicalSort(usersToRewire); + + llvm::SmallVector toErase; + for (Operation *toRewire : llvm::reverse(usersToRewire)) { + builder.setInsertionPointAfter(toRewire); + if (auto accessor = dyn_cast(toRewire)) { + if (accessor.rewire(slot, subslots) == DeletionKind::Delete) + toErase.push_back(accessor); + continue; + } + + auto promotable = cast(toRewire); + if (promotable.removeBlockingUses(info.userToBlockingUses[promotable], + builder) == DeletionKind::Delete) + toErase.push_back(promotable); + } + + for (Operation *toEraseOp : toErase) + toEraseOp->erase(); + + assert(slot.ptr.use_empty() && "after destructuring, the original slot " + "pointer should no longer be used"); + + DEBUG_WITH_TYPE("sroa", llvm::dbgs() << "[sroa] Destructured memory slot: " + << slot.ptr << "\n"); + + allocator.handleDestructuringComplete(slot); +} + +LogicalResult SROAPattern::matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const { + hasBoundedRewriteRecursion(); + + if (op->getNumRegions() == 0) + return rewriter.notifyMatchFailure(op, "no region to apply SROA to"); + + struct DestructuringJob { + DestructurableAllocationOpInterface allocator; + DestructurableMemorySlot slot; + MemorySlotDestructuringInfo info; + }; + + std::vector toDestructure; + + for (Region ®ion : op->getRegions()) + for (auto allocator : region.getOps()) + for (DestructurableMemorySlot slot : allocator.getDestructurableSlots()) + if (auto info = computeDestructuringInfo(slot)) + toDestructure.emplace_back( + {allocator, std::move(slot), std::move(info.value())}); + + if (toDestructure.empty()) + return rewriter.notifyMatchFailure( + op, "no operation to destructure within subregions"); + + // Because MemorySlot-related interfaces cannot use RewriterBase to modify the + // IR (because it does not support some kinds of mutations), an escape hatch + // is used to mutate IR outside of the context of the rewriter. This is + // achieved by marking the parent op as mutated in place and creating + // operations via a secondary OpBuilder. + OpBuilder builder(rewriter.getContext()); + + rewriter.updateRootInPlace(op, [&]() { + for (DestructuringJob &job : toDestructure) + destructureSlot(job.slot, job.allocator, builder, job.info); + }); + + return success(); +} + +namespace { + +struct SROA : public impl::SROABase { + void runOnOperation() override { + Operation *scopeOp = getOperation(); + bool changed = false; + + RewritePatternSet rewritePatterns(&getContext()); + rewritePatterns.add(&getContext()); + FrozenRewritePatternSet frozen(std::move(rewritePatterns)); + (void)applyOpPatternsAndFold({scopeOp}, frozen, GreedyRewriteConfig(), + &changed); + + if (!changed) + markAllAnalysesPreserved(); + } +}; + +} // namespace diff --git a/mlir/test/Dialect/LLVMIR/sroa.mlir b/mlir/test/Dialect/LLVMIR/sroa.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/LLVMIR/sroa.mlir @@ -0,0 +1,207 @@ +// RUN: mlir-opt %s --pass-pipeline="builtin.module(llvm.func(sroa))" --split-input-file | FileCheck %s + +// CHECK-LABEL: llvm.func @basic_struct +llvm.func @basic_struct() -> i32 { + // CHECK: %[[SIZE:.*]] = llvm.mlir.constant(1 : i32) + %0 = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %[[SIZE]] x i32 + %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, f64, i32)> {alignment = 8 : i64} : (i32) -> !llvm.ptr + %2 = llvm.getelementptr inbounds %1[0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, f64, i32)> + // CHECK: %[[RES:.*]] = llvm.load %[[ALLOCA]] + %3 = llvm.load %2 : !llvm.ptr -> i32 + // CHECK: llvm.return %[[RES]] : i32 + llvm.return %3 : i32 +} + +// ----- + +// CHECK-LABEL: llvm.func @basic_array +llvm.func @basic_array() -> i32 { + // CHECK: %[[SIZE:.*]] = llvm.mlir.constant(1 : i32) + %0 = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %[[SIZE]] x i32 + %1 = llvm.alloca %0 x !llvm.array<10 x i32> {alignment = 8 : i64} : (i32) -> !llvm.ptr + %2 = llvm.getelementptr inbounds %1[0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<10 x i32> + // CHECK: %[[RES:.*]] = llvm.load %[[ALLOCA]] + %3 = llvm.load %2 : !llvm.ptr -> i32 + // CHECK: llvm.return %[[RES]] : i32 + llvm.return %3 : i32 +} + +// ----- + +// CHECK-LABEL: llvm.func @multi_level_direct +llvm.func @multi_level_direct() -> i32 { + // CHECK: %[[SIZE:.*]] = llvm.mlir.constant(1 : i32) + %0 = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %[[SIZE]] x i32 + %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, f64, struct<"bar", (i8, array<10 x array<10 x i32>>, i8)>)> {alignment = 8 : i64} : (i32) -> !llvm.ptr + %2 = llvm.getelementptr inbounds %1[0, 2, 1, 5, 8] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, f64, struct<"bar", (i8, array<10 x array<10 x i32>>, i8)>)> + // CHECK: %[[RES:.*]] = llvm.load %[[ALLOCA]] + %3 = llvm.load %2 : !llvm.ptr -> i32 + // CHECK: llvm.return %[[RES]] : i32 + llvm.return %3 : i32 +} + +// ----- + +// The first application of SROA would generate a GEP with indices [0, 0]. This +// test ensures this GEP is not eliminated during the first application. Even +// though doing it would be correct, it would prevent the second application +// of SROA to eliminate the array. GEPs should be eliminated only when they are +// truly trivial (with indices [0]). + +// CHECK-LABEL: llvm.func @multi_level_direct_two_applications +llvm.func @multi_level_direct_two_applications() -> i32 { + // CHECK: %[[SIZE:.*]] = llvm.mlir.constant(1 : i32) + %0 = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %[[SIZE]] x i32 + %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, f64, array<10 x i32>, i8)> {alignment = 8 : i64} : (i32) -> !llvm.ptr + %2 = llvm.getelementptr inbounds %1[0, 2, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, f64, array<10 x i32>, i8)> + // CHECK: %[[RES:.*]] = llvm.load %[[ALLOCA]] + %3 = llvm.load %2 : !llvm.ptr -> i32 + // CHECK: llvm.return %[[RES]] : i32 + llvm.return %3 : i32 +} + +// ----- + +// CHECK-LABEL: llvm.func @multi_level_indirect +llvm.func @multi_level_indirect() -> i32 { + // CHECK: %[[SIZE:.*]] = llvm.mlir.constant(1 : i32) + %0 = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %[[SIZE]] x i32 + %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, f64, struct<"bar", (i8, array<10 x array<10 x i32>>, i8)>)> {alignment = 8 : i64} : (i32) -> !llvm.ptr + %2 = llvm.getelementptr inbounds %1[0, 2, 1, 5] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, f64, struct<"bar", (i8, array<10 x array<10 x i32>>, i8)>)> + %3 = llvm.getelementptr inbounds %2[0, 8] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<10 x i32> + // CHECK: %[[RES:.*]] = llvm.load %[[ALLOCA]] + %4 = llvm.load %3 : !llvm.ptr -> i32 + // CHECK: llvm.return %[[RES]] : i32 + llvm.return %4 : i32 +} + +// ----- + +// CHECK-LABEL: llvm.func @resolve_alias +// CHECK-SAME: (%[[ARG:.*]]: i32) +llvm.func @resolve_alias(%arg: i32) -> i32 { + // CHECK: %[[SIZE:.*]] = llvm.mlir.constant(1 : i32) + %0 = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %[[SIZE]] x i32 + %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, f64, i32)> {alignment = 8 : i64} : (i32) -> !llvm.ptr + %2 = llvm.getelementptr %1[0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, f64, i32)> + %3 = llvm.getelementptr inbounds %1[0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, f64, i32)> + // CHECK: llvm.store %[[ARG]], %[[ALLOCA]] + llvm.store %arg, %2 : i32, !llvm.ptr + // CHECK: %[[RES:.*]] = llvm.load %[[ALLOCA]] + %4 = llvm.load %3 : !llvm.ptr -> i32 + // CHECK: llvm.return %[[RES]] : i32 + llvm.return %4 : i32 +} + +// ----- + +// CHECK-LABEL: llvm.func @no_non_single_support +llvm.func @no_non_single_support() -> i32 { + // CHECK: %[[SIZE:.*]] = llvm.mlir.constant + %0 = llvm.mlir.constant(2 : i32) : i32 + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %[[SIZE]] x !llvm.struct<"foo", (i32, f64, i32)> {alignment = 8 : i64} : (i32) -> !llvm.ptr + %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, f64, i32)> {alignment = 8 : i64} : (i32) -> !llvm.ptr + // CHECK-NOT: = llvm.alloca + %2 = llvm.getelementptr inbounds %1[0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, f64, i32)> + %3 = llvm.load %2 : !llvm.ptr -> i32 + llvm.return %3 : i32 +} + +// ----- + +// CHECK-LABEL: llvm.func @no_pointer_indexing +llvm.func @no_pointer_indexing() -> i32 { + // CHECK: %[[SIZE:.*]] = llvm.mlir.constant(1 : i32) + %0 = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %[[SIZE]] x !llvm.struct<"foo", (i32, f64, i32)> {alignment = 8 : i64} : (i32) -> !llvm.ptr + %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, f64, i32)> {alignment = 8 : i64} : (i32) -> !llvm.ptr + // CHECK-NOT: = llvm.alloca + %2 = llvm.getelementptr %1[1, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, f64, i32)> + %3 = llvm.load %2 : !llvm.ptr -> i32 + llvm.return %3 : i32 +} + +// ----- + +// CHECK-LABEL: llvm.func @no_direct_use +llvm.func @no_direct_use() -> i32 { + // CHECK: %[[SIZE:.*]] = llvm.mlir.constant(1 : i32) + %0 = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %[[SIZE]] x !llvm.struct<"foo", (i32, f64, i32)> {alignment = 8 : i64} : (i32) -> !llvm.ptr + %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, f64, i32)> {alignment = 8 : i64} : (i32) -> !llvm.ptr + // CHECK-NOT: = llvm.alloca + %2 = llvm.getelementptr %1[0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, f64, i32)> + %3 = llvm.load %2 : !llvm.ptr -> i32 + llvm.call @use(%1) : (!llvm.ptr) -> () + llvm.return %3 : i32 +} + +llvm.func @use(!llvm.ptr) + +// ----- + +// CHECK-LABEL: llvm.func @direct_promotable_use_is_fine +llvm.func @direct_promotable_use_is_fine() -> i32 { + // CHECK: %[[SIZE:.*]] = llvm.mlir.constant(1 : i32) + %0 = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %[[SIZE]] x i32 + %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, f64, i32)> {alignment = 8 : i64} : (i32) -> !llvm.ptr + %2 = llvm.getelementptr %1[0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, f64, i32)> + // CHECK: %[[RES:.*]] = llvm.load %[[ALLOCA]] + %3 = llvm.load %2 : !llvm.ptr -> i32 + // This is a direct use of the slot but it can be removed because it implements PromotableOpInterface. + llvm.intr.lifetime.start 2, %1 : !llvm.ptr + // CHECK: llvm.return %[[RES]] : i32 + llvm.return %3 : i32 +} + +// ----- + +// CHECK-LABEL: llvm.func @direct_promotable_use_is_fine_on_accessor +llvm.func @direct_promotable_use_is_fine_on_accessor() -> i32 { + // CHECK: %[[SIZE:.*]] = llvm.mlir.constant(1 : i32) + %0 = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %[[SIZE]] x i32 + %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, f64, i32)> {alignment = 8 : i64} : (i32) -> !llvm.ptr + %2 = llvm.getelementptr %1[0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, f64, i32)> + // CHECK: %[[RES:.*]] = llvm.load %[[ALLOCA]] + %3 = llvm.load %2 : !llvm.ptr -> i32 + // This does not provide side-effect info but it can be removed because it implements PromotableOpInterface. + llvm.intr.lifetime.start 2, %2 : !llvm.ptr + // CHECK: llvm.return %[[RES]] : i32 + llvm.return %3 : i32 +} + +// ----- + +// CHECK-LABEL: llvm.func @no_dynamic_indexing +llvm.func @no_dynamic_indexing() -> i32 { + // CHECK: %[[SIZE:.*]] = llvm.mlir.constant(1 : i32) + %0 = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %[[SIZE]] x !llvm.array<10 x i32> {alignment = 8 : i64} : (i32) -> !llvm.ptr + %1 = llvm.alloca %0 x !llvm.array<10 x i32> {alignment = 8 : i64} : (i32) -> !llvm.ptr + // CHECK-NOT: = llvm.alloca + %2 = llvm.getelementptr %1[0, %0] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.array<10 x i32> + %3 = llvm.load %2 : !llvm.ptr -> i32 + llvm.return %3 : i32 +} + +// ----- + +// CHECK-LABEL: llvm.func @no_typed_pointers +llvm.func @no_typed_pointers() -> i32 { + // CHECK: %[[SIZE:.*]] = llvm.mlir.constant(1 : i32) + %0 = llvm.mlir.constant(1 : i32) : i32 + // CHECK: %[[ALLOCA:.*]] = llvm.alloca %[[SIZE]] x !llvm.array<10 x i32> {alignment = 8 : i64} : (i32) -> !llvm.ptr> + %1 = llvm.alloca %0 x !llvm.array<10 x i32> {alignment = 8 : i64} : (i32) -> !llvm.ptr> + // CHECK-NOT: = llvm.alloca + %2 = llvm.getelementptr %1[0, 1] : (!llvm.ptr>) -> !llvm.ptr + %3 = llvm.load %2 : !llvm.ptr + llvm.return %3 : i32 +}