From 43c1696465a6b49964894e02f9a74258e1cb2fb0 Mon Sep 17 00:00:00 2001 From: Noratrieb <48135649+Noratrieb@users.noreply.github.com> Date: Mon, 12 Aug 2024 19:09:29 +0200 Subject: [PATCH] pluggable encryption --- ssh-transport/src/keys.rs | 155 ++++++++++++++++++++++++++++-------- ssh-transport/src/lib.rs | 76 +++++++++++------- ssh-transport/src/packet.rs | 24 +++++- 3 files changed, 190 insertions(+), 65 deletions(-) diff --git a/ssh-transport/src/keys.rs b/ssh-transport/src/keys.rs index 122a567..2227bad 100644 --- a/ssh-transport/src/keys.rs +++ b/ssh-transport/src/keys.rs @@ -8,14 +8,23 @@ use crate::{ Msg, Result, SshRng, }; +pub trait AlgorithmName { + fn name(&self) -> &'static str; +} + #[derive(Clone, Copy)] pub struct KexAlgorithm { - pub name: &'static str, + name: &'static str, pub exchange: fn( client_public_key: &[u8], random: &mut (dyn SshRng + Send + Sync), ) -> Result, } +impl AlgorithmName for KexAlgorithm { + fn name(&self) -> &'static str { + self.name + } +} pub struct KexAlgorithmOutput { /// K pub shared_secret: Vec, @@ -69,15 +78,47 @@ pub const KEX_ECDH_SHA2_NISTP256: KexAlgorithm = KexAlgorithm { }, }; +#[derive(Clone, Copy)] +pub struct EncryptionAlgorithm { + name: &'static str, + decrypt_len: fn(keys: &[u8], bytes: &mut [u8], packet_number: u64), + decrypt_packet: fn(keys: &[u8], bytes: RawPacket, packet_number: u64) -> Result, + encrypt_packet: fn(keys: &[u8], packet: Packet, packet_number: u64) -> EncryptedPacket, +} +impl AlgorithmName for EncryptionAlgorithm { + fn name(&self) -> &'static str { + self.name + } +} +pub const ENC_CHACHA20POLY1305: EncryptionAlgorithm = EncryptionAlgorithm { + name: "chacha20-poly1305@openssh.com", + decrypt_len: |keys, bytes, packet_number| { + let alg = SshChaCha20Poly1305::from_keys(keys); + alg.decrypt_len(bytes, packet_number) + }, + decrypt_packet: |keys, bytes, packet_number| { + let alg = SshChaCha20Poly1305::from_keys(keys); + alg.decrypt_packet(bytes, packet_number) + }, + encrypt_packet: |keys, packet, packet_number| { + let alg = SshChaCha20Poly1305::from_keys(keys); + alg.encrypt_packet(packet, packet_number) + }, +}; + pub struct AlgorithmNegotiation { - pub supported: Vec<(&'static str, T)>, + pub supported: Vec, } -impl AlgorithmNegotiation { - pub fn find<'a>(&self, client_supports: &str) -> Result { +impl AlgorithmNegotiation { + pub fn find<'a>(mut self, client_supports: &str) -> Result { for client_alg in client_supports.split(',') { - if let Some(alg) = self.supported.iter().find(|alg| alg.0 == client_alg) { - return Ok(alg.1); + if let Some(alg) = self + .supported + .iter() + .position(|alg| alg.name() == client_alg) + { + return Ok(self.supported.remove(alg)); } } @@ -89,8 +130,10 @@ impl AlgorithmNegotiation { pub(crate) struct Session { session_id: [u8; 32], - encryption_key_client_to_server: SshChaCha20Poly1305, - encryption_key_server_to_client: SshChaCha20Poly1305, + encryption_key_client_to_server: [u8; 64], + encryption_client_to_server: EncryptionAlgorithm, + encryption_key_server_to_client: [u8; 64], + encryption_server_to_client: EncryptionAlgorithm, } pub(crate) trait Keys: Send + Sync + 'static { @@ -100,7 +143,13 @@ pub(crate) trait Keys: Send + Sync + 'static { fn encrypt_packet_to_msg(&mut self, packet: Packet, packet_number: u64) -> Msg; fn additional_mac_len(&self) -> usize; - fn rekey(&mut self, h: [u8; 32], k: &[u8]) -> Result<(), ()>; + fn rekey( + &mut self, + h: [u8; 32], + k: &[u8], + encryption_client_to_server: EncryptionAlgorithm, + encryption_server_to_client: EncryptionAlgorithm, + ) -> Result<(), ()>; } pub(crate) struct Plaintext; @@ -115,29 +164,52 @@ impl Keys for Plaintext { fn additional_mac_len(&self) -> usize { 0 } - fn rekey(&mut self, _: [u8; 32], _: &[u8]) -> Result<(), ()> { + fn rekey( + &mut self, + _: [u8; 32], + _: &[u8], + _: EncryptionAlgorithm, + _: EncryptionAlgorithm, + ) -> Result<(), ()> { Err(()) } } impl Session { - pub(crate) fn new(h: [u8; 32], k: &[u8]) -> Self { - Self::from_keys(h, h, k) + pub(crate) fn new( + h: [u8; 32], + k: &[u8], + encryption_client_to_server: EncryptionAlgorithm, + encryption_server_to_client: EncryptionAlgorithm, + ) -> Self { + Self::from_keys( + h, + h, + k, + encryption_client_to_server, + encryption_server_to_client, + ) } /// - fn from_keys(session_id: [u8; 32], h: [u8; 32], k: &[u8]) -> Self { - let encryption_key_client_to_server = - SshChaCha20Poly1305::new(derive_key(k, h, "C", session_id)); - let encryption_key_server_to_client = - SshChaCha20Poly1305::new(derive_key(k, h, "D", session_id)); + fn from_keys( + session_id: [u8; 32], + h: [u8; 32], + k: &[u8], + encryption_client_to_server: EncryptionAlgorithm, + encryption_server_to_client: EncryptionAlgorithm, + ) -> Self { + let encryption_key_client_to_server = derive_key(k, h, "C", session_id); + let encryption_key_server_to_client = derive_key(k, h, "D", session_id); Self { session_id, // client_to_server_iv: derive("A").into(), // server_to_client_iv: derive("B").into(), encryption_key_client_to_server, + encryption_client_to_server, encryption_key_server_to_client, + encryption_server_to_client, // integrity_key_client_to_server: derive("E").into(), // integrity_key_server_to_client: derive("F").into(), } @@ -146,19 +218,27 @@ impl Session { impl Keys for Session { fn decrypt_len(&mut self, bytes: &mut [u8; 4], packet_number: u64) { - self.encryption_key_client_to_server - .decrypt_len(bytes, packet_number); + (self.encryption_client_to_server.decrypt_len)( + &self.encryption_key_client_to_server, + bytes, + packet_number, + ); } fn decrypt_packet(&mut self, bytes: RawPacket, packet_number: u64) -> Result { - self.encryption_key_client_to_server - .decrypt_packet(bytes, packet_number) + (self.encryption_client_to_server.decrypt_packet)( + &self.encryption_key_client_to_server, + bytes, + packet_number, + ) } fn encrypt_packet_to_msg(&mut self, packet: Packet, packet_number: u64) -> Msg { - let packet = self - .encryption_key_server_to_client - .encrypt_packet(packet, packet_number); + let packet = (self.encryption_server_to_client.encrypt_packet)( + &self.encryption_key_server_to_client, + packet, + packet_number, + ); Msg(MsgKind::EncryptedPacket(packet)) } @@ -166,8 +246,20 @@ impl Keys for Session { poly1305::BLOCK_SIZE } - fn rekey(&mut self, h: [u8; 32], k: &[u8]) -> Result<(), ()> { - *self = Self::from_keys(self.session_id, h, k); + fn rekey( + &mut self, + h: [u8; 32], + k: &[u8], + encryption_client_to_server: EncryptionAlgorithm, + encryption_server_to_client: EncryptionAlgorithm, + ) -> Result<(), ()> { + *self = Self::from_keys( + self.session_id, + h, + k, + encryption_client_to_server, + encryption_server_to_client, + ); Ok(()) } } @@ -225,10 +317,11 @@ struct SshChaCha20Poly1305 { } impl SshChaCha20Poly1305 { - fn new(key: [u8; 64]) -> Self { + fn from_keys(keys: &[u8]) -> Self { + assert_eq!(keys.len(), 64); Self { - main_key: <[u8; 32]>::try_from(&key[..32]).unwrap().into(), - header_key: <[u8; 32]>::try_from(&key[32..]).unwrap().into(), + main_key: <[u8; 32]>::try_from(&keys[..32]).unwrap().into(), + header_key: <[u8; 32]>::try_from(&keys[32..]).unwrap().into(), } } @@ -241,7 +334,7 @@ impl SshChaCha20Poly1305 { cipher.apply_keystream(bytes); } - fn decrypt_packet(&mut self, mut bytes: RawPacket, packet_number: u64) -> Result { + fn decrypt_packet(&self, mut bytes: RawPacket, packet_number: u64) -> Result { // let mut cipher = ::new( @@ -276,7 +369,7 @@ impl SshChaCha20Poly1305 { Packet::from_full(encrypted_packet_content) } - fn encrypt_packet(&mut self, packet: Packet, packet_number: u64) -> EncryptedPacket { + fn encrypt_packet(&self, packet: Packet, packet_number: u64) -> EncryptedPacket { let mut bytes = packet.to_bytes(false); // Prepare the main cipher. diff --git a/ssh-transport/src/lib.rs b/ssh-transport/src/lib.rs index 006de6f..a1b7ab1 100644 --- a/ssh-transport/src/lib.rs +++ b/ssh-transport/src/lib.rs @@ -6,7 +6,7 @@ use core::str; use std::{collections::VecDeque, mem::take}; use ed25519_dalek::ed25519::signature::Signer; -use keys::AlgorithmNegotiation; +use keys::{AlgorithmName, AlgorithmNegotiation, EncryptionAlgorithm}; use packet::{ DhKeyExchangeInitReplyPacket, KeyExchangeEcDhInitPacket, KeyExchangeInitPacket, Packet, PacketTransport, SshPublicKey, SshSignature, @@ -64,10 +64,14 @@ enum ServerState { client_kexinit: Vec, server_kexinit: Vec, kex_algorithm: keys::KexAlgorithm, + encryption_client_to_server: EncryptionAlgorithm, + encryption_server_to_client: EncryptionAlgorithm, }, NewKeys { h: [u8; 32], k: Vec, + encryption_client_to_server: EncryptionAlgorithm, + encryption_server_to_client: EncryptionAlgorithm, }, ServiceRequest, Open, @@ -163,41 +167,37 @@ impl ServerConnection { } => { let kex = KeyExchangeInitPacket::parse(&packet.payload)?; - let require_algorithm = |expected: &'static str, - list: NameList<'_>| - -> Result<&'static str> { - if list.iter().any(|alg| alg == expected) { - Ok(expected) - } else { - Err(client_error!( + let require_algorithm = + |expected: &'static str, list: NameList<'_>| -> Result<&'static str> { + if list.iter().any(|alg| alg == expected) { + Ok(expected) + } else { + Err(client_error!( "client does not support algorithm {expected}. supported: {list:?}", )) - } - }; + } + }; let kex_algorithms = AlgorithmNegotiation { - supported: vec![( - keys::KEX_CURVE_25519_SHA256.name, - keys::KEX_CURVE_25519_SHA256, - ), ( - keys::KEX_ECDH_SHA2_NISTP256.name, - keys::KEX_ECDH_SHA2_NISTP256, - )], + supported: vec![keys::KEX_CURVE_25519_SHA256, keys::KEX_ECDH_SHA2_NISTP256], }; let kex_algorithm = kex_algorithms.find(kex.kex_algorithms.0)?; let server_host_key_algorithm = require_algorithm("ssh-ed25519", kex.server_host_key_algorithms)?; + let encryption_algorithms_client_to_server = AlgorithmNegotiation { + supported: vec![keys::ENC_CHACHA20POLY1305], + }; + let encryption_algorithms_server_to_client = AlgorithmNegotiation { + supported: vec![keys::ENC_CHACHA20POLY1305], + }; + // TODO: support aes256-gcm@openssh.com - let encryption_algorithm_client_to_server = require_algorithm( - "chacha20-poly1305@openssh.com", - kex.encryption_algorithms_client_to_server, - )?; - let encryption_algorithm_server_to_client = require_algorithm( - "chacha20-poly1305@openssh.com", - kex.encryption_algorithms_server_to_client, - )?; + let encryption_client_to_server = encryption_algorithms_client_to_server + .find(kex.encryption_algorithms_client_to_server.0)?; + let encryption_server_to_client = encryption_algorithms_server_to_client + .find(kex.encryption_algorithms_server_to_client.0)?; let mac_algorithm_client_to_server = require_algorithm("hmac-sha2-256", kex.mac_algorithms_client_to_server)?; let mac_algorithm_server_to_client = @@ -218,13 +218,13 @@ impl ServerConnection { let server_kexinit = KeyExchangeInitPacket { cookie: [0; 16], - kex_algorithms: NameList::one(kex_algorithm.name), + kex_algorithms: NameList::one(kex_algorithm.name()), server_host_key_algorithms: NameList::one(server_host_key_algorithm), encryption_algorithms_client_to_server: NameList::one( - encryption_algorithm_client_to_server, + encryption_client_to_server.name(), ), encryption_algorithms_server_to_client: NameList::one( - encryption_algorithm_server_to_client, + encryption_server_to_client.name(), ), mac_algorithms_client_to_server: NameList::one( mac_algorithm_client_to_server, @@ -253,6 +253,8 @@ impl ServerConnection { client_kexinit: packet.payload, server_kexinit: server_kexinit_payload, kex_algorithm, + encryption_client_to_server, + encryption_server_to_client, }; } ServerState::DhKeyInit { @@ -260,6 +262,8 @@ impl ServerConnection { client_kexinit, server_kexinit, kex_algorithm, + encryption_client_to_server, + encryption_server_to_client, } => { // TODO: move to keys.rs let dh = KeyExchangeEcDhInitPacket::parse(&packet.payload)?; @@ -333,9 +337,16 @@ impl ServerConnection { self.state = ServerState::NewKeys { h: hash.into(), k: shared_secret, + encryption_client_to_server: *encryption_client_to_server, + encryption_server_to_client: *encryption_server_to_client, }; } - ServerState::NewKeys { h, k } => { + ServerState::NewKeys { + h, + k, + encryption_client_to_server, + encryption_server_to_client, + } => { if packet.payload != [Packet::SSH_MSG_NEWKEYS] { return Err(client_error!("did not send SSH_MSG_NEWKEYS")); } @@ -343,7 +354,12 @@ impl ServerConnection { self.packet_transport.queue_packet(Packet { payload: vec![Packet::SSH_MSG_NEWKEYS], }); - self.packet_transport.set_key(*h, k); + self.packet_transport.set_key( + *h, + k, + *encryption_client_to_server, + *encryption_server_to_client, + ); self.state = ServerState::ServiceRequest {}; } ServerState::ServiceRequest => { diff --git a/ssh-transport/src/packet.rs b/ssh-transport/src/packet.rs index bea8066..47ad7b5 100644 --- a/ssh-transport/src/packet.rs +++ b/ssh-transport/src/packet.rs @@ -3,7 +3,7 @@ mod ctors; use std::collections::VecDeque; use crate::client_error; -use crate::keys::{Keys, Plaintext, Session}; +use crate::keys::{EncryptionAlgorithm, Keys, Plaintext, Session}; use crate::parse::{NameList, Parser, Writer}; use crate::Result; @@ -101,9 +101,25 @@ impl PacketTransport { self.send_packets.pop_front() } - pub(crate) fn set_key(&mut self, h: [u8; 32], k: &[u8]) { - if let Err(()) = self.keys.rekey(h, k) { - self.keys = Box::new(Session::new(h, k)); + pub(crate) fn set_key( + &mut self, + h: [u8; 32], + k: &[u8], + encryption_client_to_server: EncryptionAlgorithm, + encryption_server_to_client: EncryptionAlgorithm, + ) { + if let Err(()) = self.keys.rekey( + h, + k, + encryption_client_to_server, + encryption_server_to_client, + ) { + self.keys = Box::new(Session::new( + h, + k, + encryption_client_to_server, + encryption_server_to_client, + )); } } }