rewrite to extract packet parser

This commit is contained in:
nora 2024-08-10 01:56:21 +02:00
parent e1b2a6e1e5
commit be5b437548

View file

@ -25,7 +25,7 @@ macro_rules! client_error {
}; };
} }
use core::str; use core::str;
use std::mem::take; use std::{collections::VecDeque, mem::take};
use client_error; use client_error;
use ed25519_dalek::ed25519::signature::Signer; use ed25519_dalek::ed25519::signature::Signer;
@ -39,6 +39,7 @@ pub const SERVER_IDENTIFICATION: &[u8] = b"SSH-2.0-OpenSSH_9.7\r\n";
pub struct ServerConnection { pub struct ServerConnection {
state: ServerState, state: ServerState,
packet_transport: PacketTransport,
send_queue: Vec<Msg>, send_queue: Vec<Msg>,
rng: Box<dyn SshRng + Send + Sync>, rng: Box<dyn SshRng + Send + Sync>,
} }
@ -48,18 +49,14 @@ enum ServerState {
received: Vec<u8>, received: Vec<u8>,
}, },
KeyExchangeInit { KeyExchangeInit {
client_packet: PacketParser,
client_identification: Vec<u8>, client_identification: Vec<u8>,
}, },
DhKeyInit { DhKeyInit {
client_packet: PacketParser,
client_identification: Vec<u8>, client_identification: Vec<u8>,
client_kexinit: Vec<u8>, client_kexinit: Vec<u8>,
server_kexinit: Vec<u8>, server_kexinit: Vec<u8>,
}, },
NewKeys { NewKeys,
client_packet: PacketParser,
},
ServiceRequest {}, ServiceRequest {},
} }
@ -99,6 +96,10 @@ impl ServerConnection {
state: ServerState::ProtoExchange { state: ServerState::ProtoExchange {
received: Vec::new(), received: Vec::new(),
}, },
packet_transport: PacketTransport {
state: PacketTransportState::Plaintext(PacketParser::new()),
packets: VecDeque::new(),
},
send_queue: Vec::new(), send_queue: Vec::new(),
rng: Box::new(rng), rng: Box::new(rng),
} }
@ -116,6 +117,10 @@ impl ServerConnection {
Ok(()) Ok(())
} }
fn recv_bytes_step(&mut self, bytes: &[u8]) -> Result<Option<usize>> { fn recv_bytes_step(&mut self, bytes: &[u8]) -> Result<Option<usize>> {
if !matches!(self.state, ServerState::ProtoExchange { .. }) {
self.packet_transport.recv_bytes(bytes);
}
let result = match &mut self.state { let result = match &mut self.state {
ServerState::ProtoExchange { received } => { ServerState::ProtoExchange { received } => {
// TODO: get rid of this allocation :( // TODO: get rid of this allocation :(
@ -126,17 +131,15 @@ impl ServerConnection {
let client_identification = received.to_owned(); let client_identification = received.to_owned();
self.queue_msg(MsgKind::ServerProtocolInfo); self.queue_msg(MsgKind::ServerProtocolInfo);
self.state = ServerState::KeyExchangeInit { self.state = ServerState::KeyExchangeInit {
client_packet: PacketParser::new(),
client_identification, client_identification,
}; };
} }
None None
} }
ServerState::KeyExchangeInit { ServerState::KeyExchangeInit {
client_packet: packet,
client_identification, client_identification,
} => match packet.recv_bytes(bytes, ())? { } => match self.packet_transport.next_packet() {
Some((consumed, data)) => { Some(data) => {
let kex = KeyExchangeInitPacket::parse(&data.payload)?; let kex = KeyExchangeInitPacket::parse(&data.payload)?;
let require_algorithm = let require_algorithm =
@ -212,23 +215,21 @@ impl ServerConnection {
payload: server_kexinit_payload.clone(), payload: server_kexinit_payload.clone(),
})); }));
self.state = ServerState::DhKeyInit { self.state = ServerState::DhKeyInit {
client_packet: PacketParser::new(),
client_identification, client_identification,
client_kexinit: data.payload, client_kexinit: data.payload,
server_kexinit: server_kexinit_payload, server_kexinit: server_kexinit_payload,
}; };
Some(consumed) None
} }
None => None, None => None,
}, },
ServerState::DhKeyInit { ServerState::DhKeyInit {
client_packet: packet,
client_identification, client_identification,
client_kexinit, client_kexinit,
server_kexinit, server_kexinit,
} => match packet.recv_bytes(bytes, ())? { } => match self.packet_transport.next_packet() {
Some((consumed, data)) => { Some(data) => {
let dh = DhKeyExchangeInitPacket::parse(&data.payload)?; let dh = DhKeyExchangeInitPacket::parse(&data.payload)?;
let secret = let secret =
@ -310,16 +311,15 @@ impl ServerConnection {
self.queue_msg(MsgKind::Packet(Packet { self.queue_msg(MsgKind::Packet(Packet {
payload: packet.to_bytes(), payload: packet.to_bytes(),
})); }));
self.state = ServerState::NewKeys { self.state = ServerState::NewKeys;
client_packet: PacketParser::new(), // TODO: set keys for transport
};
Some(consumed) None
} }
None => None, None => None,
}, },
ServerState::NewKeys { client_packet } => match client_packet.recv_bytes(bytes, ())? { ServerState::NewKeys => match self.packet_transport.next_packet() {
Some((consumed, data)) => { Some(data) => {
if data.payload != &[Packet::SSH_MSG_NEWKEYS] { if data.payload != &[Packet::SSH_MSG_NEWKEYS] {
return Err(client_error!("did not send SSH_MSG_NEWKEYS")); return Err(client_error!("did not send SSH_MSG_NEWKEYS"));
} }
@ -329,7 +329,7 @@ impl ServerConnection {
})); }));
self.state = ServerState::ServiceRequest {}; self.state = ServerState::ServiceRequest {};
Some(consumed) None
} }
None => None, None => None,
}, },
@ -366,6 +366,48 @@ impl Msg {
} }
} }
/// Frames the byte stream into packets.
struct PacketTransport {
state: PacketTransportState,
packets: VecDeque<Packet>,
}
enum PacketTransportState {
Plaintext(PacketParser),
Keyed { key: () },
}
impl PacketTransport {
fn recv_bytes(&mut self, mut bytes: &[u8]) -> Result<()> {
while let Some(consumed) = self.recv_bytes_step(bytes)? {
bytes = &bytes[consumed..];
if bytes.is_empty() {
break;
}
}
Ok(())
}
fn next_packet(&mut self) -> Option<Packet> {
self.packets.pop_front()
}
fn recv_bytes_step(&mut self, bytes: &[u8]) -> Result<Option<usize>> {
// TODO: This might not work if we buffer two packets where one changes keys in between?
match &mut self.state {
PacketTransportState::Plaintext(packet) => {
let result = packet.recv_bytes(bytes, ())?;
if let Some((consumed, result)) = result {
self.packets.push_back(result);
*packet = PacketParser::new();
return Ok(Some(consumed));
}
}
PacketTransportState::Keyed { key } => todo!(),
}
Ok(None)
}
}
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq)]
struct Packet { struct Packet {
payload: Vec<u8>, payload: Vec<u8>,
@ -564,6 +606,8 @@ impl<'a> DhKeyExchangeInitReplyPacket<'a> {
} }
} }
struct EncryptedPacketParser {}
struct PacketParser { struct PacketParser {
// The length of the packet. // The length of the packet.
packet_length: Option<usize>, packet_length: Option<usize>,