diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt --- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt @@ -26,6 +26,7 @@ Core LINK_LIBS PUBLIC + MLIRAnalysis MLIRCallInterfaces MLIRControlFlowInterfaces MLIRDataLayoutInterfaces diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -12,6 +12,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "TypeDetail.h" +#include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" @@ -2848,8 +2849,23 @@ // DialectInlinerInterface //===----------------------------------------------------------------------===// +// Check whether the given alloca is an input to a lifetime intrinsic, +// optionally passing through one or more casts on the way. +static bool hasLifetimeMarkers(LLVM::AllocaOp allocaOp) { + SetVector forwardSlice; + bool foundMarker = false; + getForwardSlice(allocaOp.getOperation(), &forwardSlice, [&](Operation *op) { + foundMarker |= isa(op); + // We have to include `LLVM::AllocaOp` so `getForwardSlice` does not exclude + // the initial operation itself. + return !foundMarker && isa(op); + }); + return foundMarker; +} + /// Move all alloca operations with a constant size in the former entry block of -/// the newly inlined callee into the entry block of the caller. +/// the newly inlined callee into the entry block of the caller, and insert +/// lifetime intrinsics that limit their scope to the inlined blocks. static void moveConstantAllocasToEntryBlock( iterator_range inlinedBlocks) { Block *calleeEntryBlock = &(*inlinedBlocks.begin()); @@ -2857,22 +2873,50 @@ if (calleeEntryBlock == callerEntryBlock) // Nothing to do. return; - SmallVector> allocasToMove; + SmallVector> allocasToMove; + bool shouldInsertLifetimes = false; // Conservatively only move alloca operations that are part of the entry block // and do not inspect nested regions, since they may execute conditionally or // have other unknown semantics. for (auto allocaOp : calleeEntryBlock->getOps()) { IntegerAttr arraySize; - if (matchPattern(allocaOp.getArraySize(), m_Constant(&arraySize))) - allocasToMove.emplace_back(allocaOp, arraySize); + if (!matchPattern(allocaOp.getArraySize(), m_Constant(&arraySize))) + continue; + bool shouldInsertLifetime = + arraySize.getValue() != 0 && !hasLifetimeMarkers(allocaOp); + shouldInsertLifetimes |= shouldInsertLifetime; + allocasToMove.emplace_back(allocaOp, arraySize, shouldInsertLifetime); } OpBuilder builder(callerEntryBlock, callerEntryBlock->begin()); - for (auto &[allocaOp, arraySize] : allocasToMove) { + for (auto &[allocaOp, arraySize, shouldInsertLifetime] : allocasToMove) { auto newConstant = builder.create( allocaOp->getLoc(), allocaOp.getArraySize().getType(), arraySize); + // Insert a lifetime start intrinsic where the alloca was before moving it. + if (shouldInsertLifetime) { + OpBuilder::InsertionGuard insertionGuard(builder); + builder.setInsertionPoint(allocaOp); + builder.create( + allocaOp.getLoc(), arraySize.getValue().getLimitedValue(), + allocaOp.getResult()); + } allocaOp->moveAfter(newConstant); allocaOp.getArraySizeMutable().assign(newConstant.getResult()); } + if (!shouldInsertLifetimes) + return; + // Insert a lifetime end intrinsic before each return in the callee function. + for (Block &block : inlinedBlocks) { + if (!block.getTerminator()->hasTrait()) + continue; + builder.setInsertionPoint(block.getTerminator()); + for (auto &[allocaOp, arraySize, shouldInsertLifetime] : allocasToMove) { + if (!shouldInsertLifetime) + continue; + builder.create( + allocaOp.getLoc(), arraySize.getValue().getLimitedValue(), + allocaOp.getResult()); + } + } } namespace { @@ -2912,7 +2956,8 @@ return false; return true; }) - .Case([](auto) { return true; }) + .Case([](auto) { return true; }) .Default([](auto) { return false; }); } diff --git a/mlir/test/Dialect/LLVMIR/inlining.mlir b/mlir/test/Dialect/LLVMIR/inlining.mlir --- a/mlir/test/Dialect/LLVMIR/inlining.mlir +++ b/mlir/test/Dialect/LLVMIR/inlining.mlir @@ -231,7 +231,9 @@ // CHECK: ^{{.+}}: ^bb1: // CHECK-NOT: llvm.call @static_alloca + // CHECK: llvm.intr.lifetime.start %0 = llvm.call @static_alloca() : () -> f32 + // CHECK: llvm.intr.lifetime.end // CHECK: llvm.br llvm.br ^bb3(%0: f32) // CHECK: ^{{.+}}: @@ -275,3 +277,79 @@ %0 = llvm.call @static_alloca_not_in_entry(%cond) : (i1) -> f32 llvm.return %0 : f32 } + +// ----- + +llvm.func @static_alloca(%cond: i1) -> f32 { + %0 = llvm.mlir.constant(4 : i32) : i32 + %1 = llvm.alloca %0 x f32 : (i32) -> !llvm.ptr + llvm.cond_br %cond, ^bb1, ^bb2 +^bb1: + %2 = llvm.load %1 : !llvm.ptr -> f32 + llvm.return %2 : f32 +^bb2: + %3 = llvm.mlir.constant(3.14192 : f32) : f32 + llvm.return %3 : f32 +} + +// CHECK-LABEL: llvm.func @test_inline +llvm.func @test_inline(%cond0 : i1, %cond1 : i1, %funcArg : f32) -> f32 { + // CHECK-NOT: llvm.cond_br + // CHECK: %[[PTR:.+]] = llvm.alloca + // CHECK: llvm.cond_br %{{.+}}, ^[[BB1:.+]], ^{{.+}} + llvm.cond_br %cond0, ^bb1, ^bb2 + // CHECK: ^[[BB1]] +^bb1: + // Make sure the lifetime begin intrinsic has been inserted where the call + // used to be, even though the alloca has been moved to the entry block. + // CHECK-NEXT: llvm.intr.lifetime.start 4, %[[PTR]] + %0 = llvm.call @static_alloca(%cond1) : (i1) -> f32 + // CHECK: llvm.cond_br %{{.+}}, ^[[BB2:.+]], ^[[BB3:.+]] + llvm.br ^bb3(%0: f32) + // Make sure the lifetime end intrinsic has been inserted at both former + // return sites of the callee. + // CHECK: ^[[BB2]]: + // CHECK-NEXT: llvm.load + // CHECK-NEXT: llvm.intr.lifetime.end 4, %[[PTR]] + // CHECK: ^[[BB3]]: + // CHECK-NEXT: llvm.intr.lifetime.end 4, %[[PTR]] +^bb2: + llvm.br ^bb3(%funcArg: f32) +^bb3(%blockArg: f32): + llvm.return %blockArg : f32 +} + +// ----- + +llvm.func @alloca_with_lifetime(%cond: i1) -> f32 { + %0 = llvm.mlir.constant(4 : i32) : i32 + %1 = llvm.alloca %0 x f32 : (i32) -> !llvm.ptr + llvm.intr.lifetime.start 4, %1 : !llvm.ptr + %2 = llvm.load %1 : !llvm.ptr -> f32 + llvm.intr.lifetime.end 4, %1 : !llvm.ptr + %3 = llvm.fadd %2, %2 : f32 + llvm.return %3 : f32 +} + +// CHECK-LABEL: llvm.func @test_inline +llvm.func @test_inline(%cond0 : i1, %cond1 : i1, %funcArg : f32) -> f32 { + // CHECK-NOT: llvm.cond_br + // CHECK: %[[PTR:.+]] = llvm.alloca + // CHECK: llvm.cond_br %{{.+}}, ^[[BB1:.+]], ^{{.+}} + llvm.cond_br %cond0, ^bb1, ^bb2 + // CHECK: ^[[BB1]] +^bb1: + // Make sure the original lifetime intrinsic has been preserved, rather than + // inserting a new one with a larger scope. + // CHECK: llvm.intr.lifetime.start 4, %[[PTR]] + // CHECK-NEXT: llvm.load %[[PTR]] + // CHECK-NEXT: llvm.intr.lifetime.end 4, %[[PTR]] + // CHECK: llvm.fadd + // CHECK-NOT: llvm.intr.lifetime.end + %0 = llvm.call @alloca_with_lifetime(%cond1) : (i1) -> f32 + llvm.br ^bb3(%0: f32) +^bb2: + llvm.br ^bb3(%funcArg: f32) +^bb3(%blockArg: f32): + llvm.return %blockArg : f32 +}