From 553f2613968ee0574d5c0419131575f01ccd46b7 Mon Sep 17 00:00:00 2001 From: Noratrieb <48135649+Noratrieb@users.noreply.github.com> Date: Sat, 10 Aug 2024 17:42:49 +0200 Subject: [PATCH] more cleanup --- ssh-transport/src/keys.rs | 42 ++++++++++++---- ssh-transport/src/packet.rs | 97 +++++++++++++++++-------------------- 2 files changed, 77 insertions(+), 62 deletions(-) diff --git a/ssh-transport/src/keys.rs b/ssh-transport/src/keys.rs index 3f114ec..6f46bb3 100644 --- a/ssh-transport/src/keys.rs +++ b/ssh-transport/src/keys.rs @@ -1,7 +1,4 @@ -use chacha20poly1305::{ - aead::{Aead, AeadCore}, - ChaCha20Poly1305, KeyInit, -}; +use chacha20poly1305::{ChaCha20Poly1305, KeyInit}; use sha2::Digest; use crate::Result; @@ -12,15 +9,26 @@ pub(crate) struct Session { encryption_key_server_to_client: SshChaCha20Poly1305, } +pub(crate) trait Decryptor: Send + Sync + 'static { + fn decrypt_len(&mut self, bytes: &mut [u8; 4], packet_number: u64); + fn decrypt_packet(&mut self, bytes: &mut [u8], packet_number: u64); + fn rekey(&mut self, h: [u8; 32], k: [u8; 32]) -> Result<(), ()>; +} + +pub(crate) struct Plaintext; +impl Decryptor for Plaintext { + fn decrypt_len(&mut self, _: &mut [u8; 4], _: u64) {} + fn decrypt_packet(&mut self, _: &mut [u8], _: u64) {} + fn rekey(&mut self, _: [u8; 32], _: [u8; 32]) -> Result<(), ()> { + Err(()) + } +} + impl Session { pub(crate) fn new(h: [u8; 32], k: [u8; 32]) -> Self { Self::from_keys(h, h, k) } - pub(crate) fn rekey(&mut self, h: [u8; 32], k: [u8; 32]) { - *self = Self::from_keys(self.session_id, h, k); - } - /// fn from_keys(session_id: [u8; 32], h: [u8; 32], k: [u8; 32]) -> Self { let encryption_key_client_to_server = @@ -38,11 +46,22 @@ impl Session { // integrity_key_server_to_client: derive("F").into(), } } +} - pub(crate) fn decrypt_len(&mut self, bytes: &mut [u8], packet_number: u64) { +impl Decryptor for Session { + fn decrypt_len(&mut self, bytes: &mut [u8; 4], packet_number: u64) { self.encryption_key_client_to_server .decrypt_len(bytes, packet_number); } + + fn decrypt_packet(&mut self, bytes: &mut [u8], packet_number: u64) { + self.encryption_key_client_to_server.decrypt_packet(bytes, packet_number); + } + + fn rekey(&mut self, h: [u8; 32], k: [u8; 32]) -> Result<(), ()> { + *self = Self::from_keys(self.session_id, h, k); + Ok(()) + } } /// Derive a key from the shared secret K and exchange hash H. @@ -68,7 +87,6 @@ fn derive_key( } hash.update(&output[..(i * sha2len)]); - output[(i * sha2len)..][..sha2len].copy_from_slice(&hash.finalize()) } @@ -112,4 +130,8 @@ impl SshChaCha20Poly1305 { SshChaCha20::new(&self.header_key.into(), &packet_number.to_be_bytes().into()); cipher.apply_keystream(bytes); } + + fn decrypt_packet(&mut self, bytes: &mut [u8], packet_number: u64) { + todo!() + } } diff --git a/ssh-transport/src/packet.rs b/ssh-transport/src/packet.rs index 6f2550d..21366c0 100644 --- a/ssh-transport/src/packet.rs +++ b/ssh-transport/src/packet.rs @@ -1,26 +1,23 @@ use std::collections::VecDeque; use crate::client_error; -use crate::keys::Session; +use crate::keys::{Decryptor, Plaintext, Session}; use crate::parse::{MpInt, NameList, Parser, Writer}; use crate::Result; /// Frames the byte stream into packets. pub(crate) struct PacketTransport { - state: PacketTransportState, + decrytor: Box, + next_packet: PacketParser, packets: VecDeque, next_recv_seq_nr: u64, } -enum PacketTransportState { - Plaintext(PacketParser), - Keyed { session: Session }, -} - impl PacketTransport { pub(crate) fn new() -> Self { PacketTransport { - state: PacketTransportState::Plaintext(PacketParser::new()), + decrytor: Box::new(Plaintext), + next_packet: PacketParser::new(), packets: VecDeque::new(), next_recv_seq_nr: 0, } @@ -39,41 +36,22 @@ impl PacketTransport { } pub(crate) fn set_key(&mut self, h: [u8; 32], k: [u8; 32]) { - match &mut self.state { - PacketTransportState::Plaintext(_) => { - self.state = PacketTransportState::Keyed { - session: Session::new(h, k), - } - } - PacketTransportState::Keyed { session } => session.rekey(h, k), + if let Err(()) = self.decrytor.rekey(h, k) { + self.decrytor = Box::new(Session::new(h, k)); } } fn recv_bytes_step(&mut self, bytes: &[u8]) -> Result> { // TODO: This might not work if we buffer two packets where one changes keys in between? - match &mut self.state { - PacketTransportState::Plaintext(packet) => { - let result = packet.recv_bytes(bytes, ())?; - if let Some((consumed, result)) = result { - self.packets.push_back(result); - self.next_recv_seq_nr = self.next_recv_seq_nr.wrapping_add(1); - *packet = PacketParser::new(); - return Ok(Some(consumed)); - } - } - PacketTransportState::Keyed { session } => { - let mut len = [0_u8; 4]; - let Some(len_bytes) = bytes.get(0..4) else { - return Err(client_error!( - "packet too short, not enough bytes for length" - )); - }; - len.copy_from_slice(len_bytes); - session.decrypt_len(&mut len, self.next_recv_seq_nr); - let len = u32::from_be_bytes(len); - dbg!(len); - // TODO: dont assume we get it all as one.... AAaAAA - } + + let result = + self.next_packet + .recv_bytes(bytes, &mut *self.decrytor, self.next_recv_seq_nr)?; + if let Some((consumed, result)) = result { + self.packets.push_back(result); + self.next_recv_seq_nr = self.next_recv_seq_nr.wrapping_add(1); + self.next_packet = PacketParser::new(); + return Ok(Some(consumed)); } Ok(None) @@ -292,13 +270,25 @@ impl PacketParser { data: Vec::new(), } } - fn recv_bytes(&mut self, bytes: &[u8], mac: ()) -> Result> { - let Some((consumed, data)) = self.recv_bytes_inner(bytes, mac)? else { + fn recv_bytes( + &mut self, + bytes: &[u8], + decrytor: &mut dyn Decryptor, + next_seq_nr: u64, + ) -> Result> { + let Some((consumed, mut data)) = self.recv_bytes_inner(bytes, decrytor, next_seq_nr)? + else { return Ok(None); }; + decrytor.decrypt_packet(&mut data, next_seq_nr); Ok(Some((consumed, Packet::from_raw(&data)?))) } - fn recv_bytes_inner(&mut self, mut bytes: &[u8], _mac: ()) -> Result)>> { + fn recv_bytes_inner( + &mut self, + mut bytes: &[u8], + decrytor: &mut dyn Decryptor, + next_seq_nr: u64, + ) -> Result)>> { let mut consumed = 0; let packet_length = match self.packet_length { Some(packet_length) => packet_length, @@ -311,7 +301,10 @@ impl PacketParser { return Ok(None); } - let packet_length = u32::from_be_bytes(self.data.as_slice().try_into().unwrap()); + let mut encrypted_len = self.data.as_slice().try_into().unwrap(); + decrytor.decrypt_len(&mut encrypted_len, next_seq_nr); + + let packet_length = u32::from_be_bytes(encrypted_len); let packet_length = packet_length.try_into().unwrap(); self.data.clear(); @@ -337,8 +330,8 @@ impl PacketParser { } } #[cfg(test)] - fn test_recv_bytes(&mut self, bytes: &[u8], mac: ()) -> Option<(usize, Vec)> { - self.recv_bytes_inner(bytes, mac).unwrap() + fn test_recv_bytes(&mut self, bytes: &[u8]) -> Option<(usize, Vec)> { + self.recv_bytes_inner(bytes, &mut Plaintext, 0).unwrap() } } @@ -359,9 +352,9 @@ mod tests { #[test] fn packet_parser() { let mut p = PacketParser::new(); - p.test_recv_bytes(&2_u32.to_be_bytes(), ()).unwrap_none(); - p.test_recv_bytes(&[1], ()).unwrap_none(); - let (consumed, data) = p.test_recv_bytes(&[2], ()).unwrap(); + p.test_recv_bytes(&2_u32.to_be_bytes()).unwrap_none(); + p.test_recv_bytes(&[1]).unwrap_none(); + let (consumed, data) = p.test_recv_bytes(&[2]).unwrap(); assert_eq!(consumed, 1); assert_eq!(data, &[1, 2]); } @@ -370,11 +363,11 @@ mod tests { fn packet_parser_split_len() { let mut p = PacketParser::new(); let len = &2_u32.to_be_bytes(); - p.test_recv_bytes(&len[0..2], ()).unwrap_none(); - p.test_recv_bytes(&len[2..4], ()).unwrap_none(); + p.test_recv_bytes(&len[0..2]).unwrap_none(); + p.test_recv_bytes(&len[2..4]).unwrap_none(); - p.test_recv_bytes(&[1], ()).unwrap_none(); - let (consumed, data) = p.test_recv_bytes(&[2], ()).unwrap(); + p.test_recv_bytes(&[1]).unwrap_none(); + let (consumed, data) = p.test_recv_bytes(&[2]).unwrap(); assert_eq!(consumed, 1); assert_eq!(data, &[1, 2]); } @@ -382,7 +375,7 @@ mod tests { #[test] fn packet_parser_all() { let mut p = PacketParser::new(); - let (consumed, data) = p.test_recv_bytes(&[0, 0, 0, 2, 1, 2], ()).unwrap(); + let (consumed, data) = p.test_recv_bytes(&[0, 0, 0, 2, 1, 2]).unwrap(); assert_eq!(consumed, 6); assert_eq!(data, &[1, 2]); }