Skip to content

Commit 893f338

Browse files
author
Michael Penick
committed
CPP-793 Add SNI support to SocketConnector and SSL backend
1 parent 3f8136a commit 893f338

9 files changed

Lines changed: 149 additions & 78 deletions

File tree

cpp-driver/gtests/src/unit/mockssandra.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,13 @@ int ClientConnection::accept() {
228228
return uv_read_start(tcp_.as_stream(), on_alloc, on_read);
229229
}
230230

231+
const char* ClientConnection::sni_server_name() const {
232+
if (ssl_) {
233+
return SSL_get_servername(ssl_, TLSEXT_NAMETYPE_host_name);
234+
}
235+
return NULL;
236+
}
237+
231238
void ClientConnection::on_close(uv_handle_t* handle) {
232239
ClientConnection* connection = static_cast<ClientConnection*>(handle->data);
233240
connection->handle_close();

cpp-driver/gtests/src/unit/mockssandra.hpp

Lines changed: 26 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@ class ClientConnection {
103103
protected:
104104
int accept();
105105

106+
const char* sni_server_name() const;
107+
106108
private:
107109
static void on_close(uv_handle_t* handle);
108110
void handle_close();
@@ -144,6 +146,7 @@ class ClientConnection {
144146
class ClientConnectionFactory {
145147
public:
146148
virtual ClientConnection* create(ServerConnection* server) const = 0;
149+
virtual ~ClientConnectionFactory() {}
147150
};
148151

149152
class ServerConnectionTask : public RefCounted<ServerConnectionTask> {
@@ -1246,53 +1249,39 @@ class SimpleCluster : public Cluster {
12461249

12471250
class SimpleEchoServer {
12481251
public:
1249-
SimpleEchoServer(const Address& address = Address("127.0.0.1", 8888))
1250-
: event_loop_group_(1)
1251-
, server_(new internal::ServerConnection(address, factory_)) {}
1252+
SimpleEchoServer()
1253+
: factory_(new EchoClientConnectionFactory())
1254+
, event_loop_group_(1) {}
12521255

12531256
~SimpleEchoServer() { close(); }
12541257

12551258
void close() {
1256-
server_->close();
1257-
server_->wait_close();
1259+
if (server_) {
1260+
server_->close();
1261+
server_->wait_close();
1262+
}
12581263
}
12591264

12601265
String use_ssl(const String& cn = "") {
1261-
String key(Ssl::generate_key());
1262-
String cert(Ssl::generate_cert(key, cn));
1263-
if (!server_->use_ssl(key, cert)) {
1264-
return "";
1265-
}
1266-
return cert;
1266+
ssl_key_ = Ssl::generate_key();
1267+
ssl_cert_ = Ssl::generate_cert(ssl_key_, cn);
1268+
return ssl_cert_;
12671269
}
12681270

1269-
void use_close_immediately() { factory_.use_close_immediately(); }
1271+
void use_connection_factory(internal::ClientConnectionFactory* factory) {
1272+
factory_.reset(factory);
1273+
}
12701274

1271-
int listen() {
1275+
int listen(const Address& address = Address("127.0.0.1", 8888)) {
1276+
server_.reset(new internal::ServerConnection(address, *factory_));
1277+
if (!ssl_key_.empty() && !ssl_cert_.empty() && !server_->use_ssl(ssl_key_, ssl_cert_)) {
1278+
return -1;
1279+
}
12721280
server_->listen(&event_loop_group_);
12731281
return server_->wait_listen();
12741282
}
12751283

1276-
void reset(const Address& address) {
1277-
server_.reset(new internal::ServerConnection(address, factory_));
1278-
}
1279-
12801284
private:
1281-
class CloseConnection : public internal::ClientConnection {
1282-
public:
1283-
CloseConnection(internal::ServerConnection* server)
1284-
: internal::ClientConnection(server) {}
1285-
1286-
virtual int on_accept() {
1287-
int rc = accept();
1288-
if (rc != 0) {
1289-
return rc;
1290-
}
1291-
close();
1292-
return rc;
1293-
}
1294-
};
1295-
12961285
class EchoConnection : public internal::ClientConnection {
12971286
public:
12981287
EchoConnection(internal::ServerConnection* server)
@@ -1301,29 +1290,19 @@ class SimpleEchoServer {
13011290
virtual void on_read(const char* data, size_t len) { write(data, len); }
13021291
};
13031292

1304-
class ClientConnectionFactory : public internal::ClientConnectionFactory {
1293+
class EchoClientConnectionFactory : public internal::ClientConnectionFactory {
13051294
public:
1306-
ClientConnectionFactory()
1307-
: close_immediately_(false) {}
1308-
1309-
void use_close_immediately() { close_immediately_ = true; }
1310-
13111295
virtual internal::ClientConnection* create(internal::ServerConnection* server) const {
1312-
if (close_immediately_) {
1313-
return new CloseConnection(server);
1314-
} else {
1315-
return new EchoConnection(server);
1316-
}
1296+
return new EchoConnection(server);
13171297
}
1318-
1319-
private:
1320-
bool close_immediately_;
13211298
};
13221299

13231300
private:
1324-
ClientConnectionFactory factory_;
1301+
ScopedPtr<internal::ClientConnectionFactory> factory_;
13251302
SimpleEventLoopGroup event_loop_group_;
13261303
internal::ServerConnection::Ptr server_;
1304+
String ssl_key_;
1305+
String ssl_cert_;
13271306
};
13281307

13291308
} // namespace mockssandra

cpp-driver/gtests/src/unit/tests/test_socket.cpp

Lines changed: 85 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,54 @@
2424
#define SSL_VERIFY_PEER_DNS_ABSOLUTE_HOSTNAME SSL_VERIFY_PEER_DNS_RELATIVE_HOSTNAME "."
2525
#define SSL_VERIFY_PEER_DNS_IP_ADDRESS "127.254.254.254"
2626

27+
using mockssandra::internal::ClientConnection;
28+
using mockssandra::internal::ClientConnectionFactory;
29+
using mockssandra::internal::ServerConnection;
30+
31+
class CloseConnection : public ClientConnection {
32+
public:
33+
CloseConnection(ServerConnection* server)
34+
: ClientConnection(server) {}
35+
36+
virtual int on_accept() {
37+
int rc = accept();
38+
if (rc != 0) {
39+
return rc;
40+
}
41+
close();
42+
return rc;
43+
}
44+
};
45+
46+
class CloseConnectionFactory : public ClientConnectionFactory {
47+
public:
48+
virtual ClientConnection* create(ServerConnection* server) const {
49+
return new CloseConnection(server);
50+
}
51+
};
52+
53+
class SniServerNameConnection : public ClientConnection {
54+
public:
55+
SniServerNameConnection(ServerConnection* server)
56+
: ClientConnection(server) {}
57+
58+
virtual void on_read(const char* data, size_t len) {
59+
const char* server_name = sni_server_name();
60+
if (server_name) {
61+
write(String(server_name) + " - Closed");
62+
} else {
63+
write("<unknown> - Closed");
64+
}
65+
}
66+
};
67+
68+
class SniServerNameConnectionFactory : public ClientConnectionFactory {
69+
public:
70+
virtual ClientConnection* create(ServerConnection* server) const {
71+
return new SniServerNameConnection(server);
72+
}
73+
};
74+
2775
using namespace datastax;
2876
using namespace datastax::internal;
2977
using namespace datastax::internal::core;
@@ -88,13 +136,16 @@ class SocketUnitTest : public LoopTest {
88136
return settings;
89137
}
90138

91-
void listen() { ASSERT_EQ(server_.listen(), 0); }
92-
93-
void reset(const Address& address) { server_.reset(address); }
139+
void listen(const Address& address = Address("127.0.0.1", 8888)) {
140+
ASSERT_EQ(server_.listen(address), 0);
141+
}
94142

95143
void close() { server_.close(); }
96144

97-
void use_close_immediately() { server_.use_close_immediately(); }
145+
void use_close_immediately() { server_.use_connection_factory(new CloseConnectionFactory()); }
146+
void use_sni_server_name() {
147+
server_.use_connection_factory(new SniServerNameConnectionFactory());
148+
}
98149

99150
virtual void TearDown() {
100151
LoopTest::TearDown();
@@ -167,10 +218,10 @@ TEST_F(SocketUnitTest, Simple) {
167218
}
168219

169220
TEST_F(SocketUnitTest, Ssl) {
170-
listen();
171-
172221
SocketSettings settings(use_ssl());
173222

223+
listen();
224+
174225
String result;
175226
SocketConnector::Ptr connector(
176227
new SocketConnector(Address("127.0.0.1", 8888), bind_callback(on_socket_connected, &result)));
@@ -182,6 +233,24 @@ TEST_F(SocketUnitTest, Ssl) {
182233
EXPECT_EQ(result, "The socket is successfully connected and wrote data - Closed");
183234
}
184235

236+
TEST_F(SocketUnitTest, SslSniServerName) {
237+
SocketSettings settings(use_ssl());
238+
239+
use_sni_server_name();
240+
listen();
241+
242+
String result;
243+
SocketConnector::Ptr connector(
244+
new SocketConnector(Address("127.0.0.1", 8888, "TestSniServerName"),
245+
bind_callback(on_socket_connected, &result)));
246+
247+
connector->with_settings(settings)->connect(loop());
248+
249+
uv_run(loop(), UV_RUN_DEFAULT);
250+
251+
EXPECT_EQ(result, "TestSniServerName - Closed");
252+
}
253+
185254
TEST_F(SocketUnitTest, Refused) {
186255
bool is_refused = false;
187256
SocketConnector::Ptr connector(new SocketConnector(
@@ -194,11 +263,11 @@ TEST_F(SocketUnitTest, Refused) {
194263
}
195264

196265
TEST_F(SocketUnitTest, SslClose) {
266+
SocketSettings settings(use_ssl());
267+
197268
use_close_immediately();
198269
listen();
199270

200-
SocketSettings settings(use_ssl());
201-
202271
Vector<SocketConnector::Ptr> connectors;
203272

204273
bool is_closed = false;
@@ -241,10 +310,10 @@ TEST_F(SocketUnitTest, Cancel) {
241310
}
242311

243312
TEST_F(SocketUnitTest, SslCancel) {
244-
listen();
245-
246313
SocketSettings settings(use_ssl());
247314

315+
listen();
316+
248317
Vector<SocketConnector::Ptr> connectors;
249318

250319
bool is_canceled = false;
@@ -268,9 +337,10 @@ TEST_F(SocketUnitTest, SslCancel) {
268337
}
269338

270339
TEST_F(SocketUnitTest, SslVerifyIdentity) {
340+
SocketSettings settings(use_ssl("127.0.0.1"));
341+
271342
listen();
272343

273-
SocketSettings settings(use_ssl("127.0.0.1"));
274344
settings.ssl_context->set_verify_flags(CASS_SSL_VERIFY_PEER_IDENTITY);
275345

276346
String result;
@@ -295,11 +365,11 @@ TEST_F(SocketUnitTest, SslVerifyIdentityDns) {
295365
return;
296366
}
297367

298-
reset(Address(SSL_VERIFY_PEER_DNS_IP_ADDRESS,
299-
8888)); // Ensure the echo server is listening on the correct address
300-
listen();
301-
302368
SocketSettings settings(use_ssl(SSL_VERIFY_PEER_DNS_RELATIVE_HOSTNAME));
369+
370+
listen(Address(SSL_VERIFY_PEER_DNS_IP_ADDRESS,
371+
8888)); // Ensure the echo server is listening on the correct address
372+
303373
settings.ssl_context->set_verify_flags(CASS_SSL_VERIFY_PEER_IDENTITY_DNS);
304374

305375
String result;

cpp-driver/src/socket_connector.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,8 @@ void SocketConnector::internal_connect(uv_loop_t* loop) {
159159
}
160160

161161
if (settings_.ssl_context) {
162-
ssl_session_.reset(settings_.ssl_context->create_session(address_, hostname_));
162+
ssl_session_.reset(
163+
settings_.ssl_context->create_session(address_, hostname_, address_.server_name()));
163164
}
164165

165166
connector_.reset(new TcpConnector(address_));

cpp-driver/src/ssl.hpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,11 @@ namespace datastax { namespace internal { namespace core {
3232

3333
class SslSession : public Allocated {
3434
public:
35-
SslSession(const Address& address, const String& hostname, int flags)
35+
SslSession(const Address& address, const String& hostname, const String& sni_server_name,
36+
int flags)
3637
: address_(address)
3738
, hostname_(hostname)
39+
, sni_server_name_(sni_server_name)
3840
, verify_flags_(flags)
3941
, error_code_(CASS_OK) {}
4042

@@ -59,6 +61,7 @@ class SslSession : public Allocated {
5961
protected:
6062
Address address_;
6163
String hostname_;
64+
String sni_server_name_;
6265
int verify_flags_;
6366
rb::RingBuffer incoming_;
6467
rb::RingBuffer outgoing_;
@@ -78,7 +81,8 @@ class SslContext : public RefCounted<SslContext> {
7881
void set_verify_flags(int flags) { verify_flags_ = flags; }
7982
bool is_cert_validation_enabled() { return verify_flags_ != CASS_SSL_VERIFY_NONE; }
8083

81-
virtual SslSession* create_session(const Address& address, const String& hostname) = 0;
84+
virtual SslSession* create_session(const Address& address, const String& hostname,
85+
const String& sni_server_name) = 0;
8286
virtual CassError add_trusted_cert(const char* cert, size_t cert_length) = 0;
8387
virtual CassError set_cert(const char* cert, size_t cert_length) = 0;
8488
virtual CassError set_private_key(const char* key, size_t key_length, const char* password,

cpp-driver/src/ssl/ssl_no_impl.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,16 @@
1919
using namespace datastax;
2020
using namespace datastax::internal::core;
2121

22-
NoSslSession::NoSslSession(const Address& address, const String& hostname)
23-
: SslSession(address, hostname, CASS_SSL_VERIFY_NONE) {
22+
NoSslSession::NoSslSession(const Address& address, const String& hostname,
23+
const String& sni_server_name)
24+
: SslSession(address, hostname, sni_server_name, CASS_SSL_VERIFY_NONE) {
2425
error_code_ = CASS_ERROR_LIB_NOT_IMPLEMENTED;
2526
error_message_ = "SSL support not built into driver";
2627
}
2728

28-
SslSession* NoSslContext::create_session(const Address& address, const String& hostname) {
29-
return new NoSslSession(address, hostname);
29+
SslSession* NoSslContext::create_session(const Address& address, const String& hostname,
30+
const String& sni_server_name) {
31+
return new NoSslSession(address, hostname, sni_server_name);
3032
}
3133

3234
CassError NoSslContext::add_trusted_cert(const char* cert, size_t cert_length) {

cpp-driver/src/ssl/ssl_no_impl.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ namespace datastax { namespace internal { namespace core {
2121

2222
class NoSslSession : public SslSession {
2323
public:
24-
NoSslSession(const Address& address, const String& hostname);
24+
NoSslSession(const Address& address, const String& hostname, const String& sni_server_name);
2525

2626
virtual bool is_handshake_done() const { return false; }
2727
virtual void do_handshake() {}

0 commit comments

Comments
 (0)