diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp @@ -14,6 +14,7 @@ #include "LLVMInlining.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/Matchers.h" +#include "mlir/Interfaces/DataLayoutInterfaces.h" #include "mlir/Transforms/InliningUtils.h" #include "llvm/Support/Debug.h" @@ -183,10 +184,10 @@ } /// Handles a function argument marked with the byval attribute by introducing a -/// memcpy if necessary, either due to the pointee being writeable in the -/// callee, and/or due to an alignment mismatch. `requestedAlignment` specifies -/// the alignment set in the "align" argument attribute (or 1 if no align -/// attribute was set). +/// memcpy or realigning the defining alloca if necessary, either due to the +/// pointee being writeable in the callee, and/or due to an alignment mismatch. +/// `requestedAlignment` specifies the alignment set in the "align" argument +/// attribute (or 1 if no align attribute was set). static Value handleByValArgument(OpBuilder &builder, Operation *callable, Value argument, Type elementType, unsigned requestedAlignment) { @@ -198,11 +199,28 @@ memoryEffects.getArgMem() != LLVM::ModRefInfo::ModRef && memoryEffects.getArgMem() != LLVM::ModRefInfo::Mod; // Check if there's an alignment mismatch requiring us to copy. - DataLayout dataLayout(callable->getParentOfType()); + DataLayout dataLayout = DataLayout::closest(callable); unsigned minimumAlignment = dataLayout.getTypeABIAlignment(elementType); - if (isReadOnly && (requestedAlignment <= minimumAlignment || - getAlignmentOf(argument) >= requestedAlignment)) - return argument; + if (isReadOnly) { + if (requestedAlignment <= minimumAlignment) + return argument; + if (getAlignmentOf(argument) >= requestedAlignment) + return argument; + // Convert from bit-alignment to byte-alignment. + unsigned stackAlignment = dataLayout.getStackAlignment() >> 3; + // If this is a stack-allocated value and the requested alignment does not + // exceed the natural stack alignment, realign the alloca to avoid a copy. + // If the requested alignment exceeds the natural stack alignment, this will + // trigger a dynamic stack realignment, so we don't do this before it's + // strictly necssary. If no stack alignment is specified, conservatively + // don't try realigning the alloca. + if (requestedAlignment <= stackAlignment) { + if (auto alloca = argument.getDefiningOp()) { + alloca.setAlignment(requestedAlignment); + return argument; + } + } + } unsigned targetAlignment = std::max(requestedAlignment, minimumAlignment); return handleByValArgumentInit(builder, func.getLoc(), argument, elementType, dataLayout.getTypeSize(elementType), diff --git a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp --- a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp +++ b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp @@ -483,13 +483,13 @@ mlir::Attribute mlir::DataLayout::getAllocaMemorySpace() const { checkValid(); - MLIRContext *context = scope->getContext(); if (allocaMemorySpace) return *allocaMemorySpace; DataLayoutEntryInterface entry; if (originalLayout) entry = originalLayout.getSpecForIdentifier( - originalLayout.getAllocaMemorySpaceIdentifier(context)); + originalLayout.getAllocaMemorySpaceIdentifier( + originalLayout.getContext())); if (auto iface = dyn_cast_or_null(scope)) allocaMemorySpace = iface.getAllocaMemorySpace(entry); else @@ -499,13 +499,13 @@ unsigned mlir::DataLayout::getStackAlignment() const { checkValid(); - MLIRContext *context = scope->getContext(); if (stackAlignment) return *stackAlignment; DataLayoutEntryInterface entry; if (originalLayout) entry = originalLayout.getSpecForIdentifier( - originalLayout.getStackAlignmentIdentifier(context)); + originalLayout.getStackAlignmentIdentifier( + originalLayout.getContext())); if (auto iface = dyn_cast_or_null(scope)) stackAlignment = iface.getStackAlignment(entry); else diff --git a/mlir/test/Dialect/LLVMIR/inlining.mlir b/mlir/test/Dialect/LLVMIR/inlining.mlir --- a/mlir/test/Dialect/LLVMIR/inlining.mlir +++ b/mlir/test/Dialect/LLVMIR/inlining.mlir @@ -507,6 +507,32 @@ // ----- +module attributes { dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.stack_alignment", 32 : i32>>} { + +llvm.func @aligned_byval_arg(%ptr : !llvm.ptr { llvm.byval = i16, llvm.align = 4 }) attributes {memory = #llvm.memory_effects} { + llvm.return +} + +llvm.func @declared_func(%ptr : !llvm.ptr) + +// CHECK-LABEL: llvm.func @test_byval_realign_alloca +llvm.func @test_byval_realign_alloca() { + %size = llvm.mlir.constant(1 : i64) : i64 + // Because the natural stack alignment is larger than the target alignment, + // the alloca will just be realigned instead of doing a copy. + // CHECK-NOT: llvm.alloca {{.+}}alignment = 2 + // CHECK: llvm.alloca {{.+}}alignment = 4 + // CHECK-NOT: memcpy + %ptr = llvm.alloca %size x i16 { alignment = 2 } : (i64) -> !llvm.ptr + llvm.call @aligned_byval_arg(%ptr) : (!llvm.ptr) -> () + llvm.call @declared_func(%ptr) : (!llvm.ptr) -> () + llvm.return +} + +} + +// ----- + llvm.func @ignored_attrs(%ptr : !llvm.ptr { llvm.inreg, llvm.nocapture, llvm.nofree, llvm.preallocated = i32, llvm.returned, llvm.alignstack = 32 : i64, llvm.writeonly, llvm.noundef, llvm.nonnull }, %x : i32 { llvm.zeroext }) -> (!llvm.ptr { llvm.noundef, llvm.inreg, llvm.nonnull }) { llvm.return %ptr : !llvm.ptr }