refactor sending into transport

This commit is contained in:
nora 2024-08-11 01:12:13 +02:00
parent caf989de97
commit 86c73b4a97
4 changed files with 101 additions and 55 deletions

View file

@ -62,7 +62,7 @@ async fn handle_connection(next: (TcpStream, SocketAddr)) -> Result<()> {
} }
} }
while let Some(msg) = state.next_message_to_send() { while let Some(msg) = state.next_msg_to_send() {
conn.write_all(&msg.to_bytes()) conn.write_all(&msg.to_bytes())
.await .await
.wrap_err("writing response")?; .wrap_err("writing response")?;

View file

@ -154,12 +154,12 @@ impl SshChaCha20Poly1305 {
); );
let tag_offset = bytes.full_packet().len() - 16; let tag_offset = bytes.full_packet().len() - 16;
let data_to_mac = &bytes.full_packet()[..tag_offset]; let authenticated = &bytes.full_packet()[..tag_offset];
let mac = { let mac = {
let mut poly1305_key = [0; poly1305::KEY_SIZE]; let mut poly1305_key = [0; poly1305::KEY_SIZE];
cipher.apply_keystream(&mut poly1305_key); cipher.apply_keystream(&mut poly1305_key);
poly1305::Poly1305::new(&poly1305_key.into()).compute_unpadded(data_to_mac) poly1305::Poly1305::new(&poly1305_key.into()).compute_unpadded(authenticated)
}; };
let read_tag = poly1305::Tag::from_slice(&bytes.full_packet()[tag_offset..]); let read_tag = poly1305::Tag::from_slice(&bytes.full_packet()[tag_offset..]);

View file

@ -7,15 +7,17 @@ use std::mem::take;
use ed25519_dalek::ed25519::signature::Signer; use ed25519_dalek::ed25519::signature::Signer;
use packet::{ use packet::{
DhKeyExchangeInitPacket, DhKeyExchangeInitReplyPacket, KeyExchangeInitPacket, Packet, DhKeyExchangeInitPacket, DhKeyExchangeInitReplyPacket, KeyExchangeInitPacket, MsgKind, Packet,
PacketTransport, SshPublicKey, SshSignature, PacketTransport, SshPublicKey, SshSignature,
}; };
use parse::{MpInt, NameList, Parser, Writer}; use parse::{MpInt, NameList, Parser, Writer};
use rand::RngCore; use rand::RngCore;
use sha2::Digest; use sha2::Digest;
use tracing::{debug, info}; use tracing::debug;
use x25519_dalek::{EphemeralSecret, PublicKey}; use x25519_dalek::{EphemeralSecret, PublicKey};
pub use packet::Msg;
#[derive(Debug)] #[derive(Debug)]
pub enum SshError { pub enum SshError {
/// The client did something wrong. /// The client did something wrong.
@ -41,7 +43,6 @@ pub const SERVER_IDENTIFICATION: &[u8] = b"SSH-2.0-OpenSSH_9.7\r\n";
pub struct ServerConnection { pub struct ServerConnection {
state: ServerState, state: ServerState,
packet_transport: PacketTransport, packet_transport: PacketTransport,
send_queue: Vec<Msg>,
rng: Box<dyn SshRng + Send + Sync>, rng: Box<dyn SshRng + Send + Sync>,
} }
@ -103,7 +104,6 @@ impl ServerConnection {
received: Vec::new(), received: Vec::new(),
}, },
packet_transport: PacketTransport::new(), packet_transport: PacketTransport::new(),
send_queue: Vec::new(),
rng: Box::new(rng), rng: Box::new(rng),
} }
} }
@ -117,7 +117,7 @@ impl ServerConnection {
// TODO: care that its SSH 2.0 instead of anythin anything else // TODO: care that its SSH 2.0 instead of anythin anything else
// The client will not send any more information than this until we respond, so discord the rest of the bytes. // The client will not send any more information than this until we respond, so discord the rest of the bytes.
let client_identification = received.to_owned(); let client_identification = received.to_owned();
self.queue_msg(MsgKind::ServerProtocolInfo); self.queue_send_msg(MsgKind::ServerProtocolInfo);
self.state = ServerState::KeyExchangeInit { self.state = ServerState::KeyExchangeInit {
client_identification, client_identification,
}; };
@ -128,7 +128,7 @@ impl ServerConnection {
self.packet_transport.recv_bytes(bytes)?; self.packet_transport.recv_bytes(bytes)?;
while let Some(packet) = self.packet_transport.next_packet() { while let Some(packet) = self.packet_transport.recv_next_packet() {
match &mut self.state { match &mut self.state {
ServerState::ProtoExchange { .. } => unreachable!("handled above"), ServerState::ProtoExchange { .. } => unreachable!("handled above"),
ServerState::KeyExchangeInit { ServerState::KeyExchangeInit {
@ -205,7 +205,7 @@ impl ServerConnection {
let client_identification = take(client_identification); let client_identification = take(client_identification);
let server_kexinit_payload = server_kexinit.to_bytes(); let server_kexinit_payload = server_kexinit.to_bytes();
self.queue_msg(MsgKind::Packet(Packet { self.queue_send_msg(MsgKind::PlaintextPacket(Packet {
payload: server_kexinit_payload.clone(), payload: server_kexinit_payload.clone(),
})); }));
self.state = ServerState::DhKeyInit { self.state = ServerState::DhKeyInit {
@ -286,7 +286,7 @@ impl ServerConnection {
data: &signature.to_bytes(), data: &signature.to_bytes(),
}, },
}; };
self.queue_msg(MsgKind::Packet(Packet { self.queue_send_msg(MsgKind::PlaintextPacket(Packet {
payload: packet.to_bytes(), payload: packet.to_bytes(),
})); }));
self.state = ServerState::NewKeys { self.state = ServerState::NewKeys {
@ -301,7 +301,7 @@ impl ServerConnection {
let (h, k) = (*h, *k); let (h, k) = (*h, *k);
self.queue_msg(MsgKind::Packet(Packet { self.queue_send_msg(MsgKind::PlaintextPacket(Packet {
payload: vec![Packet::SSH_MSG_NEWKEYS], payload: vec![Packet::SSH_MSG_NEWKEYS],
})); }));
self.state = ServerState::ServiceRequest {}; self.state = ServerState::ServiceRequest {};
@ -320,7 +320,7 @@ impl ServerConnection {
} }
// TODO: encrypt this! // TODO: encrypt this!
self.queue_msg(MsgKind::Packet(Packet { self.queue_send_msg(MsgKind::PlaintextPacket(Packet {
payload: { payload: {
let mut writer = Writer::new(); let mut writer = Writer::new();
writer.u8(Packet::SSH_MSG_SERVICE_ACCEPT); writer.u8(Packet::SSH_MSG_SERVICE_ACCEPT);
@ -338,31 +338,12 @@ impl ServerConnection {
Ok(()) Ok(())
} }
pub fn next_message_to_send(&mut self) -> Option<Msg> { pub fn next_msg_to_send(&mut self) -> Option<Msg> {
self.send_queue.pop() self.packet_transport.next_msg_to_send()
} }
fn queue_msg(&mut self, msg: MsgKind) { fn queue_send_msg(&mut self, msg: MsgKind) {
self.send_queue.push(Msg(msg)); self.packet_transport.queue_send_msg(Msg(msg));
}
}
#[derive(Debug)]
pub struct Msg(MsgKind);
#[derive(Debug, PartialEq)]
enum MsgKind {
ServerProtocolInfo,
Packet(Packet),
}
impl Msg {
// TODO: MAKE THIS ZERO ALLOC AAAAAA
pub fn to_bytes(self) -> Vec<u8> {
match self.0 {
MsgKind::ServerProtocolInfo => SERVER_IDENTIFICATION.to_vec(),
MsgKind::Packet(v) => v.to_bytes(),
}
} }
} }
@ -420,7 +401,7 @@ mod tests {
fn protocol_exchange() { fn protocol_exchange() {
let mut con = ServerConnection::new(NoRng); let mut con = ServerConnection::new(NoRng);
con.recv_bytes(b"SSH-2.0-OpenSSH_9.7\r\n").unwrap(); con.recv_bytes(b"SSH-2.0-OpenSSH_9.7\r\n").unwrap();
let msg = con.next_message_to_send().unwrap(); let msg = con.next_msg_to_send().unwrap();
assert_eq!(msg.0, MsgKind::ServerProtocolInfo); assert_eq!(msg.0, MsgKind::ServerProtocolInfo);
} }
@ -429,7 +410,7 @@ mod tests {
let mut con = ServerConnection::new(NoRng); let mut con = ServerConnection::new(NoRng);
con.recv_bytes(b"SSH-2.0-").unwrap(); con.recv_bytes(b"SSH-2.0-").unwrap();
con.recv_bytes(b"OpenSSH_9.7\r\n").unwrap(); con.recv_bytes(b"OpenSSH_9.7\r\n").unwrap();
let msg = con.next_message_to_send().unwrap(); let msg = con.next_msg_to_send().unwrap();
assert_eq!(msg.0, MsgKind::ServerProtocolInfo); assert_eq!(msg.0, MsgKind::ServerProtocolInfo);
} }
@ -536,7 +517,7 @@ mod tests {
let mut con = ServerConnection::new(HardcodedRng(rng)); let mut con = ServerConnection::new(HardcodedRng(rng));
for part in conversation { for part in conversation {
con.recv_bytes(&part.client).unwrap(); con.recv_bytes(&part.client).unwrap();
let bytes = con.next_message_to_send().unwrap().to_bytes(); let bytes = con.next_msg_to_send().unwrap().to_bytes();
assert_eq!(part.server, bytes); assert_eq!(part.server, bytes);
} }
} }

View file

@ -8,18 +8,46 @@ use crate::Result;
/// Frames the byte stream into packets. /// Frames the byte stream into packets.
pub(crate) struct PacketTransport { pub(crate) struct PacketTransport {
decrytor: Box<dyn Decryptor>, decrytor: Box<dyn Decryptor>,
next_packet: PacketParser, recv_next_packet: PacketParser,
packets: VecDeque<Packet>,
next_recv_seq_nr: u64, recv_packets: VecDeque<Packet>,
recv_next_seq_nr: u64,
send_packets: VecDeque<Msg>,
send_next_seq_nr: u64,
}
#[derive(Debug)]
pub struct Msg(pub(crate) MsgKind);
#[derive(Debug, PartialEq)]
pub(crate) enum MsgKind {
ServerProtocolInfo,
PlaintextPacket(Packet),
EncryptedPacket(EncryptedPacket),
}
impl Msg {
pub fn to_bytes(self) -> Vec<u8> {
match self.0 {
MsgKind::ServerProtocolInfo => crate::SERVER_IDENTIFICATION.to_vec(),
MsgKind::PlaintextPacket(v) => v.to_bytes(),
MsgKind::EncryptedPacket(v) => v.to_bytes(),
}
}
} }
impl PacketTransport { impl PacketTransport {
pub(crate) fn new() -> Self { pub(crate) fn new() -> Self {
PacketTransport { PacketTransport {
decrytor: Box::new(Plaintext), decrytor: Box::new(Plaintext),
next_packet: PacketParser::new(), recv_next_packet: PacketParser::new(),
packets: VecDeque::new(),
next_recv_seq_nr: 0, recv_packets: VecDeque::new(),
recv_next_seq_nr: 0,
send_packets: VecDeque::new(),
send_next_seq_nr: 0,
} }
} }
pub(crate) fn recv_bytes(&mut self, mut bytes: &[u8]) -> Result<()> { pub(crate) fn recv_bytes(&mut self, mut bytes: &[u8]) -> Result<()> {
@ -31,8 +59,16 @@ impl PacketTransport {
} }
Ok(()) Ok(())
} }
pub(crate) fn next_packet(&mut self) -> Option<Packet> {
self.packets.pop_front() pub(crate) fn recv_next_packet(&mut self) -> Option<Packet> {
self.recv_packets.pop_front()
}
pub(crate) fn queue_send_msg(&mut self, msg: Msg) {
self.send_packets.push_back(msg);
}
pub(crate) fn next_msg_to_send(&mut self) -> Option<Msg> {
self.send_packets.pop_front()
} }
pub(crate) fn set_key(&mut self, h: [u8; 32], k: [u8; 32]) { pub(crate) fn set_key(&mut self, h: [u8; 32], k: [u8; 32]) {
@ -45,12 +81,12 @@ impl PacketTransport {
// TODO: This might not work if we buffer two packets where one changes keys in between? // TODO: This might not work if we buffer two packets where one changes keys in between?
let result = let result =
self.next_packet self.recv_next_packet
.recv_bytes(bytes, &mut *self.decrytor, self.next_recv_seq_nr)?; .recv_bytes(bytes, &mut *self.decrytor, self.recv_next_seq_nr)?;
if let Some((consumed, result)) = result { if let Some((consumed, result)) = result {
self.packets.push_back(result); self.recv_packets.push_back(result);
self.next_recv_seq_nr = self.next_recv_seq_nr.wrapping_add(1); self.recv_next_seq_nr = self.recv_next_seq_nr.wrapping_add(1);
self.next_packet = PacketParser::new(); self.recv_next_packet = PacketParser::new();
return Ok(Some(consumed)); return Ok(Some(consumed));
} }
@ -58,6 +94,22 @@ impl PacketTransport {
} }
} }
/*
packet teminology used throughout this crate:
length | padding_length | payload | random padding | MAC
-------------------------------------------------------- "full"
----------------------------------------------- "rest"
------- "payload"
----------------------------------------- "content"
-------------------------------------------------- "authenticated"
^^^^^^ encrypted using K2
^^^^ plaintext
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ encrypted using K1
*/
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq)]
pub(crate) struct Packet { pub(crate) struct Packet {
pub(crate) payload: Vec<u8>, pub(crate) payload: Vec<u8>,
@ -115,6 +167,19 @@ impl Packet {
} }
} }
#[derive(Debug, PartialEq)]
pub(crate) struct EncryptedPacket {
data: Vec<u8>,
}
impl EncryptedPacket {
pub(crate) fn to_bytes(self) -> Vec<u8> {
self.data
}
pub(crate) fn from_encrypted_full_bytes(data: Vec<u8>) -> Self {
Self { data }
}
}
#[derive(Debug)] #[derive(Debug)]
pub(crate) struct KeyExchangeInitPacket<'a> { pub(crate) struct KeyExchangeInitPacket<'a> {
pub(crate) cookie: [u8; 16], pub(crate) cookie: [u8; 16],