diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -2848,6 +2848,33 @@ // DialectInlinerInterface //===----------------------------------------------------------------------===// +/// Move all alloca operations with a constant size in the former entry block of +/// the newly inlined callee into the entry block of the caller. +static void moveConstantAllocasToEntryBlock( + iterator_range inlinedBlocks) { + Block *calleeEntryBlock = &(*inlinedBlocks.begin()); + Block *callerEntryBlock = &(*calleeEntryBlock->getParent()->begin()); + if (calleeEntryBlock == callerEntryBlock) + // Nothing to do. + return; + SmallVector> allocasToMove; + // Conservatively only move alloca operations that are part of the entry block + // and do not inspect nested regions, since they may execute conditionally or + // have other unknown semantics. + for (auto allocaOp : calleeEntryBlock->getOps()) { + IntegerAttr arraySize; + if (matchPattern(allocaOp.getArraySize(), m_Constant(&arraySize))) + allocasToMove.emplace_back(allocaOp, arraySize); + } + OpBuilder builder(callerEntryBlock, callerEntryBlock->begin()); + for (auto &[allocaOp, arraySize] : allocasToMove) { + auto newConstant = builder.create( + allocaOp->getLoc(), allocaOp.getArraySize().getType(), arraySize); + allocaOp->moveAfter(newConstant); + allocaOp.getArraySizeMutable().assign(newConstant.getResult()); + } +} + namespace { struct LLVMInlinerInterface : public DialectInlinerInterface { using DialectInlinerInterface::DialectInlinerInterface; @@ -2885,7 +2912,7 @@ return false; return true; }) - .Case([](auto) { return true; }) + .Case([](auto) { return true; }) .Default([](auto) { return false; }); } @@ -2918,6 +2945,17 @@ dst.replaceAllUsesWith(src); } + void processInlinedCallBlocks( + Operation *call, + iterator_range inlinedBlocks) const override { + // Alloca operations with a constant size that were in the entry block of + // the callee should be moved to the entry block of the caller, as this will + // fold into prologue/epilogue code during code generation. + // This is not implemented as a standalone pattern because we need to know + // which newly inlined block was previously the entry block of the callee. + moveConstantAllocasToEntryBlock(inlinedBlocks); + } + private: /// Returns true if all attributes of `callOp` are handled during inlining. [[nodiscard]] static bool isLegalToInlineCallAttributes(LLVM::CallOp callOp) { diff --git a/mlir/test/Dialect/LLVMIR/inlining.mlir b/mlir/test/Dialect/LLVMIR/inlining.mlir --- a/mlir/test/Dialect/LLVMIR/inlining.mlir +++ b/mlir/test/Dialect/LLVMIR/inlining.mlir @@ -20,18 +20,17 @@ // ----- -func.func @inner_func_not_inlinable() -> !llvm.ptr { - %0 = llvm.mlir.constant(0 : i32) : i32 - %1 = llvm.alloca %0 x f64 : (i32) -> !llvm.ptr - return %1 : !llvm.ptr +func.func @inner_func_not_inlinable() -> i32 { + %0 = llvm.inline_asm has_side_effects "foo", "bar" : () -> i32 + return %0 : i32 } -// CHECK-LABEL: func.func @test_not_inline() -> !llvm.ptr { -// CHECK-NEXT: %[[RES:.*]] = call @inner_func_not_inlinable() : () -> !llvm.ptr -// CHECK-NEXT: return %[[RES]] : !llvm.ptr -func.func @test_not_inline() -> !llvm.ptr { - %0 = call @inner_func_not_inlinable() : () -> !llvm.ptr - return %0 : !llvm.ptr +// CHECK-LABEL: func.func @test_not_inline() -> i32 { +// CHECK-NEXT: %[[RES:.*]] = call @inner_func_not_inlinable() : () -> i32 +// CHECK-NEXT: return %[[RES]] : i32 +func.func @test_not_inline() -> i32 { + %0 = call @inner_func_not_inlinable() : () -> i32 + return %0 : i32 } // ----- @@ -203,3 +202,76 @@ llvm.call @callee() { branch_weights = dense<42> : vector<1xi32> } : () -> () llvm.return } + +// ----- + +llvm.func @static_alloca() -> f32 { + %0 = llvm.mlir.constant(4 : i32) : i32 + %1 = llvm.alloca %0 x f32 : (i32) -> !llvm.ptr + %2 = llvm.load %1 : !llvm.ptr -> f32 + llvm.return %2 : f32 +} + +llvm.func @dynamic_alloca(%size : i32) -> f32 { + %0 = llvm.add %size, %size : i32 + %1 = llvm.alloca %0 x f32 : (i32) -> !llvm.ptr + %2 = llvm.load %1 : !llvm.ptr -> f32 + llvm.return %2 : f32 +} + +// CHECK-LABEL: llvm.func @test_inline +llvm.func @test_inline(%cond : i1, %size : i32) -> f32 { + // Check that the static alloca was moved to the entry block after inlining + // with its size defined by a constant. + // CHECK-NOT: ^{{.+}}: + // CHECK-NEXT: llvm.mlir.constant + // CHECK-NEXT: llvm.alloca + // CHECK: llvm.cond_br + llvm.cond_br %cond, ^bb1, ^bb2 + // CHECK: ^{{.+}}: +^bb1: + // CHECK-NOT: llvm.call @static_alloca + %0 = llvm.call @static_alloca() : () -> f32 + // CHECK: llvm.br + llvm.br ^bb3(%0: f32) + // CHECK: ^{{.+}}: +^bb2: + // Check that the dynamic alloca was inlined, but that it was not moved to the + // entry block. + // CHECK: llvm.add + // CHECK-NEXT: llvm.alloca + // CHECK-NOT: llvm.call @dynamic_alloca + %1 = llvm.call @dynamic_alloca(%size) : (i32) -> f32 + // CHECK: llvm.br + llvm.br ^bb3(%1: f32) + // CHECK: ^{{.+}}: +^bb3(%arg : f32): + llvm.return %arg : f32 +} + +// ----- + +llvm.func @static_alloca_not_in_entry(%cond : i1) -> f32 { + llvm.cond_br %cond, ^bb1, ^bb2 +^bb1: + %0 = llvm.mlir.constant(4 : i32) : i32 + %1 = llvm.alloca %0 x f32 : (i32) -> !llvm.ptr + llvm.br ^bb3(%1: !llvm.ptr) +^bb2: + %2 = llvm.mlir.constant(8 : i32) : i32 + %3 = llvm.alloca %2 x f32 : (i32) -> !llvm.ptr + llvm.br ^bb3(%3: !llvm.ptr) +^bb3(%ptr : !llvm.ptr): + %4 = llvm.load %ptr : !llvm.ptr -> f32 + llvm.return %4 : f32 +} + +// CHECK-LABEL: llvm.func @test_inline +llvm.func @test_inline(%cond : i1) -> f32 { + // Make sure the alloca was not moved to the entry block. + // CHECK-NOT: llvm.alloca + // CHECK: llvm.cond_br + // CHECK: llvm.alloca + %0 = llvm.call @static_alloca_not_in_entry(%cond) : (i1) -> f32 + llvm.return %0 : f32 +}