diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -1157,6 +1157,90 @@ let hasVerifier = 1; } +//===----------------------------------------------------------------------===// +// Lifetime[Start/End]Op +//===----------------------------------------------------------------------===// + +class MemRef_LifetimeBaseOp : MemRef_Op { + let arguments = (ins AnyMemRef:$memref); + + let assemblyFormat = "$memref attr-dict `:` type($memref)"; + + let extraClassDeclaration = [{ + MemRefType getMemRefType() { + return ::llvm::cast(getMemref().getType()); + } + int64_t getSize() { + MemRefType ty = getMemRefType(); + if (ty.hasStaticShape()) + return ty.getNumElements() * ty.getElementTypeBitWidth() / 8; + return -1; + } + }]; +} + +def MemRef_LifetimeStartOp : MemRef_LifetimeBaseOp<"lifetime_start"> { + let summary = "mark the start of a MemRef's lifetime"; + let description = [{ + The `lifetime_start` op marks the start of a MemRef's lifetime. Prior to + the execution of this op (if it exists), the MemRef is considered "dead" + and any memory access to it is assumed to yield undefined behavior. + + Example: + + ```mlir + %0 = memref.alloca() : memref + // No expected memory uses of %0 + memref.lifetime_start %0 : memref + // Expected memory uses of %0 + memref.lifetime_end %0 : memref + // No further expected memory uses of %0 + ``` + + The lifetime semantics are not enforced by the op that allocated the MemRef, + and instead is used as an assumption for analysis passes used on the MemRef. + This is useful for lifetime analysis (where the compiler can statically + analyze for memory out of bounds reads and writes). It is also useful for + transformation passes such as buffer reuse where allocations of MemRefs can + be statically analyzed for similar MemRefs with non-overlapping lifetimes to + be combined into a single allocation. + + These operations are not expected to "execute" on a device and are used as + markers to define sections of execution where the writer of the program (or + the compiler) believes that the MemRef should be alive or dead. + + Lifetime in the MemRef dialect uses the following assumptions, designed to + be a constrained subset of the assumptions used for + [LLVM Object Lifetime](https://llvm.org/docs/LangRef.html#object-lifetime): + + 1. Lifetime markers can only take MemRefs that have been allocated using + `alloca` ops. They cannot take subviewed MemRefs as operands. + 2. A MemRef is considered alive at a point of execution if a + `lifetime_start` op is never called on this MemRef (within its use chain) + *OR* a `lifetime_start` op has executed prior to this point in execution + and a `lifetime_end` op has not. + 3. Loading or storing to a dead MemRef is considered undefined behavior; + however, "Pure" ops that do not have memory effects can be called and + are expected to return valid results (such as `subview` ops). + 4. A path of execution within the program cannot execute a `lifetime_start` + op on the same MemRef more than once. + 5. It is possible for a path of execution to not execute a `lifetime_end` + op. + }]; +} + +def MemRef_LifetimeEndOp : MemRef_LifetimeBaseOp<"lifetime_end"> { + let summary = "mark the end of a MemRef's lifetime"; + let description = [{ + The `lifetime_end` op marks the end of a MemRef's lifetime. After this op is + executed, the MemRef is assumed to be "dead". Accessing the MemRef after + this point is assumed to yield undefined behavior. + + See the `lifetime_start` op for more information on lifetime within the + MemRef dialect. + }]; +} + //===----------------------------------------------------------------------===// // LoadOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -11,6 +11,7 @@ #include "mlir/Analysis/DataLayoutAnalysis.h" #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" +#include "mlir/Conversion/LLVMCommon/MemRefBuilder.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h" @@ -22,6 +23,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" #include "mlir/IR/AffineMap.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/IRMapping.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/MathExtras.h" @@ -755,6 +757,27 @@ } }; +// Lifetime marker operations are lowered to their associated intrinsic. +template +struct LifetimeOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(MemRefOp op, typename MemRefOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + MemRefDescriptor desc = MemRefDescriptor(adaptor.getMemref()); + Value ptr = desc.alignedPtr(rewriter, op.getLoc()); + IntegerAttr size = rewriter.getI64IntegerAttr(op.getSize()); + rewriter.replaceOpWithNewOp(op, size, ptr); + return success(); + } +}; + +using LifetimeStartOpLowering = + LifetimeOpLowering; +using LifetimeEndOpLowering = + LifetimeOpLowering; + // Load operation is lowered to obtaining a pointer to the indexed element // and loading it. struct LoadOpLowering : public LoadStoreOpLowering { @@ -1878,6 +1901,8 @@ GenericAtomicRMWOpLowering, GlobalMemrefOpLowering, GetGlobalMemrefOpLowering, + LifetimeEndOpLowering, + LifetimeStartOpLowering, LoadOpLowering, MemRefCastOpLowering, MemRefCopyOpLowering, diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir --- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir @@ -620,3 +620,26 @@ memref.store %2, %output[%1] {nontemporal = true} : memref<32xf32, affine_map<(d0) -> (d0)>> func.return } + +// ----- + +// CHECK-LABEL: func @lifetime_start_end +func.func @lifetime_start_end(%d : index) { + %0 = memref.alloca() : memref<17x42xf32> + // CHECK: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: %[[BASE_PTR:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: llvm.intr.lifetime.start 2856, %[[BASE_PTR]] : !llvm.ptr + memref.lifetime_start %0 : memref<17x42xf32> + // CHECK: %[[BASE_PTR_2:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: llvm.intr.lifetime.end 2856, %[[BASE_PTR_2]] : !llvm.ptr + memref.lifetime_end %0 : memref<17x42xf32> + %1 = memref.alloca(%d) : memref + // CHECK: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: %[[BASE_PTR_3:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: llvm.intr.lifetime.start -1, %[[BASE_PTR_3]] : !llvm.ptr + memref.lifetime_start %1 : memref + // CHECK: %[[BASE_PTR_4:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: llvm.intr.lifetime.end -1, %[[BASE_PTR_4]] : !llvm.ptr + memref.lifetime_end %1 : memref + func.return +} diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir --- a/mlir/test/Dialect/MemRef/ops.mlir +++ b/mlir/test/Dialect/MemRef/ops.mlir @@ -386,3 +386,13 @@ %dst = memref.memory_space_cast %src : memref to memref return %dst : memref } + +// CHECK-LABEL: func @memref_lifetime +func.func @memref_lifetime(%d : index) { + %0 = memref.alloca() : memref<32xf32> + // CHECK: memref.lifetime_start %{{.*}} : memref<32xf32> + memref.lifetime_start %0 : memref<32xf32> + // CHECK-NEXT: memref.lifetime_end %{{.*}} : memref<32xf32> + memref.lifetime_end %0 : memref<32xf32> + return +}