diff --git a/src/lib.rs b/src/lib.rs index a884ded..4602de6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,9 +3,11 @@ pub mod proto; use std::{ fmt::Debug, io::{self, BufWriter, Read, Write}, - net::{TcpStream, ToSocketAddrs}, + net::TcpStream, }; +use crate::proto::TLSPlaintext; + type Result = std::result::Result; pub struct ClientConnection {} @@ -54,7 +56,12 @@ impl ClientSetupConnection { println!("hello!"); let out = proto::TLSPlaintext::read(stream.get_mut())?; - dbg!(out); + dbg!(&out); + + if matches!(out, TLSPlaintext::Handshake { handshake } if handshake.is_hello_retry_request()) + { + println!("hello retry request, the server doesnt like us :("); + } // let res: proto::TLSPlaintext = proto::Value::read(&mut stream.get_mut())?; // dbg!(res); @@ -76,7 +83,6 @@ pub enum ErrorKind { impl From for Error { fn from(value: io::Error) -> Self { - panic!("io error: {value}"); Self { kind: ErrorKind::Io(value), } @@ -85,7 +91,6 @@ impl From for Error { impl From for Error { fn from(value: ErrorKind) -> Self { - panic!("error: {value:?}"); Self { kind: value } } } diff --git a/src/proto.rs b/src/proto.rs index 3b58dde..762a8da 100644 --- a/src/proto.rs +++ b/src/proto.rs @@ -1,14 +1,14 @@ +mod ser_de; + use std::{ fmt::Debug, io::{self, Read, Write}, - marker::PhantomData, - num::TryFromIntError, }; -use byteorder::{BigEndian as B, ReadBytesExt, WriteBytesExt}; - use crate::ErrorKind; +use self::ser_de::{proto_enum, proto_struct, u24, List, Todo, Value}; + #[derive(Debug, Clone, PartialEq, Eq)] pub enum TLSPlaintext { Invalid { @@ -280,355 +280,11 @@ proto_enum! { } } -macro_rules! proto_struct { - {$(#[$meta:meta])* pub struct $name:ident { - $( - $field_name:ident : $field_ty:ty, - )* - }} => { - $(#[$meta])* - pub struct $name { - $( - $field_name: $field_ty, - )* - } - - - impl Value for $name { - fn write(&self, mut w: &mut W) -> io::Result<()> { - $( - Value::write(&self.$field_name, &mut w)?; - )* - Ok(()) - } - - fn read(r: &mut R) -> crate::Result { - let ( $( $field_name ),* ) = ($( { discard!($field_name); Value::read(r)? } ),*); - - Ok(Self { - $( - $field_name, - )* - }) - } - - fn byte_size(&self) -> usize { - $( self.$field_name.byte_size() + )* 0 - } - } - }; -} -use proto_struct; - -macro_rules! proto_enum { - {$(#[$meta:meta])* pub enum $name:ident: $discr_ty:ty $( ,(length: $len_ty:ty) )? { - $( - $KindName:ident $({ - $( - $field_name:ident : $field_ty:ty, - )* - })? = $discriminant:expr, - )* - }} => { - $(#[$meta])* - pub enum $name { - $( - $KindName $({ - $( - $field_name: $field_ty, - )* - })?, - )* - } - - impl Value for $name { - fn write(&self, w: &mut W) -> io::Result<()> { - w.flush()?; - eprintln!("{}", stringify!($name)); - mod discr_consts { - $( - #[allow(non_upper_case_globals)] - pub(super) const $KindName: $discr_ty = $discriminant; - )* - } - - let write_len = |_w: &mut W, _len: usize| -> io::Result<()> { - _w.flush()?; - eprintln!("length"); - $( - <$len_ty>::try_from(_len).unwrap().write(_w)?; - )? - Ok(()) - }; - - match self { - $( - Self::$KindName $( { - $( $field_name, )* - } )? => { - let byte_size = $($( $field_name.byte_size() + )*)? 0; - - Value::write(&discr_consts::$KindName, w)?; - write_len(w, byte_size)?; - - let w = &mut MeasuringWriter(0, w); - - $($( - w.flush()?; - eprintln!("{}", stringify!($field_name)); - Value::write($field_name, w)?; - )*)? - - debug_assert_eq!(w.0, byte_size); - - Ok(()) - } - )* - } - } - - fn read(r: &mut R) -> crate::Result { - mod discr_consts { - $( - #[allow(non_upper_case_globals)] - pub(super) const $KindName: $discr_ty = $discriminant; - )* - } - - let kind: $discr_ty = Value::read(r)?; - - $( - let _len = <$len_ty>::read(r)?; - )? - - match kind { - $( - discr_consts::$KindName => { - #[allow(unused_parens)] - $(let ( $( $field_name ),* ) = ($( { discard!($field_name); Value::read(r)? } ),*);)? - - Ok(Self::$KindName $({ - $( - $field_name, - )* - })*) - }, - )* - - _ => Err(ErrorKind::InvalidFrame(Box::new(format!("invalid discriminant for {}: 0x{kind:x?}", stringify!($name)))).into()), - } - } - - fn byte_size(&self) -> usize { - mod discr_consts { - $( - #[allow(non_upper_case_globals)] - pub(super) const $KindName: $discr_ty = $discriminant; - )* - } - - $( <$len_ty>::default().byte_size() + )? match self { - $( - Self::$KindName $( { - $( $field_name, )* - } )? => { - $( $( $field_name.byte_size() + )* )? discr_consts::$KindName.byte_size() - } - )* - } - } - } - }; -} -use proto_enum; - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub struct Todo; - -impl Value for Todo { - fn write(&self, w: &mut W) -> io::Result<()> { - todo!() - } - - fn read(r: &mut R) -> crate::Result { - todo!() - } - - fn byte_size(&self) -> usize { - todo!() +impl Handshake { + pub fn is_hello_retry_request(&self) -> bool { + matches!(self, Handshake::ServerHello { random, .. } if random == &HELLO_RETRY_REQUEST) } } -#[derive(Clone, PartialEq, Eq)] -pub struct List(Vec, PhantomData); - -impl From> for List { - fn from(value: Vec) -> Self { - Self(value, PhantomData) - } -} - -impl Debug for List { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_list().entries(self.0.iter()).finish() - } -} - -impl + TryFrom + Default> Value for List { - fn read(r: &mut R) -> crate::Result { - let mut remaining_byte_size = Len::read(r)?.into(); - let mut v = Vec::new(); - - while remaining_byte_size > 0 { - let value = T::read(r)?; - remaining_byte_size -= value.byte_size(); - v.push(value); - } - Ok(Self(v, PhantomData)) - } - fn write(&self, w: &mut W) -> io::Result<()> { - let byte_size = self.0.iter().map(Value::byte_size).sum::(); - Len::write( - &byte_size - .try_into() - .unwrap_or_else(|_| panic!("list is too large for domain: {}", self.0.len())), - w, - )?; - for elem in &self.0 { - elem.write(w)?; - } - Ok(()) - } - fn byte_size(&self) -> usize { - Len::byte_size(&Default::default()) + self.0.iter().map(Value::byte_size).sum::() - } -} - -pub trait Value: Sized + std::fmt::Debug { - fn write(&self, w: &mut W) -> io::Result<()>; - fn read(r: &mut R) -> crate::Result; - fn byte_size(&self) -> usize; -} - -impl Value for [V; N] { - fn write(&self, w: &mut W) -> io::Result<()> { - self.iter().try_for_each(|v| Value::write(v, w)) - } - fn read(r: &mut R) -> crate::Result { - // ugly :( - let mut values = Vec::with_capacity(N); - for _ in 0..N { - let value = V::read(r)?; - values.push(value); - } - Ok(values.try_into().unwrap()) - } - fn byte_size(&self) -> usize { - self.iter().map(Value::byte_size).sum() - } -} - -impl Value for u8 { - fn write(&self, w: &mut W) -> io::Result<()> { - w.write_u8(*self) - } - fn read(r: &mut R) -> crate::Result { - r.read_u8().map_err(Into::into) - } - fn byte_size(&self) -> usize { - 1 - } -} - -impl Value for u16 { - fn write(&self, w: &mut W) -> io::Result<()> { - w.write_u16::(*self) - } - fn read(r: &mut R) -> crate::Result { - r.read_u16::().map_err(Into::into) - } - fn byte_size(&self) -> usize { - 2 - } -} - -impl Value for u32 { - fn write(&self, w: &mut W) -> io::Result<()> { - w.write_u32::(*self) - } - - fn read(r: &mut R) -> crate::Result { - r.read_u32::().map_err(Into::into) - } - - fn byte_size(&self) -> usize { - 4 - } -} - -impl Value for (T, U) { - fn write(&self, w: &mut W) -> io::Result<()> { - T::write(&self.0, w)?; - T::write(&self.0, w)?; - Ok(()) - } - - fn read(r: &mut R) -> crate::Result { - Ok((T::read(r)?, U::read(r)?)) - } - - fn byte_size(&self) -> usize { - self.0.byte_size() + self.1.byte_size() - } -} - -#[derive(Debug, Clone, Copy, Default)] -#[allow(non_camel_case_types)] -struct u24(u32); - -impl Value for u24 { - fn write(&self, w: &mut W) -> io::Result<()> { - w.write_u24::(self.0) - } - - fn read(r: &mut R) -> crate::Result { - r.read_u24::().map_err(Into::into).map(u24) - } - - fn byte_size(&self) -> usize { - 3 - } -} - -impl TryFrom for u24 { - type Error = TryFromIntError; - fn try_from(value: usize) -> Result { - let value = u32::try_from(value)?; - if value > 2_u32.pow(24) { - return Err(u32::try_from(usize::MAX).unwrap_err()); - } - Ok(u24(value)) - } -} - -struct MeasuringWriter(usize, W); - -impl Write for MeasuringWriter { - fn write(&mut self, buf: &[u8]) -> io::Result { - let len = self.1.write(buf)?; - self.0 += len; - Ok(len) - } - - fn flush(&mut self) -> io::Result<()> { - self.1.flush() - } -} - -macro_rules! discard { - ($($tt:tt)*) => {}; -} -use discard; - #[cfg(test)] mod tests; diff --git a/src/proto/ser_de.rs b/src/proto/ser_de.rs new file mode 100644 index 0000000..4cf3041 --- /dev/null +++ b/src/proto/ser_de.rs @@ -0,0 +1,358 @@ +use byteorder::{BigEndian as B, ReadBytesExt, WriteBytesExt}; +use std::fmt::Debug; + +macro_rules! proto_struct { + {$(#[$meta:meta])* pub struct $name:ident { + $( + $field_name:ident : $field_ty:ty, + )* + }} => { + $(#[$meta])* + pub struct $name { + $( + $field_name: $field_ty, + )* + } + + + impl crate::proto::ser_de::Value for $name { + fn write(&self, mut w: &mut W) -> io::Result<()> { + $( + crate::proto::ser_de::Value::write(&self.$field_name, &mut w)?; + )* + Ok(()) + } + + fn read(r: &mut R) -> crate::Result { + let ( $( $field_name ),* ) = ($( { crate::proto::ser_de::discard!($field_name); crate::proto::ser_de::Value::read(r)? } ),*); + + Ok(Self { + $( + $field_name, + )* + }) + } + + fn byte_size(&self) -> usize { + $( self.$field_name.byte_size() + )* 0 + } + } + }; +} +use std::{ + io::{self, Read, Write}, + marker::PhantomData, + num::TryFromIntError, +}; + +pub(crate) use proto_struct; + +macro_rules! proto_enum { + {$(#[$meta:meta])* pub enum $name:ident: $discr_ty:ty $( ,(length: $len_ty:ty) )? { + $( + $KindName:ident $({ + $( + $field_name:ident : $field_ty:ty, + )* + })? = $discriminant:expr, + )* + }} => { + $(#[$meta])* + pub enum $name { + $( + $KindName $({ + $( + $field_name: $field_ty, + )* + })?, + )* + } + + impl crate::proto::ser_de::Value for $name { + fn write(&self, w: &mut W) -> io::Result<()> { + w.flush()?; + eprintln!("{}", stringify!($name)); + mod discr_consts { + $( + #[allow(non_upper_case_globals)] + pub(super) const $KindName: $discr_ty = $discriminant; + )* + } + + let write_len = |_w: &mut W, _len: usize| -> io::Result<()> { + _w.flush()?; + eprintln!("length"); + $( + <$len_ty>::try_from(_len).unwrap().write(_w)?; + )? + Ok(()) + }; + + match self { + $( + Self::$KindName $( { + $( $field_name, )* + } )? => { + let byte_size = $($( $field_name.byte_size() + )*)? 0; + + crate::proto::ser_de::Value::write(&discr_consts::$KindName, w)?; + write_len(w, byte_size)?; + + let w = &mut crate::proto::ser_de::MeasuringWriter(0, w); + + $($( + w.flush()?; + eprintln!("{}", stringify!($field_name)); + crate::proto::ser_de::Value::write($field_name, w)?; + )*)? + + debug_assert_eq!(w.0, byte_size); + + Ok(()) + } + )* + } + } + + fn read(r: &mut R) -> crate::Result { + mod discr_consts { + $( + #[allow(non_upper_case_globals)] + pub(super) const $KindName: $discr_ty = $discriminant; + )* + } + + let kind: $discr_ty = crate::proto::ser_de::Value::read(r)?; + + $( + let _len = <$len_ty>::read(r)?; + )? + + match kind { + $( + discr_consts::$KindName => { + #[allow(unused_parens)] + $(let ( $( $field_name ),* ) = ($( { crate::proto::ser_de::discard!($field_name); crate::proto::ser_de::Value::read(r)? } ),*);)? + + Ok(Self::$KindName $({ + $( + $field_name, + )* + })*) + }, + )* + + _ => Err(ErrorKind::InvalidFrame(Box::new(format!("invalid discriminant for {}: 0x{kind:x?}", stringify!($name)))).into()), + } + } + + fn byte_size(&self) -> usize { + mod discr_consts { + $( + #[allow(non_upper_case_globals)] + pub(super) const $KindName: $discr_ty = $discriminant; + )* + } + + $( <$len_ty>::default().byte_size() + )? match self { + $( + Self::$KindName $( { + $( $field_name, )* + } )? => { + $( $( $field_name.byte_size() + )* )? discr_consts::$KindName.byte_size() + } + )* + } + } + } + }; +} +pub(crate) use proto_enum; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct Todo; + +impl Value for Todo { + fn write(&self, _: &mut W) -> io::Result<()> { + todo!() + } + + fn read(_: &mut R) -> crate::Result { + todo!() + } + + fn byte_size(&self) -> usize { + todo!() + } +} + +#[derive(Clone, PartialEq, Eq)] +pub struct List(Vec, PhantomData); + +impl From> for List { + fn from(value: Vec) -> Self { + Self(value, PhantomData) + } +} + +impl Debug for List { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_list().entries(self.0.iter()).finish() + } +} + +impl + TryFrom + Default> Value for List { + fn read(r: &mut R) -> crate::Result { + let mut remaining_byte_size = Len::read(r)?.into(); + let mut v = Vec::new(); + + while remaining_byte_size > 0 { + let value = T::read(r)?; + remaining_byte_size -= value.byte_size(); + v.push(value); + } + Ok(Self(v, PhantomData)) + } + fn write(&self, w: &mut W) -> io::Result<()> { + let byte_size = self.0.iter().map(Value::byte_size).sum::(); + Len::write( + &byte_size + .try_into() + .unwrap_or_else(|_| panic!("list is too large for domain: {}", self.0.len())), + w, + )?; + for elem in &self.0 { + elem.write(w)?; + } + Ok(()) + } + fn byte_size(&self) -> usize { + Len::byte_size(&Default::default()) + self.0.iter().map(Value::byte_size).sum::() + } +} + +pub trait Value: Sized + std::fmt::Debug { + fn write(&self, w: &mut W) -> io::Result<()>; + fn read(r: &mut R) -> crate::Result; + fn byte_size(&self) -> usize; +} + +impl Value for [V; N] { + fn write(&self, w: &mut W) -> io::Result<()> { + self.iter().try_for_each(|v| Value::write(v, w)) + } + fn read(r: &mut R) -> crate::Result { + // ugly :( + let mut values = Vec::with_capacity(N); + for _ in 0..N { + let value = V::read(r)?; + values.push(value); + } + Ok(values.try_into().unwrap()) + } + fn byte_size(&self) -> usize { + self.iter().map(Value::byte_size).sum() + } +} + +impl Value for u8 { + fn write(&self, w: &mut W) -> io::Result<()> { + w.write_u8(*self) + } + fn read(r: &mut R) -> crate::Result { + r.read_u8().map_err(Into::into) + } + fn byte_size(&self) -> usize { + 1 + } +} + +impl Value for u16 { + fn write(&self, w: &mut W) -> io::Result<()> { + w.write_u16::(*self) + } + fn read(r: &mut R) -> crate::Result { + r.read_u16::().map_err(Into::into) + } + fn byte_size(&self) -> usize { + 2 + } +} + +impl Value for u32 { + fn write(&self, w: &mut W) -> io::Result<()> { + w.write_u32::(*self) + } + + fn read(r: &mut R) -> crate::Result { + r.read_u32::().map_err(Into::into) + } + + fn byte_size(&self) -> usize { + 4 + } +} + +impl Value for (T, U) { + fn write(&self, w: &mut W) -> io::Result<()> { + T::write(&self.0, w)?; + T::write(&self.0, w)?; + Ok(()) + } + + fn read(r: &mut R) -> crate::Result { + Ok((T::read(r)?, U::read(r)?)) + } + + fn byte_size(&self) -> usize { + self.0.byte_size() + self.1.byte_size() + } +} + +#[derive(Debug, Clone, Copy, Default)] +#[allow(non_camel_case_types)] +pub struct u24(u32); + +impl Value for u24 { + fn write(&self, w: &mut W) -> io::Result<()> { + w.write_u24::(self.0) + } + + fn read(r: &mut R) -> crate::Result { + r.read_u24::().map_err(Into::into).map(u24) + } + + fn byte_size(&self) -> usize { + 3 + } +} + +impl TryFrom for u24 { + type Error = TryFromIntError; + fn try_from(value: usize) -> Result { + let value = u32::try_from(value)?; + if value > 2_u32.pow(24) { + return Err(u32::try_from(usize::MAX).unwrap_err()); + } + Ok(u24(value)) + } +} + +pub(crate) struct MeasuringWriter(pub(crate) usize, pub(crate) W); + +impl Write for MeasuringWriter { + fn write(&mut self, buf: &[u8]) -> io::Result { + let len = self.1.write(buf)?; + self.0 += len; + Ok(len) + } + + fn flush(&mut self) -> io::Result<()> { + self.1.flush() + } +} + +macro_rules! discard { + ($($tt:tt)*) => {}; +} +pub(crate) use discard;