Skip to content

Commit d2609ab

Browse files
authored
gloo: add connection retries
Differential Revision: D70345714 Pull Request resolved: #413
1 parent 5ca057d commit d2609ab

File tree

13 files changed

+300
-28
lines changed

13 files changed

+300
-28
lines changed

gloo/common/logging.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <climits>
1212
#include <exception>
1313
#include <functional>
14+
#include <iostream>
1415
#include <limits>
1516
#include <vector>
1617

@@ -156,4 +157,7 @@ BINARY_COMP_HELPER(LessEquals, <=)
156157
#define GLOO_ENFORCE_GT(x, y, ...) \
157158
GLOO_ENFORCE_THAT_IMPL(Greater((x), (y)), #x " > " #y, __VA_ARGS__)
158159

160+
#define GLOO_ERROR(...) \
161+
std::cerr << "Gloo error: " << ::gloo::MakeString(__VA_ARGS__) << std::endl
162+
159163
} // namespace gloo

gloo/common/utils.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,14 @@ bool isStoreExtendedApiEnabled() {
4242
(std::string(res) == "True" || std::string(res) == "1");
4343
}
4444

45+
bool disableConnectionRetries() {
46+
// use meyer singleton to only compute this exactly once.
47+
static bool disable = []() {
48+
const auto& res = std::getenv("GLOO_DISABLE_CONNECTION_RETRIES");
49+
return res != nullptr &&
50+
(std::string(res) == "True" || std::string(res) == "1");
51+
}();
52+
return disable;
53+
}
54+
4555
} // namespace gloo

gloo/common/utils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,6 @@ bool useRankAsSeqNumber();
1818

1919
bool isStoreExtendedApiEnabled();
2020

21+
bool disableConnectionRetries();
22+
2123
} // namespace gloo

gloo/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ if(${CMAKE_SYSTEM_NAME} STREQUAL "Linux")
2424
"${CMAKE_CURRENT_SOURCE_DIR}/linux_test.cc"
2525
"${CMAKE_CURRENT_SOURCE_DIR}/multiproc_test.cc"
2626
"${CMAKE_CURRENT_SOURCE_DIR}/transport_test.cc"
27+
"${CMAKE_CURRENT_SOURCE_DIR}/tcp_test.cc"
2728
)
2829
list(APPEND GLOO_TEST_LIBRARIES rt)
2930
endif()

gloo/test/tcp_test.cc

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
#include <gtest/gtest.h>
2+
3+
#include <gloo/transport/tcp/helpers.h>
4+
#include <gloo/transport/tcp/loop.h>
5+
6+
namespace gloo {
7+
namespace transport {
8+
namespace tcp {
9+
10+
TEST(TcpTest, ConnectTimeout) {
11+
auto loop = std::make_shared<Loop>();
12+
13+
std::mutex m;
14+
std::condition_variable cv;
15+
bool done = false;
16+
17+
// Use bad address
18+
auto remote = Address("::1", 10);
19+
auto timeout = std::chrono::milliseconds(100);
20+
auto fn = [&](std::shared_ptr<Socket>, const Error& e) {
21+
std::lock_guard<std::mutex> lock(m);
22+
done = true;
23+
cv.notify_all();
24+
25+
EXPECT_TRUE(e);
26+
EXPECT_TRUE(dynamic_cast<const TimeoutError*>(&e));
27+
};
28+
connectLoop(loop, remote, timeout, std::move(fn));
29+
30+
std::unique_lock<std::mutex> lock(m);
31+
cv.wait(lock, [&] { return done; });
32+
}
33+
34+
} // namespace tcp
35+
} // namespace transport
36+
} // namespace gloo

gloo/transport/tcp/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ else()
77
"${CMAKE_CURRENT_SOURCE_DIR}/context.cc"
88
"${CMAKE_CURRENT_SOURCE_DIR}/device.cc"
99
"${CMAKE_CURRENT_SOURCE_DIR}/error.cc"
10+
"${CMAKE_CURRENT_SOURCE_DIR}/helpers.cc"
1011
"${CMAKE_CURRENT_SOURCE_DIR}/listener.cc"
1112
"${CMAKE_CURRENT_SOURCE_DIR}/loop.cc"
1213
"${CMAKE_CURRENT_SOURCE_DIR}/pair.cc"

gloo/transport/tcp/address.cc

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,29 @@ Address::Address(const struct sockaddr* addr, size_t addrlen) {
2828
memcpy(&impl_.ss, addr, addrlen);
2929
}
3030

