diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp @@ -95,9 +95,64 @@ } } +/// Tries to find and return the alignment of the pointer `value` by looking for +/// an alignment attribute on the defining allocation op or function argument. +/// If no such attribute is found, returns 1 (i.e., assume that no alignment is +/// guaranteed). +static unsigned getAlignmentOf(Value value) { + if (Operation *definingOp = value.getDefiningOp()) { + if (auto alloca = dyn_cast(definingOp)) + return alloca.getAlignment().value_or(1); + if (auto addressOf = dyn_cast(definingOp)) + if (auto global = SymbolTable::lookupNearestSymbolFrom( + definingOp, addressOf.getGlobalNameAttr())) + return global.getAlignment().value_or(1); + // We don't currently handle this operation; assume no alignment. + return 1; + } + // Since there is no defining op, this is a block argument. Probably this + // comes directly from a function argument, so check that this is the case. + Operation *parentOp = value.getParentBlock()->getParentOp(); + if (auto func = dyn_cast(parentOp)) { + // Use the alignment attribute set for this argument in the parent + // function if it has been set. + auto blockArg = value.cast(); + if (Attribute alignAttr = func.getArgAttr( + blockArg.getArgNumber(), LLVM::LLVMDialect::getAlignAttrName())) + return cast(alignAttr).getValue().getLimitedValue(); + } + // We didn't find anything useful; assume no alignment. + return 1; +} + +/// Copies the data from a byval pointer argument into newly alloca'ed memory +/// and returns the value of the alloca. +static Value handleByValArgumentInit(OpBuilder &builder, Location loc, + Value argument, Type elementType, + unsigned elementTypeSize, + unsigned targetAlignment) { + // Allocate the new value on the stack. + Value one = builder.create(loc, builder.getI64Type(), + builder.getI64IntegerAttr(1)); + Value allocaOp = builder.create( + loc, argument.getType(), elementType, one, targetAlignment); + // Copy the pointee to the newly allocated value. + Value copySize = builder.create( + loc, builder.getI64Type(), builder.getI64IntegerAttr(elementTypeSize)); + Value isVolatile = builder.create( + loc, builder.getI1Type(), builder.getBoolAttr(false)); + builder.create(loc, allocaOp, argument, copySize, isVolatile); + return allocaOp; +} + +/// Handles a function argument marked with the byval attribute by introducing a +/// memcpy if necessary, either due to the pointee being writeable in the +/// callee, and/or due to an alignment mismatch. `requestedAlignment` specifies +/// the alignment set in the "align" argument attribute (or 1 if no align +/// attribute was set). static Value handleByValArgument(OpBuilder &builder, Operation *callable, - Value argument, - NamedAttribute byValAttribute) { + Value argument, Type elementType, + unsigned requestedAlignment) { auto func = cast(callable); LLVM::MemoryEffectsAttr memoryEffects = func.getMemoryAttr(); // If there is no memory effects attribute, assume that the function is @@ -105,34 +160,21 @@ bool isReadOnly = memoryEffects && memoryEffects.getArgMem() != LLVM::ModRefInfo::ModRef && memoryEffects.getArgMem() != LLVM::ModRefInfo::Mod; - if (isReadOnly) + // Check if there's an alignment mismatch requiring us to copy. + DataLayout dataLayout(callable->getParentOfType()); + unsigned minimumAlignment = dataLayout.getTypeABIAlignment(elementType); + if (isReadOnly && (requestedAlignment <= minimumAlignment || + getAlignmentOf(argument) >= requestedAlignment)) return argument; - // Resolve the pointee type and its size. - auto ptrType = cast(argument.getType()); - Type elementType = cast(byValAttribute.getValue()).getValue(); - unsigned int typeSize = - DataLayout(callable->getParentOfType()) - .getTypeSize(elementType); - // Allocate the new value on the stack. - Value one = builder.create( - func.getLoc(), builder.getI64Type(), builder.getI64IntegerAttr(1)); - Value allocaOp = - builder.create(func.getLoc(), ptrType, elementType, one); - // Copy the pointee to the newly allocated value. - Value copySize = builder.create( - func.getLoc(), builder.getI64Type(), builder.getI64IntegerAttr(typeSize)); - Value isVolatile = builder.create( - func.getLoc(), builder.getI1Type(), builder.getBoolAttr(false)); - builder.create(func.getLoc(), allocaOp, argument, copySize, - isVolatile); - return allocaOp; + unsigned targetAlignment = std::max(requestedAlignment, minimumAlignment); + return handleByValArgumentInit(builder, func.getLoc(), argument, elementType, + dataLayout.getTypeSize(elementType), + targetAlignment); } /// Returns true if the given argument or result attribute is supported by the /// inliner, false otherwise. static bool isArgOrResAttrSupported(NamedAttribute attr) { - if (attr.getName() == LLVM::LLVMDialect::getAlignAttrName()) - return false; if (attr.getName() == LLVM::LLVMDialect::getInAllocaAttrName()) return false; if (attr.getName() == LLVM::LLVMDialect::getNoAliasAttrName()) @@ -289,9 +331,19 @@ Value handleArgument(OpBuilder &builder, Operation *call, Operation *callable, Value argument, Type targetType, DictionaryAttr argumentAttrs) const final { - if (auto attr = - argumentAttrs.getNamed(LLVM::LLVMDialect::getByValAttrName())) - return handleByValArgument(builder, callable, argument, *attr); + if (std::optional attr = + argumentAttrs.getNamed(LLVM::LLVMDialect::getByValAttrName())) { + Type elementType = cast(attr->getValue()).getValue(); + unsigned requestedAlignment = 1; + if (std::optional alignAttr = + argumentAttrs.getNamed(LLVM::LLVMDialect::getAlignAttrName())) { + requestedAlignment = cast(alignAttr->getValue()) + .getValue() + .getLimitedValue(); + } + return handleByValArgument(builder, callable, argument, elementType, + requestedAlignment); + } return argument; } diff --git a/mlir/test/Dialect/LLVMIR/inlining.mlir b/mlir/test/Dialect/LLVMIR/inlining.mlir --- a/mlir/test/Dialect/LLVMIR/inlining.mlir +++ b/mlir/test/Dialect/LLVMIR/inlining.mlir @@ -399,6 +399,68 @@ // ----- +llvm.func @aligned_byval_arg(%ptr : !llvm.ptr { llvm.byval = i16, llvm.align = 16 }) attributes {memory = #llvm.memory_effects} { + llvm.return +} + +// CHECK-LABEL: llvm.func @test_byval_input_aligned +// CHECK-SAME: %[[UNALIGNED:[a-zA-Z0-9_]+]]: !llvm.ptr +// CHECK-SAME: %[[ALIGNED:[a-zA-Z0-9_]+]]: !llvm.ptr +llvm.func @test_byval_input_aligned(%unaligned : !llvm.ptr, %aligned : !llvm.ptr { llvm.align = 16 }) { + // Make sure only the unaligned input triggers a memcpy. + // CHECK: %[[ALLOCA:.+]] = llvm.alloca %{{.+}} x i16 {alignment = 16 + // CHECK: "llvm.intr.memcpy"(%[[ALLOCA]], %[[UNALIGNED]] + llvm.call @aligned_byval_arg(%unaligned) : (!llvm.ptr) -> () + // CHECK-NOT: memcpy + llvm.call @aligned_byval_arg(%aligned) : (!llvm.ptr) -> () + llvm.return +} + +// ----- + +llvm.func @aligned_byval_arg(%ptr : !llvm.ptr { llvm.byval = i16, llvm.align = 16 }) attributes {memory = #llvm.memory_effects} { + llvm.return +} + +// CHECK-LABEL: llvm.func @test_byval_alloca +llvm.func @test_byval_alloca() { + // Make sure only the unaligned alloca triggers a memcpy. + %size = llvm.mlir.constant(1 : i64) : i64 + // CHECK: %[[ALLOCA:.+]] = llvm.alloca {{.+}}alignment = 1 + // CHECK: "llvm.intr.memcpy"(%{{.+}}, %[[ALLOCA]] + %unaligned = llvm.alloca %size x i16 { alignment = 1 } : (i64) -> !llvm.ptr + llvm.call @aligned_byval_arg(%unaligned) : (!llvm.ptr) -> () + // CHECK-NOT: memcpy + %aligned = llvm.alloca %size x i16 { alignment = 16 } : (i64) -> !llvm.ptr + llvm.call @aligned_byval_arg(%aligned) : (!llvm.ptr) -> () + llvm.return +} + +// ----- + +llvm.mlir.global private @unaligned_global(42 : i64) : i64 +llvm.mlir.global private @aligned_global(42 : i64) { alignment = 64 } : i64 + +llvm.func @aligned_byval_arg(%ptr : !llvm.ptr { llvm.byval = i16, llvm.align = 16 }) attributes {memory = #llvm.memory_effects} { + llvm.return +} + +// CHECK-LABEL: llvm.func @test_byval_global +llvm.func @test_byval_global() { + // Make sure only the unaligned global triggers a memcpy. + // CHECK: %[[UNALIGNED:.+]] = llvm.mlir.addressof @unaligned_global + // CHECK: %[[ALLOCA:.+]] = llvm.alloca + // CHECK: "llvm.intr.memcpy"(%[[ALLOCA]], %[[UNALIGNED]] + // CHECK-NOT: llvm.alloca + %unaligned = llvm.mlir.addressof @unaligned_global : !llvm.ptr + llvm.call @aligned_byval_arg(%unaligned) : (!llvm.ptr) -> () + %aligned = llvm.mlir.addressof @aligned_global : !llvm.ptr + llvm.call @aligned_byval_arg(%aligned) : (!llvm.ptr) -> () + llvm.return +} + +// ----- + llvm.func @ignored_attrs(%ptr : !llvm.ptr { llvm.inreg, llvm.nocapture, llvm.nofree, llvm.preallocated = i32, llvm.returned, llvm.alignstack = 32 : i64, llvm.writeonly, llvm.noundef, llvm.nonnull }, %x : i32 { llvm.zeroext }) -> (!llvm.ptr { llvm.noundef, llvm.inreg, llvm.nonnull }) { llvm.return %ptr : !llvm.ptr } @@ -413,7 +475,7 @@ // ----- -llvm.func @disallowed_arg_attr(%ptr : !llvm.ptr { llvm.align = 16 : i32 }) { +llvm.func @disallowed_arg_attr(%ptr : !llvm.ptr { llvm.noalias }) { llvm.return }