Skip to content

Commit 239bb75

Browse files
committed
librustls: add rustls_client_connection_new_alpn
This allows constructing a client `rustls_connection` with custom ALPN protocol support that differs from the base `rustls_client_config`.
1 parent 4c64659 commit 239bb75

2 files changed

Lines changed: 175 additions & 33 deletions

File tree

librustls/src/client.rs

Lines changed: 139 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -749,6 +749,9 @@ impl rustls_client_config {
749749
/// point at a valid `rustls_connection`. The caller now owns the `rustls_connection`
750750
/// and must call `rustls_connection_free` when done with it.
751751
///
752+
/// Uses the `rustls_client_config` to determine ALPN protocol support. Prefer
753+
/// `rustls_client_connection_new_alpn` to customize this per-connection.
754+
///
752755
/// If this returns an error code, the memory pointed to by `conn_out` remains
753756
/// unchanged.
754757
///
@@ -762,28 +765,96 @@ impl rustls_client_config {
762765
conn_out: *mut *mut rustls_connection,
763766
) -> rustls_result {
764767
ffi_panic_boundary! {
765-
let server_name = unsafe {
766-
if server_name.is_null() {
767-
return rustls_result::NullParameter;
768-
}
769-
CStr::from_ptr(server_name)
770-
};
771-
let Ok(server_name) = server_name.to_str() else {
772-
return rustls_result::InvalidDnsNameError;
773-
};
774-
let Ok(server_name) = server_name.try_into() else {
775-
return rustls_result::InvalidDnsNameError;
776-
};
768+
Self::rustls_client_connection_new_alpn_inner(
769+
config,
770+
server_name,
771+
try_clone_arc!(config).alpn_protocols.clone(),
772+
conn_out,
773+
)
774+
}
775+
}
777776

778-
set_boxed_mut_ptr(
779-
try_mut_from_ptr_ptr!(conn_out),
780-
Connection::from_client(
781-
ClientConnection::new(try_clone_arc!(config), server_name).unwrap(),
782-
),
783-
);
784-
rustls_result::Ok
777+
/// Create a new client `rustls_connection` with custom ALPN protocols.
778+
///
779+
/// Operates the same as `rustls_client_connection_new`, but allows specifying
780+
/// custom per-connection ALPN protocols instead of inheriting ALPN protocols
781+
/// from the `rustls_clinet_config`.
782+
///
783+
/// If this returns `RUSTLS_RESULT_OK`, the memory pointed to by `conn_out` is modified to
784+
/// point at a valid `rustls_connection`. The caller now owns the `rustls_connection`
785+
/// and must call `rustls_connection_free` when done with it.
786+
///
787+
/// If this returns an error code, the memory pointed to by `conn_out` remains
788+
/// unchanged.
789+
///
790+
/// The `server_name` parameter can contain a hostname or an IP address in
791+
/// textual form (IPv4 or IPv6). This function will return an error if it
792+
/// cannot be parsed as one of those types.
793+
///
794+
/// `alpn_protocols` must point to a buffer of `rustls_slice_bytes` (built by the caller)
795+
/// with `alpn_protocols_len` elements. Each element of the buffer must be a `rustls_slice_bytes`
796+
/// whose data field points to a single ALPN protocol ID. This function makes a copy of the
797+
/// data in `alpn_protocols` and does not retain any pointers, so the caller can free the
798+
/// pointed-to memory after calling.
799+
///
800+
/// Standard ALPN protocol IDs are defined at
801+
/// <https://www.iana.org/assignments/tls-extensiontype-values/tls-extensiontype-values.xhtml#alpn-protocol-ids>.
802+
#[no_mangle]
803+
pub extern "C" fn rustls_client_connection_new_alpn(
804+
config: *const rustls_client_config,
805+
server_name: *const c_char,
806+
alpn_protocols: *const rustls_slice_bytes,
807+
alpn_protocols_len: size_t,
808+
conn_out: *mut *mut rustls_connection,
809+
) -> rustls_result {
810+
ffi_panic_boundary! {
811+
let raw_protocols = try_slice!(alpn_protocols, alpn_protocols_len);
812+
let mut alpn_protocols = Vec::with_capacity(raw_protocols.len());
813+
for p in raw_protocols {
814+
alpn_protocols.push(try_slice!(p.data, p.len).to_vec());
815+
}
816+
817+
Self::rustls_client_connection_new_alpn_inner(
818+
config,
819+
server_name,
820+
alpn_protocols,
821+
conn_out,
822+
)
785823
}
786824
}
825+
826+
fn rustls_client_connection_new_alpn_inner(
827+
config: *const rustls_client_config,
828+
server_name: *const c_char,
829+
alpn_protocols: Vec<Vec<u8>>,
830+
conn_out: *mut *mut rustls_connection,
831+
) -> rustls_result {
832+
let server_name = unsafe {
833+
if server_name.is_null() {
834+
return rustls_result::NullParameter;
835+
}
836+
CStr::from_ptr(server_name)
837+
};
838+
let Ok(server_name) = server_name.to_str() else {
839+
return rustls_result::InvalidDnsNameError;
840+
};
841+
let Ok(server_name) = server_name.try_into() else {
842+
return rustls_result::InvalidDnsNameError;
843+
};
844+
845+
set_boxed_mut_ptr(
846+
try_mut_from_ptr_ptr!(conn_out),
847+
Connection::from_client(
848+
ClientConnection::new_with_alpn(
849+
try_clone_arc!(config),
850+
server_name,
851+
alpn_protocols,
852+
)
853+
.unwrap(),
854+
),
855+
);
856+
rustls_result::Ok
857+
}
787858
}
788859

