Skip to content

Commit 6ca0612

Browse files
committed
librustls: add rustls_client_connection_new_alpn
This allows constructing a client `rustls_connection` with custom ALPN protocol support that differs from the base `rustls_client_config`.
1 parent 57443c2 commit 6ca0612

2 files changed

Lines changed: 156 additions & 14 deletions

File tree

librustls/src/client.rs

Lines changed: 120 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -749,6 +749,9 @@ impl rustls_client_config {
749749
/// point at a valid `rustls_connection`. The caller now owns the `rustls_connection`
750750
/// and must call `rustls_connection_free` when done with it.
751751
///
752+
/// Uses the `rustls_client_config` to determine ALPN protocol support. Prefer
753+
/// `rustls_client_connection_new_alpn` to customize this per-connection.
754+
///
752755
/// If this returns an error code, the memory pointed to by `conn_out` remains
753756
/// unchanged.
754757
///
@@ -784,6 +787,74 @@ impl rustls_client_config {
784787
rustls_result::Ok
785788
}
786789
}
790+
791+
/// Create a new client `rustls_connection` with custom ALPN protocols.
792+
///
793+
/// Operates the same as `rustls_client_connection_new`, but allows specifying
794+
/// custom per-connection ALPN protocols instead of inheriting ALPN protocols
795+
/// from the `rustls_clinet_config`.
796+
///
797+
/// If this returns `RUSTLS_RESULT_OK`, the memory pointed to by `conn_out` is modified to
798+
/// point at a valid `rustls_connection`. The caller now owns the `rustls_connection`
799+
/// and must call `rustls_connection_free` when done with it.
800+
///
801+
/// If this returns an error code, the memory pointed to by `conn_out` remains
802+
/// unchanged.
803+
///
804+
/// The `server_name` parameter can contain a hostname or an IP address in
805+
/// textual form (IPv4 or IPv6). This function will return an error if it
806+
/// cannot be parsed as one of those types.
807+
///
808+
/// `alpn_protocols` must point to a buffer of `rustls_slice_bytes` (built by the caller)
809+
/// with `alpn_protocols_len` elements. Each element of the buffer must be a `rustls_slice_bytes`
810+
/// whose data field points to a single ALPN protocol ID. This function makes a copy of the
811+
/// data in `alpn_protocols` and does not retain any pointers, so the caller can free the
812+
/// pointed-to memory after calling.
813+
///
814+
/// Standard ALPN protocol IDs are defined at
815+
/// <https://www.iana.org/assignments/tls-extensiontype-values/tls-extensiontype-values.xhtml#alpn-protocol-ids>.
816+
#[no_mangle]
817+
pub extern "C" fn rustls_client_connection_new_alpn(
818+
config: *const rustls_client_config,
819+
server_name: *const c_char,
820+
alpn_protocols: *const rustls_slice_bytes,
821+
alpn_protocols_len: size_t,
822+
conn_out: *mut *mut rustls_connection,
823+
) -> rustls_result {
824+
ffi_panic_boundary! {
825+
let server_name = unsafe {
826+
if server_name.is_null() {
827+
return rustls_result::NullParameter;
828+
}
829+
CStr::from_ptr(server_name)
830+
};
831+
let Ok(server_name) = server_name.to_str() else {
832+
return rustls_result::InvalidDnsNameError;
833+
};
834+
let Ok(server_name) = server_name.try_into() else {
835+
return rustls_result::InvalidDnsNameError;
836+
};
837+
838+
let raw_protocols = try_slice!(alpn_protocols, alpn_protocols_len);
839+
let mut alpn_protocols = Vec::with_capacity(raw_protocols.len());
840+
for p in raw_protocols {
841+
alpn_protocols.push(try_slice!(p.data, p.len).to_vec());
842+
}
843+
844+
set_boxed_mut_ptr(
845+
try_mut_from_ptr_ptr!(conn_out),
846+
Connection::from_client(
847+
ClientConnection::new_with_alpn(
848+
try_clone_arc!(config),
849+
server_name,
850+
alpn_protocols,
851+
)
852+
.unwrap(),
853+
),
854+
);
855+
rustls_result::Ok
856+
}
857+
}
787858
}
788859

