rewrite to extract packet parser

This commit is contained in:
nora 2024-08-10 01:56:21 +02:00
parent e1b2a6e1e5
commit be5b437548

View file

@ -25,7 +25,7 @@ macro_rules! client_error {
};
}
use core::str;
use std::mem::take;
use std::{collections::VecDeque, mem::take};
use client_error;
use ed25519_dalek::ed25519::signature::Signer;
@ -39,6 +39,7 @@ pub const SERVER_IDENTIFICATION: &[u8] = b"SSH-2.0-OpenSSH_9.7\r\n";
pub struct ServerConnection {
state: ServerState,
packet_transport: PacketTransport,
send_queue: Vec<Msg>,
rng: Box<dyn SshRng + Send + Sync>,
}
@ -48,18 +49,14 @@ enum ServerState {
received: Vec<u8>,
},
KeyExchangeInit {
client_packet: PacketParser,
client_identification: Vec<u8>,
},
DhKeyInit {
client_packet: PacketParser,
client_identification: Vec<u8>,
client_kexinit: Vec<u8>,
server_kexinit: Vec<u8>,
},
NewKeys {
client_packet: PacketParser,
},
NewKeys,
ServiceRequest {},
}
@ -99,6 +96,10 @@ impl ServerConnection {
state: ServerState::ProtoExchange {
received: Vec::new(),
},
packet_transport: PacketTransport {
state: PacketTransportState::Plaintext(PacketParser::new()),
packets: VecDeque::new(),
},
send_queue: Vec::new(),
rng: Box::new(rng),
}
@ -116,6 +117,10 @@ impl ServerConnection {
Ok(())
}
fn recv_bytes_step(&mut self, bytes: &[u8]) -> Result<Option<usize>> {
if !matches!(self.state, ServerState::ProtoExchange { .. }) {
self.packet_transport.recv_bytes(bytes);
}
let result = match &mut self.state {
ServerState::ProtoExchange { received } => {
// TODO: get rid of this allocation :(
@ -126,17 +131,15 @@ impl ServerConnection {
let client_identification = received.to_owned();
self.queue_msg(MsgKind::ServerProtocolInfo);
self.state = ServerState::KeyExchangeInit {
client_packet: PacketParser::new(),
client_identification,
};
}
None
}
ServerState::KeyExchangeInit {
client_packet: packet,
client_identification,
} => match packet.recv_bytes(bytes, ())? {
Some((consumed, data)) => {
} => match self.packet_transport.next_packet() {
Some(data) => {
let kex = KeyExchangeInitPacket::parse(&data.payload)?;
let require_algorithm =
@ -212,23 +215,21 @@ impl ServerConnection {
payload: server_kexinit_payload.clone(),
}));
self.state = ServerState::DhKeyInit {
client_packet: PacketParser::new(),
client_identification,
client_kexinit: data.payload,
server_kexinit: server_kexinit_payload,
};
Some(consumed)
None
}
None => None,
},
ServerState::DhKeyInit {
client_packet: packet,
client_identification,
client_kexinit,
server_kexinit,
} => match packet.recv_bytes(bytes, ())? {
Some((consumed, data)) => {
} => match self.packet_transport.next_packet() {
Some(data) => {
let dh = DhKeyExchangeInitPacket::parse(&data.payload)?;
let secret =
@ -310,16 +311,15 @@ impl ServerConnection {
self.queue_msg(MsgKind::Packet(Packet {
payload: packet.to_bytes(),
}));
self.state = ServerState::NewKeys {
client_packet: PacketParser::new(),
};
self.state = ServerState::NewKeys;
// TODO: set keys for transport
Some(consumed)
None
}
None => None,
},
ServerState::NewKeys { client_packet } => match client_packet.recv_bytes(bytes, ())? {
Some((consumed, data)) => {
ServerState::NewKeys => match self.packet_transport.next_packet() {
Some(data) => {
if data.payload != &[Packet::SSH_MSG_NEWKEYS] {
return Err(client_error!("did not send SSH_MSG_NEWKEYS"));
}
@ -329,7 +329,7 @@ impl ServerConnection {
}));
self.state = ServerState::ServiceRequest {};
Some(consumed)
None
}
None => None,
},
@ -366,6 +366,48 @@ impl Msg {
}
}
/// Frames the byte stream into packets.
struct PacketTransport {
state: PacketTransportState,
packets: VecDeque<Packet>,
}
enum PacketTransportState {
Plaintext(PacketParser),
Keyed { key: () },
}
impl PacketTransport {
fn recv_bytes(&mut self, mut bytes: &[u8]) -> Result<()> {
while let Some(consumed) = self.recv_bytes_step(bytes)? {
bytes = &bytes[consumed..];
if bytes.is_empty() {
break;
}
}
Ok(())
}
fn next_packet(&mut self) -> Option<Packet> {
self.packets.pop_front()
}
fn recv_bytes_step(&mut self, bytes: &[u8]) -> Result<Option<usize>> {
// TODO: This might not work if we buffer two packets where one changes keys in between?
match &mut self.state {
PacketTransportState::Plaintext(packet) => {
let result = packet.recv_bytes(bytes, ())?;
if let Some((consumed, result)) = result {
self.packets.push_back(result);
*packet = PacketParser::new();
return Ok(Some(consumed));
}
}
PacketTransportState::Keyed { key } => todo!(),
}
Ok(None)
}
}
#[derive(Debug, PartialEq)]
struct Packet {
payload: Vec<u8>,
@ -564,6 +606,8 @@ impl<'a> DhKeyExchangeInitReplyPacket<'a> {
}
}
struct EncryptedPacketParser {}
struct PacketParser {
// The length of the packet.
packet_length: Option<usize>,