789860
#[cfg(all(test, any(feature = "ring", feature = "aws-lc-rs")))]
@@ -830,20 +901,8 @@ mod tests {
830901
#[test]
831902
#[cfg_attr(miri, ignore)]
832903
fn test_client_connection_new() {
833-
let builder = rustls_client_config_builder::rustls_client_config_builder_new();
834-
let mut verifier = null_mut();
835-
let result =
836-
rustls_server_cert_verifier::rustls_platform_server_cert_verifier(&mut verifier);
837-
assert_eq!(result, rustls_result::Ok);
838-
assert!(!verifier.is_null());
839-
rustls_client_config_builder::rustls_client_config_builder_set_server_verifier(
840-
builder, verifier,
841-
);
842-
let mut config = null();
843-
let result =
844-
rustls_client_config_builder::rustls_client_config_builder_build(builder, &mut config);
845-
assert_eq!(result, rustls_result::Ok);
846-
assert!(!config.is_null());
904+
let (config, verifier) = test_config();
905+
847906
let mut conn = null_mut();
848907
let result = rustls_client_config::rustls_client_connection_new(
849908
config,
@@ -887,6 +946,53 @@ mod tests {
887946
rustls_server_cert_verifier::rustls_server_cert_verifier_free(verifier);
888947
}
889948

949+
// Build a client connection w/ custom ALPN and ensure no error occurs.
950+
#[test]
951+
#[cfg_attr(miri, ignore)]
952+
fn test_client_connection_new_alpn() {
953+
let (config, verifier) = test_config();
954+
let alpn_protocols = [
955+
rustls_slice_bytes::from(b"h2".as_ref()),
956+
rustls_slice_bytes::from(b"http/1.1".as_ref()),
957+
];
958+
959+
let mut conn = null_mut();
960+
let result = rustls_client_config::rustls_client_connection_new_alpn(
961+
config,
962+
"example.com\0".as_ptr() as *const c_char,
963+
alpn_protocols.as_ptr(),
964+
alpn_protocols.len() as size_t,
965+
&mut conn,
966+
);
967+
if !matches!(result, rustls_result::Ok) {
968+
panic!("expected RUSTLS_RESULT_OK, got {result:?}");
969+
}
970+
971+
rustls_connection::rustls_connection_free(conn);
972+
rustls_server_cert_verifier::rustls_server_cert_verifier_free(verifier);
973+
}
974+
975+
fn test_config() -> (
976+
*const rustls_client_config,
977+
*mut rustls_server_cert_verifier,
978+
) {
979+
let builder = rustls_client_config_builder::rustls_client_config_builder_new();
980+
let mut verifier = null_mut();
981+
let result =
982+
rustls_server_cert_verifier::rustls_platform_server_cert_verifier(&mut verifier);
983+
assert_eq!(result, rustls_result::Ok);
984+
assert!(!verifier.is_null());
985+
rustls_client_config_builder::rustls_client_config_builder_set_server_verifier(
986+
builder, verifier,
987+
);
988+
let mut config = null();
989+
let result =
990+
rustls_client_config_builder::rustls_client_config_builder_build(builder, &mut config);
991+
assert_eq!(result, rustls_result::Ok);
992+
assert!(!config.is_null());
993+
(config, verifier)
994+
}
995+
890996
#[test]
891997
#[cfg_attr(miri, ignore)]
892998
fn test_client_connection_new_ipaddress() {

librustls/src/rustls.h

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1586,6 +1586,9 @@ void rustls_client_config_free(const struct rustls_client_config *config);
15861586
* point at a valid `rustls_connection`. The caller now owns the `rustls_connection`
15871587
* and must call `rustls_connection_free` when done with it.
15881588
*
1589+
* Uses the `rustls_client_config` to determine ALPN protocol support. Prefer
1590+
* `rustls_client_connection_new_alpn` to customize this per-connection.
1591+
*
15891592
* If this returns an error code, the memory pointed to by `conn_out` remains
15901593
* unchanged.
15911594
*
@@ -1597,6 +1600,39 @@ rustls_result rustls_client_connection_new(const struct rustls_client_config *co
15971600
const char *server_name,
15981601
struct rustls_connection **conn_out);
15991602

1603+
/**
1604+
* Create a new client `rustls_connection` with custom ALPN protocols.
1605+
*
1606+
* Operates the same as `rustls_client_connection_new`, but allows specifying
1607+
* custom per-connection ALPN protocols instead of inheriting ALPN protocols
1608+
* from the `rustls_clinet_config`.
1609+
*
1610+
* If this returns `RUSTLS_RESULT_OK`, the memory pointed to by `conn_out` is modified to
1611+
* point at a valid `rustls_connection`. The caller now owns the `rustls_connection`
1612+
* and must call `rustls_connection_free` when done with it.
1613+
*
1614+
* If this returns an error code, the memory pointed to by `conn_out` remains
1615+
* unchanged.
1616+
*
1617+
* The `server_name` parameter can contain a hostname or an IP address in
1618+
* textual form (IPv4 or IPv6). This function will return an error if it
1619+
* cannot be parsed as one of those types.
1620+
*
1621+
* `alpn_protocols` must point to a buffer of `rustls_slice_bytes` (built by the caller)
1622+
* with `alpn_protocols_len` elements. Each element of the buffer must be a `rustls_slice_bytes`
1623+
* whose data field points to a single ALPN protocol ID. This function makes a copy of the
1624+
* data in `alpn_protocols` and does not retain any pointers, so the caller can free the
1625+
* pointed-to memory after calling.
1626+
*
1627+
* Standard ALPN protocol IDs are defined at
1628+
* <https://www.iana.org/assignments/tls-extensiontype-values/tls-extensiontype-values.xhtml#alpn-protocol-ids>.
1629+
*/
1630+
rustls_result rustls_client_connection_new_alpn(const struct rustls_client_config *config,
1631+
const char *server_name,
1632+
const struct rustls_slice_bytes *alpn_protocols,
1633+
size_t alpn_protocols_len,
1634+
struct rustls_connection **conn_out);
1635+
16001636
/**
16011637
* Set the userdata pointer associated with this connection. This will be passed
16021638
* to any callbacks invoked by the connection, if you've set up callbacks in the config.

0 commit comments

Comments
 (0)