diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp @@ -14,6 +14,7 @@ #include "LLVMInlining.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/Matchers.h" +#include "mlir/Interfaces/DataLayoutInterfaces.h" #include "mlir/Transforms/InliningUtils.h" #include "llvm/Support/Debug.h" @@ -124,14 +125,44 @@ } } +/// If `requestedAlignment` is higher than the alignment specified on `alloca`, +/// realigns `alloca` if this does not exceed the natural stack alignment. +/// Returns the post-alignment of `alloca`, whether it was realigned or not. +static unsigned tryToEnforceAllocaAlignment(LLVM::AllocaOp alloca, + unsigned requestedAlignment, + DataLayout const &dataLayout) { + unsigned allocaAlignment = alloca.getAlignment().value_or(1); + if (requestedAlignment <= allocaAlignment) + // No realignment necessary. + return allocaAlignment; + unsigned naturalStackAlignmentBits = dataLayout.getStackAlignment(); + // If the natural stack alignment is not specified, the data layout returns + // zero. Optimistically allow realignment in this case. + if (naturalStackAlignmentBits == 0 || + // If the requested alignment exceeds the natural stack alignment, this + // will trigger a dynamic stack realignment, so we prefer to copy... + 8 * requestedAlignment <= naturalStackAlignmentBits || + // ...unless the alloca already triggers dynamic stack realignment. Then + // we might as well further increase the alignment to avoid a copy. + 8 * allocaAlignment > naturalStackAlignmentBits) { + alloca.setAlignment(requestedAlignment); + allocaAlignment = requestedAlignment; + } + return allocaAlignment; +} + /// Tries to find and return the alignment of the pointer `value` by looking for /// an alignment attribute on the defining allocation op or function argument. -/// If no such attribute is found, returns 1 (i.e., assume that no alignment is -/// guaranteed). -static unsigned getAlignmentOf(Value value) { +/// If the found alignment is lower than `requestedAlignment`, tries to realign +/// the pointer, then returns the resulting post-alignment, regardless of +/// whether it was realigned or not. If no existing alignment attribute is +/// found, returns 1 (i.e., assume that no alignment is guaranteed). +static unsigned tryToEnforceAlignment(Value value, unsigned requestedAlignment, + DataLayout const &dataLayout) { if (Operation *definingOp = value.getDefiningOp()) { if (auto alloca = dyn_cast(definingOp)) - return alloca.getAlignment().value_or(1); + return tryToEnforceAllocaAlignment(alloca, requestedAlignment, + dataLayout); if (auto addressOf = dyn_cast(definingOp)) if (auto global = SymbolTable::lookupNearestSymbolFrom( definingOp, addressOf.getGlobalNameAttr())) @@ -143,8 +174,8 @@ // comes directly from a function argument, so check that this is the case. Operation *parentOp = value.getParentBlock()->getParentOp(); if (auto func = dyn_cast(parentOp)) { - // Use the alignment attribute set for this argument in the parent - // function if it has been set. + // Use the alignment attribute set for this argument in the parent function + // if it has been set. auto blockArg = value.cast(); if (Attribute alignAttr = func.getArgAttr( blockArg.getArgNumber(), LLVM::LLVMDialect::getAlignAttrName())) @@ -154,19 +185,19 @@ return 1; } -/// Copies the data from a byval pointer argument into newly alloca'ed memory -/// and returns the value of the alloca. +/// Introduces a new alloca and copies the memory pointed to by `argument` to +/// the address of the new alloca, then returns the value of the new alloca. static Value handleByValArgumentInit(OpBuilder &builder, Location loc, Value argument, Type elementType, unsigned elementTypeSize, unsigned targetAlignment) { - Block *entryBlock = &(*argument.getParentRegion()->begin()); // Allocate the new value on the stack. Value allocaOp; { // Since this is a static alloca, we can put it directly in the entry block, // so they can be absorbed into the prologue/epilogue at code generation. OpBuilder::InsertionGuard insertionGuard(builder); + Block *entryBlock = &(*argument.getParentRegion()->begin()); builder.setInsertionPointToStart(entryBlock); Value one = builder.create(loc, builder.getI64Type(), builder.getI64IntegerAttr(1)); @@ -183,10 +214,10 @@ } /// Handles a function argument marked with the byval attribute by introducing a -/// memcpy if necessary, either due to the pointee being writeable in the -/// callee, and/or due to an alignment mismatch. `requestedAlignment` specifies -/// the alignment set in the "align" argument attribute (or 1 if no align -/// attribute was set). +/// memcpy or realigning the defining operation, if required either due to the +/// pointee being writeable in the callee, and/or due to an alignment mismatch. +/// `requestedAlignment` specifies the alignment set in the "align" argument +/// attribute (or 1 if no align attribute was set). static Value handleByValArgument(OpBuilder &builder, Operation *callable, Value argument, Type elementType, unsigned requestedAlignment) { @@ -198,11 +229,16 @@ memoryEffects.getArgMem() != LLVM::ModRefInfo::ModRef && memoryEffects.getArgMem() != LLVM::ModRefInfo::Mod; // Check if there's an alignment mismatch requiring us to copy. - DataLayout dataLayout(callable->getParentOfType()); + DataLayout dataLayout = DataLayout::closest(callable); unsigned minimumAlignment = dataLayout.getTypeABIAlignment(elementType); - if (isReadOnly && (requestedAlignment <= minimumAlignment || - getAlignmentOf(argument) >= requestedAlignment)) - return argument; + if (isReadOnly) { + if (requestedAlignment <= minimumAlignment) + return argument; + unsigned currentAlignment = + tryToEnforceAlignment(argument, requestedAlignment, dataLayout); + if (currentAlignment >= requestedAlignment) + return argument; + } unsigned targetAlignment = std::max(requestedAlignment, minimumAlignment); return handleByValArgumentInit(builder, func.getLoc(), argument, elementType, dataLayout.getTypeSize(elementType), diff --git a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp --- a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp +++ b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp @@ -483,13 +483,13 @@ mlir::Attribute mlir::DataLayout::getAllocaMemorySpace() const { checkValid(); - MLIRContext *context = scope->getContext(); if (allocaMemorySpace) return *allocaMemorySpace; DataLayoutEntryInterface entry; if (originalLayout) entry = originalLayout.getSpecForIdentifier( - originalLayout.getAllocaMemorySpaceIdentifier(context)); + originalLayout.getAllocaMemorySpaceIdentifier( + originalLayout.getContext())); if (auto iface = dyn_cast_or_null(scope)) allocaMemorySpace = iface.getAllocaMemorySpace(entry); else @@ -499,13 +499,13 @@ unsigned mlir::DataLayout::getStackAlignment() const { checkValid(); - MLIRContext *context = scope->getContext(); if (stackAlignment) return *stackAlignment; DataLayoutEntryInterface entry; if (originalLayout) entry = originalLayout.getSpecForIdentifier( - originalLayout.getStackAlignmentIdentifier(context)); + originalLayout.getStackAlignmentIdentifier( + originalLayout.getContext())); if (auto iface = dyn_cast_or_null(scope)) stackAlignment = iface.getStackAlignment(entry); else diff --git a/mlir/test/Dialect/LLVMIR/inlining.mlir b/mlir/test/Dialect/LLVMIR/inlining.mlir --- a/mlir/test/Dialect/LLVMIR/inlining.mlir +++ b/mlir/test/Dialect/LLVMIR/inlining.mlir @@ -452,16 +452,19 @@ // ----- +llvm.func @func_that_uses_ptr(%ptr : !llvm.ptr) + llvm.func @aligned_byval_arg(%ptr : !llvm.ptr { llvm.byval = i16, llvm.align = 16 }) attributes {memory = #llvm.memory_effects} { + llvm.call @func_that_uses_ptr(%ptr) : (!llvm.ptr) -> () llvm.return } -// CHECK-LABEL: llvm.func @test_byval_unaligned_alloca -llvm.func @test_byval_unaligned_alloca() { +// CHECK-LABEL: llvm.func @test_byval_realign_alloca +llvm.func @test_byval_realign_alloca() { %size = llvm.mlir.constant(4 : i64) : i64 - // CHECK-DAG: %[[SRC:.+]] = llvm.alloca {{.+}}alignment = 1 : i64 - // CHECK-DAG: %[[DST:.+]] = llvm.alloca {{.+}}alignment = 16 : i64 - // CHECK: "llvm.intr.memcpy"(%[[DST]], %[[SRC]] + // CHECK-NOT: llvm.alloca{{.+}}alignment = 1 + // CHECK: llvm.alloca {{.+}}alignment = 16 : i64 + // CHECK-NOT: llvm.intr.memcpy %unaligned = llvm.alloca %size x i16 { alignment = 1 } : (i64) -> !llvm.ptr llvm.call @aligned_byval_arg(%unaligned) : (!llvm.ptr) -> () llvm.return @@ -469,19 +472,61 @@ // ----- +module attributes { + dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.stack_alignment", 32 : i32>> +} { + +llvm.func @func_that_uses_ptr(%ptr : !llvm.ptr) + llvm.func @aligned_byval_arg(%ptr : !llvm.ptr { llvm.byval = i16, llvm.align = 16 }) attributes {memory = #llvm.memory_effects} { + llvm.call @func_that_uses_ptr(%ptr) : (!llvm.ptr) -> () llvm.return } -// CHECK-LABEL: llvm.func @test_byval_aligned_alloca -llvm.func @test_byval_aligned_alloca() { - // CHECK-NOT: memcpy - %size = llvm.mlir.constant(1 : i64) : i64 - %aligned = llvm.alloca %size x i16 { alignment = 16 } : (i64) -> !llvm.ptr - llvm.call @aligned_byval_arg(%aligned) : (!llvm.ptr) -> () +// CHECK-LABEL: llvm.func @test_exceeds_natural_stack_alignment +llvm.func @test_exceeds_natural_stack_alignment() { + %size = llvm.mlir.constant(4 : i64) : i64 + // Natural stack alignment is exceeded, so prefer a copy instead of + // triggering a dynamic stack realignment. + // CHECK-DAG: %[[SRC:[a-zA-Z0-9_]+]] = llvm.alloca{{.+}}alignment = 2 + // CHECK-DAG: %[[DST:[a-zA-Z0-9_]+]] = llvm.alloca{{.+}}alignment = 16 + // CHECK: "llvm.intr.memcpy"(%[[DST]], %[[SRC]] + %unaligned = llvm.alloca %size x i16 { alignment = 2 } : (i64) -> !llvm.ptr + llvm.call @aligned_byval_arg(%unaligned) : (!llvm.ptr) -> () + llvm.return +} + +} + +// ----- + +module attributes { + dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.stack_alignment", 32 : i32>> +} { + +llvm.func @func_that_uses_ptr(%ptr : !llvm.ptr) + +llvm.func @aligned_byval_arg(%ptr : !llvm.ptr { llvm.byval = i16, llvm.align = 16 }) attributes {memory = #llvm.memory_effects} { + llvm.call @func_that_uses_ptr(%ptr) : (!llvm.ptr) -> () llvm.return } +// CHECK-LABEL: llvm.func @test_alignment_exceeded_anyway +llvm.func @test_alignment_exceeded_anyway() { + %size = llvm.mlir.constant(4 : i64) : i64 + // Natural stack alignment is lower than the target alignment, but the + // alloca's existing alignment already exceeds it, so we might as well avoid + // the copy. + // CHECK-NOT: llvm.alloca{{.+}}alignment = 1 + // CHECK: llvm.alloca {{.+}}alignment = 16 : i64 + // CHECK-NOT: llvm.intr.memcpy + %unaligned = llvm.alloca %size x i16 { alignment = 8 } : (i64) -> !llvm.ptr + llvm.call @aligned_byval_arg(%unaligned) : (!llvm.ptr) -> () + llvm.return +} + +} + // ----- llvm.mlir.global private @unaligned_global(42 : i64) : i64