diff --git a/Cargo.lock b/Cargo.lock index f9b1a2a..f1925ec 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -911,6 +911,7 @@ version = "0.1.0" dependencies = [ "ssh-transport", "tracing", + "tracing-subscriber", ] [[package]] diff --git a/ssh-connection/Cargo.toml b/ssh-connection/Cargo.toml index d49b91c..ec3f470 100644 --- a/ssh-connection/Cargo.toml +++ b/ssh-connection/Cargo.toml @@ -6,3 +6,6 @@ edition = "2021" [dependencies] ssh-transport = { path = "../ssh-transport" } tracing = "0.1.40" + +[dev-dependencies] +tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } diff --git a/ssh-connection/src/lib.rs b/ssh-connection/src/lib.rs index 74aa02e..2e5c736 100644 --- a/ssh-connection/src/lib.rs +++ b/ssh-connection/src/lib.rs @@ -450,7 +450,6 @@ impl ServerChannelsState { fn send_data(&mut self, channel_number: ChannelNumber, data: &[u8]) { let channel = self.channel(channel_number).unwrap(); - let peer = channel.peer_channel; let mut chunks = data.chunks(channel.peer_max_packet_size as usize); @@ -463,11 +462,12 @@ impl ServerChannelsState { let rest = channel.peer_window_size; let (to_send, to_keep) = data.split_at(rest as usize); - // Send everything we can, which empties the window. - channel.peer_window_size -= rest; - assert_eq!(channel.peer_window_size, 0); - self.packets_to_send - .push_back(Packet::new_msg_channel_data(peer, to_send)); + if !to_send.is_empty() { + // Send everything we can, which empties the window. + channel.peer_window_size -= rest; + assert_eq!(channel.peer_window_size, 0); + self.send_data_packet(channel_number, to_send); + } // It's over, we have exhausted all window space. // Queue the rest of the bytes. @@ -483,11 +483,23 @@ impl ServerChannelsState { } trace!(channel = %channel_number, window = %channel.peer_window_size, "Remaining window on their side"); - self.packets_to_send - .push_back(Packet::new_msg_channel_data(peer, data)); + self.send_data_packet(channel_number, data); } } + /// Send a single data packet. + /// The caller needs to ensure the windowing and packet size requirements are upheld. + fn send_data_packet(&mut self, channel_number: ChannelNumber, data: &[u8]) { + assert!(data.len() > 0, "Trying to send empty data packet"); + + trace!(%channel_number, amount = %data.len(), "Sending channel data"); + let channel = self.channel(channel_number).unwrap(); + let peer = channel.peer_channel; + assert!(channel.peer_max_packet_size >= data.len() as u32); + self.packets_to_send + .push_back(Packet::new_msg_channel_data(peer, data)); + } + fn send_channel_success(&mut self, recipient_channel: u32) { self.packets_to_send .push_back(Packet::new_msg_channel_success(recipient_channel)); @@ -539,7 +551,16 @@ mod tests { use crate::{ChannelNumber, ChannelOperation, ChannelOperationKind, ServerChannelsState}; - fn assert_response(state: &mut ServerChannelsState, types: &[u8]) { + /// If a test fails, add this to the test to get logs. + #[allow(dead_code)] + fn init_test_log() { + tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .init(); + } + + #[track_caller] + fn assert_response_types(state: &mut ServerChannelsState, types: &[u8]) { let response = state .packets_to_send() .map(|p| numbers::packet_type_to_string(p.packet_type())) @@ -558,7 +579,7 @@ mod tests { b"session", 0, 2048, 1024, )) .unwrap(); - assert_response(state, &[numbers::SSH_MSG_CHANNEL_OPEN_CONFIRMATION]); + assert_response_types(state, &[numbers::SSH_MSG_CHANNEL_OPEN_CONFIRMATION]); } #[test] @@ -572,24 +593,24 @@ mod tests { )) .unwrap(); state.do_operation(ChannelNumber(0).construct_op(ChannelOperationKind::Success)); - assert_response(state, &[numbers::SSH_MSG_CHANNEL_SUCCESS]); + assert_response_types(state, &[numbers::SSH_MSG_CHANNEL_SUCCESS]); state .recv_packet(Packet::new_msg_channel_request_shell(0, b"shell", true)) .unwrap(); state.do_operation(ChannelNumber(0).construct_op(ChannelOperationKind::Success)); - assert_response(state, &[numbers::SSH_MSG_CHANNEL_SUCCESS]); + assert_response_types(state, &[numbers::SSH_MSG_CHANNEL_SUCCESS]); state .recv_packet(Packet::new_msg_channel_data(0, b"hello, world")) .unwrap(); - assert_response(state, &[]); + assert_response_types(state, &[]); state.recv_packet(Packet::new_msg_channel_eof(0)).unwrap(); - assert_response(state, &[]); + assert_response_types(state, &[]); state.recv_packet(Packet::new_msg_channel_close(0)).unwrap(); - assert_response(state, &[numbers::SSH_MSG_CHANNEL_CLOSE]); + assert_response_types(state, &[numbers::SSH_MSG_CHANNEL_CLOSE]); } #[test] @@ -604,7 +625,7 @@ mod tests { number: ChannelNumber(0), kind: ChannelOperationKind::Close, }); - assert_response(state, &[numbers::SSH_MSG_CHANNEL_CLOSE]); + assert_response_types(state, &[numbers::SSH_MSG_CHANNEL_CLOSE]); } #[test] @@ -612,11 +633,78 @@ mod tests { let mut state = &mut ServerChannelsState::new(); open_session_channel(state); state.recv_packet(Packet::new_msg_channel_close(0)).unwrap(); - assert_response(&mut state, &[numbers::SSH_MSG_CHANNEL_CLOSE]); + assert_response_types(&mut state, &[numbers::SSH_MSG_CHANNEL_CLOSE]); state.do_operation(ChannelOperation { number: ChannelNumber(0), kind: ChannelOperationKind::Data(vec![0]), }); - assert_response(state, &[]); + assert_response_types(state, &[]); + } + + #[test] + fn respect_peer_windowing() { + let state = &mut ServerChannelsState::new(); + state + .recv_packet(Packet::new_msg_channel_open_session(b"session", 0, 10, 50)) + .unwrap(); + assert_response_types(state, &[numbers::SSH_MSG_CHANNEL_OPEN_CONFIRMATION]); + + // Send 100 bytes. + state.do_operation( + ChannelNumber(0) + .construct_op(ChannelOperationKind::Data((0_u8..200).collect::>())), + ); + + // 0..10 + assert_response_types(state, &[numbers::SSH_MSG_CHANNEL_DATA]); + + state + .recv_packet(Packet::new_msg_channel_window_adjust(0, 90)) + .unwrap(); + // 10..60, 60..100 + assert_response_types( + state, + &[numbers::SSH_MSG_CHANNEL_DATA, numbers::SSH_MSG_CHANNEL_DATA], + ); + + state + .recv_packet(Packet::new_msg_channel_window_adjust(0, 100)) + .unwrap(); + // 100..150, 150..20 + assert_response_types( + state, + &[numbers::SSH_MSG_CHANNEL_DATA, numbers::SSH_MSG_CHANNEL_DATA], + ); + + state + .recv_packet(Packet::new_msg_channel_window_adjust(0, 100)) + .unwrap(); + assert_response_types(state, &[]); + } + + #[test] + fn send_windowing_adjustments() { + let state = &mut ServerChannelsState::new(); + state + .recv_packet(Packet::new_msg_channel_open_session( + b"session", 0, 2000, 2000, + )) + .unwrap(); + assert_response_types(state, &[numbers::SSH_MSG_CHANNEL_OPEN_CONFIRMATION]); + + state + .recv_packet(Packet::new_msg_channel_data(0, &vec![0; 2000])) + .unwrap(); + assert_response_types(state, &[numbers::SSH_MSG_CHANNEL_WINDOW_ADJUST]); + + // We currently hardcode <1000 for when to send window size adjustments. + state + .recv_packet(Packet::new_msg_channel_data(0, &vec![0; 1000])) + .unwrap(); + assert_response_types(state, &[]); + state + .recv_packet(Packet::new_msg_channel_data(0, &vec![0; 1])) + .unwrap(); + assert_response_types(state, &[numbers::SSH_MSG_CHANNEL_WINDOW_ADJUST]); } }