This commit is contained in:
nora 2024-08-12 00:37:53 +02:00
parent 83e7ef1727
commit cb504817d3
2 changed files with 51 additions and 22 deletions

View file

@ -8,7 +8,7 @@ use tokio::{
use tracing::{debug, error, info}; use tracing::{debug, error, info};
use ssh_protocol::{ use ssh_protocol::{
connection::{ChannelOpen, ChannelOperation, ChannelOperationKind, ChannelRequestKind}, connection::{ChannelOpen, ChannelOperationKind, ChannelRequestKind},
transport::{self, ThreadRngRand}, transport::{self, ThreadRngRand},
ChannelUpdateKind, ServerConnection, SshStatus, ChannelUpdateKind, ServerConnection, SshStatus,
}; };
@ -95,8 +95,20 @@ async fn handle_connection(next: (TcpStream, SocketAddr)) -> Result<()> {
}, },
ChannelUpdateKind::Request(req) => { ChannelUpdateKind::Request(req) => {
match req.kind { match req.kind {
ChannelRequestKind::PtyReq { .. } => {} ChannelRequestKind::PtyReq { .. } => {
ChannelRequestKind::Shell => {} if req.want_reply {
state.do_operation(
update.number.construct_op(ChannelOperationKind::Success),
);
}
}
ChannelRequestKind::Shell => {
if req.want_reply {
state.do_operation(
update.number.construct_op(ChannelOperationKind::Success),
);
}
}
}; };
if req.want_reply { if req.want_reply {
// TODO: sent the reply. // TODO: sent the reply.
@ -106,18 +118,13 @@ async fn handle_connection(next: (TcpStream, SocketAddr)) -> Result<()> {
let is_eof = data.contains(&0x03 /*EOF, Ctrl-C*/); let is_eof = data.contains(&0x03 /*EOF, Ctrl-C*/);
// echo :3 // echo :3
state.do_operation(ChannelOperation { state
number: update.number, .do_operation(update.number.construct_op(ChannelOperationKind::Data(data)));
kind: ChannelOperationKind::Data(data),
});
if is_eof { if is_eof {
debug!(channel = ?update.number, "Received EOF, closing channel"); debug!(channel = ?update.number, "Received EOF, closing channel");
state.do_operation(ChannelOperation { state.do_operation(update.number.construct_op(ChannelOperationKind::Close));
number: update.number,
kind: ChannelOperationKind::Close,
});
} }
} }
ChannelUpdateKind::ExtendedData { .. } | ChannelUpdateKind::Eof => { /* ignore */ } ChannelUpdateKind::ExtendedData { .. } | ChannelUpdateKind::Eof => { /* ignore */ }

View file

@ -5,12 +5,16 @@ use ssh_transport::client_error;
use ssh_transport::packet::Packet; use ssh_transport::packet::Packet;
use ssh_transport::Result; use ssh_transport::Result;
/// A channel number (on our side).
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct ChannelNumber(pub u32);
pub struct ServerChannelsState { pub struct ServerChannelsState {
packets_to_send: VecDeque<Packet>, packets_to_send: VecDeque<Packet>,
channel_updates: VecDeque<ChannelUpdate>, channel_updates: VecDeque<ChannelUpdate>,
channels: HashMap<u32, Channel>, channels: HashMap<ChannelNumber, Channel>,
next_channel_id: u32, next_channel_id: ChannelNumber,
} }
struct Channel { struct Channel {
@ -23,7 +27,7 @@ struct Channel {
/// An update from a channel. /// An update from a channel.
/// The receiver-equivalent of [`ChannelOperation`]. /// The receiver-equivalent of [`ChannelOperation`].
pub struct ChannelUpdate { pub struct ChannelUpdate {
pub number: u32, pub number: ChannelNumber,
pub kind: ChannelUpdateKind, pub kind: ChannelUpdateKind,
} }
pub enum ChannelUpdateKind { pub enum ChannelUpdateKind {
@ -56,10 +60,17 @@ pub enum ChannelRequestKind {
Shell, Shell,
} }
impl ChannelNumber {
#[must_use]
pub fn construct_op(self, kind: ChannelOperationKind) -> ChannelOperation {
ChannelOperation { number: self, kind }
}
}
/// An operation to do on a channel. /// An operation to do on a channel.
/// The sender-equivalent of [`ChannelUpdate`]. /// The sender-equivalent of [`ChannelUpdate`].
pub struct ChannelOperation { pub struct ChannelOperation {
pub number: u32, pub number: ChannelNumber,
pub kind: ChannelOperationKind, pub kind: ChannelOperationKind,
} }
@ -76,7 +87,7 @@ impl ServerChannelsState {
packets_to_send: VecDeque::new(), packets_to_send: VecDeque::new(),
channels: HashMap::new(), channels: HashMap::new(),
channel_updates: VecDeque::new(), channel_updates: VecDeque::new(),
next_channel_id: 0, next_channel_id: ChannelNumber(0),
} }
} }
@ -116,13 +127,14 @@ impl ServerChannelsState {
}; };
let our_number = self.next_channel_id; let our_number = self.next_channel_id;
self.next_channel_id = self.next_channel_id.checked_add(1).ok_or_else(|| { self.next_channel_id =
ChannelNumber(self.next_channel_id.0.checked_add(1).ok_or_else(|| {
client_error!("created too many channels, overflowed the counter") client_error!("created too many channels, overflowed the counter")
})?; })?);
self.packets_to_send self.packets_to_send
.push_back(Packet::new_msg_channel_open_confirmation( .push_back(Packet::new_msg_channel_open_confirmation(
our_number, our_number.0,
sender_channel, sender_channel,
initial_window_size, initial_window_size,
max_packet_size, max_packet_size,
@ -145,6 +157,7 @@ impl ServerChannelsState {
} }
Packet::SSH_MSG_CHANNEL_DATA => { Packet::SSH_MSG_CHANNEL_DATA => {
let our_channel = packet.u32()?; let our_channel = packet.u32()?;
let our_channel = self.validate_channel(our_channel)?;
let data = packet.string()?; let data = packet.string()?;
let _ = self.channel(our_channel)?; let _ = self.channel(our_channel)?;
@ -159,6 +172,7 @@ impl ServerChannelsState {
Packet::SSH_MSG_CHANNEL_CLOSE => { Packet::SSH_MSG_CHANNEL_CLOSE => {
// <https://datatracker.ietf.org/doc/html/rfc4254#section-5.3> // <https://datatracker.ietf.org/doc/html/rfc4254#section-5.3>
let our_channel = packet.u32()?; let our_channel = packet.u32()?;
let our_channel = self.validate_channel(our_channel)?;
let channel = self.channel(our_channel)?; let channel = self.channel(our_channel)?;
if !channel.we_closed { if !channel.we_closed {
let close = Packet::new_msg_channel_close(channel.peer_channel); let close = Packet::new_msg_channel_close(channel.peer_channel);
@ -176,6 +190,7 @@ impl ServerChannelsState {
} }
Packet::SSH_MSG_CHANNEL_REQUEST => { Packet::SSH_MSG_CHANNEL_REQUEST => {
let our_channel = packet.u32()?; let our_channel = packet.u32()?;
let our_channel = self.validate_channel(our_channel)?;
let request_type = packet.utf8_string()?; let request_type = packet.utf8_string()?;
let want_reply = packet.bool()?; let want_reply = packet.bool()?;
@ -286,9 +301,16 @@ impl ServerChannelsState {
.push_back(Packet::new_msg_channel_failure(recipient_channel)); .push_back(Packet::new_msg_channel_failure(recipient_channel));
} }
fn channel(&mut self, number: u32) -> Result<&mut Channel> { fn validate_channel(&self, number: u32) -> Result<ChannelNumber> {
if !self.channels.contains_key(&ChannelNumber(number)) {
return Err(client_error!("unknown channel: {number}"));
}
Ok(ChannelNumber(number))
}
fn channel(&mut self, number: ChannelNumber) -> Result<&mut Channel> {
self.channels self.channels
.get_mut(&number) .get_mut(&number)
.ok_or_else(|| client_error!("unknown channel: {number}")) .ok_or_else(|| client_error!("unknown channel: {number:?}"))
} }
} }