diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp @@ -95,9 +95,44 @@ } } +/// Tries to find and return the alignment of the pointer `value` by looking for +/// an alignment attribute on the defining allocation op or function argument. +/// If no such attribute is found, returns 1 (i.e., assume that no alignment is +/// guaranteed). +static uint64_t getAlignmentOf(Value value) { + if (Operation *definingOp = value.getDefiningOp()) { + // TODO: Peel off getelementptr + if (auto alloca = dyn_cast(definingOp)) + return alloca.getAlignment().value_or(1); + if (auto addressOf = dyn_cast(definingOp)) + if (auto global = SymbolTable::lookupNearestSymbolFrom( + definingOp, addressOf.getGlobalNameAttr())) + return global.getAlignment().value_or(1); + } else if (auto arg = value.dyn_cast()) { + // This is a block argument with no defining op, so probably this comes + // directly from a function argument. + Operation *parentOp = value.getParentBlock()->getParentOp(); + if (auto func = dyn_cast(parentOp)) { + // Check if the corresponding function has an alignment attribute set + // for this argument: if so, use it. + if (auto alignAttr = func.getArgAttr( + arg.getArgNumber(), LLVM::LLVMDialect::getAlignAttrName())) { + return cast(alignAttr).getValue().getLimitedValue(); + } + } + } + // We didn't find anything useful; assume no alignment. + return 1; +} + +/// Handles a function argument marked with the byval attribute by introducing a +/// memcpy if necessary, either due to the pointee being writeable in the +/// callee, and/or due to an alignment mismatch. `requestedAlignment` specifies +/// the alignment set in the "align" argument attribute (or 1 if no align +/// attribute was set). static Value handleByValArgument(OpBuilder &builder, Operation *callable, - Value argument, - NamedAttribute byValAttribute) { + Value argument, Type elementType, + uint64_t requestedAlignment) { auto func = cast(callable); LLVM::MemoryEffectsAttr memoryEffects = func.getMemoryAttr(); // If there is no memory effects attribute, assume that the function is @@ -105,19 +140,29 @@ bool isReadOnly = memoryEffects && memoryEffects.getArgMem() != LLVM::ModRefInfo::ModRef && memoryEffects.getArgMem() != LLVM::ModRefInfo::Mod; - if (isReadOnly) + // Check if there's an alignment mismatch requiring us to copy. + DataLayout dataLayout(callable->getParentOfType()); + bool needsRealignment = [&]() { + if (requestedAlignment <= 1) + return false; + uint64_t minimumAlignment = dataLayout.getTypeABIAlignment(elementType); + if (minimumAlignment >= requestedAlignment) + return false; + // getAlignmentOf returns 1 if it fails to find any existing alignment on + // the pointer, conservatively triggering a copy. + uint64_t currentAlignment = getAlignmentOf(argument); + return currentAlignment < requestedAlignment; + }(); + if (isReadOnly && !needsRealignment) return argument; // Resolve the pointee type and its size. auto ptrType = cast(argument.getType()); - Type elementType = cast(byValAttribute.getValue()).getValue(); - unsigned int typeSize = - DataLayout(callable->getParentOfType()) - .getTypeSize(elementType); + unsigned int typeSize = dataLayout.getTypeSize(elementType); // Allocate the new value on the stack. Value one = builder.create( func.getLoc(), builder.getI64Type(), builder.getI64IntegerAttr(1)); - Value allocaOp = - builder.create(func.getLoc(), ptrType, elementType, one); + Value allocaOp = builder.create( + func.getLoc(), ptrType, elementType, one, requestedAlignment); // Copy the pointee to the newly allocated value. Value copySize = builder.create( func.getLoc(), builder.getI64Type(), builder.getI64IntegerAttr(typeSize)); @@ -131,8 +176,6 @@ /// Returns true if the given argument or result attribute is supported by the /// inliner, false otherwise. static bool isArgOrResAttrSupported(NamedAttribute attr) { - if (attr.getName() == LLVM::LLVMDialect::getAlignAttrName()) - return false; if (attr.getName() == LLVM::LLVMDialect::getInAllocaAttrName()) return false; if (attr.getName() == LLVM::LLVMDialect::getNoAliasAttrName()) @@ -290,8 +333,18 @@ Value argument, Type targetType, DictionaryAttr argumentAttrs) const final { if (auto attr = - argumentAttrs.getNamed(LLVM::LLVMDialect::getByValAttrName())) - return handleByValArgument(builder, callable, argument, *attr); + argumentAttrs.getNamed(LLVM::LLVMDialect::getByValAttrName())) { + Type elementType = cast(attr->getValue()).getValue(); + uint64_t requestedAlignment = 1; + if (auto alignAttr = + argumentAttrs.getNamed(LLVM::LLVMDialect::getAlignAttrName())) { + requestedAlignment = cast(alignAttr->getValue()) + .getValue() + .getLimitedValue(); + } + return handleByValArgument(builder, callable, argument, elementType, + requestedAlignment); + } return argument; } diff --git a/mlir/test/Dialect/LLVMIR/inlining.mlir b/mlir/test/Dialect/LLVMIR/inlining.mlir --- a/mlir/test/Dialect/LLVMIR/inlining.mlir +++ b/mlir/test/Dialect/LLVMIR/inlining.mlir @@ -399,6 +399,68 @@ // ----- +llvm.func @aligned_byval_arg(%ptr : !llvm.ptr { llvm.byval = i16, llvm.align = 16 }) attributes {memory = #llvm.memory_effects} { + llvm.return +} + +// CHECK-LABEL: llvm.func @test_byval_input_aligned +// CHECK-SAME: %[[UNALIGNED:[a-zA-Z0-9_]+]]: !llvm.ptr +// CHECK-SAME: %[[ALIGNED:[a-zA-Z0-9_]+]]: !llvm.ptr +llvm.func @test_byval_input_aligned(%unaligned : !llvm.ptr, %aligned : !llvm.ptr { llvm.align = 16 }) { + // Make sure only the unaligned input triggers a memcpy. + // CHECK: %[[ALLOCA:.+]] = llvm.alloca %{{.+}} x i16 {alignment = 16 + // CHECK: "llvm.intr.memcpy"(%[[ALLOCA]], %[[UNALIGNED]] + llvm.call @aligned_byval_arg(%unaligned) : (!llvm.ptr) -> () + // CHECK-NOT: memcpy + llvm.call @aligned_byval_arg(%aligned) : (!llvm.ptr) -> () + llvm.return +} + +// ----- + +llvm.func @aligned_byval_arg(%ptr : !llvm.ptr { llvm.byval = i16, llvm.align = 16 }) attributes {memory = #llvm.memory_effects} { + llvm.return +} + +// CHECK-LABEL: llvm.func @test_byval_alloca +llvm.func @test_byval_alloca() { + // Make sure only the unaligned alloca triggers a memcpy. + %size = llvm.mlir.constant(1 : i64) : i64 + // CHECK: %[[ALLOCA:.+]] = llvm.alloca {{.+}}alignment = 1 + // CHECK: "llvm.intr.memcpy"(%{{.+}}, %[[ALLOCA]] + %unaligned = llvm.alloca %size x i16 { alignment = 1 } : (i64) -> !llvm.ptr + llvm.call @aligned_byval_arg(%unaligned) : (!llvm.ptr) -> () + // CHECK-NOT: memcpy + %aligned = llvm.alloca %size x i16 { alignment = 16 } : (i64) -> !llvm.ptr + llvm.call @aligned_byval_arg(%aligned) : (!llvm.ptr) -> () + llvm.return +} + +// ----- + +llvm.mlir.global private @unaligned_global(42 : i64) : i64 +llvm.mlir.global private @aligned_global(42 : i64) { alignment = 64 } : i64 + +llvm.func @aligned_byval_arg(%ptr : !llvm.ptr { llvm.byval = i16, llvm.align = 16 }) attributes {memory = #llvm.memory_effects} { + llvm.return +} + +// CHECK-LABEL: llvm.func @test_byval_global +llvm.func @test_byval_global() { + // Make sure only the unaligned global triggers a memcpy. + // CHECK: %[[UNALIGNED:.+]] = llvm.mlir.addressof @unaligned_global + // CHECK: %[[ALLOCA:.+]] = llvm.alloca + // CHECK: "llvm.intr.memcpy"(%[[ALLOCA]], %[[UNALIGNED]] + // CHECK-NOT: llvm.alloca + %unaligned = llvm.mlir.addressof @unaligned_global : !llvm.ptr + llvm.call @aligned_byval_arg(%unaligned) : (!llvm.ptr) -> () + %aligned = llvm.mlir.addressof @aligned_global : !llvm.ptr + llvm.call @aligned_byval_arg(%aligned) : (!llvm.ptr) -> () + llvm.return +} + +// ----- + llvm.func @ignored_attrs(%ptr : !llvm.ptr { llvm.inreg, llvm.nocapture, llvm.nofree, llvm.preallocated = i32, llvm.returned, llvm.alignstack = 32 : i64, llvm.writeonly, llvm.noundef, llvm.nonnull }, %x : i32 { llvm.zeroext }) -> (!llvm.ptr { llvm.noundef, llvm.inreg, llvm.nonnull }) { llvm.return %ptr : !llvm.ptr } @@ -413,7 +475,7 @@ // ----- -llvm.func @disallowed_arg_attr(%ptr : !llvm.ptr { llvm.align = 16 : i32 }) { +llvm.func @disallowed_arg_attr(%ptr : !llvm.ptr { llvm.noalias }) { llvm.return }