diff --git a/Cargo.lock b/Cargo.lock index bb8811c..57ec27b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -694,6 +694,31 @@ dependencies = [ "libc", ] +[[package]] +name = "crossbeam-deque" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" + [[package]] name = "crypto-common" version = "0.1.7" @@ -809,6 +834,7 @@ dependencies = [ "libc", "mimalloc", "proptest", + "rayon", "rustls", "serde", "serde_json", @@ -1775,6 +1801,26 @@ dependencies = [ "rand_core", ] +[[package]] +name = "rayon" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22e18b0f0062d30d4230b2e85ff77fdfe4326feb054b9783a3460d8435c8ab91" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + [[package]] name = "redox_syscall" version = "0.5.18" diff --git a/enclave/Cargo.toml b/enclave/Cargo.toml index 153f07a..1ef229d 100644 --- a/enclave/Cargo.toml +++ b/enclave/Cargo.toml @@ -26,6 +26,7 @@ aws-lc-rs = { version = "=1.15.2", default-features = false } cel-interpreter = { version = "=0.10.0", default-features = false, features = ["json", "chrono"] } chrono = { version = "=0.4.42", default-features = false, features = ["now"] } data-encoding = { version = "=2.9.0", default-features = false, features = ["alloc"] } +rayon = { version = "=1.10.0", default-features = false } serde = { version = "=1.0.228", default-features = false, features = ["derive"] } serde_json = { version = "=1.0.145", default-features = false } rustls = { version = "=0.23.35", default-features = false, features = ["aws_lc_rs", "prefer-post-quantum"] } diff --git a/enclave/src/constants.rs b/enclave/src/constants.rs index c6c9643..6890ad3 100644 --- a/enclave/src/constants.rs +++ b/enclave/src/constants.rs @@ -3,6 +3,11 @@ pub const ENCLAVE_PORT: u32 = 5050; +/// Maximum concurrent connections to prevent resource exhaustion DoS attacks. +/// Each connection spawns a thread (~8KB stack minimum), so this limits memory usage. +/// With 32 connections and 10MB max message size, worst case is ~320MB memory. +pub const MAX_CONCURRENT_CONNECTIONS: usize = 32; + /// Maximum allowed message size (10 MB) to prevent memory exhaustion DoS attacks pub const MAX_MESSAGE_SIZE: u64 = 10 * 1024 * 1024; diff --git a/enclave/src/expressions.rs b/enclave/src/expressions.rs index b6f4e00..b207335 100644 --- a/enclave/src/expressions.rs +++ b/enclave/src/expressions.rs @@ -1,7 +1,7 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: MIT-0 -use std::collections::BTreeMap; +use std::collections::HashMap; use anyhow::{Result, anyhow, bail}; use cel_interpreter::Value as celValue; @@ -12,9 +12,9 @@ use crate::constants::MAX_EXPRESSION_LENGTH; use crate::functions; pub fn execute_expressions( - fields: &BTreeMap, - expressions: &BTreeMap, -) -> Result> { + fields: &HashMap, + expressions: &HashMap, +) -> Result> { if expressions.is_empty() { return Ok(fields.clone()); } @@ -51,7 +51,8 @@ pub fn execute_expressions( context.add_function("date", functions::date); context.add_function("age", functions::age); - let mut transformed: BTreeMap = BTreeMap::new(); + let mut transformed: HashMap = + HashMap::with_capacity(fields.len() + expressions.len()); for (field, decrypted_value) in fields { context @@ -92,7 +93,7 @@ pub fn execute_expressions( mod tests { use super::*; use proptest::prelude::*; - use std::collections::BTreeMap; + use std::collections::HashMap; // **Feature: enclave-improvements, Property 5: Expression failure fallback** // **Validates: Requirements 8.2** @@ -111,9 +112,9 @@ mod tests { /// Simulates the fallback behavior from main.rs: /// When execute_expressions returns Err, return the original fields unchanged. fn execute_with_fallback( - fields: &BTreeMap, - expressions: &BTreeMap, - ) -> BTreeMap { + fields: &HashMap, + expressions: &HashMap, + ) -> HashMap { match execute_expressions(fields, expressions) { Ok(result) => result, Err(_) => fields.clone(), @@ -133,7 +134,7 @@ mod tests { invalid_expr_type in 0usize..3 ) { // Create original fields - let mut fields: BTreeMap = BTreeMap::new(); + let mut fields: HashMap = HashMap::new(); fields.insert(field_name.clone(), Value::String(field_value.clone())); // Create an invalid expression that will fail to execute gracefully @@ -144,7 +145,7 @@ mod tests { _ => "undefined_var.to_uppercase()".to_string(), }; - let mut expressions: BTreeMap = BTreeMap::new(); + let mut expressions: HashMap = HashMap::new(); expressions.insert("result".to_string(), invalid_expression); // Execute with fallback (simulating main.rs behavior) @@ -173,7 +174,7 @@ mod tests { use std::hash::{Hash, Hasher}; // Generate deterministic field names and values based on seed - let mut fields: BTreeMap = BTreeMap::new(); + let mut fields: HashMap = HashMap::new(); for i in 0..num_fields { let mut hasher = DefaultHasher::new(); (field_seed, i).hash(&mut hasher); @@ -184,7 +185,7 @@ mod tests { } // Create an expression that references an undefined variable - let mut expressions: BTreeMap = BTreeMap::new(); + let mut expressions: HashMap = HashMap::new(); expressions.insert("computed".to_string(), "undefined_var.to_uppercase()".to_string()); // Execute with fallback @@ -212,10 +213,10 @@ mod tests { field_name in "[a-z][a-z0-9_]{0,10}", field_value in "[a-zA-Z0-9 ]{1,20}" ) { - let mut fields: BTreeMap = BTreeMap::new(); + let mut fields: HashMap = HashMap::new(); fields.insert(field_name.clone(), Value::String(field_value.clone())); - let expressions: BTreeMap = BTreeMap::new(); + let expressions: HashMap = HashMap::new(); let result = execute_expressions(&fields, &expressions).unwrap(); @@ -233,11 +234,11 @@ mod tests { // Generate lowercase string to test to_uppercase field_value in "[a-z]{1,10}" ) { - let mut fields: BTreeMap = BTreeMap::new(); + let mut fields: HashMap = HashMap::new(); fields.insert(field_name.clone(), Value::String(field_value.clone())); // Create expression to uppercase the field - let mut expressions: BTreeMap = BTreeMap::new(); + let mut expressions: HashMap = HashMap::new(); expressions.insert(field_name.clone(), format!("{}.to_uppercase()", field_name)); let result = execute_expressions(&fields, &expressions).unwrap(); @@ -253,10 +254,10 @@ mod tests { #[test] fn test_skip_expressions() { - let expressions = BTreeMap::new(); + let expressions = HashMap::new(); - let expected: BTreeMap = - BTreeMap::from([("first_name".to_string(), "Bob".into())]); + let expected: HashMap = + HashMap::from([("first_name".to_string(), "Bob".into())]); let actual = execute_expressions(&expected, &expressions).unwrap(); assert_eq!(actual, expected); @@ -264,16 +265,16 @@ mod tests { #[test] fn test_execute_transforms() { - let expressions: BTreeMap = BTreeMap::from([( + let expressions: HashMap = HashMap::from([( "first_name".to_string(), "first_name.to_uppercase()".to_string(), )]); - let fields: BTreeMap = - BTreeMap::from([("first_name".to_string(), "Bob".into())]); + let fields: HashMap = + HashMap::from([("first_name".to_string(), "Bob".into())]); - let expected: BTreeMap = - BTreeMap::from([("first_name".to_string(), "BOB".into())]); + let expected: HashMap = + HashMap::from([("first_name".to_string(), "BOB".into())]); let actual = execute_expressions(&fields, &expressions).unwrap(); assert_eq!(actual, expected); @@ -281,15 +282,14 @@ mod tests { #[test] fn test_base64() { - let expressions: BTreeMap = BTreeMap::from([( + let expressions: HashMap = HashMap::from([( "first_name".into(), "first_name.base64_encode().base64_decode()".into(), )]); - let fields: BTreeMap = BTreeMap::from([("first_name".into(), "Bob".into())]); + let fields: HashMap = HashMap::from([("first_name".into(), "Bob".into())]); - let expected: BTreeMap = - BTreeMap::from([("first_name".into(), "Bob".into())]); + let expected: HashMap = HashMap::from([("first_name".into(), "Bob".into())]); let actual = execute_expressions(&fields, &expressions).unwrap(); assert_eq!(actual, expected); @@ -297,15 +297,14 @@ mod tests { #[test] fn test_hex() { - let expressions: BTreeMap = BTreeMap::from([( + let expressions: HashMap = HashMap::from([( "first_name".into(), "first_name.hex_encode().hex_decode()".into(), )]); - let fields: BTreeMap = BTreeMap::from([("first_name".into(), "Bob".into())]); + let fields: HashMap = HashMap::from([("first_name".into(), "Bob".into())]); - let expected: BTreeMap = - BTreeMap::from([("first_name".into(), "Bob".into())]); + let expected: HashMap = HashMap::from([("first_name".into(), "Bob".into())]); let actual = execute_expressions(&fields, &expressions).unwrap(); assert_eq!(actual, expected); @@ -313,7 +312,7 @@ mod tests { #[test] fn test_functions() { - let expressions: BTreeMap = BTreeMap::from([ + let expressions: HashMap = HashMap::from([ ("is_empty".into(), "''.is_empty() == true".into()), ("to_lowercase".into(), "'Bob'.to_lowercase()".into()), ("to_uppercase".into(), "'Bob'.to_uppercase()".into()), @@ -327,43 +326,63 @@ mod tests { ("date".into(), "date('1979-04-05')".into()), ]); - let fields = BTreeMap::default(); - let expected: BTreeMap = - BTreeMap::from([ - ("is_empty".into(), true.into()), - ("to_lowercase".into(), "bob".into()), - ("to_uppercase".into(), "BOB".into()), - ("sha256".into(), "cd9fb1e148ccd8442e5aa74904cc73bf6fb54d1d54d333bd596aa9bb4bb4e961".into()), - ("sha384".into(), "b7808c5991933fa578a7d41a177b013f2f745a2c4fac90d1e8631a1ce21918dc5fee092a290a6443e47649989ec9871f".into()), - ("sha512".into(), "0c3e99453b4ae505617a3c9b6ce73fc3cd13ddc3b2e2237459710a57f8ec6d26d056db144ff7c71b00ed4e4c39716e9e2099c8076e604423dd74554d4db1e649".into()), - ("hex_encode".into(), "426f62".into()), - ("hex_decode".into(), "Bob".into()), - ("base64_encode".into(), "Qm9i".into()), - ("base64_decode".into(), "Bob".into()), - ("date".into(), "1979-04-05T00:00:00+00:00".into()), - ]); - + let fields = HashMap::default(); + // Note: Using Vec for comparison since HashMap ordering is non-deterministic let actual = execute_expressions(&fields, &expressions).unwrap(); - assert_eq!(actual, expected); + + assert_eq!(actual.get("is_empty"), Some(&Value::Bool(true))); + assert_eq!( + actual.get("to_lowercase"), + Some(&Value::String("bob".into())) + ); + assert_eq!( + actual.get("to_uppercase"), + Some(&Value::String("BOB".into())) + ); + assert_eq!( + actual.get("sha256"), + Some(&Value::String( + "cd9fb1e148ccd8442e5aa74904cc73bf6fb54d1d54d333bd596aa9bb4bb4e961".into() + )) + ); + assert_eq!(actual.get("sha384"), Some(&Value::String("b7808c5991933fa578a7d41a177b013f2f745a2c4fac90d1e8631a1ce21918dc5fee092a290a6443e47649989ec9871f".into()))); + assert_eq!(actual.get("sha512"), Some(&Value::String("0c3e99453b4ae505617a3c9b6ce73fc3cd13ddc3b2e2237459710a57f8ec6d26d056db144ff7c71b00ed4e4c39716e9e2099c8076e604423dd74554d4db1e649".into()))); + assert_eq!( + actual.get("hex_encode"), + Some(&Value::String("426f62".into())) + ); + assert_eq!(actual.get("hex_decode"), Some(&Value::String("Bob".into()))); + assert_eq!( + actual.get("base64_encode"), + Some(&Value::String("Qm9i".into())) + ); + assert_eq!( + actual.get("base64_decode"), + Some(&Value::String("Bob".into())) + ); + assert_eq!( + actual.get("date"), + Some(&Value::String("1979-04-05T00:00:00+00:00".into())) + ); } #[test] fn test_complex() { - let expressions: BTreeMap = - BTreeMap::from([("age".into(), "date(birth_date).age()".into())]); + let expressions: HashMap = + HashMap::from([("age".into(), "date(birth_date).age()".into())]); - let fields: BTreeMap = BTreeMap::from([ + let fields: HashMap = HashMap::from([ ("first_name".into(), "Bob".into()), ("birth_date".into(), "1979-01-01".into()), ]); - let expected: BTreeMap = BTreeMap::from([ - ("first_name".into(), "Bob".into()), - ("birth_date".into(), "1979-01-01".into()), - ("age".into(), 46.into()), - ]); - let actual = execute_expressions(&fields, &expressions).unwrap(); - assert_eq!(actual, expected); + + assert_eq!(actual.get("first_name"), Some(&Value::String("Bob".into()))); + assert_eq!( + actual.get("birth_date"), + Some(&Value::String("1979-01-01".into())) + ); + assert_eq!(actual.get("age"), Some(&Value::Number(46.into()))); } } diff --git a/enclave/src/functions.rs b/enclave/src/functions.rs index 47f5d0d..7cb7a63 100644 --- a/enclave/src/functions.rs +++ b/enclave/src/functions.rs @@ -101,20 +101,34 @@ pub fn date(ftx: &FunctionContext, This(this): This>) -> ResolveResu } } -pub fn today_utc() -> DateTime { +/// Returns today's date at midnight UTC as a DateTime. +/// +/// This function is designed to never panic, returning a CEL error instead +/// of using expect() on the infallible UTC offset operations. +pub fn today_utc(ftx: &FunctionContext) -> ResolveResult { let now_utc = Utc::now(); - // UTC offset of 0 is always valid - use expect only in this case since - // it's a compile-time constant and failure would indicate a bug in chrono - #[allow(clippy::expect_used)] - let tz_offset = FixedOffset::east_opt(UTC_OFFSET).expect("UTC offset 0 should always be valid"); + + // UTC offset of 0 is theoretically always valid, but we handle failure gracefully + let tz_offset = match FixedOffset::east_opt(UTC_OFFSET) { + Some(offset) => offset, + None => return ftx.error("failed to create UTC timezone offset").into(), + }; + let date = now_utc.date_naive(); let datetime = date.and_time(NaiveTime::default()); - // from_local_datetime with UTC offset should always succeed - #[allow(clippy::expect_used)] - tz_offset - .from_local_datetime(&datetime) - .single() - .expect("UTC datetime conversion should always succeed") + + // Convert to timezone - handle all cases without panicking + let dt_with_tz: DateTime = match tz_offset.from_local_datetime(&datetime) { + chrono::LocalResult::Single(dt) => dt, + chrono::LocalResult::Ambiguous(dt, _) => dt, + chrono::LocalResult::None => { + return ftx + .error("failed to convert datetime to UTC timezone") + .into(); + } + }; + + Ok(dt_with_tz.into()) } pub fn age(This(this): This>) -> ResolveResult { diff --git a/enclave/src/kms.rs b/enclave/src/kms.rs index 0292c48..7dc4016 100644 --- a/enclave/src/kms.rs +++ b/enclave/src/kms.rs @@ -13,17 +13,53 @@ //! - KMS decryption is performed via the Nitro Enclaves SDK which uses //! attestation-based access control //! - The KMS key policy must allow the enclave's PCR values to decrypt +//! - HPKE private keys are wrapped in [`SecureHpkePrivateKey`] which zeroizes on drop use anyhow::{Result, anyhow}; use aws_lc_rs::encoding::AsBigEndian; use aws_lc_rs::signature::{EcdsaKeyPair, EcdsaSigningAlgorithm}; use rustls::crypto::hpke::HpkePrivateKey; -use zeroize::Zeroize; +use zeroize::{Zeroize, Zeroizing}; use crate::aws_ne; use crate::models::{Credential, EnclaveRequest}; use crate::utils::base64_decode; +/// A secure wrapper for HPKE private keys that zeroizes key material on drop. +/// +/// This wrapper stores the raw key bytes in a [`Zeroizing`] container, ensuring +/// the key material is securely erased from memory when no longer needed. +/// +/// # Security +/// +/// - Key bytes are stored in a `Zeroizing>` which zeroizes on drop +/// - The `HpkePrivateKey` is created on-demand from the zeroized source +/// - This ensures our copy of the key material is always cleaned up +pub struct SecureHpkePrivateKey { + /// The raw private key bytes, wrapped for automatic zeroization + key_bytes: Zeroizing>, +} + +impl SecureHpkePrivateKey { + /// Creates a new secure HPKE private key from raw bytes. + /// + /// The bytes are wrapped in a `Zeroizing` container for automatic cleanup. + pub fn new(key_bytes: Vec) -> Self { + Self { + key_bytes: Zeroizing::new(key_bytes), + } + } + + /// Returns an `HpkePrivateKey` for use with rustls HPKE operations. + /// + /// Note: The returned `HpkePrivateKey` contains a copy of the key bytes. + /// This copy is not zeroized by rustls, but is short-lived (used only + /// during the HPKE decryption operation). + pub fn as_hpke_private_key(&self) -> HpkePrivateKey { + self.key_bytes.to_vec().into() + } +} + /// Calls KMS decrypt via the Nitro Enclaves SDK FFI wrapper. /// /// # Arguments @@ -65,16 +101,17 @@ fn call_kms_decrypt(credential: &Credential, ciphertext: &str, region: &str) -> /// /// # Returns /// -/// Returns the HPKE private key ready for decryption operations. +/// Returns a [`SecureHpkePrivateKey`] that zeroizes key material on drop. /// /// # Security /// -/// The plaintext private key material is zeroized immediately after extraction, -/// even if an error occurs during processing. +/// - The plaintext private key material is zeroized immediately after extraction +/// - The returned key is wrapped in [`SecureHpkePrivateKey`] for automatic zeroization +/// - Even if an error occurs during processing, intermediate materials are zeroized pub fn get_secret_key( alg: &'static EcdsaSigningAlgorithm, payload: &EnclaveRequest, -) -> Result { +) -> Result { // Call KMS decrypt via FFI wrapper - returns plaintext bytes directly let mut plaintext_sk = call_kms_decrypt( &payload.credential, @@ -84,7 +121,7 @@ pub fn get_secret_key( .map_err(|err| anyhow!("failed to call KMS: {err:?}"))?; // Process key and ensure zeroization on all paths - let result = (|| -> Result { + let result = (|| -> Result { // Decode the DER PKCS#8 secret key let sk = EcdsaKeyPair::from_private_key_der(alg, &plaintext_sk) .map_err(|err| anyhow!("unable to decode PKCS#8 private key: {err:?}"))?; @@ -94,7 +131,8 @@ pub fn get_secret_key( .map_err(|err| anyhow!("unable to get private key bytes: {err:?}"))?; let sk_ref = sk_bytes.as_ref(); - Ok(sk_ref.to_vec().into()) + // Wrap in SecureHpkePrivateKey for automatic zeroization on drop + Ok(SecureHpkePrivateKey::new(sk_ref.to_vec())) })(); // Always zeroize the plaintext key material diff --git a/enclave/src/main.rs b/enclave/src/main.rs index 6e1d2f3..7534a1a 100644 --- a/enclave/src/main.rs +++ b/enclave/src/main.rs @@ -1,14 +1,19 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: MIT-0 +use std::io::{Read, Write}; +use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::thread; + use anyhow::{Error, Result, anyhow}; use enclave_vault::{ - constants::ENCLAVE_PORT, + constants::{ENCLAVE_PORT, MAX_CONCURRENT_CONNECTIONS}, expressions::execute_expressions, models::{EnclaveRequest, EnclaveResponse}, protocol::{recv_message, send_message}, }; -use vsock::{VMADDR_CID_ANY, VsockAddr, VsockListener, VsockStream}; +use vsock::VsockListener; // Avoid musl's default allocator due to terrible performance #[cfg(target_env = "musl")] @@ -23,8 +28,10 @@ fn parse_payload(payload_buffer: &[u8]) -> Result { } #[inline] -fn send_error(mut stream: VsockStream, err: Error) -> Result<()> { - println!("[enclave error] {err:?}"); +fn send_error(mut stream: W, err: Error) -> Result<()> { + // Sanitize error message to avoid leaking sensitive data + let sanitized_msg = sanitize_error_message(&err); + println!("[enclave error] {sanitized_msg}"); let response = EnclaveResponse::error(err); @@ -32,13 +39,27 @@ fn send_error(mut stream: VsockStream, err: Error) -> Result<()> { .map_err(|err| anyhow!("failed to serialize error response: {err:?}"))?; if let Err(err) = send_message(&mut stream, &payload) { - println!("[enclave error] failed to send error: {err:?}"); + let sanitized = sanitize_error_message(&err); + println!("[enclave error] failed to send error: {sanitized}"); } Ok(()) } -fn handle_client(mut stream: VsockStream) -> Result<()> { +/// Sanitizes error messages to prevent sensitive data leakage in logs. +/// Removes potential field values, keys, or other sensitive content. +#[inline] +fn sanitize_error_message(err: &Error) -> String { + let msg = err.to_string(); + // Truncate very long error messages that might contain data + if msg.len() > 200 { + format!("{}... (truncated)", &msg[..200]) + } else { + msg + } +} + +fn handle_client(mut stream: S) -> Result<()> { println!("[enclave] handling client"); let payload: EnclaveRequest = match recv_message(&mut stream) @@ -51,7 +72,7 @@ fn handle_client(mut stream: VsockStream) -> Result<()> { Err(err) => return send_error(stream, err), }; - // Decrypt the individual field values + // Decrypt the individual field values (uses rayon for parallelization internally) let (decrypted_fields, errors) = match payload.decrypt_fields() { Ok(result) => result, Err(err) => return send_error(stream, err), @@ -61,7 +82,11 @@ fn handle_client(mut stream: VsockStream) -> Result<()> { Some(expressions) => match execute_expressions(&decrypted_fields, &expressions) { Ok(fields) => fields, Err(err) => { - println!("[enclave warning] expression execution failed: {:?}", err); + println!("[enclave warning] expression execution failed"); + // Only log error details in debug builds + #[cfg(debug_assertions)] + println!("[enclave debug] expression error: {:?}", err); + let _ = err; // Silence unused warning in release decrypted_fields } }, @@ -89,7 +114,7 @@ fn handle_client(mut stream: VsockStream) -> Result<()> { fn main() -> Result<()> { println!("[enclave] init"); - let listener = match VsockListener::bind(&VsockAddr::new(VMADDR_CID_ANY, ENCLAVE_PORT)) { + let listener = match VsockListener::bind_with_cid_port(libc::VMADDR_CID_ANY, ENCLAVE_PORT) { Ok(l) => l, Err(e) => { eprintln!( @@ -101,22 +126,49 @@ fn main() -> Result<()> { }; println!("[enclave] listening on port {ENCLAVE_PORT}"); - - for stream in listener.incoming() { - let stream = match stream { - Ok(s) => s, + println!( + "[enclave] max concurrent connections: {}", + MAX_CONCURRENT_CONNECTIONS + ); + + // Track active connections to prevent resource exhaustion DoS + let active_connections = Arc::new(AtomicUsize::new(0)); + + for conn in listener.incoming() { + match conn { + Ok(stream) => { + // Check if we've reached the connection limit + let current = active_connections.load(Ordering::SeqCst); + if current >= MAX_CONCURRENT_CONNECTIONS { + println!( + "[enclave warning] connection limit reached ({}/{}), rejecting", + current, MAX_CONCURRENT_CONNECTIONS + ); + // Drop the stream to close the connection + drop(stream); + continue; + } + + // Increment connection count + active_connections.fetch_add(1, Ordering::SeqCst); + let connections = Arc::clone(&active_connections); + + // Spawn a new thread to handle each client concurrently + thread::spawn(move || { + if let Err(err) = handle_client(stream) { + let sanitized = sanitize_error_message(&err); + println!("[enclave error] {sanitized}"); + } + // Decrement connection count when done + connections.fetch_sub(1, Ordering::SeqCst); + }); + } Err(e) => { println!("[enclave error] failed to accept connection: {:?}", e); continue; } - }; - - if let Err(err) = handle_client(stream) { - println!("[enclave error] {:?}", err); } } - println!("[enclave] finished"); - Ok(()) } diff --git a/enclave/src/models.rs b/enclave/src/models.rs index cef0571..26b8b8f 100644 --- a/enclave/src/models.rs +++ b/enclave/src/models.rs @@ -18,8 +18,9 @@ //! - Field count is limited to prevent resource exhaustion //! - Input validation is performed before processing -use std::collections::BTreeMap; +use std::collections::HashMap; use std::fmt; +use std::sync::Mutex; use anyhow::{Error, Result, anyhow, bail}; use aws_lc_rs::signature::{ @@ -27,11 +28,12 @@ use aws_lc_rs::signature::{ EcdsaSigningAlgorithm, }; use data_encoding::HEXLOWER; +use rayon::prelude::*; use rustls::crypto::aws_lc_rs::hpke::{ DH_KEM_P256_HKDF_SHA256_AES_256, DH_KEM_P384_HKDF_SHA384_AES_256, DH_KEM_P521_HKDF_SHA512_AES_256, }; -use rustls::crypto::hpke::{Hpke, HpkePrivateKey}; +use rustls::crypto::hpke::Hpke; use serde::{Deserialize, Serialize}; use serde_json::Value; use zeroize::ZeroizeOnDrop; @@ -39,7 +41,7 @@ use zeroize::ZeroizeOnDrop; use crate::constants::{ENCODING_BINARY, ENCODING_HEX, MAX_FIELDS, P256, P384, P521}; use crate::hpke::decrypt_value; -use crate::kms::get_secret_key; +use crate::kms::{SecureHpkePrivateKey, get_secret_key}; use crate::utils::base64_decode; /// AWS credentials for KMS access. @@ -76,11 +78,11 @@ impl fmt::Debug for Credential { pub struct ParentRequest { pub vault_id: String, pub region: String, - pub fields: BTreeMap, + pub fields: HashMap, pub suite_id: String, // base64 encoded pub encrypted_private_key: String, // base64 encoded #[serde(skip_serializing_if = "Option::is_none")] - pub expressions: Option>, + pub expressions: Option>, #[serde(skip_serializing_if = "Option::is_none")] pub encoding: Option, } @@ -143,28 +145,33 @@ impl EnclaveRequest { Ok(()) } - fn get_private_key(&self, suite: &Suite) -> Result { + fn get_private_key(&self, suite: &Suite) -> Result { let alg = suite.get_signing_algorithm(); - // Decrypt the KMS secret key - let sk: HpkePrivateKey = get_secret_key(alg, self)?; + // Decrypt the KMS secret key - wrapped in SecureHpkePrivateKey for zeroization + let sk = get_secret_key(alg, self)?; Ok(sk) } - pub fn decrypt_fields(&self) -> Result<(BTreeMap, Vec)> { + pub fn decrypt_fields(&self) -> Result<(HashMap, Vec)> { // Validate all inputs before processing self.validate()?; let suite: Suite = self.request.suite_id.as_str().try_into()?; let encoding: Encoding = self.request.encoding.as_ref().try_into()?; - let private_key = self.get_private_key(&suite)?; + // Get private key wrapped in SecureHpkePrivateKey for automatic zeroization + let secure_private_key = self.get_private_key(&suite)?; println!("[enclave] decrypted KMS secret key"); + // Get the HpkePrivateKey for use with rustls HPKE operations + // Note: This creates a short-lived copy; the secure wrapper's copy is zeroized on drop + let private_key = secure_private_key.as_hpke_private_key(); + let hpke_suite = suite.get_hpke_suite(); let info = self.request.vault_id.as_bytes(); - let mut errors: Vec = Vec::new(); + let errors: Mutex> = Mutex::new(Vec::new()); // Sensitive context logging gated behind debug builds only #[cfg(debug_assertions)] @@ -173,33 +180,69 @@ impl EnclaveRequest { println!("[enclave] encoding: {:?}", encoding); } - // Single loop with encoding-based parsing - let mut decrypted_fields = BTreeMap::new(); - for (field, value) in &self.request.fields { - let encrypted_data = encoding.parse(value.as_str(), &suite)?; - - let decrypted = decrypt_value(hpke_suite, &private_key, info, field, encrypted_data) - .unwrap_or_else(|error| { - errors.push(error); - Value::Null - }); - decrypted_fields.insert(field.to_string(), decrypted); - } + // First pass: parse all encrypted data (sequential, may fail early) + let parsed_fields: Vec<(&String, EncryptedData)> = self + .request + .fields + .iter() + .map(|(field, value)| { + let encrypted_data = encoding.parse(value.as_str(), &suite)?; + Ok((field, encrypted_data)) + }) + .collect::>>()?; + + // Second pass: decrypt in parallel (CPU-intensive operations) + let decrypted_fields: HashMap = parsed_fields + .into_par_iter() + .map(|(field, encrypted_data)| { + let decrypted = + decrypt_value(hpke_suite, &private_key, info, field, encrypted_data) + .unwrap_or_else(|error| { + // Handle mutex lock - log if poisoned (indicates a panic occurred) + match errors.lock() { + Ok(mut err_vec) => err_vec.push(error), + Err(poisoned) => { + // Mutex is poisoned - a thread panicked. Log and recover. + eprintln!( + "[enclave critical] mutex poisoned during decryption - \ + a thread may have panicked" + ); + // Recover the data and continue + poisoned.into_inner().push(error); + } + } + Value::Null + }); + (field.clone(), decrypted) + }) + .collect(); + + // Extract errors from mutex - handle poisoning with logging + let final_errors = match errors.into_inner() { + Ok(errs) => errs, + Err(poisoned) => { + eprintln!( + "[enclave critical] mutex poisoned during final error extraction - \ + a thread may have panicked during decryption" + ); + poisoned.into_inner() + } + }; - Ok((decrypted_fields, errors)) + Ok((decrypted_fields, final_errors)) } } #[derive(Debug, Clone, Default, Serialize, Deserialize)] pub struct EnclaveResponse { #[serde(skip_serializing_if = "Option::is_none")] - pub fields: Option>, + pub fields: Option>, #[serde(skip_serializing_if = "Option::is_none")] pub errors: Option>, } impl EnclaveResponse { - pub fn new(fields: BTreeMap, errors: Option>) -> Self { + pub fn new(fields: HashMap, errors: Option>) -> Self { let errors = errors.map(|errors| errors.iter().map(|e| e.to_string()).collect()); Self { @@ -1092,7 +1135,7 @@ mod tests { #[test] fn test_enclave_response_serialization_with_string_fields() { - let mut fields = BTreeMap::new(); + let mut fields = HashMap::new(); fields.insert("name".to_string(), Value::String("Bob".to_string())); fields.insert( "email".to_string(), @@ -1110,7 +1153,7 @@ mod tests { #[test] fn test_enclave_response_serialization_with_integer_fields() { - let mut fields = BTreeMap::new(); + let mut fields = HashMap::new(); fields.insert("age".to_string(), Value::Number(46.into())); fields.insert("count".to_string(), Value::Number(100.into())); @@ -1123,7 +1166,7 @@ mod tests { #[test] fn test_enclave_response_serialization_with_boolean_fields() { - let mut fields = BTreeMap::new(); + let mut fields = HashMap::new(); fields.insert("is_active".to_string(), Value::Bool(true)); fields.insert("is_verified".to_string(), Value::Bool(false)); @@ -1136,7 +1179,7 @@ mod tests { #[test] fn test_enclave_response_serialization_with_null_fields() { - let mut fields = BTreeMap::new(); + let mut fields = HashMap::new(); fields.insert("missing".to_string(), Value::Null); let response = EnclaveResponse::new(fields.clone(), None); @@ -1149,7 +1192,7 @@ mod tests { #[test] fn test_enclave_response_serialization_with_mixed_fields() { // This tests the typical output from CEL expressions - let mut fields = BTreeMap::new(); + let mut fields = HashMap::new(); fields.insert("name".to_string(), Value::String("Bob".to_string())); fields.insert("age".to_string(), Value::Number(46.into())); fields.insert("is_empty".to_string(), Value::Bool(false)); @@ -1169,7 +1212,7 @@ mod tests { #[test] fn test_enclave_response_serialization_with_errors() { - let mut fields = BTreeMap::new(); + let mut fields = HashMap::new(); fields.insert("name".to_string(), Value::String("Bob".to_string())); fields.insert("failed_field".to_string(), Value::Null); @@ -1196,7 +1239,7 @@ mod tests { #[test] fn test_enclave_response_serialization_produces_valid_json() { - let mut fields = BTreeMap::new(); + let mut fields = HashMap::new(); fields.insert("name".to_string(), Value::String("Bob".to_string())); let response = EnclaveResponse::new(fields, None); @@ -1210,7 +1253,7 @@ mod tests { #[test] fn test_enclave_response_serialization_skips_none_fields() { - let mut fields = BTreeMap::new(); + let mut fields = HashMap::new(); fields.insert("name".to_string(), Value::String("Bob".to_string())); let response = EnclaveResponse::new(fields, None); @@ -1236,7 +1279,7 @@ mod tests { use serde_json::json; // Test with string fields - let mut fields = BTreeMap::new(); + let mut fields = HashMap::new(); fields.insert("name".to_string(), Value::String("Bob".to_string())); fields.insert( "email".to_string(), @@ -1262,7 +1305,7 @@ mod tests { fn test_to_string_equals_json_macro_with_mixed_types() { use serde_json::json; - let mut fields = BTreeMap::new(); + let mut fields = HashMap::new(); fields.insert("name".to_string(), Value::String("Bob".to_string())); fields.insert("age".to_string(), Value::Number(46.into())); fields.insert("is_active".to_string(), Value::Bool(true)); @@ -1286,7 +1329,7 @@ mod tests { fn test_to_string_equals_json_macro_with_errors() { use serde_json::json; - let mut fields = BTreeMap::new(); + let mut fields = HashMap::new(); fields.insert("name".to_string(), Value::String("Bob".to_string())); let errors = vec![anyhow!("test error")]; @@ -1326,7 +1369,7 @@ mod tests { fn test_to_string_handles_special_characters_correctly() { use serde_json::json; - let mut fields = BTreeMap::new(); + let mut fields = HashMap::new(); // Test various special characters that need JSON escaping fields.insert( "with_quotes".to_string(), @@ -1375,7 +1418,7 @@ mod tests { #[test] fn test_to_string_no_double_escaping() { // Specifically test that quotes aren't double-escaped - let mut fields = BTreeMap::new(); + let mut fields = HashMap::new(); let original_value = r#"{"nested": "json"}"#; fields.insert( "json_string".to_string(), @@ -1471,7 +1514,7 @@ mod tests { request: ParentRequest { vault_id: "v_test123".to_string(), region: "us-east-1".to_string(), - fields: BTreeMap::new(), + fields: HashMap::new(), suite_id: "SFBLRQARAAIAAg==".to_string(), encrypted_private_key: "test_key".to_string(), expressions: None,