diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefMemorySlot.h b/mlir/include/mlir/Dialect/MemRef/IR/MemRefMemorySlot.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefMemorySlot.h @@ -0,0 +1,20 @@ +//===- MemRefMemorySlot.h - Implementation of Memory Slot Interfaces ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_MEMREF_IR_MEMREFMEMORYSLOT_H +#define MLIR_DIALECT_MEMREF_IR_MEMREFMEMORYSLOT_H + +namespace mlir { +class DialectRegistry; + +namespace memref { +void registerMemorySlotExternalModels(DialectRegistry ®istry); +} // namespace memref +} // namespace mlir + +#endif // MLIR_DIALECT_MEMREF_IR_MEMREFMEMORYSLOT_H diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -311,7 +311,8 @@ def MemRef_AllocaOp : AllocLikeOp<"alloca", AutomaticAllocationScopeResource,[ DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { let summary = "stack memory allocation operation"; let description = [{ The `alloca` operation allocates memory on the stack, to be automatically @@ -1162,7 +1163,8 @@ "memref", "result", "::llvm::cast($_self).getElementType()">, MemRefsNormalizable, - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { let summary = "load operation"; let description = [{ The `load` op reads an element from a memref specified by an index list. The @@ -1752,7 +1754,8 @@ "memref", "value", "::llvm::cast($_self).getElementType()">, MemRefsNormalizable, - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { let summary = "store operation"; let description = [{ Store a value to a memref location given by indices. The value stored should diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -48,6 +48,7 @@ #include "mlir/Dialect/MLProgram/IR/MLProgram.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h" #include "mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h" #include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h" #include "mlir/Dialect/MemRef/Transforms/BufferizableOpInterfaceImpl.h" @@ -148,6 +149,7 @@ memref::registerBufferizableOpInterfaceExternalModels(registry); memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry); memref::registerValueBoundsOpInterfaceExternalModels(registry); + memref::registerMemorySlotExternalModels(registry); scf::registerBufferizableOpInterfaceExternalModels(registry); scf::registerValueBoundsOpInterfaceExternalModels(registry); shape::registerBufferizableOpInterfaceExternalModels(registry); diff --git a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt --- a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt @@ -1,6 +1,6 @@ add_mlir_dialect_library(MLIRMemRefDialect MemRefDialect.cpp - MemRefMem2Reg.cpp + MemRefMemorySlot.cpp MemRefOps.cpp ValueBoundsOpInterfaceImpl.cpp @@ -21,6 +21,7 @@ MLIRDialectUtils MLIRInferTypeOpInterface MLIRIR + MLIRMemorySlotInterfaces MLIRShapedOpInterfaces MLIRSideEffectInterfaces MLIRValueBoundsOpInterface diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMem2Reg.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMem2Reg.cpp deleted file mode 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefMem2Reg.cpp +++ /dev/null @@ -1,119 +0,0 @@ -//===- MemRefMem2Reg.cpp - Mem2Reg Interfaces -------------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements Mem2Reg-related interfaces for MemRef dialect -// operations. -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "llvm/ADT/TypeSwitch.h" - -using namespace mlir; - -//===----------------------------------------------------------------------===// -// AllocaOp interfaces -//===----------------------------------------------------------------------===// - -static bool isSupportedElementType(Type type) { - return llvm::isa(type) || - OpBuilder(type.getContext()).getZeroAttr(type); -} - -SmallVector memref::AllocaOp::getPromotableSlots() { - MemRefType type = getType(); - if (!isSupportedElementType(type.getElementType())) - return {}; - if (!type.hasStaticShape()) - return {}; - // Make sure the memref contains only a single element. - if (any_of(type.getShape(), [](uint64_t dim) { return dim != 1; })) - return {}; - - return {MemorySlot{getResult(), type.getElementType()}}; -} - -Value memref::AllocaOp::getDefaultValue(const MemorySlot &slot, - RewriterBase &rewriter) { - assert(isSupportedElementType(slot.elemType)); - // TODO: support more types. - return TypeSwitch(slot.elemType) - .Case([&](MemRefType t) { - return rewriter.create(getLoc(), t); - }) - .Default([&](Type t) { - return rewriter.create(getLoc(), t, - rewriter.getZeroAttr(t)); - }); -} - -void memref::AllocaOp::handlePromotionComplete(const MemorySlot &slot, - Value defaultValue, - RewriterBase &rewriter) { - if (defaultValue.use_empty()) - rewriter.eraseOp(defaultValue.getDefiningOp()); - rewriter.eraseOp(*this); -} - -void memref::AllocaOp::handleBlockArgument(const MemorySlot &slot, - BlockArgument argument, - RewriterBase &rewriter) {} - -//===----------------------------------------------------------------------===// -// LoadOp/StoreOp interfaces -//===----------------------------------------------------------------------===// - -bool memref::LoadOp::loadsFrom(const MemorySlot &slot) { - return getMemRef() == slot.ptr; -} - -Value memref::LoadOp::getStored(const MemorySlot &slot) { return {}; } - -bool memref::LoadOp::canUsesBeRemoved( - const MemorySlot &slot, const SmallPtrSetImpl &blockingUses, - SmallVectorImpl &newBlockingUses) { - if (blockingUses.size() != 1) - return false; - Value blockingUse = (*blockingUses.begin())->get(); - return blockingUse == slot.ptr && getMemRef() == slot.ptr && - getResult().getType() == slot.elemType; -} - -DeletionKind memref::LoadOp::removeBlockingUses( - const MemorySlot &slot, const SmallPtrSetImpl &blockingUses, - RewriterBase &rewriter, Value reachingDefinition) { - // `canUsesBeRemoved` checked this blocking use must be the loaded slot - // pointer. - rewriter.replaceAllUsesWith(getResult(), reachingDefinition); - return DeletionKind::Delete; -} - -bool memref::StoreOp::loadsFrom(const MemorySlot &slot) { return false; } - -Value memref::StoreOp::getStored(const MemorySlot &slot) { - if (getMemRef() != slot.ptr) - return {}; - return getValue(); -} - -bool memref::StoreOp::canUsesBeRemoved( - const MemorySlot &slot, const SmallPtrSetImpl &blockingUses, - SmallVectorImpl &newBlockingUses) { - if (blockingUses.size() != 1) - return false; - Value blockingUse = (*blockingUses.begin())->get(); - return blockingUse == slot.ptr && getMemRef() == slot.ptr && - getValue() != slot.ptr && getValue().getType() == slot.elemType; -} - -DeletionKind memref::StoreOp::removeBlockingUses( - const MemorySlot &slot, const SmallPtrSetImpl &blockingUses, - RewriterBase &rewriter, Value reachingDefinition) { - return DeletionKind::Delete; -} diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp @@ -0,0 +1,331 @@ +//===- MemRefMemorySlot.cpp - Memory Slot Interfaces ------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements Mem2Reg-related interfaces for MemRef dialect +// operations. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Complex/IR/Complex.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Value.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Interfaces/MemorySlotInterfaces.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Utilities +//===----------------------------------------------------------------------===// + +/// Walks over the indices of the elements of a tensor of a given `shape` by +/// updating `index` in place to the next index. This returns failure if the +/// provided index was the last index. +static LogicalResult nextIndex(ArrayRef shape, + MutableArrayRef index) { + for (size_t i = 0; i < shape.size(); ++i) { + index[i]++; + if (index[i] < shape[i]) + return success(); + index[i] = 0; + } + return failure(); +} + +/// Calls `walker` for each index within a tensor of a given `shape`, providing +/// the index as an array attribute of the coordinates. +template +static void walkIndicesAsAttr(MLIRContext *ctx, ArrayRef shape, + CallableT &&walker) { + Type indexType = IndexType::get(ctx); + SmallVector shapeIter(shape.size(), 0); + do { + SmallVector indexAsAttr; + for (int64_t dim : shapeIter) + indexAsAttr.push_back(IntegerAttr::get(indexType, dim)); + walker(ArrayAttr::get(ctx, indexAsAttr)); + } while (succeeded(nextIndex(shape, shapeIter))); +} + +//===----------------------------------------------------------------------===// +// Interfaces for AllocaOp +//===----------------------------------------------------------------------===// + +static bool isSupportedElementType(Type type) { + return type.isa() || + OpBuilder(type.getContext()).getZeroAttr(type); +} + +SmallVector memref::AllocaOp::getPromotableSlots() { + MemRefType type = getType(); + if (!isSupportedElementType(type.getElementType())) + return {}; + if (!type.hasStaticShape()) + return {}; + // Make sure the memref contains only a single element. + if (type.getNumElements() != 1) + return {}; + + return {MemorySlot{getResult(), type.getElementType()}}; +} + +Value memref::AllocaOp::getDefaultValue(const MemorySlot &slot, + RewriterBase &rewriter) { + assert(isSupportedElementType(slot.elemType)); + // TODO: support more types. + return TypeSwitch(slot.elemType) + .Case([&](MemRefType t) { + return rewriter.create(getLoc(), t); + }) + .Default([&](Type t) { + return rewriter.create(getLoc(), t, + rewriter.getZeroAttr(t)); + }); +} + +void memref::AllocaOp::handlePromotionComplete(const MemorySlot &slot, + Value defaultValue, + RewriterBase &rewriter) { + if (defaultValue.use_empty()) + rewriter.eraseOp(defaultValue.getDefiningOp()); + rewriter.eraseOp(*this); +} + +void memref::AllocaOp::handleBlockArgument(const MemorySlot &slot, + BlockArgument argument, + RewriterBase &rewriter) {} + +SmallVector +memref::AllocaOp::getDestructurableSlots() { + MemRefType memrefType = getType(); + auto destructurable = memrefType.dyn_cast(); + if (!destructurable) + return {}; + + Optional> destructuredType = + destructurable.getSubelementIndexMap(); + if (!destructuredType) + return {}; + + DenseMap indexMap; + for (auto const &[index, type] : *destructuredType) + indexMap.insert({index, MemRefType::get({}, type)}); + + return {DestructurableMemorySlot{{getMemref(), memrefType}, indexMap}}; +} + +DenseMap +memref::AllocaOp::destructure(const DestructurableMemorySlot &slot, + const SmallPtrSetImpl &usedIndices, + RewriterBase &rewriter) { + rewriter.setInsertionPointAfter(*this); + + DenseMap slotMap; + + auto memrefType = getType().cast(); + for (Attribute usedIndex : usedIndices) { + Type elemType = memrefType.getTypeAtIndex(usedIndex); + MemRefType elemPtr = MemRefType::get({}, elemType); + auto subAlloca = rewriter.create(getLoc(), elemPtr); + slotMap.try_emplace(usedIndex, + {subAlloca.getResult(), elemType}); + } + + return slotMap; +} + +void memref::AllocaOp::handleDestructuringComplete( + const DestructurableMemorySlot &slot, RewriterBase &rewriter) { + assert(slot.ptr == getResult()); + rewriter.eraseOp(*this); +} + +//===----------------------------------------------------------------------===// +// Interfaces for LoadOp/StoreOp +//===----------------------------------------------------------------------===// + +bool memref::LoadOp::loadsFrom(const MemorySlot &slot) { + return getMemRef() == slot.ptr; +} + +Value memref::LoadOp::getStored(const MemorySlot &slot) { return {}; } + +bool memref::LoadOp::canUsesBeRemoved( + const MemorySlot &slot, const SmallPtrSetImpl &blockingUses, + SmallVectorImpl &newBlockingUses) { + if (blockingUses.size() != 1) + return false; + Value blockingUse = (*blockingUses.begin())->get(); + return blockingUse == slot.ptr && getMemRef() == slot.ptr && + getResult().getType() == slot.elemType; +} + +DeletionKind memref::LoadOp::removeBlockingUses( + const MemorySlot &slot, const SmallPtrSetImpl &blockingUses, + RewriterBase &rewriter, Value reachingDefinition) { + // `canUsesBeRemoved` checked this blocking use must be the loaded slot + // pointer. + rewriter.replaceAllUsesWith(getResult(), reachingDefinition); + return DeletionKind::Delete; +} + +/// Returns the index of a memref in attribute form, given its indices. +static Attribute getAttributeIndexFromIndexOperands(MLIRContext *ctx, + ValueRange indices) { + SmallVector index; + for (Value coord : indices) { + IntegerAttr coordAttr; + if (!matchPattern(coord, m_Constant(&coordAttr))) + return {}; + index.push_back(coordAttr); + } + return ArrayAttr::get(ctx, index); +} + +bool memref::LoadOp::canRewire(const DestructurableMemorySlot &slot, + SmallPtrSetImpl &usedIndices, + SmallVectorImpl &mustBeSafelyUsed) { + if (slot.ptr != getMemRef()) + return false; + Attribute index = + getAttributeIndexFromIndexOperands(getContext(), getIndices()); + if (!index) + return false; + usedIndices.insert(index); + return true; +} + +DeletionKind memref::LoadOp::rewire(const DestructurableMemorySlot &slot, + DenseMap &subslots, + RewriterBase &rewriter) { + Attribute index = + getAttributeIndexFromIndexOperands(getContext(), getIndices()); + const MemorySlot &memorySlot = subslots.at(index); + rewriter.updateRootInPlace(*this, [&]() { + setMemRef(memorySlot.ptr); + getIndicesMutable().clear(); + }); + return DeletionKind::Keep; +} + +bool memref::StoreOp::loadsFrom(const MemorySlot &slot) { return false; } + +Value memref::StoreOp::getStored(const MemorySlot &slot) { + if (getMemRef() != slot.ptr) + return {}; + return getValue(); +} + +bool memref::StoreOp::canUsesBeRemoved( + const MemorySlot &slot, const SmallPtrSetImpl &blockingUses, + SmallVectorImpl &newBlockingUses) { + if (blockingUses.size() != 1) + return false; + Value blockingUse = (*blockingUses.begin())->get(); + return blockingUse == slot.ptr && getMemRef() == slot.ptr && + getValue() != slot.ptr && getValue().getType() == slot.elemType; +} + +DeletionKind memref::StoreOp::removeBlockingUses( + const MemorySlot &slot, const SmallPtrSetImpl &blockingUses, + RewriterBase &rewriter, Value reachingDefinition) { + return DeletionKind::Delete; +} + +bool memref::StoreOp::canRewire(const DestructurableMemorySlot &slot, + SmallPtrSetImpl &usedIndices, + SmallVectorImpl &mustBeSafelyUsed) { + if (slot.ptr != getMemRef() || getValue() == slot.ptr) + return false; + Attribute index = + getAttributeIndexFromIndexOperands(getContext(), getIndices()); + if (!index || !slot.elementPtrs.contains(index)) + return false; + usedIndices.insert(index); + return true; +} + +DeletionKind memref::StoreOp::rewire(const DestructurableMemorySlot &slot, + DenseMap &subslots, + RewriterBase &rewriter) { + Attribute index = + getAttributeIndexFromIndexOperands(getContext(), getIndices()); + const MemorySlot &memorySlot = subslots.at(index); + rewriter.updateRootInPlace(*this, [&]() { + setMemRef(memorySlot.ptr); + getIndicesMutable().clear(); + }); + return DeletionKind::Keep; +} + +//===----------------------------------------------------------------------===// +// Interfaces for destructurable types +//===----------------------------------------------------------------------===// + +namespace { + +struct MemRefDestructurableTypeExternalModel + : public DestructurableTypeInterface::ExternalModel< + MemRefDestructurableTypeExternalModel, MemRefType> { + std::optional> + getSubelementIndexMap(Type type) const { + auto memrefType = type.cast(); + constexpr int64_t maxMemrefSizeForDestructuring = 16; + if (!memrefType.hasStaticShape() || + memrefType.getNumElements() > maxMemrefSizeForDestructuring || + memrefType.getNumElements() == 1) + return {}; + + DenseMap destructured; + walkIndicesAsAttr( + memrefType.getContext(), memrefType.getShape(), [&](Attribute index) { + destructured.insert({index, memrefType.getElementType()}); + }); + + return destructured; + } + + Type getTypeAtIndex(Type type, Attribute index) const { + auto memrefType = type.cast(); + auto coordArrAttr = index.dyn_cast(); + if (!coordArrAttr || coordArrAttr.size() != memrefType.getShape().size()) + return {}; + + Type indexType = IndexType::get(memrefType.getContext()); + for (const auto &[coordAttr, dimSize] : + llvm::zip(coordArrAttr, memrefType.getShape())) { + auto coord = coordAttr.dyn_cast(); + if (!coord || coord.getType() != indexType || coord.getInt() < 0 || + coord.getInt() >= dimSize) + return {}; + } + + return memrefType.getElementType(); + } +}; + +} // namespace + +//===----------------------------------------------------------------------===// +// Register external models +//===----------------------------------------------------------------------===// + +void mlir::memref::registerMemorySlotExternalModels(DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, BuiltinDialect *dialect) { + MemRefType::attachInterface(*ctx); + }); +} diff --git a/mlir/test/Dialect/MemRef/sroa.mlir b/mlir/test/Dialect/MemRef/sroa.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/MemRef/sroa.mlir @@ -0,0 +1,154 @@ +// RUN: mlir-opt %s --pass-pipeline="builtin.module(func.func(sroa))" --split-input-file | FileCheck %s + +// CHECK-LABEL: func.func @basic +// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32) +func.func @basic(%arg0: i32, %arg1: i32) -> i32 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + // CHECK-NOT: = memref.alloca() + // CHECK-COUNT-2: = memref.alloca() : memref + // CHECK-NOT: = memref.alloca() + %alloca = memref.alloca() : memref<2xi32> + // CHECK: memref.store %[[ARG0]], %[[ALLOCA0:.*]][] + memref.store %arg0, %alloca[%c0] : memref<2xi32> + // CHECK: memref.store %[[ARG1]], %[[ALLOCA1:.*]][] + memref.store %arg1, %alloca[%c1] : memref<2xi32> + // CHECK: %[[RES:.*]] = memref.load %[[ALLOCA0]][] + %res = memref.load %alloca[%c0] : memref<2xi32> + // CHECK: return %[[RES]] : i32 + return %res : i32 +} + +// ----- + +// CHECK-LABEL: func.func @basic_high_dimensions +// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32) +func.func @basic_high_dimensions(%arg0: i32, %arg1: i32, %arg2: i32) -> i32 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + // CHECK-NOT: = memref.alloca() + // CHECK-COUNT-3: = memref.alloca() : memref + // CHECK-NOT: = memref.alloca() + %alloca = memref.alloca() : memref<2x2xi32> + // CHECK: memref.store %[[ARG0]], %[[ALLOCA0:.*]][] + memref.store %arg0, %alloca[%c0, %c0] : memref<2x2xi32> + // CHECK: memref.store %[[ARG1]], %[[ALLOCA1:.*]][] + memref.store %arg1, %alloca[%c0, %c1] : memref<2x2xi32> + // CHECK: memref.store %[[ARG2]], %[[ALLOCA2:.*]][] + memref.store %arg2, %alloca[%c1, %c0] : memref<2x2xi32> + // CHECK: %[[RES:.*]] = memref.load %[[ALLOCA1]][] + %res = memref.load %alloca[%c0, %c1] : memref<2x2xi32> + // CHECK: return %[[RES]] : i32 + return %res : i32 +} + +// ----- + +// CHECK-LABEL: func.func @resolve_alias +// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32) +func.func @resolve_alias(%arg0: i32, %arg1: i32) -> i32 { + %c0 = arith.constant 0 : index + // CHECK-NOT: = memref.alloca() + // CHECK: %[[ALLOCA:.*]] = memref.alloca() : memref + // CHECK-NOT: = memref.alloca() + %alloca = memref.alloca() : memref<2xi32> + // CHECK: memref.store %[[ARG0]], %[[ALLOCA]][] + memref.store %arg0, %alloca[%c0] : memref<2xi32> + // CHECK: memref.store %[[ARG1]], %[[ALLOCA]][] + memref.store %arg1, %alloca[%c0] : memref<2xi32> + // CHECK: %[[RES:.*]] = memref.load %[[ALLOCA]][] + %res = memref.load %alloca[%c0] : memref<2xi32> + // CHECK: return %[[RES]] : i32 + return %res : i32 +} + +// ----- + +// CHECK-LABEL: func.func @no_direct_use +// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32) +func.func @no_direct_use(%arg0: i32, %arg1: i32) -> i32 { + // CHECK: %[[C0:.*]] = arith.constant 0 : index + %c0 = arith.constant 0 : index + // CHECK: %[[C1:.*]] = arith.constant 1 : index + %c1 = arith.constant 1 : index + // CHECK-NOT: = memref.alloca() + // CHECK: %[[ALLOCA:.*]] = memref.alloca() : memref<2xi32> + // CHECK-NOT: = memref.alloca() + %alloca = memref.alloca() : memref<2xi32> + // CHECK: memref.store %[[ARG0]], %[[ALLOCA]][%[[C0]]] + memref.store %arg0, %alloca[%c0] : memref<2xi32> + // CHECK: memref.store %[[ARG1]], %[[ALLOCA]][%[[C1]]] + memref.store %arg1, %alloca[%c1] : memref<2xi32> + // CHECK: %[[RES:.*]] = memref.load %[[ALLOCA]][%[[C0]]] + %res = memref.load %alloca[%c0] : memref<2xi32> + call @use(%alloca) : (memref<2xi32>) -> () + // CHECK: return %[[RES]] : i32 + return %res : i32 +} + +func.func @use(%foo: memref<2xi32>) { return } + +// ----- + +// CHECK-LABEL: func.func @no_dynamic_indexing +// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[INDEX:.*]]: index) +func.func @no_dynamic_indexing(%arg0: i32, %arg1: i32, %index: index) -> i32 { + // CHECK: %[[C0:.*]] = arith.constant 0 : index + %c0 = arith.constant 0 : index + // CHECK-NOT: = memref.alloca() + // CHECK: %[[ALLOCA:.*]] = memref.alloca() : memref<2xi32> + // CHECK-NOT: = memref.alloca() + %alloca = memref.alloca() : memref<2xi32> + // CHECK: memref.store %[[ARG0]], %[[ALLOCA]][%[[C0]]] + memref.store %arg0, %alloca[%c0] : memref<2xi32> + // CHECK: memref.store %[[ARG1]], %[[ALLOCA]][%[[INDEX]]] + memref.store %arg1, %alloca[%index] : memref<2xi32> + // CHECK: %[[RES:.*]] = memref.load %[[ALLOCA]][%[[C0]]] + %res = memref.load %alloca[%c0] : memref<2xi32> + // CHECK: return %[[RES]] : i32 + return %res : i32 +} + +// ----- + +// CHECK-LABEL: func.func @no_dynamic_shape +// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32) +func.func @no_dynamic_shape(%arg0: i32, %arg1: i32) -> i32 { + // CHECK: %[[C0:.*]] = arith.constant 0 : index + %c0 = arith.constant 0 : index + // CHECK: %[[C1:.*]] = arith.constant 1 : index + %c1 = arith.constant 1 : index + // CHECK-NOT: = memref.alloca() + // CHECK: %[[ALLOCA:.*]] = memref.alloca(%[[C1]]) : memref + // CHECK-NOT: = memref.alloca() + %alloca = memref.alloca(%c1) : memref + // CHECK: memref.store %[[ARG0]], %[[ALLOCA]][%[[C0]], %[[C0]]] + memref.store %arg0, %alloca[%c0, %c0] : memref + // CHECK: %[[RES:.*]] = memref.load %[[ALLOCA]][%[[C0]], %[[C0]]] + %res = memref.load %alloca[%c0, %c0] : memref + // CHECK: return %[[RES]] : i32 + return %res : i32 +} + +// ----- + +// CHECK-LABEL: func.func @no_out_of_bounds +// CHECK-SAME: (%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32) +func.func @no_out_of_bounds(%arg0: i32, %arg1: i32) -> i32 { + // CHECK: %[[C0:.*]] = arith.constant 0 : index + %c0 = arith.constant 0 : index + // CHECK: %[[C100:.*]] = arith.constant 100 : index + %c100 = arith.constant 100 : index + // CHECK-NOT: = memref.alloca() + // CHECK: %[[ALLOCA:.*]] = memref.alloca() : memref<2xi32> + // CHECK-NOT: = memref.alloca() + %alloca = memref.alloca() : memref<2xi32> + // CHECK: memref.store %[[ARG0]], %[[ALLOCA]][%[[C0]]] + memref.store %arg0, %alloca[%c0] : memref<2xi32> + // CHECK: memref.store %[[ARG1]], %[[ALLOCA]][%[[C100]]] + memref.store %arg1, %alloca[%c100] : memref<2xi32> + // CHECK: %[[RES:.*]] = memref.load %[[ALLOCA]][%[[C0]]] + %res = memref.load %alloca[%c0] : memref<2xi32> + // CHECK: return %[[RES]] : i32 + return %res : i32 +}