diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h @@ -16,6 +16,7 @@ #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/CopyOpInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Interfaces/Mem2RegInterfaces.h" #include "mlir/Interfaces/ShapedOpInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/ViewLikeInterface.h" diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -15,6 +15,7 @@ include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/CopyOpInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/Mem2RegInterfaces.td" include "mlir/Interfaces/ShapedOpInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/ViewLikeInterface.td" @@ -298,7 +299,8 @@ //===----------------------------------------------------------------------===// def MemRef_AllocaOp : AllocLikeOp<"alloca", AutomaticAllocationScopeResource,[ - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { let summary = "stack memory allocation operation"; let description = [{ The `alloca` operation allocates memory on the stack, to be automatically @@ -1148,7 +1150,8 @@ [TypesMatchWith<"result type matches element type of 'memref'", "memref", "result", "$_self.cast().getElementType()">, - MemRefsNormalizable]> { + MemRefsNormalizable, + DeclareOpInterfaceMethods]> { let summary = "load operation"; let description = [{ The `load` op reads an element from a memref specified by an index list. The @@ -1737,7 +1740,8 @@ [TypesMatchWith<"type of 'value' matches element type of 'memref'", "memref", "value", "$_self.cast().getElementType()">, - MemRefsNormalizable]> { + MemRefsNormalizable, + DeclareOpInterfaceMethods]> { let summary = "store operation"; let description = [{ Store a value to a memref location given by indices. The value stored should diff --git a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt --- a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRMemRefDialect MemRefDialect.cpp + MemRefMem2Reg.cpp MemRefOps.cpp ValueBoundsOpInterfaceImpl.cpp diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMem2Reg.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMem2Reg.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/MemRef/IR/MemRefMem2Reg.cpp @@ -0,0 +1,119 @@ +//===- MemRefMem2Reg.cpp - Mem2Reg Interfaces -------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements Mem2Reg-related interfaces for MemRef dialect +// operations. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Complex/IR/Complex.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Interfaces/Mem2RegInterfaces.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// AllocaOp interfaces +//===----------------------------------------------------------------------===// + +static bool isSupportedElementType(Type type) { + return type.isa() || + OpBuilder(type.getContext()).getZeroAttr(type); +} + +SmallVector memref::AllocaOp::getPromotableSlots() { + MemRefType type = getType(); + if (!isSupportedElementType(type.getElementType())) + return {}; + if (!type.hasStaticShape()) + return {}; + // Make sure the memref contains only a single element. + if (any_of(type.getShape(), [](uint64_t dim) { return dim != 1; })) + return {}; + + return {MemorySlot{getResult(), type.getElementType()}}; +} + +Value memref::AllocaOp::getDefaultValue(const MemorySlot &slot, + OpBuilder &builder) { + assert(isSupportedElementType(slot.elemType)); + // TODO: support more types + return TypeSwitch(slot.elemType) + .Case( + [&](auto t) { return builder.create(getLoc(), t); }) + .Default([&](Type t) { + return builder.create(getLoc(), t, + builder.getZeroAttr(t)); + }); +} + +void memref::AllocaOp::handlePromotionComplete(const MemorySlot &slot, + Value defaultValue) { + if (defaultValue.use_empty()) + defaultValue.getDefiningOp()->erase(); + erase(); +} + +void memref::AllocaOp::handleBlockArgument(const MemorySlot &slot, + BlockArgument argument, + OpBuilder &builder) {} + +//===----------------------------------------------------------------------===// +// LoadOp/StoreOp interfaces +//===----------------------------------------------------------------------===// + +bool memref::LoadOp::loadsFrom(const MemorySlot &slot) { + return getMemRef() == slot.ptr; +} + +Value memref::LoadOp::getStored(const MemorySlot &slot) { return {}; } + +bool memref::LoadOp::canUsesBeRemoved( + const MemorySlot &slot, const SmallPtrSetImpl &blockingUses, + SmallVectorImpl &newBlockingUses) { + if (blockingUses.size() != 1) + return false; + Value blockingUse = (*blockingUses.begin())->get(); + return blockingUse == slot.ptr && getMemRef() == slot.ptr && + getResult().getType() == slot.elemType; +} + +DeletionKind memref::LoadOp::removeBlockingUses( + const MemorySlot &slot, const SmallPtrSetImpl &blockingUses, + OpBuilder &builder, Value reachingDefinition) { + // `canUsesBeRemoved` checked this blocking use must be the loaded slot + // pointer. + getResult().replaceAllUsesWith(reachingDefinition); + return DeletionKind::Delete; +} + +bool memref::StoreOp::loadsFrom(const MemorySlot &slot) { return false; } + +Value memref::StoreOp::getStored(const MemorySlot &slot) { + if (getMemRef() != slot.ptr) + return Value(); + return getValue(); +} + +bool memref::StoreOp::canUsesBeRemoved( + const MemorySlot &slot, const SmallPtrSetImpl &blockingUses, + SmallVectorImpl &newBlockingUses) { + if (blockingUses.size() != 1) + return false; + Value blockingUse = (*blockingUses.begin())->get(); + return blockingUse == slot.ptr && getMemRef() == slot.ptr && + getValue() != slot.ptr && getValue().getType() == slot.elemType; +} + +DeletionKind memref::StoreOp::removeBlockingUses( + const MemorySlot &slot, const SmallPtrSetImpl &blockingUses, + OpBuilder &builder, Value reachingDefinition) { + return DeletionKind::Delete; +} diff --git a/mlir/test/Transforms/mem2reg-memref.mlir b/mlir/test/Transforms/mem2reg-memref.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Transforms/mem2reg-memref.mlir @@ -0,0 +1,142 @@ +// RUN: ns-opt %s --pass-pipeline='builtin.module(func.func(mem2reg))' --split-input-file | FileCheck %s + +// CHECK-LABEL: func.func @basic +func.func @basic() -> i32 { + // CHECK-NOT: = memref.alloca + // CHECK: %[[RES:.*]] = arith.constant 5 : i32 + // CHECK-NOT: = memref.alloca + %0 = arith.constant 5 : i32 + %1 = memref.alloca() : memref + memref.store %0, %1[] : memref + %2 = memref.load %1[] : memref + // CHECK: return %[[RES]] : i32 + return %2 : i32 +} + +// ----- + +// CHECK-LABEL: func.func @basic_default +func.func @basic_default() -> i32 { + // CHECK-NOT: = memref.alloca + // CHECK: %[[RES:.*]] = arith.constant 0 : i32 + // CHECK-NOT: = memref.alloca + %0 = arith.constant 5 : i32 + %1 = memref.alloca() : memref + %2 = memref.load %1[] : memref + // CHECK-NOT: memref.store + memref.store %0, %1[] : memref + // CHECK: return %[[RES]] : i32 + return %2 : i32 +} + +// ----- + +// CHECK-LABEL: func.func @basic_float +func.func @basic_float() -> f32 { + // CHECK-NOT: = memref.alloca + // CHECK: %[[RES:.*]] = arith.constant {{.*}} : f32 + %0 = arith.constant 5.2 : f32 + // CHECK-NOT: = memref.alloca + %1 = memref.alloca() : memref + memref.store %0, %1[] : memref + %2 = memref.load %1[] : memref + // CHECK: return %[[RES]] : f32 + return %2 : f32 +} + +// ----- + +// CHECK-LABEL: func.func @basic_ranked +func.func @basic_ranked() -> i32 { + // CHECK-NOT: = memref.alloca + // CHECK: %[[RES:.*]] = arith.constant 5 : i32 + // CHECK-NOT: = memref.alloca + %0 = arith.constant 0 : index + %1 = arith.constant 5 : i32 + %2 = memref.alloca() : memref<1x1xi32> + memref.store %1, %2[%0, %0] : memref<1x1xi32> + %3 = memref.load %2[%0, %0] : memref<1x1xi32> + // CHECK: return %[[RES]] : i32 + return %3 : i32 +} + +// ----- + +// CHECK-LABEL: func.func @reject_multiple_elements +func.func @reject_multiple_elements() -> i32 { + // CHECK: %[[INDEX:.*]] = arith.constant 0 : index + %0 = arith.constant 0 : index + // CHECK: %[[STORED:.*]] = arith.constant 5 : i32 + %1 = arith.constant 5 : i32 + // CHECK: %[[ALLOCA:.*]] = memref.alloca() + %2 = memref.alloca() : memref<1x2xi32> + // CHECK: memref.store %[[STORED]], %[[ALLOCA]][%[[INDEX]], %[[INDEX]]] + memref.store %1, %2[%0, %0] : memref<1x2xi32> + // CHECK: %[[RES:.*]] = memref.load %[[ALLOCA]][%[[INDEX]], %[[INDEX]]] + %3 = memref.load %2[%0, %0] : memref<1x2xi32> + // CHECK: return %[[RES]] : i32 + return %3 : i32 +} + +// ----- + +// CHECK-LABEL: func.func @cycle +// CHECK-SAME: (%[[ARG0:.*]]: i64, %[[ARG1:.*]]: i1, %[[ARG2:.*]]: i64) +func.func @cycle(%arg0: i64, %arg1: i1, %arg2: i64) { + // CHECK-NOT: = memref.alloca + %alloca = memref.alloca() : memref + memref.store %arg2, %alloca[] : memref + // CHECK: cf.cond_br %[[ARG1:.*]], ^[[BB1:.*]](%[[ARG2]] : i64), ^[[BB2:.*]](%[[ARG2]] : i64) + cf.cond_br %arg1, ^bb1, ^bb2 +// CHECK: ^[[BB1]](%[[USE:.*]]: i64): +^bb1: + %use = memref.load %alloca[] : memref + // CHECK: call @use(%[[USE]]) + func.call @use(%use) : (i64) -> () + memref.store %arg0, %alloca[] : memref + // CHECK: cf.br ^[[BB2]](%[[ARG0]] : i64) + cf.br ^bb2 +// CHECK: ^[[BB2]](%[[FWD:.*]]: i64): +^bb2: + // CHECK: cf.br ^[[BB1]](%[[FWD]] : i64) + cf.br ^bb1 +} + +func.func @use(%arg: i64) { return } + +// ----- + +// CHECK-LABEL: func.func @recursive +// CHECK-SAME: (%[[ARG:.*]]: i64) +func.func @recursive(%arg: i64) -> i64 { + // CHECK-NOT: = memref.alloca() + %alloca0 = memref.alloca() : memref>> + %alloca1 = memref.alloca() : memref> + %alloca2 = memref.alloca() : memref + memref.store %arg, %alloca2[] : memref + memref.store %alloca2, %alloca1[] : memref> + memref.store %alloca1, %alloca0[] : memref>> + %load0 = memref.load %alloca0[] : memref>> + %load1 = memref.load %load0[] : memref> + %load2 = memref.load %load1[] : memref + // CHECK: return %[[ARG]] : i64 + return %load2 : i64 +} + +// ----- + +// CHECK-LABEL: func.func @deny_store_of_alloca +func.func @deny_store_of_alloca(%arg: memref>) -> i32 { + // CHECK: %[[VALUE:.*]] = arith.constant 5 : i32 + %0 = arith.constant 5 : i32 + // CHECK: %[[ALLOCA:.*]] = memref.alloca + %1 = memref.alloca() : memref + // CHECK: memref.store %[[VALUE]], %[[ALLOCA]][] + memref.store %0, %1[] : memref + // CHECK: memref.store %[[ALLOCA]], %{{.*}}[] + memref.store %1, %arg[] : memref> + // CHECK: %[[RES:.*]] = memref.load %[[ALLOCA]][] + %2 = memref.load %1[] : memref + // CHECK: return %[[RES]] : i32 + return %2 : i32 +}