diff --git a/openmp/libomptarget/plugins/remote/include/Utils.h b/openmp/libomptarget/plugins/remote/include/Utils.h --- a/openmp/libomptarget/plugins/remote/include/Utils.h +++ b/openmp/libomptarget/plugins/remote/include/Utils.h @@ -47,20 +47,39 @@ using openmp::libomptarget::remote::TargetOffloadEntry; using openmp::libomptarget::remote::TargetTable; -struct RPCConfig { +struct ClientManagerConfigTy { std::vector ServerAddresses; uint64_t MaxSize; uint64_t BlockSize; - RPCConfig() { - ServerAddresses = {"0.0.0.0:50051"}; - MaxSize = 1 << 30; - BlockSize = 1 << 20; + int Timeout; + + ClientManagerConfigTy() + : ServerAddresses({"0.0.0.0:50051"}), MaxSize(1 << 30), + BlockSize(1 << 20), Timeout(5) { + // TODO: Error handle for incorrect inputs + if (const char *Env = std::getenv("LIBOMPTARGET_RPC_ADDRESS")) { + ServerAddresses.clear(); + std::string AddressString = Env; + const std::string Delimiter = ","; + + size_t Pos; + std::string Token; + while ((Pos = AddressString.find(Delimiter)) != std::string::npos) { + Token = AddressString.substr(0, Pos); + ServerAddresses.push_back(Token); + AddressString.erase(0, Pos + Delimiter.length()); + } + ServerAddresses.push_back(AddressString); + } + if (const char *Env = std::getenv("LIBOMPTARGET_RPC_ALLOCATOR_MAX")) + MaxSize = std::stoi(Env); + if (const char *Env = std::getenv("LIBOMPTARGET_RPC_BLOCK_SIZE")) + BlockSize = std::stoi(Env); + if (const char *Env1 = std::getenv("LIBOMPTARGET_RPC_LATENCY")) + Timeout = std::stoi(Env1); } }; -/// Helper function to parse common environment variables between client/server -void parseEnvironment(RPCConfig &Config); - /// Loads a target binary description into protobuf. void loadTargetBinaryDescription(const __tgt_bin_desc *Desc, TargetBinaryDescription &Request); diff --git a/openmp/libomptarget/plugins/remote/include/openmp.proto b/openmp/libomptarget/plugins/remote/include/openmp.proto --- a/openmp/libomptarget/plugins/remote/include/openmp.proto +++ b/openmp/libomptarget/plugins/remote/include/openmp.proto @@ -16,19 +16,18 @@ rpc InitRequires(I64) returns (I32) {} rpc LoadBinary(Binary) returns (TargetTable) {} - rpc Synchronize(SynchronizeDevice) returns (I32) {} rpc DataAlloc(AllocData) returns (Pointer) {} rpc DataDelete(DeleteData) returns (I32) {} - rpc DataSubmitAsync(stream SubmitDataAsync) returns (I32) {} - rpc DataRetrieveAsync(RetrieveDataAsync) returns (stream Data) {} + rpc DataSubmit(stream SubmitData) returns (I32) {} + rpc DataRetrieve(RetrieveData) returns (stream Data) {} rpc IsDataExchangeable(DevicePair) returns (I32) {} - rpc DataExchangeAsync(ExchangeDataAsync) returns (I32) {} + rpc DataExchange(ExchangeData) returns (I32) {} - rpc RunTargetRegionAsync(TargetRegionAsync) returns (I32) {} - rpc RunTargetTeamRegionAsync(TargetTeamRegionAsync) returns (I32) {} + rpc RunTargetRegion(TargetRegion) returns (I32) {} + rpc RunTargetTeamRegion(TargetTeamRegion) returns (I32) {} } message Null {} @@ -92,32 +91,25 @@ uint64 bin_ptr = 5; } -message SynchronizeDevice { - uint64 queue_ptr = 1; - int32 device_id = 2; -} - message AllocData { uint64 size = 1; uint64 hst_ptr = 2; int32 device_id = 3; } -message SubmitDataAsync { +message SubmitData { bytes data = 1; uint64 hst_ptr = 2; uint64 tgt_ptr = 3; - uint64 queue_ptr = 4; uint64 start = 5; uint64 size = 6; int32 device_id = 7; } -message RetrieveDataAsync { +message RetrieveData { uint64 hst_ptr = 1; uint64 tgt_ptr = 2; uint64 size = 3; - uint64 queue_ptr = 4; int32 device_id = 5; } @@ -128,12 +120,11 @@ int32 ret = 4; } -message ExchangeDataAsync { +message ExchangeData { uint64 src_dev_id = 1; uint64 src_ptr = 2; uint64 dst_dev_id = 3; uint64 dst_ptr = 4; - uint64 queue_ptr = 5; uint64 size = 6; } @@ -142,23 +133,21 @@ int32 device_id = 2; } -message TargetRegionAsync { +message TargetRegion { repeated uint64 tgt_args = 1; repeated int64 tgt_offsets = 2; uint64 tgt_entry_ptr = 3; - uint64 queue_ptr = 4; - int32 device_id = 5; - int32 arg_num = 6; + int32 device_id = 4; + int32 arg_num = 5; } -message TargetTeamRegionAsync { +message TargetTeamRegion { repeated uint64 tgt_args = 1; repeated int64 tgt_offsets = 2; uint64 tgt_entry_ptr = 3; uint64 loop_tripcount = 4; - uint64 queue_ptr = 5; - int32 device_id = 6; - int32 arg_num = 7; - int32 team_num = 8; - int32 thread_limit = 9; + int32 device_id = 5; + int32 arg_num = 6; + int32 team_num = 7; + int32 thread_limit = 8; } diff --git a/openmp/libomptarget/plugins/remote/lib/Utils.cpp b/openmp/libomptarget/plugins/remote/lib/Utils.cpp --- a/openmp/libomptarget/plugins/remote/lib/Utils.cpp +++ b/openmp/libomptarget/plugins/remote/lib/Utils.cpp @@ -14,27 +14,6 @@ #include "omptarget.h" namespace RemoteOffloading { -void parseEnvironment(RPCConfig &Config) { - // TODO: Error handle for incorrect inputs - if (const char *Env = std::getenv("LIBOMPTARGET_RPC_ADDRESS")) { - Config.ServerAddresses.clear(); - std::string AddressString = Env; - const std::string Delimiter = ","; - - size_t Pos = 0; - std::string Token; - while ((Pos = AddressString.find(Delimiter)) != std::string::npos) { - Token = AddressString.substr(0, Pos); - Config.ServerAddresses.push_back(Token); - AddressString.erase(0, Pos + Delimiter.length()); - } - Config.ServerAddresses.push_back(AddressString); - } - if (const char *Env = std::getenv("LIBOMPTARGET_RPC_ALLOCATOR_MAX")) - Config.MaxSize = std::stoi(Env); - if (const char *Env = std::getenv("LIBOMPTARGET_RPC_BLOCK_SIZE")) - Config.BlockSize = std::stoi(Env); -} void loadTargetBinaryDescription(const __tgt_bin_desc *Desc, TargetBinaryDescription &Request) { @@ -101,10 +80,12 @@ // Copy Global Offload Entries __tgt_offload_entry *CurEntry = Desc->HostEntriesBegin; - for (int i = 0; i < Request->entries_size(); i++) { - copyOffloadEntry(Request->entries()[i], CurEntry); - CopiedOffloadEntries[(void *)Request->entry_ptrs()[i]] = CurEntry; + size_t I = 0; + for (auto &Entry : Request->entries()) { + copyOffloadEntry(Entry, CurEntry); + CopiedOffloadEntries[(void *)Request->entry_ptrs()[I]] = CurEntry; CurEntry++; + I++; } Desc->HostEntriesEnd = CurEntry; @@ -113,7 +94,7 @@ auto ImageItr = Request->image_ptrs().begin(); for (auto Image : Request->images()) { // Copy Device Offload Entries - auto *CurEntry = Desc->HostEntriesBegin; + CurEntry = Desc->HostEntriesBegin; bool Found = false; if (!Desc->HostEntriesBegin) { @@ -121,21 +102,19 @@ CurImage->EntriesEnd = nullptr; } - for (int i = 0; i < Image.entries_size(); i++) { + for (size_t I = 0; I < Image.entries_size(); I++) { auto TgtEntry = - CopiedOffloadEntries.find((void *)Request->entry_ptrs()[i]); + CopiedOffloadEntries.find((void *)Request->entry_ptrs()[I]); if (TgtEntry != CopiedOffloadEntries.end()) { if (!Found) CurImage->EntriesBegin = CurEntry; + CurImage->EntriesEnd = CurEntry + 1; Found = true; - if (Found) { - CurImage->EntriesEnd = CurEntry + 1; - } } else { Found = false; - copyOffloadEntry(Image.entries()[i], CurEntry); - CopiedOffloadEntries[(void *)(Request->entry_ptrs()[i])] = CurEntry; + copyOffloadEntry(Image.entries()[I], CurEntry); + CopiedOffloadEntries[(void *)(Request->entry_ptrs()[I])] = CurEntry; } CurEntry++; } @@ -199,10 +178,10 @@ Table->EntriesBegin = new __tgt_offload_entry[TableResponse.entries_size()]; auto *CurEntry = Table->EntriesBegin; - for (int i = 0; i < TableResponse.entries_size(); i++) { - copyOffloadEntry(TableResponse.entries()[i], CurEntry); + for (size_t I = 0; I < TableResponse.entries_size(); I++) { + copyOffloadEntry(TableResponse.entries()[I], CurEntry); HostToRemoteTargetTableMap[CurEntry->addr] = - (void *)TableResponse.entry_ptrs()[i]; + (void *)TableResponse.entry_ptrs()[I]; CurEntry++; } Table->EntriesEnd = CurEntry; @@ -292,10 +271,10 @@ void dump(TargetOffloadEntry Entry) { fprintf(stderr, "Entry: "); - fprintf(stderr, " %s\n", Entry.name().c_str()); - fprintf(stderr, " %d\n", Entry.reserved()); - fprintf(stderr, " %d\n", Entry.flags()); - fprintf(stderr, " %ld\n", Entry.data().size()); + fprintf(stderr, " Name: %s\n", Entry.name().c_str()); + fprintf(stderr, " Reserved: %d\n", Entry.reserved()); + fprintf(stderr, " Flags: %d\n", Entry.flags()); + fprintf(stderr, " Size: %ld\n", Entry.data().size()); dump(static_cast(Entry.data().data()), static_cast((Entry.data().c_str() + Entry.data().size()))); } diff --git a/openmp/libomptarget/plugins/remote/server/OffloadingServer.cpp b/openmp/libomptarget/plugins/remote/server/OffloadingServer.cpp --- a/openmp/libomptarget/plugins/remote/server/OffloadingServer.cpp +++ b/openmp/libomptarget/plugins/remote/server/OffloadingServer.cpp @@ -24,8 +24,7 @@ std::promise ShutdownPromise; int main() { - RPCConfig Config; - parseEnvironment(Config); + ClientManagerConfigTy Config; RemoteOffloadImpl Service(Config.MaxSize, Config.BlockSize); diff --git a/openmp/libomptarget/plugins/remote/server/Server.h b/openmp/libomptarget/plugins/remote/server/Server.h --- a/openmp/libomptarget/plugins/remote/server/Server.h +++ b/openmp/libomptarget/plugins/remote/server/Server.h @@ -40,8 +40,6 @@ std::unordered_map HostToRemoteDeviceImage; - std::unordered_map - HostToRemoteOffloadEntry; std::unordered_map> Descriptions; __tgt_target_table *Table = nullptr; @@ -80,35 +78,29 @@ Status LoadBinary(ServerContext *Context, const Binary *Binary, TargetTable *Reply) override; - Status Synchronize(ServerContext *Context, const SynchronizeDevice *Info, - I32 *Reply) override; Status IsDataExchangeable(ServerContext *Context, const DevicePair *Request, I32 *Reply) override; Status DataAlloc(ServerContext *Context, const AllocData *Request, Pointer *Reply) override; - Status DataSubmitAsync(ServerContext *Context, - ServerReader *Reader, - I32 *Reply) override; - Status DataRetrieveAsync(ServerContext *Context, - const RetrieveDataAsync *Request, - ServerWriter *Writer) override; + Status DataSubmit(ServerContext *Context, ServerReader *Reader, + I32 *Reply) override; + Status DataRetrieve(ServerContext *Context, const RetrieveData *Request, + ServerWriter *Writer) override; - Status DataExchangeAsync(ServerContext *Context, - const ExchangeDataAsync *Request, - I32 *Reply) override; + Status DataExchange(ServerContext *Context, const ExchangeData *Request, + I32 *Reply) override; Status DataDelete(ServerContext *Context, const DeleteData *Request, I32 *Reply) override; - Status RunTargetRegionAsync(ServerContext *Context, - const TargetRegionAsync *Request, - I32 *Reply) override; + Status RunTargetRegion(ServerContext *Context, const TargetRegion *Request, + I32 *Reply) override; - Status RunTargetTeamRegionAsync(ServerContext *Context, - const TargetTeamRegionAsync *Request, - I32 *Reply) override; + Status RunTargetTeamRegion(ServerContext *Context, + const TargetTeamRegion *Request, + I32 *Reply) override; }; #endif diff --git a/openmp/libomptarget/plugins/remote/server/Server.cpp b/openmp/libomptarget/plugins/remote/server/Server.cpp --- a/openmp/libomptarget/plugins/remote/server/Server.cpp +++ b/openmp/libomptarget/plugins/remote/server/Server.cpp @@ -24,7 +24,7 @@ Status RemoteOffloadImpl::Shutdown(ServerContext *Context, const Null *Request, I32 *Reply) { - SERVER_DBG("Shutting down the server"); + SERVER_DBG("Shutting down the server") Reply->set_number(0); ShutdownPromise.set_value(); @@ -35,8 +35,6 @@ RemoteOffloadImpl::RegisterLib(ServerContext *Context, const TargetBinaryDescription *Description, I32 *Reply) { - SERVER_DBG("Registering library"); - auto Desc = std::make_unique<__tgt_bin_desc>(); unloadTargetBinaryDescription(Description, Desc.get(), @@ -49,15 +47,13 @@ else Descriptions[(void *)Description->bin_ptr()] = std::move(Desc); - SERVER_DBG("Registered library"); + SERVER_DBG("Registered library") Reply->set_number(0); return Status::OK; } Status RemoteOffloadImpl::UnregisterLib(ServerContext *Context, const Pointer *Request, I32 *Reply) { - SERVER_DBG("Unregistering library"); - if (Descriptions.find((void *)Request->number()) == Descriptions.end()) { Reply->set_number(1); return Status::OK; @@ -67,7 +63,7 @@ freeTargetBinaryDescription(Descriptions[(void *)Request->number()].get()); Descriptions.erase((void *)Request->number()); - SERVER_DBG("Unregistered library"); + SERVER_DBG("Unregistered library") Reply->set_number(0); return Status::OK; } @@ -75,9 +71,6 @@ Status RemoteOffloadImpl::IsValidBinary(ServerContext *Context, const TargetDeviceImagePtr *DeviceImage, I32 *IsValid) { - SERVER_DBG("Checking if binary (%p) is valid", - (void *)(DeviceImage->image_ptr())); - __tgt_device_image *Image = HostToRemoteDeviceImage[(void *)DeviceImage->image_ptr()]; @@ -90,14 +83,13 @@ } SERVER_DBG("Checked if binary (%p) is valid", - (void *)(DeviceImage->image_ptr())); + (void *)(DeviceImage->image_ptr())) return Status::OK; } Status RemoteOffloadImpl::GetNumberOfDevices(ServerContext *Context, const Null *Null, I32 *NumberOfDevices) { - SERVER_DBG("Getting number of devices"); std::call_once(PM->RTLs.initFlag, &RTLsTy::LoadRTLs, &PM->RTLs); int32_t Devices = 0; @@ -108,39 +100,32 @@ NumberOfDevices->set_number(Devices); - SERVER_DBG("Got number of devices"); + SERVER_DBG("Got number of devices") return Status::OK; } Status RemoteOffloadImpl::InitDevice(ServerContext *Context, const I32 *DeviceNum, I32 *Reply) { - SERVER_DBG("Initializing device %d", DeviceNum->number()); - Reply->set_number(PM->Devices[DeviceNum->number()].RTL->init_device( mapHostRTLDeviceId(DeviceNum->number()))); - SERVER_DBG("Initialized device %d", DeviceNum->number()); + SERVER_DBG("Initialized device %d", DeviceNum->number()) return Status::OK; } Status RemoteOffloadImpl::InitRequires(ServerContext *Context, const I64 *RequiresFlag, I32 *Reply) { - SERVER_DBG("Initializing requires for devices"); - for (auto &Device : PM->Devices) if (Device.RTL->init_requires) Device.RTL->init_requires(RequiresFlag->number()); Reply->set_number(RequiresFlag->number()); - SERVER_DBG("Initialized requires for devices"); + SERVER_DBG("Initialized requires for devices") return Status::OK; } Status RemoteOffloadImpl::LoadBinary(ServerContext *Context, const Binary *Binary, TargetTable *Reply) { - SERVER_DBG("Loading binary (%p) to device %d", (void *)Binary->image_ptr(), - Binary->device_id()); - __tgt_device_image *Image = HostToRemoteDeviceImage[(void *)Binary->image_ptr()]; @@ -150,32 +135,13 @@ loadTargetTable(Table, *Reply, Image); SERVER_DBG("Loaded binary (%p) to device %d", (void *)Binary->image_ptr(), - Binary->device_id()); - return Status::OK; -} - -Status RemoteOffloadImpl::Synchronize(ServerContext *Context, - const SynchronizeDevice *Info, - I32 *Reply) { - SERVER_DBG("Synchronizing device %d (probably won't work)", - Info->device_id()); - - void *AsyncInfo = (void *)Info->queue_ptr(); - Reply->set_number(0); - if (PM->Devices[Info->device_id()].RTL->synchronize) - Reply->set_number(PM->Devices[Info->device_id()].synchronize( - (__tgt_async_info *)AsyncInfo)); - - SERVER_DBG("Synchronized device %d", Info->device_id()); + Binary->device_id()) return Status::OK; } Status RemoteOffloadImpl::IsDataExchangeable(ServerContext *Context, const DevicePair *Request, I32 *Reply) { - SERVER_DBG("Checking if data exchangable between device %d and device %d", - Request->src_dev_id(), Request->dst_dev_id()); - Reply->set_number(-1); if (PM->Devices[mapHostRTLDeviceId(Request->src_dev_id())] .RTL->is_data_exchangable) @@ -183,40 +149,30 @@ .RTL->is_data_exchangable(Request->src_dev_id(), Request->dst_dev_id())); - SERVER_DBG("Checked if data exchangable between device %d and device %d", - Request->src_dev_id(), Request->dst_dev_id()); + SERVER_DBG("Checked if data exchangeable between device %d and device %d", + Request->src_dev_id(), Request->dst_dev_id()) return Status::OK; } Status RemoteOffloadImpl::DataAlloc(ServerContext *Context, const AllocData *Request, Pointer *Reply) { - SERVER_DBG("Allocating %ld bytes on sevice %d", Request->size(), - Request->device_id()); - uint64_t TgtPtr = (uint64_t)PM->Devices[Request->device_id()].RTL->data_alloc( mapHostRTLDeviceId(Request->device_id()), Request->size(), - (void *)Request->hst_ptr()); + (void *)Request->hst_ptr(), TARGET_ALLOC_DEFAULT); Reply->set_number(TgtPtr); - SERVER_DBG("Allocated at " DPxMOD "", DPxPTR((void *)TgtPtr)); + SERVER_DBG("Allocated at " DPxMOD "", DPxPTR((void *)TgtPtr)) return Status::OK; } -Status RemoteOffloadImpl::DataSubmitAsync(ServerContext *Context, - ServerReader *Reader, - I32 *Reply) { - SubmitDataAsync Request; +Status RemoteOffloadImpl::DataSubmit(ServerContext *Context, + ServerReader *Reader, + I32 *Reply) { + SubmitData Request; uint8_t *HostCopy = nullptr; while (Reader->Read(&Request)) { if (Request.start() == 0 && Request.size() == Request.data().size()) { - SERVER_DBG("Submitting %lu bytes async to (%p) on device %d", - Request.data().size(), (void *)Request.tgt_ptr(), - Request.device_id()); - - SERVER_DBG(" Host Pointer Info: %p, %p", (void *)Request.hst_ptr(), - static_cast(Request.data().data())); - Reader->SendInitialMetadata(); Reply->set_number(PM->Devices[Request.device_id()].RTL->data_submit( @@ -225,7 +181,7 @@ SERVER_DBG("Submitted %lu bytes async to (%p) on device %d", Request.data().size(), (void *)Request.tgt_ptr(), - Request.device_id()); + Request.device_id()) return Status::OK; } @@ -234,15 +190,9 @@ Reader->SendInitialMetadata(); } - SERVER_DBG("Submitting %lu-%lu/%lu bytes async to (%p) on device %d", - Request.start(), Request.start() + Request.data().size(), - Request.size(), (void *)Request.tgt_ptr(), Request.device_id()); - memcpy((void *)((char *)HostCopy + Request.start()), Request.data().data(), Request.data().size()); } - SERVER_DBG(" Host Pointer Info: %p, %p", (void *)Request.hst_ptr(), - static_cast(Request.data().data())); Reply->set_number(PM->Devices[Request.device_id()].RTL->data_submit( mapHostRTLDeviceId(Request.device_id()), (void *)Request.tgt_ptr(), @@ -251,15 +201,16 @@ delete[] HostCopy; SERVER_DBG("Submitted %lu bytes to (%p) on device %d", Request.data().size(), - (void *)Request.tgt_ptr(), Request.device_id()); + (void *)Request.tgt_ptr(), Request.device_id()) return Status::OK; } -Status RemoteOffloadImpl::DataRetrieveAsync(ServerContext *Context, - const RetrieveDataAsync *Request, - ServerWriter *Writer) { +Status RemoteOffloadImpl::DataRetrieve(ServerContext *Context, + const RetrieveData *Request, + ServerWriter *Writer) { auto HstPtr = std::make_unique(Request->size()); + auto Ret = PM->Devices[Request->device_id()].RTL->data_retrieve( mapHostRTLDeviceId(Request->device_id()), HstPtr.get(), (void *)Request->tgt_ptr(), Request->size()); @@ -277,17 +228,13 @@ Reply->set_data((char *)HstPtr.get() + Start, End - Start); Reply->set_ret(Ret); - SERVER_DBG("Retrieving %lu-%lu/%lu bytes from (%p) on device %d", Start, - End, Request->size(), (void *)Request->tgt_ptr(), - mapHostRTLDeviceId(Request->device_id())); - if (!Writer->Write(*Reply)) { - CLIENT_DBG("Broken stream when submitting data"); + CLIENT_DBG("Broken stream when submitting data") } SERVER_DBG("Retrieved %lu-%lu/%lu bytes from (%p) on device %d", Start, End, Request->size(), (void *)Request->tgt_ptr(), - mapHostRTLDeviceId(Request->device_id())); + mapHostRTLDeviceId(Request->device_id())) Start += BlockSize; End += BlockSize; @@ -297,10 +244,6 @@ } else { auto *Reply = protobuf::Arena::CreateMessage(Arena.get()); - SERVER_DBG("Retrieve %lu bytes from (%p) on device %d", Request->size(), - (void *)Request->tgt_ptr(), - mapHostRTLDeviceId(Request->device_id())); - Reply->set_start(0); Reply->set_size(Request->size()); Reply->set_data((char *)HstPtr.get(), Request->size()); @@ -308,7 +251,7 @@ SERVER_DBG("Retrieved %lu bytes from (%p) on device %d", Request->size(), (void *)Request->tgt_ptr(), - mapHostRTLDeviceId(Request->device_id())); + mapHostRTLDeviceId(Request->device_id())) Writer->WriteLast(*Reply, WriteOptions()); } @@ -316,16 +259,9 @@ return Status::OK; } -Status RemoteOffloadImpl::DataExchangeAsync(ServerContext *Context, - const ExchangeDataAsync *Request, - I32 *Reply) { - SERVER_DBG( - "Exchanging data asynchronously from device %d (%p) to device %d (%p) of " - "size %lu", - mapHostRTLDeviceId(Request->src_dev_id()), (void *)Request->src_ptr(), - mapHostRTLDeviceId(Request->dst_dev_id()), (void *)Request->dst_ptr(), - Request->size()); - +Status RemoteOffloadImpl::DataExchange(ServerContext *Context, + const ExchangeData *Request, + I32 *Reply) { if (PM->Devices[Request->src_dev_id()].RTL->data_exchange) { int32_t Ret = PM->Devices[Request->src_dev_id()].RTL->data_exchange( mapHostRTLDeviceId(Request->src_dev_id()), (void *)Request->src_ptr(), @@ -340,30 +276,24 @@ "size %lu", mapHostRTLDeviceId(Request->src_dev_id()), (void *)Request->src_ptr(), mapHostRTLDeviceId(Request->dst_dev_id()), (void *)Request->dst_ptr(), - Request->size()); + Request->size()) return Status::OK; } Status RemoteOffloadImpl::DataDelete(ServerContext *Context, const DeleteData *Request, I32 *Reply) { - SERVER_DBG("Deleting data from (%p) on device %d", (void *)Request->tgt_ptr(), - mapHostRTLDeviceId(Request->device_id())); - auto Ret = PM->Devices[Request->device_id()].RTL->data_delete( mapHostRTLDeviceId(Request->device_id()), (void *)Request->tgt_ptr()); Reply->set_number(Ret); SERVER_DBG("Deleted data from (%p) on device %d", (void *)Request->tgt_ptr(), - mapHostRTLDeviceId(Request->device_id())); + mapHostRTLDeviceId(Request->device_id())) return Status::OK; } -Status RemoteOffloadImpl::RunTargetRegionAsync(ServerContext *Context, - const TargetRegionAsync *Request, - I32 *Reply) { - SERVER_DBG("Running TargetRegionAsync on device %d with %d args", - mapHostRTLDeviceId(Request->device_id()), Request->arg_num()); - +Status RemoteOffloadImpl::RunTargetRegion(ServerContext *Context, + const TargetRegion *Request, + I32 *Reply) { std::vector TgtArgs(Request->arg_num()); for (auto I = 0; I < Request->arg_num(); I++) TgtArgs[I] = (uint64_t)Request->tgt_args()[I]; @@ -381,16 +311,14 @@ Reply->set_number(Ret); - SERVER_DBG("Ran TargetRegionAsync on device %d with %d args", - mapHostRTLDeviceId(Request->device_id()), Request->arg_num()); + SERVER_DBG("Ran TargetRegion on device %d with %d args", + mapHostRTLDeviceId(Request->device_id()), Request->arg_num()) return Status::OK; } -Status RemoteOffloadImpl::RunTargetTeamRegionAsync( - ServerContext *Context, const TargetTeamRegionAsync *Request, I32 *Reply) { - SERVER_DBG("Running TargetTeamRegionAsync on device %d with %d args", - mapHostRTLDeviceId(Request->device_id()), Request->arg_num()); - +Status RemoteOffloadImpl::RunTargetTeamRegion(ServerContext *Context, + const TargetTeamRegion *Request, + I32 *Reply) { std::vector TgtArgs(Request->arg_num()); for (auto I = 0; I < Request->arg_num(); I++) TgtArgs[I] = (uint64_t)Request->tgt_args()[I]; @@ -401,6 +329,7 @@ TgtOffsets[I] = (ptrdiff_t)*TgtOffsetItr; void *TgtEntryPtr = ((__tgt_offload_entry *)Request->tgt_entry_ptr())->addr; + int32_t Ret = PM->Devices[Request->device_id()].RTL->run_team_region( mapHostRTLDeviceId(Request->device_id()), TgtEntryPtr, (void **)TgtArgs.data(), TgtOffsets.data(), Request->arg_num(), @@ -408,8 +337,8 @@ Reply->set_number(Ret); - SERVER_DBG("Ran TargetTeamRegionAsync on device %d with %d args", - mapHostRTLDeviceId(Request->device_id()), Request->arg_num()); + SERVER_DBG("Ran TargetTeamRegion on device %d with %d args", + mapHostRTLDeviceId(Request->device_id()), Request->arg_num()) return Status::OK; } diff --git a/openmp/libomptarget/plugins/remote/src/Client.h b/openmp/libomptarget/plugins/remote/src/Client.h --- a/openmp/libomptarget/plugins/remote/src/Client.h +++ b/openmp/libomptarget/plugins/remote/src/Client.h @@ -30,11 +30,10 @@ using namespace google; class RemoteOffloadClient { - const int Timeout; - int DebugLevel; - uint64_t MaxSize; - int64_t BlockSize; + const int Timeout; + const uint64_t MaxSize; + const int64_t BlockSize; std::unique_ptr Stub; std::unique_ptr Arena; @@ -45,8 +44,8 @@ std::map> DevicesToTables; template - auto remoteCall(Fn1 Preprocess, Fn2 Postprocess, TReturn ErrorValue, - bool Timeout = true); + auto remoteCall(Fn1 Preprocessor, Fn2 Postprocessor, TReturn ErrorValue, + bool CanTimeOut = true); public: RemoteOffloadClient(std::shared_ptr Channel, int Timeout, @@ -77,35 +76,29 @@ int32_t initRequires(int64_t RequiresFlags); __tgt_target_table *loadBinary(int32_t DeviceId, __tgt_device_image *Image); - int64_t synchronize(int32_t DeviceId, __tgt_async_info *AsyncInfo); - int32_t isDataExchangeable(int32_t SrcDevId, int32_t DstDevId); void *dataAlloc(int32_t DeviceId, int64_t Size, void *HstPtr); int32_t dataDelete(int32_t DeviceId, void *TgtPtr); - int32_t dataSubmitAsync(int32_t DeviceId, void *TgtPtr, void *HstPtr, - int64_t Size, __tgt_async_info *AsyncInfo); - int32_t dataRetrieveAsync(int32_t DeviceId, void *HstPtr, void *TgtPtr, - int64_t Size, __tgt_async_info *AsyncInfo); + int32_t dataSubmit(int32_t DeviceId, void *TgtPtr, void *HstPtr, + int64_t Size); + int32_t dataRetrieve(int32_t DeviceId, void *HstPtr, void *TgtPtr, + int64_t Size); - int32_t dataExchangeAsync(int32_t SrcDevId, void *SrcPtr, int32_t DstDevId, - void *DstPtr, int64_t Size, - __tgt_async_info *AsyncInfo); - - int32_t runTargetRegionAsync(int32_t DeviceId, void *TgtEntryPtr, - void **TgtArgs, ptrdiff_t *TgtOffsets, - int32_t ArgNum, __tgt_async_info *AsyncInfo); - - int32_t runTargetTeamRegionAsync(int32_t DeviceId, void *TgtEntryPtr, - void **TgtArgs, ptrdiff_t *TgtOffsets, - int32_t ArgNum, int32_t TeamNum, - int32_t ThreadLimit, uint64_t LoopTripCount, - __tgt_async_info *AsyncInfo); + int32_t isDataExchangeable(int32_t SrcDevId, int32_t DstDevId); + int32_t dataExchange(int32_t SrcDevId, void *SrcPtr, int32_t DstDevId, + void *DstPtr, int64_t Size); + + int32_t runTargetRegion(int32_t DeviceId, void *TgtEntryPtr, void **TgtArgs, + ptrdiff_t *TgtOffsets, int32_t ArgNum); + int32_t runTargetTeamRegion(int32_t DeviceId, void *TgtEntryPtr, + void **TgtArgs, ptrdiff_t *TgtOffsets, + int32_t ArgNum, int32_t TeamNum, + int32_t ThreadLimit, uint64_t LoopTripCount); }; class RemoteClientManager { private: - std::vector Addresses; std::vector Clients; std::vector Devices; @@ -113,16 +106,16 @@ int DebugLevel; public: - RemoteClientManager(std::vector Addresses, int Timeout, - uint64_t MaxSize, int64_t BlockSize) - : Addresses(Addresses) { + RemoteClientManager() { + ClientManagerConfigTy Config; + grpc::ChannelArguments ChArgs; ChArgs.SetMaxReceiveMessageSize(-1); DebugLevel = getDebugLevel(); - for (auto Address : Addresses) { + for (auto Address : Config.ServerAddresses) { Clients.push_back(RemoteOffloadClient( grpc::CreateChannel(Address, grpc::InsecureChannelCredentials()), - Timeout, MaxSize, BlockSize)); + Config.Timeout, Config.MaxSize, Config.BlockSize)); } } @@ -138,30 +131,25 @@ int32_t initRequires(int64_t RequiresFlags); __tgt_target_table *loadBinary(int32_t DeviceId, __tgt_device_image *Image); - int64_t synchronize(int32_t DeviceId, __tgt_async_info *AsyncInfo); - int32_t isDataExchangeable(int32_t SrcDevId, int32_t DstDevId); void *dataAlloc(int32_t DeviceId, int64_t Size, void *HstPtr); int32_t dataDelete(int32_t DeviceId, void *TgtPtr); - int32_t dataSubmitAsync(int32_t DeviceId, void *TgtPtr, void *HstPtr, - int64_t Size, __tgt_async_info *AsyncInfo); - int32_t dataRetrieveAsync(int32_t DeviceId, void *HstPtr, void *TgtPtr, - int64_t Size, __tgt_async_info *AsyncInfo); + int32_t dataSubmit(int32_t DeviceId, void *TgtPtr, void *HstPtr, + int64_t Size); + int32_t dataRetrieve(int32_t DeviceId, void *HstPtr, void *TgtPtr, + int64_t Size); - int32_t dataExchangeAsync(int32_t SrcDevId, void *SrcPtr, int32_t DstDevId, - void *DstPtr, int64_t Size, - __tgt_async_info *AsyncInfo); - - int32_t runTargetRegionAsync(int32_t DeviceId, void *TgtEntryPtr, - void **TgtArgs, ptrdiff_t *TgtOffsets, - int32_t ArgNum, __tgt_async_info *AsyncInfo); - - int32_t runTargetTeamRegionAsync(int32_t DeviceId, void *TgtEntryPtr, - void **TgtArgs, ptrdiff_t *TgtOffsets, - int32_t ArgNum, int32_t TeamNum, - int32_t ThreadLimit, uint64_t LoopTripCount, - __tgt_async_info *AsyncInfo); + int32_t isDataExchangeable(int32_t SrcDevId, int32_t DstDevId); + int32_t dataExchange(int32_t SrcDevId, void *SrcPtr, int32_t DstDevId, + void *DstPtr, int64_t Size); + + int32_t runTargetRegion(int32_t DeviceId, void *TgtEntryPtr, void **TgtArgs, + ptrdiff_t *TgtOffsets, int32_t ArgNum); + int32_t runTargetTeamRegion(int32_t DeviceId, void *TgtEntryPtr, + void **TgtArgs, ptrdiff_t *TgtOffsets, + int32_t ArgNum, int32_t TeamNum, + int32_t ThreadLimit, uint64_t LoopTripCount); }; #endif diff --git a/openmp/libomptarget/plugins/remote/src/Client.cpp b/openmp/libomptarget/plugins/remote/src/Client.cpp --- a/openmp/libomptarget/plugins/remote/src/Client.cpp +++ b/openmp/libomptarget/plugins/remote/src/Client.cpp @@ -24,31 +24,30 @@ using grpc::Status; template -auto RemoteOffloadClient::remoteCall(Fn1 Preprocess, Fn2 Postprocess, - TReturn ErrorValue, bool Timeout) { +auto RemoteOffloadClient::remoteCall(Fn1 Preprocessor, Fn2 Postprocessor, + TReturn ErrorValue, bool CanTimeOut) { ArenaAllocatorLock->lock(); if (Arena->SpaceAllocated() >= MaxSize) Arena->Reset(); ArenaAllocatorLock->unlock(); ClientContext Context; - if (Timeout) { + if (CanTimeOut) { auto Deadline = std::chrono::system_clock::now() + std::chrono::seconds(Timeout); Context.set_deadline(Deadline); } Status RPCStatus; - auto Reply = Preprocess(RPCStatus, Context); + auto Reply = Preprocessor(RPCStatus, Context); - // TODO: Error handle more appropriately if (!RPCStatus.ok()) { - CLIENT_DBG("%s", RPCStatus.error_message().c_str()); + CLIENT_DBG("%s", RPCStatus.error_message().c_str()) } else { - return Postprocess(Reply); + return Postprocessor(Reply); } - CLIENT_DBG("Failed"); + CLIENT_DBG("Failed") return ErrorValue; } @@ -56,7 +55,7 @@ ClientContext Context; Null Request; I32 Reply; - CLIENT_DBG("Shutting down server."); + CLIENT_DBG("Shutting down server.") auto Status = Stub->Shutdown(&Context, Request, &Reply); if (Status.ok()) return Reply.number(); @@ -65,7 +64,7 @@ int32_t RemoteOffloadClient::registerLib(__tgt_bin_desc *Desc) { return remoteCall( - /* Preprocess */ + /* Preprocessor */ [&](auto &RPCStatus, auto &Context) { auto *Request = protobuf::Arena::CreateMessage( Arena.get()); @@ -73,14 +72,13 @@ loadTargetBinaryDescription(Desc, *Request); Request->set_bin_ptr((uint64_t)Desc); - CLIENT_DBG("Registering library"); RPCStatus = Stub->RegisterLib(&Context, *Request, Reply); return Reply; }, - /* Postprocess */ + /* Postprocessor */ [&](const auto &Reply) { if (Reply->number() == 0) { - CLIENT_DBG("Registered library"); + CLIENT_DBG("Registered library") return 0; } return 1; @@ -90,24 +88,23 @@ int32_t RemoteOffloadClient::unregisterLib(__tgt_bin_desc *Desc) { return remoteCall( - /* Preprocess */ + /* Preprocessor */ [&](auto &RPCStatus, auto &Context) { auto *Request = protobuf::Arena::CreateMessage(Arena.get()); auto *Reply = protobuf::Arena::CreateMessage(Arena.get()); Request->set_number((uint64_t)Desc); - CLIENT_DBG("Unregistering library"); RPCStatus = Stub->UnregisterLib(&Context, *Request, Reply); return Reply; }, - /* Postprocess */ + /* Postprocessor */ [&](const auto &Reply) { if (Reply->number() == 0) { - CLIENT_DBG("Unregistered library"); + CLIENT_DBG("Unregistered library") return 0; } - CLIENT_DBG("Failed to unregister library"); + CLIENT_DBG("Failed to unregister library") return 1; }, /* Error Value */ 1); @@ -115,7 +112,7 @@ int32_t RemoteOffloadClient::isValidBinary(__tgt_device_image *Image) { return remoteCall( - /* Preprocess */ + /* Preprocessor */ [&](auto &RPCStatus, auto &Context) { auto *Request = protobuf::Arena::CreateMessage(Arena.get()); @@ -127,16 +124,15 @@ while (EntryItr != Image->EntriesEnd) Request->add_entry_ptrs((uint64_t)EntryItr++); - CLIENT_DBG("Validating binary"); RPCStatus = Stub->IsValidBinary(&Context, *Request, Reply); return Reply; }, - /* Postprocess */ + /* Postprocessor */ [&](const auto &Reply) { if (Reply->number()) { - CLIENT_DBG("Validated binary"); + CLIENT_DBG("Validated binary") } else { - CLIENT_DBG("Could not validate binary"); + CLIENT_DBG("Could not validate binary") } return Reply->number(); }, @@ -145,22 +141,21 @@ int32_t RemoteOffloadClient::getNumberOfDevices() { return remoteCall( - /* Preprocess */ + /* Preprocessor */ [&](Status &RPCStatus, ClientContext &Context) { auto *Request = protobuf::Arena::CreateMessage(Arena.get()); auto *Reply = protobuf::Arena::CreateMessage(Arena.get()); - CLIENT_DBG("Getting number of devices"); RPCStatus = Stub->GetNumberOfDevices(&Context, *Request, Reply); return Reply; }, - /* Postprocess */ + /* Postprocessor */ [&](const auto &Reply) { if (Reply->number()) { - CLIENT_DBG("Found %d devices", Reply->number()); + CLIENT_DBG("Found %d devices", Reply->number()) } else { - CLIENT_DBG("Could not get the number of devices"); + CLIENT_DBG("Could not get the number of devices") } return Reply->number(); }, @@ -169,24 +164,23 @@ int32_t RemoteOffloadClient::initDevice(int32_t DeviceId) { return remoteCall( - /* Preprocess */ + /* Preprocessor */ [&](auto &RPCStatus, auto &Context) { auto *Request = protobuf::Arena::CreateMessage(Arena.get()); auto *Reply = protobuf::Arena::CreateMessage(Arena.get()); Request->set_number(DeviceId); - CLIENT_DBG("Initializing device %d", DeviceId); RPCStatus = Stub->InitDevice(&Context, *Request, Reply); return Reply; }, - /* Postprocess */ + /* Postprocessor */ [&](const auto &Reply) { if (!Reply->number()) { - CLIENT_DBG("Initialized device %d", DeviceId); + CLIENT_DBG("Initialized device %d", DeviceId) } else { - CLIENT_DBG("Could not initialize device %d", DeviceId); + CLIENT_DBG("Could not initialize device %d", DeviceId) } return Reply->number(); }, @@ -195,21 +189,20 @@ int32_t RemoteOffloadClient::initRequires(int64_t RequiresFlags) { return remoteCall( - /* Preprocess */ + /* Preprocessor */ [&](auto &RPCStatus, auto &Context) { auto *Request = protobuf::Arena::CreateMessage(Arena.get()); auto *Reply = protobuf::Arena::CreateMessage(Arena.get()); Request->set_number(RequiresFlags); - CLIENT_DBG("Initializing requires"); RPCStatus = Stub->InitRequires(&Context, *Request, Reply); return Reply; }, - /* Postprocess */ + /* Postprocessor */ [&](const auto &Reply) { if (Reply->number()) { - CLIENT_DBG("Initialized requires"); + CLIENT_DBG("Initialized requires") } else { - CLIENT_DBG("Could not initialize requires"); + CLIENT_DBG("Could not initialize requires") } return Reply->number(); }, @@ -219,7 +212,7 @@ __tgt_target_table *RemoteOffloadClient::loadBinary(int32_t DeviceId, __tgt_device_image *Image) { return remoteCall( - /* Preprocess */ + /* Preprocessor */ [&](auto &RPCStatus, auto &Context) { auto *ImageMessage = protobuf::Arena::CreateMessage(Arena.get()); @@ -227,14 +220,13 @@ ImageMessage->set_image_ptr((uint64_t)Image->ImageStart); ImageMessage->set_device_id(DeviceId); - CLIENT_DBG("Loading Image %p to device %d", Image, DeviceId); RPCStatus = Stub->LoadBinary(&Context, *ImageMessage, Reply); return Reply; }, - /* Postprocess */ + /* Postprocessor */ [&](auto &Reply) { if (Reply->entries_size() == 0) { - CLIENT_DBG("Could not load image %p onto device %d", Image, DeviceId); + CLIENT_DBG("Could not load image %p onto device %d", Image, DeviceId) return (__tgt_target_table *)nullptr; } DevicesToTables[DeviceId] = std::make_unique<__tgt_target_table>(); @@ -242,46 +234,18 @@ RemoteEntries[DeviceId]); CLIENT_DBG("Loaded Image %p to device %d with %d entries", Image, - DeviceId, Reply->entries_size()); + DeviceId, Reply->entries_size()) return DevicesToTables[DeviceId].get(); }, /* Error Value */ (__tgt_target_table *)nullptr, - /* Timeout */ false); -} - -int64_t RemoteOffloadClient::synchronize(int32_t DeviceId, - __tgt_async_info *AsyncInfo) { - return remoteCall( - /* Preprocess */ - [&](auto &RPCStatus, auto &Context) { - auto *Reply = protobuf::Arena::CreateMessage(Arena.get()); - auto *Info = - protobuf::Arena::CreateMessage(Arena.get()); - - Info->set_device_id(DeviceId); - Info->set_queue_ptr((uint64_t)AsyncInfo); - - CLIENT_DBG("Synchronizing device %d", DeviceId); - RPCStatus = Stub->Synchronize(&Context, *Info, Reply); - return Reply; - }, - /* Postprocess */ - [&](auto &Reply) { - if (Reply->number()) { - CLIENT_DBG("Synchronized device %d", DeviceId); - } else { - CLIENT_DBG("Could not synchronize device %d", DeviceId); - } - return Reply->number(); - }, - /* Error Value */ -1); + /* CanTimeOut */ false); } int32_t RemoteOffloadClient::isDataExchangeable(int32_t SrcDevId, int32_t DstDevId) { return remoteCall( - /* Preprocess */ + /* Preprocessor */ [&](auto &RPCStatus, auto &Context) { auto *Request = protobuf::Arena::CreateMessage(Arena.get()); auto *Reply = protobuf::Arena::CreateMessage(Arena.get()); @@ -289,18 +253,16 @@ Request->set_src_dev_id(SrcDevId); Request->set_dst_dev_id(DstDevId); - CLIENT_DBG("Asking if data is exchangeable between %d, %d", SrcDevId, - DstDevId); RPCStatus = Stub->IsDataExchangeable(&Context, *Request, Reply); return Reply; }, - /* Postprocess */ + /* Postprocessor */ [&](auto &Reply) { if (Reply->number()) { - CLIENT_DBG("Data is exchangeable between %d, %d", SrcDevId, DstDevId); + CLIENT_DBG("Data is exchangeable between %d, %d", SrcDevId, DstDevId) } else { CLIENT_DBG("Data is not exchangeable between %d, %d", SrcDevId, - DstDevId); + DstDevId) } return Reply->number(); }, @@ -310,7 +272,7 @@ void *RemoteOffloadClient::dataAlloc(int32_t DeviceId, int64_t Size, void *HstPtr) { return remoteCall( - /* Preprocess */ + /* Preprocessor */ [&](auto &RPCStatus, auto &Context) { auto *Reply = protobuf::Arena::CreateMessage(Arena.get()); auto *Request = protobuf::Arena::CreateMessage(Arena.get()); @@ -319,40 +281,38 @@ Request->set_size(Size); Request->set_hst_ptr((uint64_t)HstPtr); - CLIENT_DBG("Allocating %ld bytes on device %d", Size, DeviceId); RPCStatus = Stub->DataAlloc(&Context, *Request, Reply); return Reply; }, - /* Postprocess */ + /* Postprocessor */ [&](auto &Reply) { if (Reply->number()) { CLIENT_DBG("Allocated %ld bytes on device %d at %p", Size, DeviceId, - (void *)Reply->number()); + (void *)Reply->number()) } else { CLIENT_DBG("Could not allocate %ld bytes on device %d at %p", Size, - DeviceId, (void *)Reply->number()); + DeviceId, (void *)Reply->number()) } return (void *)Reply->number(); }, /* Error Value */ (void *)nullptr); } -int32_t RemoteOffloadClient::dataSubmitAsync(int32_t DeviceId, void *TgtPtr, - void *HstPtr, int64_t Size, - __tgt_async_info *AsyncInfo) { +int32_t RemoteOffloadClient::dataSubmit(int32_t DeviceId, void *TgtPtr, + void *HstPtr, int64_t Size) { return remoteCall( - /* Preprocess */ + /* Preprocessor */ [&](auto &RPCStatus, auto &Context) { auto *Reply = protobuf::Arena::CreateMessage(Arena.get()); - std::unique_ptr> Writer( - Stub->DataSubmitAsync(&Context, Reply)); + std::unique_ptr> Writer( + Stub->DataSubmit(&Context, Reply)); if (Size > BlockSize) { int64_t Start = 0, End = BlockSize; for (auto I = 0; I < ceil((float)Size / BlockSize); I++) { auto *Request = - protobuf::Arena::CreateMessage(Arena.get()); + protobuf::Arena::CreateMessage(Arena.get()); Request->set_device_id(DeviceId); Request->set_data((char *)HstPtr + Start, End - Start); @@ -360,13 +320,9 @@ Request->set_tgt_ptr((uint64_t)TgtPtr); Request->set_start(Start); Request->set_size(Size); - Request->set_queue_ptr((uint64_t)AsyncInfo); - - CLIENT_DBG("Submitting %ld-%ld/%ld bytes async on device %d at %p", - Start, End, Size, DeviceId, TgtPtr) if (!Writer->Write(*Request)) { - CLIENT_DBG("Broken stream when submitting data"); + CLIENT_DBG("Broken stream when submitting data") Reply->set_number(0); return Reply; } @@ -378,7 +334,7 @@ } } else { auto *Request = - protobuf::Arena::CreateMessage(Arena.get()); + protobuf::Arena::CreateMessage(Arena.get()); Request->set_device_id(DeviceId); Request->set_data(HstPtr, Size); @@ -387,10 +343,8 @@ Request->set_start(0); Request->set_size(Size); - CLIENT_DBG("Submitting %ld bytes async on device %d at %p", Size, - DeviceId, TgtPtr) if (!Writer->Write(*Request)) { - CLIENT_DBG("Broken stream when submitting data"); + CLIENT_DBG("Broken stream when submitting data") Reply->set_number(0); return Reply; } @@ -401,11 +355,11 @@ return Reply; }, - /* Postprocess */ + /* Postprocessor */ [&](auto &Reply) { if (!Reply->number()) { - CLIENT_DBG("Async submitted %ld bytes on device %d at %p", Size, - DeviceId, TgtPtr) + CLIENT_DBG(" submitted %ld bytes on device %d at %p", Size, DeviceId, + TgtPtr) } else { CLIENT_DBG("Could not async submit %ld bytes on device %d at %p", Size, DeviceId, TgtPtr) @@ -413,27 +367,25 @@ return Reply->number(); }, /* Error Value */ -1, - /* Timeout */ false); + /* CanTimeOut */ false); } -int32_t RemoteOffloadClient::dataRetrieveAsync(int32_t DeviceId, void *HstPtr, - void *TgtPtr, int64_t Size, - __tgt_async_info *AsyncInfo) { +int32_t RemoteOffloadClient::dataRetrieve(int32_t DeviceId, void *HstPtr, + void *TgtPtr, int64_t Size) { return remoteCall( - /* Preprocess */ + /* Preprocessor */ [&](auto &RPCStatus, auto &Context) { auto *Request = - protobuf::Arena::CreateMessage(Arena.get()); + protobuf::Arena::CreateMessage(Arena.get()); Request->set_device_id(DeviceId); Request->set_size(Size); Request->set_hst_ptr((int64_t)HstPtr); Request->set_tgt_ptr((int64_t)TgtPtr); - Request->set_queue_ptr((uint64_t)AsyncInfo); auto *Reply = protobuf::Arena::CreateMessage(Arena.get()); std::unique_ptr> Reader( - Stub->DataRetrieveAsync(&Context, *Request)); + Stub->DataRetrieve(&Context, *Request)); Reader->WaitForInitialMetadata(); while (Reader->Read(Reply)) { if (Reply->ret()) { @@ -444,18 +396,10 @@ } if (Reply->start() == 0 && Reply->size() == Reply->data().size()) { - CLIENT_DBG("Async retrieving %ld bytes on device %d at %p for %p", - Size, DeviceId, TgtPtr, HstPtr) - memcpy(HstPtr, Reply->data().data(), Reply->data().size()); return Reply; } - CLIENT_DBG("Retrieving %lu-%lu/%lu bytes async from (%p) to (%p) " - "on Device %d", - Reply->start(), Reply->start() + Reply->data().size(), - Reply->size(), (void *)Request->tgt_ptr(), HstPtr, - Request->device_id()); memcpy((void *)((char *)HstPtr + Reply->start()), Reply->data().data(), Reply->data().size()); @@ -464,54 +408,49 @@ return Reply; }, - /* Postprocess */ + /* Postprocessor */ [&](auto &Reply) { if (!Reply->ret()) { - CLIENT_DBG("Async retrieve %ld bytes on Device %d", Size, DeviceId); + CLIENT_DBG("Retrieved %ld bytes on Device %d", Size, DeviceId) } else { CLIENT_DBG("Could not async retrieve %ld bytes on Device %d", Size, - DeviceId); + DeviceId) } return Reply->ret(); }, /* Error Value */ -1, - /* Timeout */ false); + /* CanTimeOut */ false); } -int32_t RemoteOffloadClient::dataExchangeAsync(int32_t SrcDevId, void *SrcPtr, - int32_t DstDevId, void *DstPtr, - int64_t Size, - __tgt_async_info *AsyncInfo) { +int32_t RemoteOffloadClient::dataExchange(int32_t SrcDevId, void *SrcPtr, + int32_t DstDevId, void *DstPtr, + int64_t Size) { return remoteCall( - /* Preprocess */ + /* Preprocessor */ [&](auto &RPCStatus, auto &Context) { auto *Reply = protobuf::Arena::CreateMessage(Arena.get()); auto *Request = - protobuf::Arena::CreateMessage(Arena.get()); + protobuf::Arena::CreateMessage(Arena.get()); Request->set_src_dev_id(SrcDevId); Request->set_src_ptr((uint64_t)SrcPtr); Request->set_dst_dev_id(DstDevId); Request->set_dst_ptr((uint64_t)DstPtr); Request->set_size(Size); - Request->set_queue_ptr((uint64_t)AsyncInfo); - CLIENT_DBG( - "Exchanging %ld bytes on device %d at %p for %p on device %d", Size, - SrcDevId, SrcPtr, DstPtr, DstDevId); - RPCStatus = Stub->DataExchangeAsync(&Context, *Request, Reply); + RPCStatus = Stub->DataExchange(&Context, *Request, Reply); return Reply; }, - /* Postprocess */ + /* Postprocessor */ [&](auto &Reply) { if (Reply->number()) { CLIENT_DBG( "Exchanged %ld bytes on device %d at %p for %p on device %d", - Size, SrcDevId, SrcPtr, DstPtr, DstDevId); + Size, SrcDevId, SrcPtr, DstPtr, DstDevId) } else { CLIENT_DBG("Could not exchange %ld bytes on device %d at %p for %p " "on device %d", - Size, SrcDevId, SrcPtr, DstPtr, DstDevId); + Size, SrcDevId, SrcPtr, DstPtr, DstDevId) } return Reply->number(); }, @@ -520,7 +459,7 @@ int32_t RemoteOffloadClient::dataDelete(int32_t DeviceId, void *TgtPtr) { return remoteCall( - /* Preprocess */ + /* Preprocessor */ [&](auto &RPCStatus, auto &Context) { auto *Reply = protobuf::Arena::CreateMessage(Arena.get()); auto *Request = protobuf::Arena::CreateMessage(Arena.get()); @@ -528,11 +467,10 @@ Request->set_device_id(DeviceId); Request->set_tgt_ptr((uint64_t)TgtPtr); - CLIENT_DBG("Deleting data at %p on device %d", TgtPtr, DeviceId) RPCStatus = Stub->DataDelete(&Context, *Request, Reply); return Reply; }, - /* Postprocess */ + /* Postprocessor */ [&](auto &Reply) { if (!Reply->number()) { CLIENT_DBG("Deleted data at %p on device %d", TgtPtr, DeviceId) @@ -545,18 +483,18 @@ /* Error Value */ -1); } -int32_t RemoteOffloadClient::runTargetRegionAsync( - int32_t DeviceId, void *TgtEntryPtr, void **TgtArgs, ptrdiff_t *TgtOffsets, - int32_t ArgNum, __tgt_async_info *AsyncInfo) { +int32_t RemoteOffloadClient::runTargetRegion(int32_t DeviceId, + void *TgtEntryPtr, void **TgtArgs, + ptrdiff_t *TgtOffsets, + int32_t ArgNum) { return remoteCall( - /* Preprocess */ + /* Preprocessor */ [&](auto &RPCStatus, auto &Context) { auto *Reply = protobuf::Arena::CreateMessage(Arena.get()); auto *Request = - protobuf::Arena::CreateMessage(Arena.get()); + protobuf::Arena::CreateMessage(Arena.get()); Request->set_device_id(DeviceId); - Request->set_queue_ptr((uint64_t)AsyncInfo); Request->set_tgt_entry_ptr( (uint64_t)RemoteEntries[DeviceId][TgtEntryPtr]); @@ -571,37 +509,34 @@ Request->set_arg_num(ArgNum); - CLIENT_DBG("Running target region async on device %d", DeviceId); - RPCStatus = Stub->RunTargetRegionAsync(&Context, *Request, Reply); + RPCStatus = Stub->RunTargetRegion(&Context, *Request, Reply); return Reply; }, - /* Postprocess */ + /* Postprocessor */ [&](auto &Reply) { if (!Reply->number()) { - CLIENT_DBG("Ran target region async on device %d", DeviceId); + CLIENT_DBG("Ran target region async on device %d", DeviceId) } else { - CLIENT_DBG("Could not run target region async on device %d", - DeviceId); + CLIENT_DBG("Could not run target region async on device %d", DeviceId) } return Reply->number(); }, /* Error Value */ -1, - /* Timeout */ false); + /* CanTimeOut */ false); } -int32_t RemoteOffloadClient::runTargetTeamRegionAsync( +int32_t RemoteOffloadClient::runTargetTeamRegion( int32_t DeviceId, void *TgtEntryPtr, void **TgtArgs, ptrdiff_t *TgtOffsets, int32_t ArgNum, int32_t TeamNum, int32_t ThreadLimit, - uint64_t LoopTripcount, __tgt_async_info *AsyncInfo) { + uint64_t LoopTripcount) { return remoteCall( - /* Preprocess */ + /* Preprocessor */ [&](auto &RPCStatus, auto &Context) { auto *Reply = protobuf::Arena::CreateMessage(Arena.get()); auto *Request = - protobuf::Arena::CreateMessage(Arena.get()); + protobuf::Arena::CreateMessage(Arena.get()); Request->set_device_id(DeviceId); - Request->set_queue_ptr((uint64_t)AsyncInfo); Request->set_tgt_entry_ptr( (uint64_t)RemoteEntries[DeviceId][TgtEntryPtr]); @@ -620,25 +555,23 @@ Request->set_thread_limit(ThreadLimit); Request->set_loop_tripcount(LoopTripcount); - CLIENT_DBG("Running target team region async on device %d", DeviceId); - RPCStatus = Stub->RunTargetTeamRegionAsync(&Context, *Request, Reply); + RPCStatus = Stub->RunTargetTeamRegion(&Context, *Request, Reply); return Reply; }, - /* Postprocess */ + /* Postprocessor */ [&](auto &Reply) { if (!Reply->number()) { - CLIENT_DBG("Ran target team region async on device %d", DeviceId); + CLIENT_DBG("Ran target team region async on device %d", DeviceId) } else { CLIENT_DBG("Could not run target team region async on device %d", - DeviceId); + DeviceId) } return Reply->number(); }, /* Error Value */ -1, - /* Timeout */ false); + /* CanTimeOut */ false); } -// TODO: Better error handling for the next three functions int32_t RemoteClientManager::shutdown(void) { int32_t Ret = 0; for (auto &Client : Clients) @@ -684,7 +617,7 @@ std::pair RemoteClientManager::mapDeviceId(int32_t DeviceId) { for (size_t ClientIdx = 0; ClientIdx < Devices.size(); ClientIdx++) { - if (!(DeviceId >= Devices[ClientIdx])) + if (DeviceId < Devices[ClientIdx]) return {ClientIdx, DeviceId}; DeviceId -= Devices[ClientIdx]; } @@ -711,13 +644,6 @@ return Clients[ClientIdx].loadBinary(DeviceIdx, Image); } -int64_t RemoteClientManager::synchronize(int32_t DeviceId, - __tgt_async_info *AsyncInfo) { - int32_t ClientIdx, DeviceIdx; - std::tie(ClientIdx, DeviceIdx) = mapDeviceId(DeviceId); - return Clients[ClientIdx].synchronize(DeviceIdx, AsyncInfo); -} - int32_t RemoteClientManager::isDataExchangeable(int32_t SrcDevId, int32_t DstDevId) { int32_t SrcClientIdx, SrcDeviceIdx, DstClientIdx, DstDeviceIdx; @@ -739,51 +665,47 @@ return Clients[ClientIdx].dataDelete(DeviceIdx, TgtPtr); } -int32_t RemoteClientManager::dataSubmitAsync(int32_t DeviceId, void *TgtPtr, - void *HstPtr, int64_t Size, - __tgt_async_info *AsyncInfo) { +int32_t RemoteClientManager::dataSubmit(int32_t DeviceId, void *TgtPtr, + void *HstPtr, int64_t Size) { int32_t ClientIdx, DeviceIdx; std::tie(ClientIdx, DeviceIdx) = mapDeviceId(DeviceId); - return Clients[ClientIdx].dataSubmitAsync(DeviceIdx, TgtPtr, HstPtr, Size, - AsyncInfo); + return Clients[ClientIdx].dataSubmit(DeviceIdx, TgtPtr, HstPtr, Size); } -int32_t RemoteClientManager::dataRetrieveAsync(int32_t DeviceId, void *HstPtr, - void *TgtPtr, int64_t Size, - __tgt_async_info *AsyncInfo) { +int32_t RemoteClientManager::dataRetrieve(int32_t DeviceId, void *HstPtr, + void *TgtPtr, int64_t Size) { int32_t ClientIdx, DeviceIdx; std::tie(ClientIdx, DeviceIdx) = mapDeviceId(DeviceId); - return Clients[ClientIdx].dataRetrieveAsync(DeviceIdx, HstPtr, TgtPtr, Size, - AsyncInfo); + return Clients[ClientIdx].dataRetrieve(DeviceIdx, HstPtr, TgtPtr, Size); } -int32_t RemoteClientManager::dataExchangeAsync(int32_t SrcDevId, void *SrcPtr, - int32_t DstDevId, void *DstPtr, - int64_t Size, - __tgt_async_info *AsyncInfo) { +int32_t RemoteClientManager::dataExchange(int32_t SrcDevId, void *SrcPtr, + int32_t DstDevId, void *DstPtr, + int64_t Size) { int32_t SrcClientIdx, SrcDeviceIdx, DstClientIdx, DstDeviceIdx; std::tie(SrcClientIdx, SrcDeviceIdx) = mapDeviceId(SrcDevId); std::tie(DstClientIdx, DstDeviceIdx) = mapDeviceId(DstDevId); - return Clients[SrcClientIdx].dataExchangeAsync( - SrcDeviceIdx, SrcPtr, DstDeviceIdx, DstPtr, Size, AsyncInfo); + return Clients[SrcClientIdx].dataExchange(SrcDeviceIdx, SrcPtr, DstDeviceIdx, + DstPtr, Size); } -int32_t RemoteClientManager::runTargetRegionAsync( - int32_t DeviceId, void *TgtEntryPtr, void **TgtArgs, ptrdiff_t *TgtOffsets, - int32_t ArgNum, __tgt_async_info *AsyncInfo) { +int32_t RemoteClientManager::runTargetRegion(int32_t DeviceId, + void *TgtEntryPtr, void **TgtArgs, + ptrdiff_t *TgtOffsets, + int32_t ArgNum) { int32_t ClientIdx, DeviceIdx; std::tie(ClientIdx, DeviceIdx) = mapDeviceId(DeviceId); - return Clients[ClientIdx].runTargetRegionAsync( - DeviceIdx, TgtEntryPtr, TgtArgs, TgtOffsets, ArgNum, AsyncInfo); + return Clients[ClientIdx].runTargetRegion(DeviceIdx, TgtEntryPtr, TgtArgs, + TgtOffsets, ArgNum); } -int32_t RemoteClientManager::runTargetTeamRegionAsync( +int32_t RemoteClientManager::runTargetTeamRegion( int32_t DeviceId, void *TgtEntryPtr, void **TgtArgs, ptrdiff_t *TgtOffsets, int32_t ArgNum, int32_t TeamNum, int32_t ThreadLimit, - uint64_t LoopTripCount, __tgt_async_info *AsyncInfo) { + uint64_t LoopTripCount) { int32_t ClientIdx, DeviceIdx; std::tie(ClientIdx, DeviceIdx) = mapDeviceId(DeviceId); - return Clients[ClientIdx].runTargetTeamRegionAsync( - DeviceIdx, TgtEntryPtr, TgtArgs, TgtOffsets, ArgNum, TeamNum, ThreadLimit, - LoopTripCount, AsyncInfo); + return Clients[ClientIdx].runTargetTeamRegion(DeviceIdx, TgtEntryPtr, TgtArgs, + TgtOffsets, ArgNum, TeamNum, + ThreadLimit, LoopTripCount); } diff --git a/openmp/libomptarget/plugins/remote/src/rtl.cpp b/openmp/libomptarget/plugins/remote/src/rtl.cpp --- a/openmp/libomptarget/plugins/remote/src/rtl.cpp +++ b/openmp/libomptarget/plugins/remote/src/rtl.cpp @@ -27,15 +27,7 @@ __attribute__((constructor(101))) void initRPC() { DP("Init RPC library!\n"); - RPCConfig Config; - parseEnvironment(Config); - - int Timeout = 5; - if (const char *Env1 = std::getenv("LIBOMPTARGET_RPC_LATENCY")) - Timeout = std::stoi(Env1); - - Manager = new RemoteClientManager(Config.ServerAddresses, Timeout, - Config.MaxSize, Config.BlockSize); + Manager = new RemoteClientManager(); } __attribute__((destructor(101))) void deinitRPC() { @@ -76,17 +68,13 @@ return Manager->loadBinary(DeviceId, (__tgt_device_image *)Image); } -int32_t __tgt_rtl_synchronize(int32_t DeviceId, __tgt_async_info *AsyncInfo) { - return Manager->synchronize(DeviceId, AsyncInfo); -} - int32_t __tgt_rtl_is_data_exchangable(int32_t SrcDevId, int32_t DstDevId) { return Manager->isDataExchangeable(SrcDevId, DstDevId); } void *__tgt_rtl_data_alloc(int32_t DeviceId, int64_t Size, void *HstPtr, - int32_t kind) { - if (kind != TARGET_ALLOC_DEFAULT) { + int32_t Kind) { + if (Kind != TARGET_ALLOC_DEFAULT) { REPORT("Invalid target data allocation kind or requested allocator not " "implemented yet\n"); return NULL; @@ -97,24 +85,12 @@ int32_t __tgt_rtl_data_submit(int32_t DeviceId, void *TgtPtr, void *HstPtr, int64_t Size) { - return Manager->dataSubmitAsync(DeviceId, TgtPtr, HstPtr, Size, nullptr); -} - -int32_t __tgt_rtl_data_submit_async(int32_t DeviceId, void *TgtPtr, - void *HstPtr, int64_t Size, - __tgt_async_info *AsyncInfo) { - return Manager->dataSubmitAsync(DeviceId, TgtPtr, HstPtr, Size, AsyncInfo); + return Manager->dataSubmit(DeviceId, TgtPtr, HstPtr, Size); } int32_t __tgt_rtl_data_retrieve(int32_t DeviceId, void *HstPtr, void *TgtPtr, int64_t Size) { - return Manager->dataRetrieveAsync(DeviceId, HstPtr, TgtPtr, Size, nullptr); -} - -int32_t __tgt_rtl_data_retrieve_async(int32_t DeviceId, void *HstPtr, - void *TgtPtr, int64_t Size, - __tgt_async_info *AsyncInfo) { - return Manager->dataRetrieveAsync(DeviceId, HstPtr, TgtPtr, Size, AsyncInfo); + return Manager->dataRetrieve(DeviceId, HstPtr, TgtPtr, Size); } int32_t __tgt_rtl_data_delete(int32_t DeviceId, void *TgtPtr) { @@ -123,31 +99,15 @@ int32_t __tgt_rtl_data_exchange(int32_t SrcDevId, void *SrcPtr, int32_t DstDevId, void *DstPtr, int64_t Size) { - return Manager->dataExchangeAsync(SrcDevId, SrcPtr, DstDevId, DstPtr, Size, - nullptr); + return Manager->dataExchange(SrcDevId, SrcPtr, DstDevId, DstPtr, Size); } -int32_t __tgt_rtl_data_exchange_async(int32_t SrcDevId, void *SrcPtr, - int32_t DstDevId, void *DstPtr, - int64_t Size, - __tgt_async_info *AsyncInfo) { - return Manager->dataExchangeAsync(SrcDevId, SrcPtr, DstDevId, DstPtr, Size, - AsyncInfo); -} int32_t __tgt_rtl_run_target_region(int32_t DeviceId, void *TgtEntryPtr, void **TgtArgs, ptrdiff_t *TgtOffsets, int32_t ArgNum) { - return Manager->runTargetRegionAsync(DeviceId, TgtEntryPtr, TgtArgs, - TgtOffsets, ArgNum, nullptr); -} - -int32_t __tgt_rtl_run_target_region_async(int32_t DeviceId, void *TgtEntryPtr, - void **TgtArgs, ptrdiff_t *TgtOffsets, - int32_t ArgNum, - __tgt_async_info *AsyncInfo) { - return Manager->runTargetRegionAsync(DeviceId, TgtEntryPtr, TgtArgs, - TgtOffsets, ArgNum, AsyncInfo); + return Manager->runTargetRegion(DeviceId, TgtEntryPtr, TgtArgs, TgtOffsets, + ArgNum); } int32_t __tgt_rtl_run_target_team_region(int32_t DeviceId, void *TgtEntryPtr, @@ -155,18 +115,9 @@ int32_t ArgNum, int32_t TeamNum, int32_t ThreadLimit, uint64_t LoopTripCount) { - return Manager->runTargetTeamRegionAsync(DeviceId, TgtEntryPtr, TgtArgs, - TgtOffsets, ArgNum, TeamNum, - ThreadLimit, LoopTripCount, nullptr); -} - -int32_t __tgt_rtl_run_target_team_region_async( - int32_t DeviceId, void *TgtEntryPtr, void **TgtArgs, ptrdiff_t *TgtOffsets, - int32_t ArgNum, int32_t TeamNum, int32_t ThreadLimit, - uint64_t LoopTripCount, __tgt_async_info *AsyncInfo) { - return Manager->runTargetTeamRegionAsync( - DeviceId, TgtEntryPtr, TgtArgs, TgtOffsets, ArgNum, TeamNum, ThreadLimit, - LoopTripCount, AsyncInfo); + return Manager->runTargetTeamRegion(DeviceId, TgtEntryPtr, TgtArgs, + TgtOffsets, ArgNum, TeamNum, ThreadLimit, + LoopTripCount); } // Exposed library API function