31+
Address::Address(const std::string& ip, uint16_t port, sequence_number_t seq) {
32+
if (ip.empty()) {
33+
throw std::invalid_argument("Invalid IP address");
34+
}
35+
sockaddr_in* addr4 = reinterpret_cast<sockaddr_in*>(&impl_.ss);
36+
sockaddr_in6* addr6 = reinterpret_cast<sockaddr_in6*>(&impl_.ss);
37+
// Check if the IP address is an IPv4 or IPv6 address
38+
if (inet_pton(AF_INET, ip.c_str(), &addr4->sin_addr) == 1) {
39+
// IPv4 address
40+
addr4->sin_family = AF_INET;
41+
addr4->sin_port = htons(port);
42+
} else if (inet_pton(AF_INET6, ip.c_str(), &addr6->sin6_addr) == 1) {
43+
// IPv6 address
44+
addr6->sin6_family = AF_INET6;
45+
addr6->sin6_port = htons(port);
46+
} else {
47+
throw std::invalid_argument("Invalid IP address");
48+
}
49+
50+
// Store sequence number
51+
impl_.seq = seq;
52+
}
53+
3154
Address& Address::operator=(Address&& other) {
3255
std::lock_guard<std::mutex> lock(m_);
3356
impl_.ss = std::move(other.impl_.ss);

gloo/transport/tcp/address.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,14 @@
88

99
#pragma once
1010

11-
#include <sys/socket.h>
12-
#include <unistd.h>
1311
#include <mutex>
1412

13+
#ifdef _WIN32
14+
#include "gloo/common/win.h" // @manual
15+
#else
16+
#include <sys/socket.h>
17+
#endif
18+
1519
#include "gloo/transport/address.h"
1620

1721
namespace gloo {
@@ -32,6 +36,11 @@ class Address : public ::gloo::transport::Address {
3236

3337
explicit Address(const std::vector<char>&);
3438

39+
explicit Address(
40+
const std::string& ip,
41+
uint16_t port,
42+
sequence_number_t seq = -1);
43+
3544
Address(const Address& other);
3645

3746
Address& operator=(Address&& other);

gloo/transport/tcp/device.cc

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "gloo/common/error.h"
1818
#include "gloo/common/linux.h"
1919
#include "gloo/common/logging.h"
20+
#include "gloo/common/utils.h"
2021
#include "gloo/transport/tcp/context.h"
2122
#include "gloo/transport/tcp/helpers.h"
2223
#include "gloo/transport/tcp/pair.h"
@@ -334,20 +335,39 @@ void Device::connectAsListener(
334335
//
335336
void Device::connectAsInitiator(
336337
const Address& remote,
337-
std::chrono::milliseconds /* unused */,
338+
std::chrono::milliseconds timeout,
338339
connect_callback_t fn) {
339-
const auto& sockaddr = remote.getSockaddr();
340-
341-
// Create new socket to connect to peer.
342-
auto socket = Socket::createForFamily(sockaddr.ss_family);
343-
socket->reuseAddr(true);
344-
socket->noDelay(true);
345-
socket->connect(sockaddr);
346-
347-
// Write sequence number for peer to new socket.
348-
// TODO(pietern): Use timeout.
349-
write<sequence_number_t>(
350-
loop_, std::move(socket), remote.getSeq(), std::move(fn));
340+
auto writeSeq = [loop = loop_, seq = remote.getSeq()](
341+
std::shared_ptr<Socket> socket, connect_callback_t fn) {
342+
// Write sequence number for peer to new socket.
343+
write<sequence_number_t>(loop, std::move(socket), seq, std::move(fn));
344+
};
345+
346+
if (disableConnectionRetries()) {
347+
const auto& sockaddr = remote.getSockaddr();
348+
349+
// Create new socket to connect to peer.
350+
auto socket = Socket::createForFamily(sockaddr.ss_family);
351+
socket->reuseAddr(true);
352+
socket->noDelay(true);
353+
socket->connect(sockaddr);
354+
355+
writeSeq(std::move(socket), std::move(fn));
356+
} else {
357+
connectLoop(
358+
loop_,
359+
remote,
360+
timeout,
361+
[loop = loop_, fn = std::move(fn), writeSeq = std::move(writeSeq)](
362+
std::shared_ptr<Socket> socket, const Error& error) {
363+
if (error) {
364+
fn(socket, error);
365+
return;
366+
}
367+
368+
writeSeq(std::move(socket), std::move(fn));
369+
});
370+
}
351371
}
352372

353373
} // namespace tcp

gloo/transport/tcp/error.cc

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,24 +23,32 @@ std::string Error::what() const {
2323

2424
std::string SystemError::what() const {
2525
std::ostringstream ss;
26-
ss << syscall_ << ": " << strerror(error_);
26+
ss << syscall_ << ": " << strerror(error_) << ", remote=" << remote_.str();
2727
return ss.str();
2828
}
2929

3030
std::string ShortReadError::what() const {
3131
std::ostringstream ss;
3232
ss << "short read: got " << actual_ << " bytes while expecting to read "
33-
<< expected_ << " bytes";
33+
<< expected_ << " bytes, remote=" << remote_.str();
3434
return ss.str();
3535
}
3636

3737
std::string ShortWriteError::what() const {
3838
std::ostringstream ss;
3939
ss << "short write: wrote " << actual_ << " bytes while expecting to write "
40-
<< expected_ << " bytes";
40+
<< expected_ << " bytes, remote=" << remote_.str();
4141
return ss.str();
4242
}
4343

44+
std::string TimeoutError::what() const {
45+
return msg_;
46+
}
47+
48+
std::string LoopError::what() const {
49+
return msg_;
50+
}
51+
4452
} // namespace tcp
4553
} // namespace transport
4654
} // namespace gloo

0 commit comments

Comments
 (0)