refactor sending into transport

This commit is contained in:
nora 2024-08-11 01:12:13 +02:00
parent caf989de97
commit 86c73b4a97
4 changed files with 101 additions and 55 deletions

View file

@ -7,15 +7,17 @@ use std::mem::take;
use ed25519_dalek::ed25519::signature::Signer;
use packet::{
DhKeyExchangeInitPacket, DhKeyExchangeInitReplyPacket, KeyExchangeInitPacket, Packet,
DhKeyExchangeInitPacket, DhKeyExchangeInitReplyPacket, KeyExchangeInitPacket, MsgKind, Packet,
PacketTransport, SshPublicKey, SshSignature,
};
use parse::{MpInt, NameList, Parser, Writer};
use rand::RngCore;
use sha2::Digest;
use tracing::{debug, info};
use tracing::debug;
use x25519_dalek::{EphemeralSecret, PublicKey};
pub use packet::Msg;
#[derive(Debug)]
pub enum SshError {
/// The client did something wrong.
@ -41,7 +43,6 @@ pub const SERVER_IDENTIFICATION: &[u8] = b"SSH-2.0-OpenSSH_9.7\r\n";
pub struct ServerConnection {
state: ServerState,
packet_transport: PacketTransport,
send_queue: Vec<Msg>,
rng: Box<dyn SshRng + Send + Sync>,
}
@ -103,7 +104,6 @@ impl ServerConnection {
received: Vec::new(),
},
packet_transport: PacketTransport::new(),
send_queue: Vec::new(),
rng: Box::new(rng),
}
}
@ -117,7 +117,7 @@ impl ServerConnection {
// TODO: care that its SSH 2.0 instead of anythin anything else
// The client will not send any more information than this until we respond, so discord the rest of the bytes.
let client_identification = received.to_owned();
self.queue_msg(MsgKind::ServerProtocolInfo);
self.queue_send_msg(MsgKind::ServerProtocolInfo);
self.state = ServerState::KeyExchangeInit {
client_identification,
};
@ -128,7 +128,7 @@ impl ServerConnection {
self.packet_transport.recv_bytes(bytes)?;
while let Some(packet) = self.packet_transport.next_packet() {
while let Some(packet) = self.packet_transport.recv_next_packet() {
match &mut self.state {
ServerState::ProtoExchange { .. } => unreachable!("handled above"),
ServerState::KeyExchangeInit {
@ -205,7 +205,7 @@ impl ServerConnection {
let client_identification = take(client_identification);
let server_kexinit_payload = server_kexinit.to_bytes();
self.queue_msg(MsgKind::Packet(Packet {
self.queue_send_msg(MsgKind::PlaintextPacket(Packet {
payload: server_kexinit_payload.clone(),
}));
self.state = ServerState::DhKeyInit {
@ -259,9 +259,9 @@ impl ServerConnection {
hash_string(&mut hash, client_kexinit); // I_C
hash_string(&mut hash, server_kexinit); // I_S
add_hash(&mut hash, &pub_hostkey.to_bytes()); // K_S
// For normal DH as in RFC4253, e and f are mpints.
// But for ECDH as defined in RFC5656, Q_C and Q_S are strings.
// <https://datatracker.ietf.org/doc/html/rfc5656#section-4>
// For normal DH as in RFC4253, e and f are mpints.
// But for ECDH as defined in RFC5656, Q_C and Q_S are strings.
// <https://datatracker.ietf.org/doc/html/rfc5656#section-4>
hash_string(&mut hash, &client_public_key.0); // Q_C
hash_string(&mut hash, server_public_key.as_bytes()); // Q_S
hash_mpint(&mut hash, shared_secret.as_bytes()); // K
@ -286,7 +286,7 @@ impl ServerConnection {
data: &signature.to_bytes(),
},
};
self.queue_msg(MsgKind::Packet(Packet {
self.queue_send_msg(MsgKind::PlaintextPacket(Packet {
payload: packet.to_bytes(),
}));
self.state = ServerState::NewKeys {
@ -301,7 +301,7 @@ impl ServerConnection {
let (h, k) = (*h, *k);
self.queue_msg(MsgKind::Packet(Packet {
self.queue_send_msg(MsgKind::PlaintextPacket(Packet {
payload: vec![Packet::SSH_MSG_NEWKEYS],
}));
self.state = ServerState::ServiceRequest {};
@ -320,7 +320,7 @@ impl ServerConnection {
}
// TODO: encrypt this!
self.queue_msg(MsgKind::Packet(Packet {
self.queue_send_msg(MsgKind::PlaintextPacket(Packet {
payload: {
let mut writer = Writer::new();
writer.u8(Packet::SSH_MSG_SERVICE_ACCEPT);
@ -338,31 +338,12 @@ impl ServerConnection {
Ok(())
}
pub fn next_message_to_send(&mut self) -> Option<Msg> {
self.send_queue.pop()
pub fn next_msg_to_send(&mut self) -> Option<Msg> {
self.packet_transport.next_msg_to_send()
}
fn queue_msg(&mut self, msg: MsgKind) {
self.send_queue.push(Msg(msg));
}
}
#[derive(Debug)]
pub struct Msg(MsgKind);
#[derive(Debug, PartialEq)]
enum MsgKind {
ServerProtocolInfo,
Packet(Packet),
}
impl Msg {
// TODO: MAKE THIS ZERO ALLOC AAAAAA
pub fn to_bytes(self) -> Vec<u8> {
match self.0 {
MsgKind::ServerProtocolInfo => SERVER_IDENTIFICATION.to_vec(),
MsgKind::Packet(v) => v.to_bytes(),
}
fn queue_send_msg(&mut self, msg: MsgKind) {
self.packet_transport.queue_send_msg(Msg(msg));
}
}
@ -420,7 +401,7 @@ mod tests {
fn protocol_exchange() {
let mut con = ServerConnection::new(NoRng);
con.recv_bytes(b"SSH-2.0-OpenSSH_9.7\r\n").unwrap();
let msg = con.next_message_to_send().unwrap();
let msg = con.next_msg_to_send().unwrap();
assert_eq!(msg.0, MsgKind::ServerProtocolInfo);
}
@ -429,7 +410,7 @@ mod tests {
let mut con = ServerConnection::new(NoRng);
con.recv_bytes(b"SSH-2.0-").unwrap();
con.recv_bytes(b"OpenSSH_9.7\r\n").unwrap();
let msg = con.next_message_to_send().unwrap();
let msg = con.next_msg_to_send().unwrap();
assert_eq!(msg.0, MsgKind::ServerProtocolInfo);
}
@ -536,7 +517,7 @@ mod tests {
let mut con = ServerConnection::new(HardcodedRng(rng));
for part in conversation {
con.recv_bytes(&part.client).unwrap();
let bytes = con.next_message_to_send().unwrap().to_bytes();
let bytes = con.next_msg_to_send().unwrap().to_bytes();
assert_eq!(part.server, bytes);
}
}