diff --git a/Cargo.lock b/Cargo.lock index daca4d0..1c6c20d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -212,17 +212,11 @@ dependencies = [ name = "fakessh" version = "0.1.0" dependencies = [ - "crypto-bigint", - "ed25519-dalek", "eyre", - "hex-literal", - "rand", - "rand_core", - "sha2", + "ssh-transport", "tokio", "tracing", "tracing-subscriber", - "x25519-dalek", ] [[package]] @@ -631,6 +625,23 @@ dependencies = [ "der", ] +[[package]] +name = "ssh-transport" +version = "0.1.0" +dependencies = [ + "crypto-bigint", + "ed25519-dalek", + "eyre", + "hex-literal", + "rand", + "rand_core", + "sha2", + "tokio", + "tracing", + "tracing-subscriber", + "x25519-dalek", +] + [[package]] name = "subtle" version = "2.6.1" diff --git a/Cargo.toml b/Cargo.toml index d00cea9..6f2ca4d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,19 +1,15 @@ +[workspace] +members = ["ssh-transport"] + [package] name = "fakessh" version = "0.1.0" edition = "2021" [dependencies] -crypto-bigint = "0.5.5" -ed25519-dalek = { version = "2.1.1" } eyre = "0.6.12" -rand = "0.8.5" -rand_core = "0.6.4" -sha2 = "0.10.8" +ssh-transport = { path = "./ssh-transport" } + tokio = { version = "1.39.2", features = ["full"] } tracing = "0.1.40" tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } -x25519-dalek = "2.0.1" - -[dev-dependencies] -hex-literal = "0.4.1" diff --git a/src/main.rs b/src/main.rs index f1a93ba..d2a2b63 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,13 +1,14 @@ use std::net::SocketAddr; use eyre::{Context, Result}; -use fakessh::{ServerConnection, SshError, ThreadRngRand}; use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, net::{TcpListener, TcpStream}, }; use tracing::{error, info}; +use ssh_transport::{ServerConnection, SshError, ThreadRngRand}; + #[tokio::main] async fn main() -> eyre::Result<()> { tracing_subscriber::fmt().init(); diff --git a/ssh-transport/Cargo.toml b/ssh-transport/Cargo.toml new file mode 100644 index 0000000..f2f4922 --- /dev/null +++ b/ssh-transport/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "ssh-transport" +version = "0.1.0" +edition = "2021" + +[dependencies] +crypto-bigint = "0.5.5" +ed25519-dalek = { version = "2.1.1" } +eyre = "0.6.12" +rand = "0.8.5" +rand_core = "0.6.4" +sha2 = "0.10.8" +tracing = "0.1.40" +tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } +x25519-dalek = "2.0.1" + +[dev-dependencies] +hex-literal = "0.4.1" diff --git a/src/lib.rs b/ssh-transport/src/lib.rs similarity index 62% rename from src/lib.rs rename to ssh-transport/src/lib.rs index 733c93d..e4bb287 100644 --- a/src/lib.rs +++ b/ssh-transport/src/lib.rs @@ -1,5 +1,19 @@ +mod packet; mod parse; +use core::str; +use std::mem::take; + +use ed25519_dalek::ed25519::signature::Signer; +use packet::{ + DhKeyExchangeInitPacket, DhKeyExchangeInitReplyPacket, KeyExchangeInitPacket, Packet, + PacketTransport, SshPublicKey, SshSignature, +}; +use parse::{MpInt, NameList}; +use rand::RngCore; +use sha2::Digest; +use x25519_dalek::{EphemeralSecret, PublicKey}; + #[derive(Debug)] pub enum SshError { /// The client did something wrong. @@ -19,21 +33,6 @@ impl From for SshError { } } -macro_rules! client_error { - ($($tt:tt)*) => { - $crate::SshError::ClientError(::std::format!($($tt)*)) - }; -} -use core::str; -use std::{collections::VecDeque, mem::take}; - -use client_error; -use ed25519_dalek::ed25519::signature::Signer; -use parse::{MpInt, NameList, Parser, Writer}; -use rand::RngCore; -use sha2::Digest; -use x25519_dalek::{EphemeralSecret, PublicKey}; - // This is definitely who we are. pub const SERVER_IDENTIFICATION: &[u8] = b"SSH-2.0-OpenSSH_9.7\r\n"; @@ -96,10 +95,7 @@ impl ServerConnection { state: ServerState::ProtoExchange { received: Vec::new(), }, - packet_transport: PacketTransport { - state: PacketTransportState::Plaintext(PacketParser::new()), - packets: VecDeque::new(), - }, + packet_transport: PacketTransport::new(), send_queue: Vec::new(), rng: Box::new(rng), } @@ -125,13 +121,13 @@ impl ServerConnection { self.packet_transport.recv_bytes(bytes)?; - match &mut self.state { - ServerState::ProtoExchange { .. } => unreachable!("handled above"), - ServerState::KeyExchangeInit { - client_identification, - } => match self.packet_transport.next_packet() { - Some(data) => { - let kex = KeyExchangeInitPacket::parse(&data.payload)?; + while let Some(packet) = self.packet_transport.next_packet() { + match &mut self.state { + ServerState::ProtoExchange { .. } => unreachable!("handled above"), + ServerState::KeyExchangeInit { + client_identification, + } => { + let kex = KeyExchangeInitPacket::parse(&packet.payload)?; let require_algorithm = |expected: &'static str, list: NameList<'_>| -> Result<&'static str> { @@ -207,19 +203,16 @@ impl ServerConnection { })); self.state = ServerState::DhKeyInit { client_identification, - client_kexinit: data.payload, + client_kexinit: packet.payload, server_kexinit: server_kexinit_payload, }; } - None => {}, - }, - ServerState::DhKeyInit { - client_identification, - client_kexinit, - server_kexinit, - } => match self.packet_transport.next_packet() { - Some(data) => { - let dh = DhKeyExchangeInitPacket::parse(&data.payload)?; + ServerState::DhKeyInit { + client_identification, + client_kexinit, + server_kexinit, + } => { + let dh = DhKeyExchangeInitPacket::parse(&packet.payload)?; let secret = EphemeralSecret::random_from_rng(SshRngRandAdapter(&mut *self.rng)); @@ -303,11 +296,8 @@ impl ServerConnection { self.state = ServerState::NewKeys; // TODO: set keys for transport } - None => {}, - }, - ServerState::NewKeys => match self.packet_transport.next_packet() { - Some(data) => { - if data.payload != &[Packet::SSH_MSG_NEWKEYS] { + ServerState::NewKeys => { + if packet.payload != &[Packet::SSH_MSG_NEWKEYS] { return Err(client_error!("did not send SSH_MSG_NEWKEYS")); } @@ -316,9 +306,8 @@ impl ServerConnection { })); self.state = ServerState::ServiceRequest {}; } - None => {}, - }, - ServerState::ServiceRequest {} => {}, + ServerState::ServiceRequest {} => {} + } } Ok(()) } @@ -351,312 +340,6 @@ impl Msg { } } -/// Frames the byte stream into packets. -struct PacketTransport { - state: PacketTransportState, - packets: VecDeque, -} - -enum PacketTransportState { - Plaintext(PacketParser), - Keyed { key: () }, -} - -impl PacketTransport { - fn recv_bytes(&mut self, mut bytes: &[u8]) -> Result<()> { - while let Some(consumed) = self.recv_bytes_step(bytes)? { - bytes = &bytes[consumed..]; - if bytes.is_empty() { - break; - } - } - Ok(()) - } - fn next_packet(&mut self) -> Option { - self.packets.pop_front() - } - fn recv_bytes_step(&mut self, bytes: &[u8]) -> Result> { - // TODO: This might not work if we buffer two packets where one changes keys in between? - match &mut self.state { - PacketTransportState::Plaintext(packet) => { - let result = packet.recv_bytes(bytes, ())?; - if let Some((consumed, result)) = result { - self.packets.push_back(result); - *packet = PacketParser::new(); - return Ok(Some(consumed)); - } - } - PacketTransportState::Keyed { key } => todo!(), - } - - Ok(None) - } -} - -#[derive(Debug, PartialEq)] -struct Packet { - payload: Vec, -} -impl Packet { - const SSH_MSG_KEXINIT: u8 = 20; - const SSH_MSG_NEWKEYS: u8 = 21; - const SSH_MSG_KEXDH_INIT: u8 = 30; - const SSH_MSG_KEXDH_REPLY: u8 = 31; - - fn from_raw(bytes: &[u8]) -> Result { - let Some(padding_length) = bytes.get(0) else { - return Err(client_error!("empty packet")); - }; - // TODO: mac? - let Some(payload_len) = (bytes.len() - 1).checked_sub(*padding_length as usize) else { - return Err(client_error!("packet padding longer than packet")); - }; - let payload = &bytes[1..][..payload_len]; - - if (bytes.len() + 4) % 8 != 0 { - return Err(client_error!("full packet length must be multiple of 8")); - } - - Ok(Self { - payload: payload.to_vec(), - }) - } - - fn to_bytes(&self) -> Vec { - let mut new = Vec::new(); - - let min_full_length = self.payload.len() + 4 + 1; - - // The padding must give a factor of 8. - let min_padding_len = (min_full_length.next_multiple_of(8) - min_full_length) as u8; - // > There MUST be at least four bytes of padding. - // So let's satisfy this by just adding 8. We can always properly randomize it later if desired. - let padding_len = min_padding_len + 8; - - let packet_len = self.payload.len() + (padding_len as usize) + 1; - new.extend_from_slice(&u32::to_be_bytes(packet_len as u32)); - new.extend_from_slice(&[padding_len]); - new.extend_from_slice(&self.payload); - new.extend(std::iter::repeat(0).take(padding_len as usize)); - // mac... - - assert!((4 + 1 + self.payload.len() + (padding_len as usize)) % 8 == 0); - assert!(new.len() % 8 == 0); - - new - } -} - -#[derive(Debug)] -struct KeyExchangeInitPacket<'a> { - cookie: [u8; 16], - kex_algorithms: NameList<'a>, - server_host_key_algorithms: NameList<'a>, - encryption_algorithms_client_to_server: NameList<'a>, - encryption_algorithms_server_to_client: NameList<'a>, - mac_algorithms_client_to_server: NameList<'a>, - mac_algorithms_server_to_client: NameList<'a>, - compression_algorithms_client_to_server: NameList<'a>, - compression_algorithms_server_to_client: NameList<'a>, - languages_client_to_server: NameList<'a>, - languages_server_to_client: NameList<'a>, - first_kex_packet_follows: bool, -} - -impl<'a> KeyExchangeInitPacket<'a> { - fn parse(payload: &'a [u8]) -> Result> { - let mut c = Parser::new(payload); - - let kind = c.u8()?; - if kind != Packet::SSH_MSG_KEXINIT { - return Err(client_error!( - "expected SSH_MSG_KEXINIT packet, found {kind}" - )); - } - let cookie = c.read_array::<16>()?; - let kex_algorithms = c.name_list()?; - let server_host_key_algorithms = c.name_list()?; - let encryption_algorithms_client_to_server = c.name_list()?; - let encryption_algorithms_server_to_client = c.name_list()?; - let mac_algorithms_client_to_server = c.name_list()?; - let mac_algorithms_server_to_client = c.name_list()?; - let compression_algorithms_client_to_server = c.name_list()?; - let compression_algorithms_server_to_client = c.name_list()?; - - let languages_client_to_server = c.name_list()?; - let languages_server_to_client = c.name_list()?; - - let first_kex_packet_follows = c.bool()?; - - let _ = c.u32()?; // Reserved. - - Ok(Self { - cookie, - kex_algorithms, - server_host_key_algorithms, - encryption_algorithms_client_to_server, - encryption_algorithms_server_to_client, - mac_algorithms_client_to_server, - mac_algorithms_server_to_client, - compression_algorithms_client_to_server, - compression_algorithms_server_to_client, - languages_client_to_server, - languages_server_to_client, - first_kex_packet_follows, - }) - } - - fn to_bytes(&self) -> Vec { - let mut data = Writer::new(); - - data.u8(Packet::SSH_MSG_KEXINIT); - data.write(&self.cookie); - data.name_list(self.kex_algorithms); - data.name_list(self.server_host_key_algorithms); - data.name_list(self.encryption_algorithms_client_to_server); - data.name_list(self.encryption_algorithms_server_to_client); - data.name_list(self.mac_algorithms_client_to_server); - data.name_list(self.mac_algorithms_server_to_client); - data.name_list(self.compression_algorithms_client_to_server); - data.name_list(self.compression_algorithms_server_to_client); - data.name_list(self.languages_client_to_server); - data.name_list(self.languages_server_to_client); - data.u8(self.first_kex_packet_follows as u8); - data.u32(0); // Reserved. - - data.finish() - } -} - -#[derive(Debug)] -struct DhKeyExchangeInitPacket<'a> { - e: MpInt<'a>, -} -impl<'a> DhKeyExchangeInitPacket<'a> { - fn parse(payload: &'a [u8]) -> Result> { - let mut c = Parser::new(payload); - - let kind = c.u8()?; - if kind != Packet::SSH_MSG_KEXDH_INIT { - return Err(client_error!( - "expected SSH_MSG_KEXDH_INIT packet, found {kind}" - )); - } - let e = c.mpint()?; - Ok(Self { e }) - } -} - -#[derive(Debug)] -struct SshPublicKey<'a> { - format: &'a [u8], - data: &'a [u8], -} -impl SshPublicKey<'_> { - fn to_bytes(&self) -> Vec { - let mut data = Writer::new(); - data.u32((4 + self.format.len() + 4 + self.data.len()) as u32); - // ed25519-specific! - // - data.string(&self.format); - data.string(&self.data); - data.finish() - } -} -#[derive(Debug)] -struct SshSignature<'a> { - format: &'a [u8], - data: &'a [u8], -} - -#[derive(Debug)] -struct DhKeyExchangeInitReplyPacket<'a> { - pubkey: SshPublicKey<'a>, - f: MpInt<'a>, - signature: SshSignature<'a>, -} -impl<'a> DhKeyExchangeInitReplyPacket<'a> { - fn to_bytes(&self) -> Vec { - let mut data = Writer::new(); - - data.u8(Packet::SSH_MSG_KEXDH_REPLY); - data.write(&self.pubkey.to_bytes()); - data.mpint(self.f); - - data.u32((4 + self.signature.format.len() + 4 + self.signature.data.len()) as u32); - // - data.string(&self.signature.format); - data.string(&self.signature.data); - data.finish() - } -} - -struct EncryptedPacketParser {} - -struct PacketParser { - // The length of the packet. - packet_length: Option, - // Before we've read the length fully, this stores the length. - // Afterwards, this stores the packet data *after* the length. - data: Vec, -} -impl PacketParser { - fn new() -> Self { - Self { - packet_length: None, - data: Vec::new(), - } - } - fn recv_bytes(&mut self, bytes: &[u8], mac: ()) -> Result> { - let Some((consumed, data)) = self.recv_bytes_inner(bytes, mac)? else { - return Ok(None); - }; - Ok(Some((consumed, Packet::from_raw(&data)?))) - } - fn recv_bytes_inner(&mut self, mut bytes: &[u8], _mac: ()) -> Result)>> { - let mut consumed = 0; - let packet_length = match self.packet_length { - Some(packet_length) => packet_length, - None => { - let remaining_len = std::cmp::min(bytes.len(), 4 - self.data.len()); - // Try to read the bytes of the length. - self.data.extend_from_slice(&bytes[..remaining_len]); - if self.data.len() < 4 { - // Not enough data yet :(. - return Ok(None); - } - - let packet_length = u32::from_be_bytes(self.data.as_slice().try_into().unwrap()); - let packet_length = packet_length.try_into().unwrap(); - self.data.clear(); - - self.packet_length = Some(packet_length); - - // We have the data. - bytes = &bytes[remaining_len..]; - consumed += remaining_len; - - packet_length - } - }; - - let remaining_len = std::cmp::min(bytes.len(), packet_length - self.data.len()); - self.data.extend_from_slice(&bytes[..remaining_len]); - consumed += remaining_len; - - if self.data.len() == packet_length { - // We have the full data. - Ok(Some((consumed, std::mem::take(&mut self.data)))) - } else { - Ok(None) - } - } - #[cfg(test)] - fn test_recv_bytes(&mut self, bytes: &[u8], mac: ()) -> Option<(usize, Vec)> { - self.recv_bytes_inner(bytes, mac).unwrap() - } -} - // hardcoded test keys. lol. const _PUBKEY: &str = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIOk5zfpvwNc3MztTTpE90zLI1Ref4AwwRVdSFyJLGbj2 testkey"; @@ -679,21 +362,18 @@ const PRIVKEY_BYTES: &[u8; 32] = &[ 0x0b, 0x9a, 0x4a, 0x44, 0xd5, 0x47, 0xc7, 0x5b, 0x9e, 0x31, 0x7d, 0xa1, 0xd5, 0x75, 0x27, 0x99, ]; +macro_rules! client_error { + ($($tt:tt)*) => { + $crate::SshError::ClientError(::std::format!($($tt)*)) + }; +} +use client_error; + #[cfg(test)] mod tests { use hex_literal::hex; - use crate::{MsgKind, PacketParser, ServerConnection, SshRng}; - - trait OptionExt { - fn unwrap_none(self); - } - impl OptionExt for Option { - #[track_caller] - fn unwrap_none(self) { - assert!(self.is_none()); - } - } + use crate::{MsgKind, ServerConnection, SshRng}; struct NoRng; impl SshRng for NoRng { @@ -727,37 +407,6 @@ mod tests { assert_eq!(msg.0, MsgKind::ServerProtocolInfo); } - #[test] - fn packet_parser() { - let mut p = PacketParser::new(); - p.test_recv_bytes(&2_u32.to_be_bytes(), ()).unwrap_none(); - p.test_recv_bytes(&[1], ()).unwrap_none(); - let (consumed, data) = p.test_recv_bytes(&[2], ()).unwrap(); - assert_eq!(consumed, 1); - assert_eq!(data, &[1, 2]); - } - - #[test] - fn packet_parser_split_len() { - let mut p = PacketParser::new(); - let len = &2_u32.to_be_bytes(); - p.test_recv_bytes(&len[0..2], ()).unwrap_none(); - p.test_recv_bytes(&len[2..4], ()).unwrap_none(); - - p.test_recv_bytes(&[1], ()).unwrap_none(); - let (consumed, data) = p.test_recv_bytes(&[2], ()).unwrap(); - assert_eq!(consumed, 1); - assert_eq!(data, &[1, 2]); - } - - #[test] - fn packet_parser_all() { - let mut p = PacketParser::new(); - let (consumed, data) = p.test_recv_bytes(&[0, 0, 0, 2, 1, 2], ()).unwrap(); - assert_eq!(consumed, 6); - assert_eq!(data, &[1, 2]); - } - #[test] fn handshake() { #[rustfmt::skip] diff --git a/ssh-transport/src/packet.rs b/ssh-transport/src/packet.rs new file mode 100644 index 0000000..ad07935 --- /dev/null +++ b/ssh-transport/src/packet.rs @@ -0,0 +1,361 @@ +use std::collections::VecDeque; + +use crate::client_error; +use crate::parse::{MpInt, NameList, Parser, Writer}; +use crate::Result; + +/// Frames the byte stream into packets. +pub(crate) struct PacketTransport { + state: PacketTransportState, + packets: VecDeque, +} + +enum PacketTransportState { + Plaintext(PacketParser), + Keyed { key: () }, +} + +impl PacketTransport { + pub(crate) fn new() -> Self { + PacketTransport { + state: PacketTransportState::Plaintext(PacketParser::new()), + packets: VecDeque::new(), + } + } + pub(crate) fn recv_bytes(&mut self, mut bytes: &[u8]) -> Result<()> { + while let Some(consumed) = self.recv_bytes_step(bytes)? { + bytes = &bytes[consumed..]; + if bytes.is_empty() { + break; + } + } + Ok(()) + } + pub(crate) fn next_packet(&mut self) -> Option { + self.packets.pop_front() + } + fn recv_bytes_step(&mut self, bytes: &[u8]) -> Result> { + // TODO: This might not work if we buffer two packets where one changes keys in between? + match &mut self.state { + PacketTransportState::Plaintext(packet) => { + let result = packet.recv_bytes(bytes, ())?; + if let Some((consumed, result)) = result { + self.packets.push_back(result); + *packet = PacketParser::new(); + return Ok(Some(consumed)); + } + } + PacketTransportState::Keyed { key } => todo!(), + } + + Ok(None) + } +} + +#[derive(Debug, PartialEq)] +pub(crate) struct Packet { + pub(crate) payload: Vec, +} +impl Packet { + pub(crate) const SSH_MSG_KEXINIT: u8 = 20; + pub(crate) const SSH_MSG_NEWKEYS: u8 = 21; + pub(crate) const SSH_MSG_KEXDH_INIT: u8 = 30; + pub(crate) const SSH_MSG_KEXDH_REPLY: u8 = 31; + + fn from_raw(bytes: &[u8]) -> Result { + let Some(padding_length) = bytes.get(0) else { + return Err(client_error!("empty packet")); + }; + // TODO: mac? + let Some(payload_len) = (bytes.len() - 1).checked_sub(*padding_length as usize) else { + return Err(client_error!("packet padding longer than packet")); + }; + let payload = &bytes[1..][..payload_len]; + + if (bytes.len() + 4) % 8 != 0 { + return Err(client_error!("full packet length must be multiple of 8")); + } + + Ok(Self { + payload: payload.to_vec(), + }) + } + + pub(crate) fn to_bytes(&self) -> Vec { + let mut new = Vec::new(); + + let min_full_length = self.payload.len() + 4 + 1; + + // The padding must give a factor of 8. + let min_padding_len = (min_full_length.next_multiple_of(8) - min_full_length) as u8; + // > There MUST be at least four bytes of padding. + // So let's satisfy this by just adding 8. We can always properly randomize it later if desired. + let padding_len = min_padding_len + 8; + + let packet_len = self.payload.len() + (padding_len as usize) + 1; + new.extend_from_slice(&u32::to_be_bytes(packet_len as u32)); + new.extend_from_slice(&[padding_len]); + new.extend_from_slice(&self.payload); + new.extend(std::iter::repeat(0).take(padding_len as usize)); + // mac... + + assert!((4 + 1 + self.payload.len() + (padding_len as usize)) % 8 == 0); + assert!(new.len() % 8 == 0); + + new + } +} + +#[derive(Debug)] +pub(crate) struct KeyExchangeInitPacket<'a> { + pub(crate) cookie: [u8; 16], + pub(crate) kex_algorithms: NameList<'a>, + pub(crate) server_host_key_algorithms: NameList<'a>, + pub(crate) encryption_algorithms_client_to_server: NameList<'a>, + pub(crate) encryption_algorithms_server_to_client: NameList<'a>, + pub(crate) mac_algorithms_client_to_server: NameList<'a>, + pub(crate) mac_algorithms_server_to_client: NameList<'a>, + pub(crate) compression_algorithms_client_to_server: NameList<'a>, + pub(crate) compression_algorithms_server_to_client: NameList<'a>, + pub(crate) languages_client_to_server: NameList<'a>, + pub(crate) languages_server_to_client: NameList<'a>, + pub(crate) first_kex_packet_follows: bool, +} + +impl<'a> KeyExchangeInitPacket<'a> { + pub(crate) fn parse(payload: &'a [u8]) -> Result> { + let mut c = Parser::new(payload); + + let kind = c.u8()?; + if kind != Packet::SSH_MSG_KEXINIT { + return Err(client_error!( + "expected SSH_MSG_KEXINIT packet, found {kind}" + )); + } + let cookie = c.read_array::<16>()?; + let kex_algorithms = c.name_list()?; + let server_host_key_algorithms = c.name_list()?; + let encryption_algorithms_client_to_server = c.name_list()?; + let encryption_algorithms_server_to_client = c.name_list()?; + let mac_algorithms_client_to_server = c.name_list()?; + let mac_algorithms_server_to_client = c.name_list()?; + let compression_algorithms_client_to_server = c.name_list()?; + let compression_algorithms_server_to_client = c.name_list()?; + + let languages_client_to_server = c.name_list()?; + let languages_server_to_client = c.name_list()?; + + let first_kex_packet_follows = c.bool()?; + + let _ = c.u32()?; // Reserved. + + Ok(Self { + cookie, + kex_algorithms, + server_host_key_algorithms, + encryption_algorithms_client_to_server, + encryption_algorithms_server_to_client, + mac_algorithms_client_to_server, + mac_algorithms_server_to_client, + compression_algorithms_client_to_server, + compression_algorithms_server_to_client, + languages_client_to_server, + languages_server_to_client, + first_kex_packet_follows, + }) + } + + pub(crate) fn to_bytes(&self) -> Vec { + let mut data = Writer::new(); + + data.u8(Packet::SSH_MSG_KEXINIT); + data.write(&self.cookie); + data.name_list(self.kex_algorithms); + data.name_list(self.server_host_key_algorithms); + data.name_list(self.encryption_algorithms_client_to_server); + data.name_list(self.encryption_algorithms_server_to_client); + data.name_list(self.mac_algorithms_client_to_server); + data.name_list(self.mac_algorithms_server_to_client); + data.name_list(self.compression_algorithms_client_to_server); + data.name_list(self.compression_algorithms_server_to_client); + data.name_list(self.languages_client_to_server); + data.name_list(self.languages_server_to_client); + data.u8(self.first_kex_packet_follows as u8); + data.u32(0); // Reserved. + + data.finish() + } +} + +#[derive(Debug)] +pub(crate) struct DhKeyExchangeInitPacket<'a> { + pub(crate) e: MpInt<'a>, +} +impl<'a> DhKeyExchangeInitPacket<'a> { + pub(crate) fn parse(payload: &'a [u8]) -> Result> { + let mut c = Parser::new(payload); + + let kind = c.u8()?; + if kind != Packet::SSH_MSG_KEXDH_INIT { + return Err(client_error!( + "expected SSH_MSG_KEXDH_INIT packet, found {kind}" + )); + } + let e = c.mpint()?; + Ok(Self { e }) + } +} + +#[derive(Debug)] +pub(crate) struct SshPublicKey<'a> { + pub(crate) format: &'a [u8], + pub(crate) data: &'a [u8], +} +impl SshPublicKey<'_> { + pub(crate) fn to_bytes(&self) -> Vec { + let mut data = Writer::new(); + data.u32((4 + self.format.len() + 4 + self.data.len()) as u32); + // ed25519-specific! + // + data.string(&self.format); + data.string(&self.data); + data.finish() + } +} +#[derive(Debug)] +pub(crate) struct SshSignature<'a> { + pub(crate) format: &'a [u8], + pub(crate) data: &'a [u8], +} + +#[derive(Debug)] +pub(crate) struct DhKeyExchangeInitReplyPacket<'a> { + pub(crate) pubkey: SshPublicKey<'a>, + pub(crate) f: MpInt<'a>, + pub(crate) signature: SshSignature<'a>, +} +impl<'a> DhKeyExchangeInitReplyPacket<'a> { + pub(crate) fn to_bytes(&self) -> Vec { + let mut data = Writer::new(); + + data.u8(Packet::SSH_MSG_KEXDH_REPLY); + data.write(&self.pubkey.to_bytes()); + data.mpint(self.f); + + data.u32((4 + self.signature.format.len() + 4 + self.signature.data.len()) as u32); + // + data.string(&self.signature.format); + data.string(&self.signature.data); + data.finish() + } +} + +struct PacketParser { + // The length of the packet. + packet_length: Option, + // Before we've read the length fully, this stores the length. + // Afterwards, this stores the packet data *after* the length. + data: Vec, +} +impl PacketParser { + fn new() -> Self { + Self { + packet_length: None, + data: Vec::new(), + } + } + fn recv_bytes(&mut self, bytes: &[u8], mac: ()) -> Result> { + let Some((consumed, data)) = self.recv_bytes_inner(bytes, mac)? else { + return Ok(None); + }; + Ok(Some((consumed, Packet::from_raw(&data)?))) + } + fn recv_bytes_inner(&mut self, mut bytes: &[u8], _mac: ()) -> Result)>> { + let mut consumed = 0; + let packet_length = match self.packet_length { + Some(packet_length) => packet_length, + None => { + let remaining_len = std::cmp::min(bytes.len(), 4 - self.data.len()); + // Try to read the bytes of the length. + self.data.extend_from_slice(&bytes[..remaining_len]); + if self.data.len() < 4 { + // Not enough data yet :(. + return Ok(None); + } + + let packet_length = u32::from_be_bytes(self.data.as_slice().try_into().unwrap()); + let packet_length = packet_length.try_into().unwrap(); + self.data.clear(); + + self.packet_length = Some(packet_length); + + // We have the data. + bytes = &bytes[remaining_len..]; + consumed += remaining_len; + + packet_length + } + }; + + let remaining_len = std::cmp::min(bytes.len(), packet_length - self.data.len()); + self.data.extend_from_slice(&bytes[..remaining_len]); + consumed += remaining_len; + + if self.data.len() == packet_length { + // We have the full data. + Ok(Some((consumed, std::mem::take(&mut self.data)))) + } else { + Ok(None) + } + } + #[cfg(test)] + fn test_recv_bytes(&mut self, bytes: &[u8], mac: ()) -> Option<(usize, Vec)> { + self.recv_bytes_inner(bytes, mac).unwrap() + } +} + +#[cfg(test)] +mod tests { + use crate::packet::PacketParser; + + trait OptionExt { + fn unwrap_none(self); + } + impl OptionExt for Option { + #[track_caller] + fn unwrap_none(self) { + assert!(self.is_none()); + } + } + + #[test] + fn packet_parser() { + let mut p = PacketParser::new(); + p.test_recv_bytes(&2_u32.to_be_bytes(), ()).unwrap_none(); + p.test_recv_bytes(&[1], ()).unwrap_none(); + let (consumed, data) = p.test_recv_bytes(&[2], ()).unwrap(); + assert_eq!(consumed, 1); + assert_eq!(data, &[1, 2]); + } + + #[test] + fn packet_parser_split_len() { + let mut p = PacketParser::new(); + let len = &2_u32.to_be_bytes(); + p.test_recv_bytes(&len[0..2], ()).unwrap_none(); + p.test_recv_bytes(&len[2..4], ()).unwrap_none(); + + p.test_recv_bytes(&[1], ()).unwrap_none(); + let (consumed, data) = p.test_recv_bytes(&[2], ()).unwrap(); + assert_eq!(consumed, 1); + assert_eq!(data, &[1, 2]); + } + + #[test] + fn packet_parser_all() { + let mut p = PacketParser::new(); + let (consumed, data) = p.test_recv_bytes(&[0, 0, 0, 2, 1, 2], ()).unwrap(); + assert_eq!(consumed, 6); + assert_eq!(data, &[1, 2]); + } +} diff --git a/src/parse.rs b/ssh-transport/src/parse.rs similarity index 100% rename from src/parse.rs rename to ssh-transport/src/parse.rs