789860
#[cfg(all(test, any(feature = "ring", feature = "aws-lc-rs")))]
@@ -830,20 +901,8 @@ mod tests {
830901
#[test]
831902
#[cfg_attr(miri, ignore)]
832903
fn test_client_connection_new() {
833-
let builder = rustls_client_config_builder::rustls_client_config_builder_new();
834-
let mut verifier = null_mut();
835-
let result =
836-
rustls_server_cert_verifier::rustls_platform_server_cert_verifier(&mut verifier);
837-
assert_eq!(result, rustls_result::Ok);
838-
assert!(!verifier.is_null());
839-
rustls_client_config_builder::rustls_client_config_builder_set_server_verifier(
840-
builder, verifier,
841-
);
842-
let mut config = null();
843-
let result =
844-
rustls_client_config_builder::rustls_client_config_builder_build(builder, &mut config);
845-
assert_eq!(result, rustls_result::Ok);
846-
assert!(!config.is_null());
904+
let (config, verifier) = test_config();
905+
847906
let mut conn = null_mut();
848907
let result = rustls_client_config::rustls_client_connection_new(
849908
config,
@@ -887,6 +946,53 @@ mod tests {
887946
rustls_server_cert_verifier::rustls_server_cert_verifier_free(verifier);
888947
}
889948

949+
// Build a client connection w/ custom ALPN and ensure no error occurs.
950+
#[test]
951+
#[cfg_attr(miri, ignore)]
952+
fn test_client_connection_new_alpn() {
953+
let (config, verifier) = test_config();
954+
let alpn_protocols = [
955+
rustls_slice_bytes::from(b"h2".as_ref()),
956+
rustls_slice_bytes::from(b"http/1.1".as_ref()),
957+
];
958+
959+
let mut conn = null_mut();
960+
let result = rustls_client_config::rustls_client_connection_new_alpn(
961+
config,
962+
"example.com\0".as_ptr() as *const c_char,
963+
alpn_protocols.as_ptr(),
964+
alpn_protocols.len() as size_t,
965+
&mut conn,
966+
);
967+
if !matches!(result, rustls_result::Ok) {
968+
panic!("expected RUSTLS_RESULT_OK, got {result:?}");
969+
}
970+
971+
rustls_connection::rustls_connection_free(conn);
972+
rustls_server_cert_verifier::rustls_server_cert_verifier_free(verifier);
973+
}
974+
975+
fn test_config() -> (
976+
*const rustls_client_config,
977+
*mut rustls_server_cert_verifier,
978+
) {
979+
let builder = rustls_client_config_builder::rustls_client_config_builder_new();
980+
let mut verifier = null_mut();
981+
let result =
982+
rustls_server_cert_verifier::rustls_platform_server_cert_verifier(&mut verifier);
983+
assert_eq!(result, rustls_result::Ok);
984+
assert!(!verifier.is_null());
985+
rustls_client_config_builder::rustls_client_config_builder_set_server_verifier(
986+
builder, verifier,
987+
);
988+
let mut config = null();
989+
let result =
990+
rustls_client_config_builder::rustls_client_config_builder_build(builder, &mut config);
991+
assert_eq!(result, rustls_result::Ok);
992+
assert!(!config.is_null());
993+
(config, verifier)
994+
}
995+
890996
#[test]
891997
#[cfg_attr(miri, ignore)]
892998
fn test_client_connection_new_ipaddress() {

librustls/src/rustls.h

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1586,6 +1586,9 @@ void rustls_client_config_free(const struct rustls_client_config *config);
15861586
* point at a valid `rustls_connection`. The caller now owns the `rustls_connection`
15871587
* and must call `rustls_connection_free` when done with it.
15881588
*
1589+
* Uses the `rustls_client_config` to determine ALPN protocol support. Prefer
1590+
* `rustls_client_connection_new_alpn` to customize this per-connection.
1591+
*
15891592
* If this returns an error code, the memory pointed to by `conn_out` remains
15901593
* unchanged.
15911594
*
@@ -1597,6 +1600,39 @@ rustls_result rustls_client_connection_new(const struct rustls_client_config *co
15971600
const char *server_name,
15981601
struct rustls_connection **conn_out);
15991602

1603+
/**
1604+
* Create a new client `rustls_connection` with custom ALPN protocols.
1605+
*
1606+
* Operates the same as `rustls_client_connection_new`, but allows specifying
1607+
* custom per-connection ALPN protocols instead of inheriting ALPN protocols
1608+
* from the `rustls_clinet_config`.
1609+
*
1610+
* If this returns `RUSTLS_RESULT_OK`, the memory pointed to by `conn_out` is modified to
1611+
* point at a valid `rustls_connection`. The caller now owns the `rustls_connection`
1612+
* and must call `rustls_connection_free` when done with it.
1613+
*
1614+
* If this returns an error code, the memory pointed to by `conn_out` remains
1615+
* unchanged.
1616+
*
1617+
* The `server_name` parameter can contain a hostname or an IP address in
1618+
* textual form (IPv4 or IPv6). This function will return an error if it
1619+
* cannot be parsed as one of those types.
1620+
*
1621+
* `alpn_protocols` must point to a buffer of `rustls_slice_bytes` (built by the caller)
1622+
* with `alpn_protocols_len` elements. Each element of the buffer must be a `rustls_slice_bytes`
1623+
* whose data field points to a single ALPN protocol ID. This function makes a copy of the
1624+
* data in `alpn_protocols` and does not retain any pointers, so the caller can free the
1625+
* pointed-to memory after calling.
1626+
*
1627+
* Standard ALPN protocol IDs are defined at
1628+
* <https://www.iana.org/assignments/tls-extensiontype-values/tls-extensiontype-values.xhtml#alpn-protocol-ids>.
1629+
*/
1630+
rustls_result rustls_client_connection_new_alpn(const struct rustls_client_config *config,
1631+
const char *server_name,
1632+
const struct rustls_slice_bytes *alpn_protocols,
1633+
size_t alpn_protocols_len,
1634+
struct rustls_connection **conn_out);
1635+
16001636
/**
16011637
* Set the userdata pointer associated with this connection. This will be passed
16021638
* to any callbacks invoked by the connection, if you've set up callbacks in the config.

0 commit comments

Comments
 (0)