diff --git a/src/lib.rs b/src/lib.rs index 1926761..84ea220 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -25,16 +25,25 @@ struct ClientSetupConnection { _w: W, } +macro_rules! unexpected_message { + ($($tt:tt)*) => { + Err(ErrorKind::UnexpectedMessage(format!($($tt)*)).into()) + }; +} + impl ClientSetupConnection { fn establish(mut stream: W, host: &str) -> Result { let secret = x25519_dalek::EphemeralSecret::random_from_rng(rand::thread_rng()); let public = x25519_dalek::PublicKey::from(&secret); + let legacy_session_id = rand::random::<[u8; 32]>(); + let cipher_suites = vec![proto::CipherSuite::TLS_AES_128_GCM_SHA256]; + let handshake = proto::Handshake::ClientHello { legacy_version: proto::LEGACY_TLSV12, random: rand::random(), - legacy_session_id: rand::random::<[u8; 32]>().to_vec().into(), - cipher_suites: vec![proto::CipherSuite::TLS_AES_128_GCM_SHA256].into(), + legacy_session_id: legacy_session_id.to_vec().into(), + cipher_suites: cipher_suites.clone().into(), legacy_compressions_methods: vec![0].into(), extensions: vec![ proto::ExtensionCH::ServerName { @@ -86,13 +95,89 @@ impl ClientSetupConnection { let out = proto::TLSPlaintext::read(&mut stream)?; dbg!(&out); - if matches!(out, TLSPlaintext::Handshake { handshake } if handshake.is_hello_retry_request()) - { - println!("hello retry request, the server doesnt like us :("); + let proto::TLSPlaintext::Handshake { + handshake: + proto::Handshake::ServerHello { + legacy_version, + random, + legacy_session_id_echo, + cipher_suite, + legacy_compression_method, + extensions, + }, + } = out + else { + return Err( + ErrorKind::UnexpectedMessage(format!("expected ServerHello, got {out:?}")).into(), + ); + }; + + if random.is_hello_retry_request() { + return Err(ErrorKind::HelloRetryRequest.into()); } - // let res: proto::TLSPlaintext = proto::Value::read(&mut stream.get_mut())?; - // dbg!(res); + if legacy_version != proto::LEGACY_TLSV12 { + return unexpected_message!( + "unexpected TLS version in legacy_version field: {legacy_version:x?}" + ); + } + + if legacy_session_id_echo.as_ref() != legacy_session_id { + return unexpected_message!( + "server did not echo the legacy_session_id: {legacy_session_id_echo:?}" + ); + } + + if !cipher_suites.contains(&cipher_suite) { + return unexpected_message!( + "cipher suite from server not sent in client hello: {cipher_suite:?}" + ); + } + + if legacy_compression_method != 0 { + return unexpected_message!( + "legacy compression method MUST be zero: {legacy_compression_method}" + ); + } + + let mut supported_versions = false; + let mut server_key = None; + + for ext in extensions.as_ref() { + match ext { + proto::ExtensionSH::PreSharedKey => todo!(), + proto::ExtensionSH::SupportedVersions { selected_version } => { + if *selected_version != proto::TLSV13 { + return unexpected_message!("server returned non-TLS 1.3 version: {selected_version}"); + } + supported_versions = true; + }, + proto::ExtensionSH::Cookie { .. } => todo!(), + proto::ExtensionSH::KeyShare { key_share } => { + let entry = key_share.unwrap_server_hello(); + match entry { + proto::KeyShareEntry::X25519 { len, key_exchange } => { + if *len != 32 { + return unexpected_message!("key length for X25519 key share must be 32: {len}"); + } + server_key = Some(key_exchange); + }, + } + }, + } + } + + if !supported_versions { + return unexpected_message!("server did not send supported_versions extension"); + } + + let Some(server_key) = server_key else { + return unexpected_message!("server did not send its key"); + }; + let server_key = x25519_dalek::PublicKey::from(*server_key); + let dh_shared_secret = secret.diffie_hellman(&server_key); + + println!("we have established a shared secret. dont leak it!! anywhere here is it: {:x?}", dh_shared_secret.as_bytes()); todo!() } @@ -106,6 +191,8 @@ pub struct Error { #[derive(Debug)] pub enum ErrorKind { InvalidFrame(Box), + HelloRetryRequest, + UnexpectedMessage(String), Io(io::Error), } diff --git a/src/proto.rs b/src/proto.rs index a4c0fd1..c19a559 100644 --- a/src/proto.rs +++ b/src/proto.rs @@ -358,6 +358,15 @@ pub enum ServerHelloKeyshare { ServerHello(KeyShareEntry), } +impl ServerHelloKeyshare { + pub fn unwrap_server_hello(&self) -> &KeyShareEntry { + match self { + Self::HelloRetryRequest(_) => panic!("unexpected hello retry request, expected server hello"), + Self::ServerHello(entry) => entry, + } + } +} + impl Value for ServerHelloKeyshare { fn write(&self, w: &mut W) -> io::Result<()> { match self { @@ -382,9 +391,9 @@ impl Value for ServerHelloKeyshare { } } -impl Handshake { +impl ServerHelloRandom { pub fn is_hello_retry_request(&self) -> bool { - matches!(self, Handshake::ServerHello { random, .. } if random.0 == HELLO_RETRY_REQUEST) + self.0 == HELLO_RETRY_REQUEST } } diff --git a/src/proto/ser_de.rs b/src/proto/ser_de.rs index 5ca6e7a..b611274 100644 --- a/src/proto/ser_de.rs +++ b/src/proto/ser_de.rs @@ -234,6 +234,12 @@ impl From> for List { } } +impl AsRef<[T]> for List { + fn as_ref(&self) -> &[T] { + self.0.as_ref() + } +} + impl Debug for List { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_list().entries(self.0.iter()).finish() diff --git a/tests/expect.rs b/tests/expect.rs index 060651c..fa573e5 100644 --- a/tests/expect.rs +++ b/tests/expect.rs @@ -1,3 +1,5 @@ +#![allow(dead_code)] + use std::io::{Read, Write}; struct ExpectServer { @@ -62,5 +64,5 @@ fn connect() { Expect::Client(vec![0]), // TODO: do this ]); - let conn = tls::ClientConnection::establish(&mut expect, "example.com").unwrap(); + tls::ClientConnection::establish(&mut expect, "example.com").unwrap(); }