diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -2848,6 +2848,34 @@ // DialectInlinerInterface //===----------------------------------------------------------------------===// +/// Move all `LLVM::AllocaOp` operations with a constant size in `sourceBlock` +/// to the end of `entryBlock`. +static void moveConstantAllocasToEntryBlock(Block *sourceBlock, + Block *entryBlock) { + assert(entryBlock->isEntryBlock() && + "Second argument must be an entry block."); + assert(sourceBlock->getParent() == entryBlock->getParent() && + "Blocks must belong to the same region."); + if (sourceBlock == entryBlock) + // Nothing to do. + return; + SmallVector> allocasToMove; + // Don't walk nested regions of the block, since we can't be sure of their + // semantics. + for (auto allocaOp : sourceBlock->getOps()) { + IntegerAttr arraySize; + if (matchPattern(allocaOp.getArraySize(), m_Constant(&arraySize))) + allocasToMove.emplace_back(allocaOp, arraySize); + } + OpBuilder builder(entryBlock, entryBlock->begin()); + for (auto &[allocaOp, arraySize] : allocasToMove) { + auto newConstant = builder.create( + allocaOp->getLoc(), allocaOp.getArraySize().getType(), arraySize); + allocaOp->moveAfter(newConstant); + allocaOp.getArraySizeMutable().assign(newConstant.getResult()); + } +} + namespace { struct LLVMInlinerInterface : public DialectInlinerInterface { using DialectInlinerInterface::DialectInlinerInterface; @@ -2885,7 +2913,7 @@ return false; return true; }) - .Case([](auto) { return true; }) + .Case([](auto) { return true; }) .Default([](auto) { return false; }); } @@ -2964,6 +2992,19 @@ .Default(false); }); } + + void processInlinedCallBlocks( + Operation *call, + iterator_range inlinedBlocks) const override { + // alloca operations with a constant size that were in the entry block of + // the callee should be moved to the entry block of the caller, as this will + // fold into prologue/epilogue code during code generation. + // This is not implemented as a standalone pattern because we need to know + // which newly inlined block was previously the entry block of the callee. + Block &calleeEntryBlock = *inlinedBlocks.begin(); + Block &callerEntryBlock = *calleeEntryBlock.getParent()->begin(); + moveConstantAllocasToEntryBlock(&calleeEntryBlock, &callerEntryBlock); + } }; } // end anonymous namespace diff --git a/mlir/test/Dialect/LLVMIR/inlining.mlir b/mlir/test/Dialect/LLVMIR/inlining.mlir --- a/mlir/test/Dialect/LLVMIR/inlining.mlir +++ b/mlir/test/Dialect/LLVMIR/inlining.mlir @@ -20,18 +20,17 @@ // ----- -func.func @inner_func_not_inlinable() -> !llvm.ptr { - %0 = llvm.mlir.constant(0 : i32) : i32 - %1 = llvm.alloca %0 x f64 : (i32) -> !llvm.ptr - return %1 : !llvm.ptr +func.func @inner_func_not_inlinable() -> i32 { + %0 = llvm.inline_asm has_side_effects "foo", "bar" : () -> i32 + return %0 : i32 } -// CHECK-LABEL: func.func @test_not_inline() -> !llvm.ptr { -// CHECK-NEXT: %[[RES:.*]] = call @inner_func_not_inlinable() : () -> !llvm.ptr -// CHECK-NEXT: return %[[RES]] : !llvm.ptr -func.func @test_not_inline() -> !llvm.ptr { - %0 = call @inner_func_not_inlinable() : () -> !llvm.ptr - return %0 : !llvm.ptr +// CHECK-LABEL: func.func @test_not_inline() -> i32 { +// CHECK-NEXT: %[[RES:.*]] = call @inner_func_not_inlinable() : () -> i32 +// CHECK-NEXT: return %[[RES]] : i32 +func.func @test_not_inline() -> i32 { + %0 = call @inner_func_not_inlinable() : () -> i32 + return %0 : i32 } // ----- @@ -203,3 +202,70 @@ llvm.call @callee() { branch_weights = dense<42> : vector<1xi32> } : () -> () llvm.return } + +// ----- + +llvm.func @static_alloca() -> !llvm.ptr { + %0 = llvm.mlir.constant(4 : i32) : i32 + %1 = llvm.alloca %0 x f32 : (i32) -> !llvm.ptr + llvm.return %1 : !llvm.ptr +} + +llvm.func @dynamic_alloca(%size : i32) -> !llvm.ptr { + %0 = llvm.add %size, %size : i32 + %1 = llvm.alloca %0 x f32 : (i32) -> !llvm.ptr + llvm.return %1 : !llvm.ptr +} + +// CHECK-LABEL: llvm.func @test_inline +// Check that the static alloca was moved to the entry block after inlining +// with its size defined by a constant. +// CHECK-NOT: ^{{.+}}: +// CHECK-NEXT: llvm.mlir.constant +// CHECK-NEXT: llvm.alloca +// CHECK: llvm.cond_br +// CHECK-NOT: llvm.call @static_alloca +// Check that the dynamic alloca was inlined, but that it was not moved to the +// entry block. +// CHECK: llvm.add +// CHECK-NEXT: llvm.alloca +// CHECK-NOT: llvm.call @dynamic_alloca +llvm.func @test_inline(%cond : i1, %size : i32) -> f32 { + llvm.cond_br %cond, ^bb1, ^bb2 +^bb1: + %0 = llvm.call @static_alloca() : () -> !llvm.ptr + llvm.br ^bb3(%0: !llvm.ptr) +^bb2: + %1 = llvm.call @dynamic_alloca(%size) : (i32) -> !llvm.ptr + llvm.br ^bb3(%1: !llvm.ptr) +^bb3(%ptr : !llvm.ptr): + %2 = llvm.load %ptr : !llvm.ptr -> f32 + llvm.return %2 : f32 +} + +// ----- + +llvm.func @static_alloca_not_in_entry(%cond : i1) -> f32 { + llvm.cond_br %cond, ^bb1, ^bb2 +^bb1: + %0 = llvm.mlir.constant(4 : i32) : i32 + %1 = llvm.alloca %0 x f32 : (i32) -> !llvm.ptr + llvm.br ^bb3(%1: !llvm.ptr) +^bb2: + %2 = llvm.mlir.constant(8 : i32) : i32 + %3 = llvm.alloca %2 x f32 : (i32) -> !llvm.ptr + llvm.br ^bb3(%3: !llvm.ptr) +^bb3(%ptr : !llvm.ptr): + %4 = llvm.load %ptr : !llvm.ptr -> f32 + llvm.return %4 : f32 +} + +// CHECK-LABEL: llvm.func @test_inline +// Make sure the alloca was not moved to the entry block. +// CHECK-NOT: llvm.alloca +// CHECK: llvm.cond_br +// CHECK: llvm.alloca +llvm.func @test_inline(%cond : i1) -> f32 { + %0 = llvm.call @static_alloca_not_in_entry(%cond) : (i1) -> f32 + llvm.return %0 : f32 +}