diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp @@ -238,6 +238,27 @@ appendCallOpAliasScopes(call, inlinedBlocks); } +/// Appends any access groups of the call operation to any inlined memory +/// operation. +static void handleAccessGroups(Operation *call, + iterator_range inlinedBlocks) { + auto callAccessGroupInterface = dyn_cast(call); + if (!callAccessGroupInterface) + return; + + auto accessGroups = callAccessGroupInterface.getAccessGroupsOrNull(); + if (!accessGroups) + return; + + // Simply append the call op's access groups to any operation implementing + // AccessGroupOpInterface. + for (Block &block : inlinedBlocks) + for (auto accessGroupOpInterface : + block.getOps()) + accessGroupOpInterface.setAccessGroups(concatArrayAttr( + accessGroupOpInterface.getAccessGroupsOrNull(), accessGroups)); +} + /// If `requestedAlignment` is higher than the alignment specified on `alloca`, /// realigns `alloca` if this does not exceed the natural stack alignment. /// Returns the post-alignment of `alloca`, whether it was realigned or not. @@ -433,16 +454,6 @@ bool isLegalToInline(Operation *op, Region *, bool, IRMapping &) const final { if (isPure(op)) return true; - // Some attributes on memory operations require handling during - // inlining. Since this is not yet implemented, refuse to inline memory - // operations that have any of these attributes. - if (auto iface = dyn_cast(op)) { - if (iface.getAccessGroupsOrNull()) { - LLVM_DEBUG(llvm::dbgs() - << "Cannot inline: unhandled access group metadata\n"); - return false; - } - } // clang-format off if (isa inlinedBlocks) const override { handleInlinedAllocas(call, inlinedBlocks); handleAliasScopes(call, inlinedBlocks); + handleAccessGroups(call, inlinedBlocks); } // Keeping this (immutable) state on the interface allows us to look up diff --git a/mlir/test/Dialect/LLVMIR/inlining.mlir b/mlir/test/Dialect/LLVMIR/inlining.mlir --- a/mlir/test/Dialect/LLVMIR/inlining.mlir +++ b/mlir/test/Dialect/LLVMIR/inlining.mlir @@ -52,42 +52,6 @@ return %0 : i32 } -// ----- - -#group = #llvm.access_group> - -llvm.func @inlinee(%ptr : !llvm.ptr) -> i32 { - %0 = llvm.load %ptr { access_groups = [#group] } : !llvm.ptr -> i32 - llvm.return %0 : i32 -} - -// CHECK-LABEL: func @test_not_inline -llvm.func @test_not_inline(%ptr : !llvm.ptr) -> i32 { - // CHECK-NEXT: llvm.call @inlinee - %0 = llvm.call @inlinee(%ptr) : (!llvm.ptr) -> (i32) - llvm.return %0 : i32 -} - -// ----- - -#group = #llvm.access_group> - -func.func private @with_mem_attr(%ptr : !llvm.ptr) { - %0 = llvm.mlir.constant(42 : i32) : i32 - // Do not inline load/store operations that carry attributes requiring - // handling while inlining, until this is supported by the inliner. - llvm.store %0, %ptr { access_groups = [#group] }: i32, !llvm.ptr - return -} - -// CHECK-LABEL: func.func @test_not_inline -// CHECK-NEXT: call @with_mem_attr -// CHECK-NEXT: return -func.func @test_not_inline(%ptr : !llvm.ptr) { - call @with_mem_attr(%ptr) : (!llvm.ptr) -> () - return -} - // ----- // Check that llvm.return is correctly handled @@ -584,3 +548,47 @@ llvm.call @disallowed_arg_attr(%ptr) : (!llvm.ptr) -> () llvm.return } + +// ----- + +#callee = #llvm.access_group> +#caller = #llvm.access_group> + +llvm.func @inlinee(%ptr : !llvm.ptr) -> i32 { + %0 = llvm.load %ptr { access_groups = [#callee] } : !llvm.ptr -> i32 + llvm.return %0 : i32 +} + +// CHECK-DAG: #[[$CALLEE:.*]] = #llvm.access_group +// CHECK-DAG: #[[$CALLER:.*]] = #llvm.access_group + +// CHECK-LABEL: func @caller +// CHECK: llvm.load +// CHECK-SAME: access_groups = [#[[$CALLEE]], #[[$CALLER]]] +llvm.func @caller(%ptr : !llvm.ptr) -> i32 { + %0 = llvm.call @inlinee(%ptr) { access_groups = [#caller] } : (!llvm.ptr) -> (i32) + llvm.return %0 : i32 +} + +// ----- + +#caller = #llvm.access_group> + +llvm.func @inlinee(%ptr : !llvm.ptr) -> i32 { + %0 = llvm.load %ptr : !llvm.ptr -> i32 + llvm.return %0 : i32 +} + +// CHECK-DAG: #[[$CALLER:.*]] = #llvm.access_group + +// CHECK-LABEL: func @caller +// CHECK: llvm.load +// CHECK-SAME: access_groups = [#[[$CALLER]]] +// CHECK: llvm.store +// CHECK-SAME: access_groups = [#[[$CALLER]]] +llvm.func @caller(%ptr : !llvm.ptr) -> i32 { + %c5 = llvm.mlir.constant(5 : i32) : i32 + %0 = llvm.call @inlinee(%ptr) { access_groups = [#caller] } : (!llvm.ptr) -> (i32) + llvm.store %c5, %ptr { access_groups = [#caller] } : i32, !llvm.ptr + llvm.return %0 : i32 +}