diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -12,6 +12,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "TypeDetail.h" +#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" #include "mlir/Dialect/LLVMIR/LLVMInterfaces.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/IR/Builders.h" @@ -2854,6 +2855,38 @@ } } +static Value handleByValArgument(OpBuilder &builder, Operation *call, + Operation *callable, Value argument, + NamedAttribute byValAttribute) { + auto func = cast(callable); + LLVM::MemoryEffectsAttr memoryEffects = func.getMemoryAttr(); + // If there is no memory effects attribute, assume that the function is + // not read-only. + bool isReadOnly = + memoryEffects && memoryEffects.getArgMem() != ModRefInfo::ModRef; + if (isReadOnly) + return argument; + // Resolve the pointee type and its size. + auto ptrType = cast(argument.getType()); + auto elementType = cast(byValAttribute.getValue()).getValue(); + unsigned int typeSize = + DataLayout(call->getParentOfType()) + .getTypeSize(elementType); + // Allocate the new value on the stack. + Value one = builder.create( + func.getLoc(), builder.getI64Type(), builder.getI64IntegerAttr(1)); + Value allocaOp = + builder.create(func.getLoc(), ptrType, elementType, one); + // Copy the pointee to the newly allocated value. + Value copySize = builder.create( + func.getLoc(), builder.getI64Type(), builder.getI64IntegerAttr(typeSize)); + Value isVolatile = builder.create( + func.getLoc(), builder.getI1Type(), builder.getBoolAttr(false)); + builder.create(func.getLoc(), allocaOp, argument, copySize, + isVolatile); + return allocaOp; +} + namespace { struct LLVMInlinerInterface : public DialectInlinerInterface { using DialectInlinerInterface::DialectInlinerInterface; @@ -2866,8 +2899,19 @@ auto funcOp = dyn_cast(callable); if (!callOp || !funcOp) return false; - // TODO: Handle argument and result attributes; - if (funcOp.getArgAttrs() || funcOp.getResAttrs()) + if (auto attrs = funcOp.getArgAttrs()) { + for (Attribute attr : *attrs) { + auto attrDict = cast(attr); + for (NamedAttribute attr : attrDict) { + if (attr.getName() == LLVMDialect::getByValAttrName()) + continue; + // TODO: Handle all argument attributes; + return false; + } + } + } + // TODO: Handle result attributes; + if (funcOp.getResAttrs()) return false; // TODO: Handle exceptions. if (funcOp.getPersonality()) @@ -2942,6 +2986,14 @@ dst.replaceAllUsesWith(src); } + Value handleArgument(OpBuilder &builder, Operation *call, Operation *callable, + Value argument, Type targetType, + DictionaryAttr argumentAttrs) const final { + if (auto attr = argumentAttrs.getNamed(LLVMDialect::getByValAttrName())) + return handleByValArgument(builder, call, callable, argument, *attr); + return argument; + } + void processInlinedCallBlocks( Operation *call, iterator_range inlinedBlocks) const override { diff --git a/mlir/test/Dialect/LLVMIR/inlining.mlir b/mlir/test/Dialect/LLVMIR/inlining.mlir --- a/mlir/test/Dialect/LLVMIR/inlining.mlir +++ b/mlir/test/Dialect/LLVMIR/inlining.mlir @@ -187,20 +187,6 @@ // ----- -llvm.func @callee(%ptr : !llvm.ptr {llvm.byval = !llvm.ptr}) -> (!llvm.ptr) { - llvm.return %ptr : !llvm.ptr -} - -// CHECK-LABEL: llvm.func @caller -// CHECK-NEXT: llvm.call @callee -// CHECK-NEXT: return -llvm.func @caller(%ptr : !llvm.ptr) -> (!llvm.ptr) { - %0 = llvm.call @callee(%ptr) : (!llvm.ptr) -> (!llvm.ptr) - llvm.return %0 : !llvm.ptr -} - -// ----- - llvm.func @static_alloca() -> f32 { %0 = llvm.mlir.constant(4 : i32) : i32 %1 = llvm.alloca %0 x f32 : (i32) -> !llvm.ptr @@ -349,3 +335,32 @@ ^bb3(%blockArg: f32): llvm.return %blockArg : f32 } + +// ----- + +llvm.func @with_byval_arg(%ptr : !llvm.ptr { llvm.byval = f64 }) { + llvm.return +} + +// CHECK-LABEL: llvm.func @test_byval +// CHECK-SAME: %[[PTR:[a-zA-Z0-9_]+]]: !llvm.ptr +// CHECK: %[[ALLOCA:.+]] = llvm.alloca %{{.+}} x f64 +// CHECK: "llvm.intr.memcpy"(%[[ALLOCA]], %[[PTR]] +llvm.func @test_byval(%ptr : !llvm.ptr) { + llvm.call @with_byval_arg(%ptr) : (!llvm.ptr) -> () + llvm.return +} + +// ----- + +llvm.func @with_byval_arg(%ptr : !llvm.ptr { llvm.byval = f64 }) attributes {memory = #llvm.memory_effects} { + llvm.return +} + +// CHECK-LABEL: llvm.func @test_byval +// CHECK-NOT: llvm.call +// CHECK-NEXT: llvm.return +llvm.func @test_byval_read_only(%ptr : !llvm.ptr) { + llvm.call @with_byval_arg(%ptr) : (!llvm.ptr) -> () + llvm.return +}