diff --git a/src/main.rs b/src/main.rs index a636850..f91398d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -13,7 +13,9 @@ use tracing_subscriber::EnvFilter; #[tokio::main] async fn main() -> eyre::Result<()> { tracing_subscriber::fmt() - .with_env_filter(EnvFilter::from_default_env()) + .with_env_filter( + EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")), + ) .init(); let addr = "0.0.0.0:2222".parse::().unwrap(); diff --git a/ssh-transport/src/keys.rs b/ssh-transport/src/keys.rs index 9d82f62..f7368ff 100644 --- a/ssh-transport/src/keys.rs +++ b/ssh-transport/src/keys.rs @@ -3,8 +3,8 @@ use sha2::Digest; use subtle::ConstantTimeEq; use crate::{ - packet::{Packet, RawPacket}, - Result, + packet::{EncryptedPacket, MsgKind, Packet, RawPacket}, + Msg, Result, }; pub(crate) struct Session { @@ -13,19 +13,25 @@ pub(crate) struct Session { encryption_key_server_to_client: SshChaCha20Poly1305, } -pub(crate) trait Decryptor: Send + Sync + 'static { +pub(crate) trait Keys: Send + Sync + 'static { fn decrypt_len(&mut self, bytes: &mut [u8; 4], packet_number: u64); fn decrypt_packet(&mut self, raw_packet: RawPacket, packet_number: u64) -> Result; + + fn encrypt_packet_to_msg(&mut self, packet: Packet, packet_number: u64) -> Msg; + fn additional_mac_len(&self) -> usize; fn rekey(&mut self, h: [u8; 32], k: [u8; 32]) -> Result<(), ()>; } pub(crate) struct Plaintext; -impl Decryptor for Plaintext { +impl Keys for Plaintext { fn decrypt_len(&mut self, _: &mut [u8; 4], _: u64) {} fn decrypt_packet(&mut self, raw: RawPacket, _: u64) -> Result { Packet::from_raw(&raw.rest()) } + fn encrypt_packet_to_msg(&mut self, packet: Packet, _: u64) -> Msg { + Msg(MsgKind::PlaintextPacket(packet)) + } fn additional_mac_len(&self) -> usize { 0 } @@ -58,7 +64,7 @@ impl Session { } } -impl Decryptor for Session { +impl Keys for Session { fn decrypt_len(&mut self, bytes: &mut [u8; 4], packet_number: u64) { self.encryption_key_client_to_server .decrypt_len(bytes, packet_number); @@ -69,6 +75,13 @@ impl Decryptor for Session { .decrypt_packet(bytes, packet_number) } + fn encrypt_packet_to_msg(&mut self, packet: Packet, packet_number: u64) -> Msg { + let packet = self + .encryption_key_server_to_client + .encrypt_packet(packet, packet_number); + Msg(MsgKind::EncryptedPacket(packet)) + } + fn additional_mac_len(&self) -> usize { poly1305::BLOCK_SIZE } @@ -170,11 +183,52 @@ impl SshChaCha20Poly1305 { )); } - cipher.seek(64); + // Advance ChaCha's block counter to 1 + cipher + .seek(::block_size()); let encrypted_packet_content = bytes.content_mut(); cipher.apply_keystream(encrypted_packet_content); Packet::from_raw(encrypted_packet_content) } + + fn encrypt_packet(&mut self, packet: Packet, packet_number: u64) -> EncryptedPacket { + let mut bytes = packet.to_bytes(false); + + dbg!(u32::from_be_bytes(bytes[0..4].try_into().unwrap())); + + // Prepare the main cipher. + let mut main_cipher = ::new( + &self.main_key, + &packet_number.to_be_bytes().into(), + ); + + // Get the poly1305 key first, but don't use it yet! + // We encrypt-then-mac. + let mut poly1305_key = [0; poly1305::KEY_SIZE]; + main_cipher.apply_keystream(&mut poly1305_key); + + // As the first act of encryption, encrypt the length. + // THIS PART IS CORRECT!!! + let mut len_cipher = + SshChaCha20::new(&self.header_key, &packet_number.to_be_bytes().into()); + len_cipher.apply_keystream(&mut bytes[..4]); + + // Advance ChaCha's block counter to 1 + main_cipher + .seek(::block_size()); + // Encrypt the content of the packet, excluding the length and the MAC, which is not pushed yet. + main_cipher.apply_keystream(&mut bytes[4..]); + + // Now, MAC the length || content, and push that to the end. + let mac = poly1305::Poly1305::new(&poly1305_key.into()).compute_unpadded(&bytes); + dbg!(bytes.len()); + + bytes.extend_from_slice(&mac); + + dbg!(bytes.len()); + + EncryptedPacket::from_encrypted_full_bytes(bytes) + } } diff --git a/ssh-transport/src/lib.rs b/ssh-transport/src/lib.rs index da009a4..098a748 100644 --- a/ssh-transport/src/lib.rs +++ b/ssh-transport/src/lib.rs @@ -7,7 +7,7 @@ use std::mem::take; use ed25519_dalek::ed25519::signature::Signer; use packet::{ - DhKeyExchangeInitPacket, DhKeyExchangeInitReplyPacket, KeyExchangeInitPacket, MsgKind, Packet, + DhKeyExchangeInitPacket, DhKeyExchangeInitReplyPacket, KeyExchangeInitPacket, Packet, PacketTransport, SshPublicKey, SshSignature, }; use parse::{MpInt, NameList, Parser, Writer}; @@ -117,7 +117,7 @@ impl ServerConnection { // TODO: care that its SSH 2.0 instead of anythin anything else // The client will not send any more information than this until we respond, so discord the rest of the bytes. let client_identification = received.to_owned(); - self.queue_send_msg(MsgKind::ServerProtocolInfo); + self.packet_transport.queue_send_protocol_info(); self.state = ServerState::KeyExchangeInit { client_identification, }; @@ -205,9 +205,9 @@ impl ServerConnection { let client_identification = take(client_identification); let server_kexinit_payload = server_kexinit.to_bytes(); - self.queue_send_msg(MsgKind::PlaintextPacket(Packet { + self.packet_transport.queue_packet(Packet { payload: server_kexinit_payload.clone(), - })); + }); self.state = ServerState::DhKeyInit { client_identification, client_kexinit: packet.payload, @@ -286,9 +286,9 @@ impl ServerConnection { data: &signature.to_bytes(), }, }; - self.queue_send_msg(MsgKind::PlaintextPacket(Packet { + self.packet_transport.queue_packet(Packet { payload: packet.to_bytes(), - })); + }); self.state = ServerState::NewKeys { h: hash.into(), k: shared_secret.to_bytes(), @@ -301,9 +301,9 @@ impl ServerConnection { let (h, k) = (*h, *k); - self.queue_send_msg(MsgKind::PlaintextPacket(Packet { + self.packet_transport.queue_packet(Packet { payload: vec![Packet::SSH_MSG_NEWKEYS], - })); + }); self.state = ServerState::ServiceRequest {}; self.packet_transport.set_key(h, k); } @@ -320,14 +320,14 @@ impl ServerConnection { } // TODO: encrypt this! - self.queue_send_msg(MsgKind::PlaintextPacket(Packet { + self.packet_transport.queue_packet(Packet { payload: { let mut writer = Writer::new(); writer.u8(Packet::SSH_MSG_SERVICE_ACCEPT); writer.string(service.as_bytes()); writer.finish() }, - })); + }); self.state = ServerState::UserAuthRequest; } ServerState::UserAuthRequest => { @@ -341,10 +341,6 @@ impl ServerConnection { pub fn next_msg_to_send(&mut self) -> Option { self.packet_transport.next_msg_to_send() } - - fn queue_send_msg(&mut self, msg: MsgKind) { - self.packet_transport.queue_send_msg(Msg(msg)); - } } // hardcoded test keys. lol. @@ -380,7 +376,7 @@ use client_error; mod tests { use hex_literal::hex; - use crate::{MsgKind, ServerConnection, SshRng}; + use crate::{packet::MsgKind, ServerConnection, SshRng}; struct NoRng; impl SshRng for NoRng { diff --git a/ssh-transport/src/packet.rs b/ssh-transport/src/packet.rs index ff6605a..6a7f4de 100644 --- a/ssh-transport/src/packet.rs +++ b/ssh-transport/src/packet.rs @@ -1,13 +1,13 @@ use std::collections::VecDeque; use crate::client_error; -use crate::keys::{Decryptor, Plaintext, Session}; +use crate::keys::{Keys, Plaintext, Session}; use crate::parse::{MpInt, NameList, Parser, Writer}; use crate::Result; /// Frames the byte stream into packets. pub(crate) struct PacketTransport { - decrytor: Box, + keys: Box, recv_next_packet: PacketParser, recv_packets: VecDeque, @@ -31,7 +31,7 @@ impl Msg { pub fn to_bytes(self) -> Vec { match self.0 { MsgKind::ServerProtocolInfo => crate::SERVER_IDENTIFICATION.to_vec(), - MsgKind::PlaintextPacket(v) => v.to_bytes(), + MsgKind::PlaintextPacket(v) => v.to_bytes(true), MsgKind::EncryptedPacket(v) => v.to_bytes(), } } @@ -40,7 +40,7 @@ impl Msg { impl PacketTransport { pub(crate) fn new() -> Self { PacketTransport { - decrytor: Box::new(Plaintext), + keys: Box::new(Plaintext), recv_next_packet: PacketParser::new(), recv_packets: VecDeque::new(), @@ -60,29 +60,12 @@ impl PacketTransport { Ok(()) } - pub(crate) fn recv_next_packet(&mut self) -> Option { - self.recv_packets.pop_front() - } - - pub(crate) fn queue_send_msg(&mut self, msg: Msg) { - self.send_packets.push_back(msg); - } - pub(crate) fn next_msg_to_send(&mut self) -> Option { - self.send_packets.pop_front() - } - - pub(crate) fn set_key(&mut self, h: [u8; 32], k: [u8; 32]) { - if let Err(()) = self.decrytor.rekey(h, k) { - self.decrytor = Box::new(Session::new(h, k)); - } - } - fn recv_bytes_step(&mut self, bytes: &[u8]) -> Result> { // TODO: This might not work if we buffer two packets where one changes keys in between? let result = self.recv_next_packet - .recv_bytes(bytes, &mut *self.decrytor, self.recv_next_seq_nr)?; + .recv_bytes(bytes, &mut *self.keys, self.recv_next_seq_nr)?; if let Some((consumed, result)) = result { self.recv_packets.push_back(result); self.recv_next_seq_nr = self.recv_next_seq_nr.wrapping_add(1); @@ -92,6 +75,35 @@ impl PacketTransport { Ok(None) } + + pub(crate) fn queue_packet(&mut self, packet: Packet) { + let seq_nr = self.send_next_seq_nr; + self.send_next_seq_nr = self.send_next_seq_nr.wrapping_add(1); + let msg = self.keys.encrypt_packet_to_msg(packet, seq_nr); + self.queue_send_msg(msg); + } + + pub(crate) fn queue_send_protocol_info(&mut self) { + self.queue_send_msg(Msg(MsgKind::ServerProtocolInfo)); + } + + pub(crate) fn recv_next_packet(&mut self) -> Option { + self.recv_packets.pop_front() + } + + // Private: Make sure all sending goes through variant-specific functions here. + fn queue_send_msg(&mut self, msg: Msg) { + self.send_packets.push_back(msg); + } + pub(crate) fn next_msg_to_send(&mut self) -> Option { + self.send_packets.pop_front() + } + + pub(crate) fn set_key(&mut self, h: [u8; 32], k: [u8; 32]) { + if let Err(()) = self.keys.rekey(h, k) { + self.keys = Box::new(Session::new(h, k)); + } + } } /* @@ -105,9 +117,9 @@ length | padding_length | payload | random padding | MAC ----------------------------------------- "content" -------------------------------------------------- "authenticated" -^^^^^^ encrypted using K2 +^^^^^^ encrypted using K1 ^^^^ plaintext - ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ encrypted using K1 + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ encrypted using K2 */ #[derive(Debug, PartialEq)] @@ -142,26 +154,30 @@ impl Packet { }) } - pub(crate) fn to_bytes(&self) -> Vec { - let mut new = Vec::new(); + pub(crate) fn to_bytes(&self, respect_len_for_padding: bool) -> Vec { + let let_bytes = if respect_len_for_padding { 4 } else { 0 }; - let min_full_length = self.payload.len() + 4 + 1; + // + let min_full_length = self.payload.len() + let_bytes + 1; // The padding must give a factor of 8. let min_padding_len = (min_full_length.next_multiple_of(8) - min_full_length) as u8; // > There MUST be at least four bytes of padding. - // So let's satisfy this by just adding 8. We can always properly randomize it later if desired. - let padding_len = min_padding_len + 8; + let padding_len = if min_padding_len < 4 { + min_padding_len + 8 + } else { + min_padding_len + }; let packet_len = self.payload.len() + (padding_len as usize) + 1; + + let mut new = Vec::new(); new.extend_from_slice(&u32::to_be_bytes(packet_len as u32)); new.extend_from_slice(&[padding_len]); new.extend_from_slice(&self.payload); new.extend(std::iter::repeat(0).take(padding_len as usize)); - // mac... - assert!((4 + 1 + self.payload.len() + (padding_len as usize)) % 8 == 0); - assert!(new.len() % 8 == 0); + assert!((let_bytes + 1 + self.payload.len() + (padding_len as usize)) % 8 == 0); new } @@ -325,9 +341,9 @@ impl<'a> DhKeyExchangeInitReplyPacket<'a> { } pub(crate) struct RawPacket { - len: usize, - mac_len: usize, - raw: Vec, + pub len: usize, + pub mac_len: usize, + pub raw: Vec, } impl RawPacket { pub(crate) fn rest(&self) -> &[u8] { @@ -358,7 +374,7 @@ impl PacketParser { fn recv_bytes( &mut self, bytes: &[u8], - decrytor: &mut dyn Decryptor, + decrytor: &mut dyn Keys, next_seq_nr: u64, ) -> Result> { let Some((consumed, data)) = self.recv_bytes_inner(bytes, decrytor, next_seq_nr)? else { @@ -370,7 +386,7 @@ impl PacketParser { fn recv_bytes_inner( &mut self, mut bytes: &[u8], - decrytor: &mut dyn Decryptor, + keys: &mut dyn Keys, next_seq_nr: u64, ) -> Result> { let mut consumed = 0; @@ -391,11 +407,11 @@ impl PacketParser { let mut len_to_decrypt = [0_u8; 4]; len_to_decrypt.copy_from_slice(self.raw_data.as_slice()); - decrytor.decrypt_len(&mut len_to_decrypt, next_seq_nr); + keys.decrypt_len(&mut len_to_decrypt, next_seq_nr); let packet_length = u32::from_be_bytes(len_to_decrypt); let packet_length: usize = packet_length.try_into().unwrap(); - let packet_length = packet_length + decrytor.additional_mac_len(); + let packet_length = packet_length + keys.additional_mac_len(); self.packet_length = Some(packet_length); @@ -417,7 +433,7 @@ impl PacketParser { consumed, RawPacket { raw: std::mem::take(&mut self.raw_data), - mac_len: decrytor.additional_mac_len(), + mac_len: keys.additional_mac_len(), len: packet_length, }, )))