Index: streamexecutor/lib/unittests/DeviceTest.cpp =================================================================== --- streamexecutor/lib/unittests/DeviceTest.cpp +++ streamexecutor/lib/unittests/DeviceTest.cpp @@ -29,14 +29,21 @@ class DeviceTest : public ::testing::Test { public: DeviceTest() - : HostA5{0, 1, 2, 3, 4}, HostB5{5, 6, 7, 8, 9}, + : Device(&PDevice), HostA5{0, 1, 2, 3, 4}, HostB5{5, 6, 7, 8, 9}, HostA7{10, 11, 12, 13, 14, 15, 16}, HostB7{17, 18, 19, 20, 21, 22, 23}, - DeviceA5(se::GlobalDeviceMemory::makeFromElementCount(HostA5, 5)), - DeviceB5(se::GlobalDeviceMemory::makeFromElementCount(HostB5, 5)), - DeviceA7(se::GlobalDeviceMemory::makeFromElementCount(HostA7, 7)), - DeviceB7(se::GlobalDeviceMemory::makeFromElementCount(HostB7, 7)), - Host5{24, 25, 26, 27, 28}, Host7{29, 30, 31, 32, 33, 34, 35}, - Device(&PDevice) {} + DeviceA5(getOrDie(Device.allocateDeviceMemory(5))), + DeviceB5(getOrDie(Device.allocateDeviceMemory(5))), + DeviceA7(getOrDie(Device.allocateDeviceMemory(7))), + DeviceB7(getOrDie(Device.allocateDeviceMemory(7))), + Host5{24, 25, 26, 27, 28}, Host7{29, 30, 31, 32, 33, 34, 35} { + se::dieIfError(Device.synchronousCopyH2D(HostA5, DeviceA5)); + se::dieIfError(Device.synchronousCopyH2D(HostB5, DeviceB5)); + se::dieIfError(Device.synchronousCopyH2D(HostA7, DeviceA7)); + se::dieIfError(Device.synchronousCopyH2D(HostB7, DeviceB7)); + } + + SimpleHostPlatformDevice PDevice; + se::Device Device; // Device memory is backed by host arrays. int HostA5[5]; @@ -51,9 +58,6 @@ // Host memory to be used as actual host memory. int Host5[5]; int Host7[7]; - - SimpleHostPlatformDevice PDevice; - se::Device Device; }; #define EXPECT_NO_ERROR(E) EXPECT_FALSE(static_cast(E)) @@ -186,12 +190,12 @@ TEST_F(DeviceTest, SyncCopyH2DToArrayRefByCount) { EXPECT_NO_ERROR(Device.synchronousCopyH2D(ArrayRef(Host5), DeviceA5, 5)); for (int I = 0; I < 5; ++I) { - EXPECT_EQ(HostA5[I], Host5[I]); + EXPECT_EQ(getDeviceValue(DeviceA5, I), Host5[I]); } EXPECT_NO_ERROR(Device.synchronousCopyH2D(ArrayRef(Host5), DeviceB5, 2)); for (int I = 0; I < 2; ++I) { - EXPECT_EQ(HostB5[I], Host5[I]); + EXPECT_EQ(getDeviceValue(DeviceB5, I), Host5[I]); } EXPECT_ERROR(Device.synchronousCopyH2D(ArrayRef(Host7), DeviceA5, 7)); @@ -204,7 +208,7 @@ TEST_F(DeviceTest, SyncCopyH2DToArrayRef) { EXPECT_NO_ERROR(Device.synchronousCopyH2D(ArrayRef(Host5), DeviceA5)); for (int I = 0; I < 5; ++I) { - EXPECT_EQ(HostA5[I], Host5[I]); + EXPECT_EQ(getDeviceValue(DeviceA5, I), Host5[I]); } EXPECT_ERROR(Device.synchronousCopyH2D(ArrayRef(Host5), DeviceA7)); @@ -215,7 +219,7 @@ TEST_F(DeviceTest, SyncCopyH2DToPointer) { EXPECT_NO_ERROR(Device.synchronousCopyH2D(Host5, DeviceA5, 5)); for (int I = 0; I < 5; ++I) { - EXPECT_EQ(HostA5[I], Host5[I]); + EXPECT_EQ(getDeviceValue(DeviceA5, I), Host5[I]); } EXPECT_ERROR(Device.synchronousCopyH2D(Host7, DeviceA5, 7)); @@ -225,13 +229,13 @@ EXPECT_NO_ERROR(Device.synchronousCopyH2D( ArrayRef(Host5 + 1, 4), DeviceA5.asSlice().drop_front(1), 4)); for (int I = 1; I < 5; ++I) { - EXPECT_EQ(HostA5[I], Host5[I]); + EXPECT_EQ(getDeviceValue(DeviceA5, I), Host5[I]); } EXPECT_NO_ERROR(Device.synchronousCopyH2D( ArrayRef(Host5), DeviceB5.asSlice().drop_back(1), 2)); for (int I = 0; I < 2; ++I) { - EXPECT_EQ(HostB5[I], Host5[I]); + EXPECT_EQ(getDeviceValue(DeviceB5, I), Host5[I]); } EXPECT_ERROR( @@ -248,7 +252,7 @@ EXPECT_NO_ERROR( Device.synchronousCopyH2D(ArrayRef(Host5), DeviceA5.asSlice())); for (int I = 0; I < 5; ++I) { - EXPECT_EQ(HostA5[I], Host5[I]); + EXPECT_EQ(getDeviceValue(DeviceA5, I), Host5[I]); } EXPECT_ERROR( @@ -261,7 +265,7 @@ TEST_F(DeviceTest, SyncCopyH2DSliceToPointer) { EXPECT_NO_ERROR(Device.synchronousCopyH2D(Host5, DeviceA5.asSlice(), 5)); for (int I = 0; I < 5; ++I) { - EXPECT_EQ(HostA5[I], Host5[I]); + EXPECT_EQ(getDeviceValue(DeviceA5, I), Host5[I]); } EXPECT_ERROR(Device.synchronousCopyH2D(Host7, DeviceA5.asSlice(), 7)); @@ -272,12 +276,12 @@ TEST_F(DeviceTest, SyncCopyD2DByCount) { EXPECT_NO_ERROR(Device.synchronousCopyD2D(DeviceA5, DeviceB5, 5)); for (int I = 0; I < 5; ++I) { - EXPECT_EQ(HostA5[I], HostB5[I]); + EXPECT_EQ(getDeviceValue(DeviceA5, I), getDeviceValue(DeviceB5, I)); } EXPECT_NO_ERROR(Device.synchronousCopyD2D(DeviceA7, DeviceB7, 2)); for (int I = 0; I < 2; ++I) { - EXPECT_EQ(HostA7[I], HostB7[I]); + EXPECT_EQ(getDeviceValue(DeviceA7, I), getDeviceValue(DeviceB7, I)); } EXPECT_ERROR(Device.synchronousCopyD2D(DeviceA5, DeviceB5, 7)); @@ -290,7 +294,7 @@ TEST_F(DeviceTest, SyncCopyD2D) { EXPECT_NO_ERROR(Device.synchronousCopyD2D(DeviceA5, DeviceB5)); for (int I = 0; I < 5; ++I) { - EXPECT_EQ(HostA5[I], HostB5[I]); + EXPECT_EQ(getDeviceValue(DeviceA5, I), getDeviceValue(DeviceB5, I)); } EXPECT_ERROR(Device.synchronousCopyD2D(DeviceA7, DeviceB5)); @@ -302,13 +306,13 @@ EXPECT_NO_ERROR( Device.synchronousCopyD2D(DeviceA5.asSlice().drop_front(1), DeviceB5, 4)); for (int I = 0; I < 4; ++I) { - EXPECT_EQ(HostA5[I + 1], HostB5[I]); + EXPECT_EQ(getDeviceValue(DeviceA5, I + 1), getDeviceValue(DeviceB5, I)); } EXPECT_NO_ERROR( Device.synchronousCopyD2D(DeviceA7.asSlice().drop_back(1), DeviceB7, 2)); for (int I = 0; I < 2; ++I) { - EXPECT_EQ(HostA7[I], HostB7[I]); + EXPECT_EQ(getDeviceValue(DeviceA7, I), getDeviceValue(DeviceB7, I)); } EXPECT_ERROR(Device.synchronousCopyD2D(DeviceA5.asSlice(), DeviceB5, 7)); @@ -322,7 +326,7 @@ EXPECT_NO_ERROR( Device.synchronousCopyD2D(DeviceA7.asSlice().drop_back(2), DeviceB5)); for (int I = 0; I < 5; ++I) { - EXPECT_EQ(HostA7[I], HostB5[I]); + EXPECT_EQ(getDeviceValue(DeviceA7, I), getDeviceValue(DeviceB5, I)); } EXPECT_ERROR( @@ -336,13 +340,13 @@ EXPECT_NO_ERROR( Device.synchronousCopyD2D(DeviceA5, DeviceB7.asSlice().drop_front(2), 5)); for (int I = 0; I < 5; ++I) { - EXPECT_EQ(HostA5[I], HostB7[I + 2]); + EXPECT_EQ(getDeviceValue(DeviceA5, I), getDeviceValue(DeviceB7, I + 2)); } EXPECT_NO_ERROR( Device.synchronousCopyD2D(DeviceA7, DeviceB7.asSlice().drop_back(3), 2)); for (int I = 0; I < 2; ++I) { - EXPECT_EQ(HostA7[I], HostB7[I]); + EXPECT_EQ(getDeviceValue(DeviceA7, I), getDeviceValue(DeviceB7, I)); } EXPECT_ERROR(Device.synchronousCopyD2D(DeviceA5, DeviceB5.asSlice(), 7)); @@ -356,7 +360,7 @@ EXPECT_NO_ERROR( Device.synchronousCopyD2D(DeviceA5, DeviceB7.asSlice().drop_back(2))); for (int I = 0; I < 5; ++I) { - EXPECT_EQ(HostA5[I], HostB7[I]); + EXPECT_EQ(getDeviceValue(DeviceA5, I), getDeviceValue(DeviceB7, I)); } EXPECT_ERROR(Device.synchronousCopyD2D(DeviceA7, DeviceB5.asSlice())); @@ -368,13 +372,13 @@ EXPECT_NO_ERROR( Device.synchronousCopyD2D(DeviceA5.asSlice(), DeviceB5.asSlice(), 5)); for (int I = 0; I < 5; ++I) { - EXPECT_EQ(HostA5[I], HostB5[I]); + EXPECT_EQ(getDeviceValue(DeviceA5, I), getDeviceValue(DeviceB5, I)); } EXPECT_NO_ERROR( Device.synchronousCopyD2D(DeviceA7.asSlice(), DeviceB7.asSlice(), 2)); for (int I = 0; I < 2; ++I) { - EXPECT_EQ(HostA7[I], HostB7[I]); + EXPECT_EQ(getDeviceValue(DeviceA7, I), getDeviceValue(DeviceB7, I)); } EXPECT_ERROR( @@ -391,7 +395,7 @@ EXPECT_NO_ERROR( Device.synchronousCopyD2D(DeviceA5.asSlice(), DeviceB5.asSlice())); for (int I = 0; I < 5; ++I) { - EXPECT_EQ(HostA5[I], HostB5[I]); + EXPECT_EQ(getDeviceValue(DeviceA5, I), getDeviceValue(DeviceB5, I)); } EXPECT_ERROR( Index: streamexecutor/lib/unittests/SimpleHostPlatformDevice.h =================================================================== --- streamexecutor/lib/unittests/SimpleHostPlatformDevice.h +++ streamexecutor/lib/unittests/SimpleHostPlatformDevice.h @@ -132,4 +132,12 @@ } }; +/// Gets the value at the given index from a GlobalDeviceMemory instance +/// created by SimpleHostPlatformDevice. +template +T getDeviceValue(const streamexecutor::GlobalDeviceMemory &Memory, + size_t Index) { + return static_cast(Memory.getHandle())[Index]; +} + #endif // STREAMEXECUTOR_LIB_UNITTESTS_SIMPLEHOSTPLATFORMDEVICE_H Index: streamexecutor/lib/unittests/StreamTest.cpp =================================================================== --- streamexecutor/lib/unittests/StreamTest.cpp +++ streamexecutor/lib/unittests/StreamTest.cpp @@ -31,32 +31,41 @@ class StreamTest : public ::testing::Test { public: StreamTest() - : HostA5{0, 1, 2, 3, 4}, HostB5{5, 6, 7, 8, 9}, + : Device(&PDevice), + Stream(llvm::make_unique(&PDevice)), + HostA5{0, 1, 2, 3, 4}, HostB5{5, 6, 7, 8, 9}, HostA7{10, 11, 12, 13, 14, 15, 16}, HostB7{17, 18, 19, 20, 21, 22, 23}, - DeviceA5(se::GlobalDeviceMemory::makeFromElementCount(HostA5, 5)), - DeviceB5(se::GlobalDeviceMemory::makeFromElementCount(HostB5, 5)), - DeviceA7(se::GlobalDeviceMemory::makeFromElementCount(HostA7, 7)), - DeviceB7(se::GlobalDeviceMemory::makeFromElementCount(HostB7, 7)), Host5{24, 25, 26, 27, 28}, Host7{29, 30, 31, 32, 33, 34, 35}, - Stream(llvm::make_unique(&PDevice)) {} + DeviceA5(getOrDie(Device.allocateDeviceMemory(5))), + DeviceB5(getOrDie(Device.allocateDeviceMemory(5))), + DeviceA7(getOrDie(Device.allocateDeviceMemory(7))), + DeviceB7(getOrDie(Device.allocateDeviceMemory(7))) { + se::dieIfError(Device.synchronousCopyH2D(HostA5, DeviceA5)); + se::dieIfError(Device.synchronousCopyH2D(HostB5, DeviceB5)); + se::dieIfError(Device.synchronousCopyH2D(HostA7, DeviceA7)); + se::dieIfError(Device.synchronousCopyH2D(HostB7, DeviceB7)); + } protected: - // Device memory is backed by host arrays. + SimpleHostPlatformDevice PDevice; + se::Device Device; + se::Stream Stream; + + // Device memory is matched by host arrays. int HostA5[5]; int HostB5[5]; int HostA7[7]; int HostB7[7]; - se::GlobalDeviceMemory DeviceA5; - se::GlobalDeviceMemory DeviceB5; - se::GlobalDeviceMemory DeviceA7; - se::GlobalDeviceMemory DeviceB7; // Host memory to be used as actual host memory. int Host5[5]; int Host7[7]; - SimpleHostPlatformDevice PDevice; - se::Stream Stream; + // Device memory. + se::GlobalDeviceMemory DeviceA5; + se::GlobalDeviceMemory DeviceB5; + se::GlobalDeviceMemory DeviceA7; + se::GlobalDeviceMemory DeviceB7; }; using llvm::ArrayRef; @@ -151,13 +160,13 @@ Stream.thenCopyH2D(ArrayRef(Host5), DeviceA5, 5); EXPECT_TRUE(Stream.isOK()); for (int I = 0; I < 5; ++I) { - EXPECT_EQ(HostA5[I], Host5[I]); + EXPECT_EQ(getDeviceValue(DeviceA5, I), Host5[I]); } Stream.thenCopyH2D(ArrayRef(Host5), DeviceB5, 2); EXPECT_TRUE(Stream.isOK()); for (int I = 0; I < 2; ++I) { - EXPECT_EQ(HostB5[I], Host5[I]); + EXPECT_EQ(getDeviceValue(DeviceB5, I), Host5[I]); } Stream.thenCopyH2D(ArrayRef(Host7), DeviceA5, 7); @@ -168,7 +177,7 @@ Stream.thenCopyH2D(ArrayRef(Host5), DeviceA5); EXPECT_TRUE(Stream.isOK()); for (int I = 0; I < 5; ++I) { - EXPECT_EQ(HostA5[I], Host5[I]); + EXPECT_EQ(getDeviceValue(DeviceA5, I), Host5[I]); } Stream.thenCopyH2D(ArrayRef(Host7), DeviceA5); @@ -179,7 +188,7 @@ Stream.thenCopyH2D(Host5, DeviceA5, 5); EXPECT_TRUE(Stream.isOK()); for (int I = 0; I < 5; ++I) { - EXPECT_EQ(HostA5[I], Host5[I]); + EXPECT_EQ(getDeviceValue(DeviceA5, I), Host5[I]); } Stream.thenCopyH2D(Host7, DeviceA5, 7); @@ -191,13 +200,13 @@ DeviceA5.asSlice().drop_front(1), 4); EXPECT_TRUE(Stream.isOK()); for (int I = 1; I < 5; ++I) { - EXPECT_EQ(HostA5[I], Host5[I]); + EXPECT_EQ(getDeviceValue(DeviceA5, I), Host5[I]); } Stream.thenCopyH2D(ArrayRef(Host5), DeviceB5.asSlice().drop_back(1), 2); EXPECT_TRUE(Stream.isOK()); for (int I = 0; I < 2; ++I) { - EXPECT_EQ(HostB5[I], Host5[I]); + EXPECT_EQ(getDeviceValue(DeviceB5, I), Host5[I]); } Stream.thenCopyH2D(ArrayRef(Host5), DeviceA5.asSlice(), 7); @@ -209,7 +218,7 @@ Stream.thenCopyH2D(ArrayRef(Host5), DeviceA5.asSlice()); EXPECT_TRUE(Stream.isOK()); for (int I = 0; I < 5; ++I) { - EXPECT_EQ(HostA5[I], Host5[I]); + EXPECT_EQ(getDeviceValue(DeviceA5, I), Host5[I]); } Stream.thenCopyH2D(ArrayRef(Host7), DeviceA5.asSlice()); @@ -220,7 +229,7 @@ Stream.thenCopyH2D(Host5, DeviceA5.asSlice(), 5); EXPECT_TRUE(Stream.isOK()); for (int I = 0; I < 5; ++I) { - EXPECT_EQ(HostA5[I], Host5[I]); + EXPECT_EQ(getDeviceValue(DeviceA5, I), Host5[I]); } Stream.thenCopyH2D(Host7, DeviceA5.asSlice(), 7); @@ -233,13 +242,13 @@ Stream.thenCopyD2D(DeviceA5, DeviceB5, 5); EXPECT_TRUE(Stream.isOK()); for (int I = 0; I < 5; ++I) { - EXPECT_EQ(HostA5[I], HostB5[I]); + EXPECT_EQ(getDeviceValue(DeviceA5, I), getDeviceValue(DeviceB5, I)); } Stream.thenCopyD2D(DeviceA7, DeviceB7, 2); EXPECT_TRUE(Stream.isOK()); for (int I = 0; I < 2; ++I) { - EXPECT_EQ(HostA7[I], HostB7[I]); + EXPECT_EQ(getDeviceValue(DeviceA7, I), getDeviceValue(DeviceB7, I)); } Stream.thenCopyD2D(DeviceA7, DeviceB5, 7); @@ -250,7 +259,7 @@ Stream.thenCopyD2D(DeviceA5, DeviceB5); EXPECT_TRUE(Stream.isOK()); for (int I = 0; I < 5; ++I) { - EXPECT_EQ(HostA5[I], HostB5[I]); + EXPECT_EQ(getDeviceValue(DeviceA5, I), getDeviceValue(DeviceB5, I)); } Stream.thenCopyD2D(DeviceA7, DeviceB5); @@ -261,13 +270,13 @@ Stream.thenCopyD2D(DeviceA5.asSlice().drop_front(1), DeviceB5, 4); EXPECT_TRUE(Stream.isOK()); for (int I = 0; I < 4; ++I) { - EXPECT_EQ(HostA5[I + 1], HostB5[I]); + EXPECT_EQ(getDeviceValue(DeviceA5, I + 1), getDeviceValue(DeviceB5, I)); } Stream.thenCopyD2D(DeviceA7.asSlice().drop_back(1), DeviceB7, 2); EXPECT_TRUE(Stream.isOK()); for (int I = 0; I < 2; ++I) { - EXPECT_EQ(HostA7[I], HostB7[I]); + EXPECT_EQ(getDeviceValue(DeviceA7, I), getDeviceValue(DeviceB7, I)); } Stream.thenCopyD2D(DeviceA5.asSlice(), DeviceB5, 7); @@ -279,7 +288,7 @@ Stream.thenCopyD2D(DeviceA7.asSlice().drop_back(2), DeviceB5); EXPECT_TRUE(Stream.isOK()); for (int I = 0; I < 5; ++I) { - EXPECT_EQ(HostA7[I], HostB5[I]); + EXPECT_EQ(getDeviceValue(DeviceA7, I), getDeviceValue(DeviceB5, I)); } Stream.thenCopyD2D(DeviceA5.asSlice().drop_back(1), DeviceB7); @@ -290,13 +299,13 @@ Stream.thenCopyD2D(DeviceA5, DeviceB7.asSlice().drop_front(2), 5); EXPECT_TRUE(Stream.isOK()); for (int I = 0; I < 5; ++I) { - EXPECT_EQ(HostA5[I], HostB7[I + 2]); + EXPECT_EQ(getDeviceValue(DeviceA5, I), getDeviceValue(DeviceB7, I + 2)); } Stream.thenCopyD2D(DeviceA7, DeviceB7.asSlice().drop_back(3), 2); EXPECT_TRUE(Stream.isOK()); for (int I = 0; I < 2; ++I) { - EXPECT_EQ(HostA7[I], HostB7[I]); + EXPECT_EQ(getDeviceValue(DeviceA7, I), getDeviceValue(DeviceB7, I)); } Stream.thenCopyD2D(DeviceA5, DeviceB7.asSlice(), 7); @@ -308,7 +317,7 @@ Stream.thenCopyD2D(DeviceA5, DeviceB7.asSlice().drop_back(2)); EXPECT_TRUE(Stream.isOK()); for (int I = 0; I < 5; ++I) { - EXPECT_EQ(HostA5[I], HostB7[I]); + EXPECT_EQ(getDeviceValue(DeviceA5, I), getDeviceValue(DeviceB7, I)); } Stream.thenCopyD2D(DeviceA5, DeviceB7.asSlice()); @@ -320,13 +329,13 @@ Stream.thenCopyD2D(DeviceA5.asSlice(), DeviceB5.asSlice(), 5); EXPECT_TRUE(Stream.isOK()); for (int I = 0; I < 5; ++I) { - EXPECT_EQ(HostA5[I], HostB5[I]); + EXPECT_EQ(getDeviceValue(DeviceA5, I), getDeviceValue(DeviceB5, I)); } Stream.thenCopyD2D(DeviceA7.asSlice(), DeviceB7.asSlice(), 2); EXPECT_TRUE(Stream.isOK()); for (int I = 0; I < 2; ++I) { - EXPECT_EQ(HostA7[I], HostB7[I]); + EXPECT_EQ(getDeviceValue(DeviceA7, I), getDeviceValue(DeviceB7, I)); } Stream.thenCopyD2D(DeviceA7.asSlice(), DeviceB5.asSlice(), 7); @@ -338,7 +347,7 @@ Stream.thenCopyD2D(DeviceA5.asSlice(), DeviceB5.asSlice()); EXPECT_TRUE(Stream.isOK()); for (int I = 0; I < 5; ++I) { - EXPECT_EQ(HostA5[I], HostB5[I]); + EXPECT_EQ(getDeviceValue(DeviceA5, I), getDeviceValue(DeviceB5, I)); } Stream.thenCopyD2D(DeviceA5.asSlice(), DeviceB7.asSlice());