migrate cluelessh-faked to cluelessh-tokio

This commit is contained in:
nora 2024-08-25 15:01:22 +02:00
parent 01d6a861f1
commit b6d0675976
11 changed files with 513 additions and 137 deletions

10
Cargo.lock generated
View file

@ -347,6 +347,7 @@ name = "cluelessh-faked"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"cluelessh-protocol", "cluelessh-protocol",
"cluelessh-tokio",
"eyre", "eyre",
"hex-literal", "hex-literal",
"rand", "rand",
@ -425,6 +426,15 @@ dependencies = [
"x25519-dalek", "x25519-dalek",
] ]
[[package]]
name = "cluelesshd"
version = "0.1.0"
dependencies = [
"cluelessh-protocol",
"cluelessh-tokio",
"cluelessh-transport",
]
[[package]] [[package]]
name = "colorchoice" name = "colorchoice"
version = "1.0.2" version = "1.0.2"

View file

@ -8,6 +8,7 @@ eyre = "0.6.12"
hex-literal = "0.4.1" hex-literal = "0.4.1"
rand = "0.8.5" rand = "0.8.5"
cluelessh-protocol = { path = "../../lib/cluelessh-protocol" } cluelessh-protocol = { path = "../../lib/cluelessh-protocol" }
cluelessh-tokio = { path = "../../lib/cluelessh-tokio" }
tokio = { version = "1.39.2", features = ["full"] } tokio = { version = "1.39.2", features = ["full"] }
tracing-subscriber = { version = "0.3.18", features = ["env-filter", "json"] } tracing-subscriber = { version = "0.3.18", features = ["env-filter", "json"] }

View file

@ -1,27 +1,19 @@
use std::{collections::HashMap, net::SocketAddr}; use std::{net::SocketAddr, sync::Arc};
use cluelessh_tokio::Channel;
use eyre::{Context, Result}; use eyre::{Context, Result};
use rand::RngCore;
use tokio::{ use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::{TcpListener, TcpStream}, net::{TcpListener, TcpStream},
sync::Mutex,
}; };
use tracing::{debug, error, info, info_span, Instrument}; use tracing::{debug, error, info, info_span, warn, Instrument};
use cluelessh_protocol::{ use cluelessh_protocol::{
connection::{ChannelOpen, ChannelOperationKind, ChannelRequest}, connection::{ChannelKind, ChannelOperationKind, ChannelRequest},
transport::{self}, ChannelUpdateKind, SshStatus,
ChannelUpdateKind, ServerConnection, SshStatus,
}; };
use tracing_subscriber::EnvFilter; use tracing_subscriber::EnvFilter;
struct ThreadRngRand;
impl cluelessh_protocol::transport::SshRng for ThreadRngRand {
fn fill_bytes(&mut self, dest: &mut [u8]) {
rand::thread_rng().fill_bytes(dest);
}
}
#[tokio::main] #[tokio::main]
async fn main() -> eyre::Result<()> { async fn main() -> eyre::Result<()> {
let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")); let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
@ -45,14 +37,16 @@ async fn main() -> eyre::Result<()> {
let listener = TcpListener::bind(addr).await.wrap_err("binding listener")?; let listener = TcpListener::bind(addr).await.wrap_err("binding listener")?;
let mut listener = cluelessh_tokio::server::ServerListener::new(listener);
loop { loop {
let next = listener.accept().await?; let next = listener.accept().await?;
let span = info_span!("connection", addr = %next.1); let span = info_span!("connection", addr = %next.peer_addr());
tokio::spawn( tokio::spawn(
async { async move {
let mut total_sent_data = Vec::new(); let total_sent_data = Arc::new(Mutex::new(Vec::new()));
if let Err(err) = handle_connection(next, &mut total_sent_data).await { if let Err(err) = handle_connection(next, total_sent_data.clone()).await {
if let Some(err) = err.downcast_ref::<std::io::Error>() { if let Some(err) = err.downcast_ref::<std::io::Error>() {
if err.kind() == std::io::ErrorKind::ConnectionReset { if err.kind() == std::io::ErrorKind::ConnectionReset {
return; return;
@ -63,6 +57,7 @@ async fn main() -> eyre::Result<()> {
} }
// Limit stdin to 500 characters. // Limit stdin to 500 characters.
let total_sent_data = total_sent_data.lock().await;
let stdin = String::from_utf8_lossy(&total_sent_data); let stdin = String::from_utf8_lossy(&total_sent_data);
let stdin = if let Some((idx, _)) = stdin.char_indices().nth(500) { let stdin = if let Some((idx, _)) = stdin.char_indices().nth(500) {
&stdin[..idx] &stdin[..idx]
@ -78,46 +73,18 @@ async fn main() -> eyre::Result<()> {
} }
async fn handle_connection( async fn handle_connection(
next: (TcpStream, SocketAddr), mut conn: cluelessh_tokio::server::ServerConnection<TcpStream>,
total_sent_data: &mut Vec<u8>, total_sent_data: Arc<Mutex<Vec<u8>>>,
) -> Result<()> { ) -> Result<()> {
let (mut conn, addr) = next; info!(addr = %conn.peer_addr(), "Received a new connection");
info!(%addr, "Received a new connection");
/*let rng = vec![
0x14, 0xa2, 0x04, 0xa5, 0x4b, 0x2f, 0x5f, 0xa7, 0xff, 0x53, 0x13, 0x67, 0x57, 0x67, 0xbc,
0x55, 0x3f, 0xc0, 0x6c, 0x0d, 0x07, 0x8f, 0xe2, 0x75, 0x95, 0x18, 0x4b, 0xd2, 0xcb, 0xd0,
0x64, 0x06, 0x14, 0xa2, 0x04, 0xa5, 0x4b, 0x2f, 0x5f, 0xa7, 0xff, 0x53, 0x13, 0x67, 0x57,
0x67, 0xbc, 0x55, 0x3f, 0xc0, 0x6c, 0x0d, 0x07, 0x8f, 0xe2, 0x75, 0x95, 0x18, 0x4b, 0xd2,
0xcb, 0xd0, 0x64, 0x06, 0x67, 0xbc, 0x55, 0x3f, 0xc0, 0x6c, 0x0d, 0x07, 0x8f, 0xe2, 0x75,
0x95, 0x18, 0x4b, 0xd2, 0xcb, 0xd0, 0x64, 0x06,
];
struct HardcodedRng(Vec<u8>);
impl cluelessh_protocol::transport::SshRng for HardcodedRng {
fn fill_bytes(&mut self, dest: &mut [u8]) {
dest.copy_from_slice(&self.0[..dest.len()]);
self.0.splice(0..dest.len(), []);
}
}*/
let mut state = ServerConnection::new(transport::server::ServerConnection::new(ThreadRngRand));
let mut session_channels = HashMap::new();
loop { loop {
let mut buf = [0; 1024]; match conn.progress().await {
let read = conn Ok(()) => {}
.read(&mut buf) Err(cluelessh_tokio::server::Error::ServerError(err)) => {
.await return Err(err);
.wrap_err("reading from connection")?;
if read == 0 {
info!("Did not read any bytes from TCP stream, EOF");
return Ok(());
} }
Err(cluelessh_tokio::server::Error::SshStatus(status)) => match status {
if let Err(err) = state.recv_bytes(&buf[..read]) {
match err {
SshStatus::PeerError(err) => { SshStatus::PeerError(err) => {
info!(?err, "disconnecting client after invalid operation"); info!(?err, "disconnecting client after invalid operation");
return Ok(()); return Ok(());
@ -126,28 +93,40 @@ async fn handle_connection(
info!("Received disconnect from client"); info!("Received disconnect from client");
return Ok(()); return Ok(());
} }
},
}
while let Some(channel) = conn.next_new_channel() {
if *channel.kind() == ChannelKind::Session {
let total_sent_data = total_sent_data.clone();
tokio::spawn(async {
let _ = handle_session_channel(channel, total_sent_data).await;
});
} else {
warn!("Trying to open non-session channel");
}
}
} }
} }
while let Some(update) = state.next_channel_update() { async fn handle_session_channel(
//eprintln!("{:?}", update); mut channel: Channel,
match update.kind { total_sent_data: Arc<Mutex<Vec<u8>>>,
ChannelUpdateKind::Open(kind) => match kind { ) -> Result<()> {
ChannelOpen::Session => { loop {
session_channels.insert(update.number, ()); match channel.next_update().await {
} Ok(update) => match update {
},
ChannelUpdateKind::Request(req) => { ChannelUpdateKind::Request(req) => {
let success = update.number.construct_op(ChannelOperationKind::Success); let success = ChannelOperationKind::Success;
match req { match req {
ChannelRequest::PtyReq { want_reply, .. } => { ChannelRequest::PtyReq { want_reply, .. } => {
if want_reply { if want_reply {
state.do_operation(success); channel.send(success).await?;
} }
} }
ChannelRequest::Shell { want_reply } => { ChannelRequest::Shell { want_reply } => {
if want_reply { if want_reply {
state.do_operation(success); channel.send(success).await?;
} }
} }
ChannelRequest::Exec { ChannelRequest::Exec {
@ -155,26 +134,20 @@ async fn handle_connection(
command, command,
} => { } => {
if want_reply { if want_reply {
state.do_operation(success); channel.send(success).await?;
} }
let result = execute_command(&command); let result = execute_command(&command);
state.do_operation( channel
update .send(ChannelOperationKind::Data(result.stdout))
.number .await?;
.construct_op(ChannelOperationKind::Data(result.stdout)), channel
); .send(ChannelOperationKind::Request(ChannelRequest::ExitStatus {
state.do_operation(update.number.construct_op(
ChannelOperationKind::Request(ChannelRequest::ExitStatus {
status: result.status, status: result.status,
}), }))
)); .await?;
state.do_operation( channel.send(ChannelOperationKind::Eof).await?;
update.number.construct_op(ChannelOperationKind::Eof), channel.send(ChannelOperationKind::Close).await?;
);
state.do_operation(
update.number.construct_op(ChannelOperationKind::Close),
);
} }
ChannelRequest::ExitStatus { .. } => {} ChannelRequest::ExitStatus { .. } => {}
ChannelRequest::Env { .. } => {} ChannelRequest::Env { .. } => {}
@ -185,44 +158,37 @@ async fn handle_connection(
let is_eof = data.contains(&0x04 /*EOF, Ctrl-D*/); let is_eof = data.contains(&0x04 /*EOF, Ctrl-D*/);
// echo :3 // echo :3
state.do_operation( channel
update .send(ChannelOperationKind::Data(data.clone()))
.number .await?;
.construct_op(ChannelOperationKind::Data(data.clone())),
);
let mut total_sent_data = total_sent_data.lock().await;
// arbitrary limit // arbitrary limit
if total_sent_data.len() < 50_000 { if total_sent_data.len() < 50_000 {
total_sent_data.extend_from_slice(&data); total_sent_data.extend_from_slice(&data);
} else { } else {
info!(channel = %update.number, "Reached stdin limit"); info!("Reached stdin limit");
state.do_operation(update.number.construct_op(ChannelOperationKind::Data( channel
b"Thanks Hayley!\n".to_vec(), .send(ChannelOperationKind::Data(b"Thanks Hayley!\n".to_vec()))
))); .await?;
state.do_operation(update.number.construct_op(ChannelOperationKind::Close)); channel.send(ChannelOperationKind::Close).await?;
} }
if is_eof { if is_eof {
debug!(channel = %update.number, "Received Ctrl-D, closing channel"); debug!("Received Ctrl-D, closing channel");
state.do_operation(update.number.construct_op(ChannelOperationKind::Eof)); channel.send(ChannelOperationKind::Eof).await?;
state.do_operation(update.number.construct_op(ChannelOperationKind::Close)); channel.send(ChannelOperationKind::Close).await?;
} }
} }
ChannelUpdateKind::ExtendedData { .. } ChannelUpdateKind::Open(_)
| ChannelUpdateKind::Closed
| ChannelUpdateKind::ExtendedData { .. }
| ChannelUpdateKind::Eof | ChannelUpdateKind::Eof
| ChannelUpdateKind::Success | ChannelUpdateKind::Success
| ChannelUpdateKind::Failure => { /* ignore */ } | ChannelUpdateKind::Failure => { /* ignore */ }
ChannelUpdateKind::Closed => { },
session_channels.remove(&update.number); Err(err) => return Err(err),
}
}
}
while let Some(msg) = state.next_msg_to_send() {
conn.write_all(&msg.to_bytes())
.await
.wrap_err("writing response")?;
} }
} }
} }

View file

@ -8,7 +8,7 @@ use cluelessh_transport::{key::PublicKey, numbers, parse::Writer};
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tracing::{debug, error}; use tracing::{debug, error};
use cluelessh_protocol::connection::{ChannelOpen, ChannelOperationKind, ChannelRequest}; use cluelessh_protocol::connection::{ChannelKind, ChannelOperationKind, ChannelRequest};
use tracing_subscriber::EnvFilter; use tracing_subscriber::EnvFilter;
#[derive(clap::Parser, Debug)] #[derive(clap::Parser, Debug)]
@ -121,7 +121,7 @@ async fn main() -> eyre::Result<()> {
) )
.await?; .await?;
let session = tokio_conn.open_channel(ChannelOpen::Session); let session = tokio_conn.open_channel(ChannelKind::Session);
tokio::spawn(async { tokio::spawn(async {
let result = main_channel(session).await; let result = main_channel(session).await;

View file

@ -0,0 +1,9 @@
[package]
name = "cluelesshd"
version = "0.1.0"
edition = "2021"
[dependencies]
cluelessh-protocol = { path = "../../lib/cluelessh-protocol" }
cluelessh-tokio = { path = "../../lib/cluelessh-tokio" }
cluelessh-transport = { path = "../../lib/cluelessh-transport" }

View file

@ -0,0 +1,3 @@
fn main() {
println!("Hello, world!");
}

View file

@ -32,7 +32,7 @@ enum ChannelState {
our_window_size: u32, our_window_size: u32,
/// For validation only. /// For validation only.
our_max_packet_size: u32, our_max_packet_size: u32,
update_message: ChannelOpen, update_message: ChannelKind,
}, },
Open(Channel), Open(Channel),
} }
@ -71,7 +71,7 @@ pub struct ChannelUpdate {
pub enum ChannelUpdateKind { pub enum ChannelUpdateKind {
Success, Success,
Failure, Failure,
Open(ChannelOpen), Open(ChannelKind),
OpenFailed { code: u32, message: String }, OpenFailed { code: u32, message: String },
Request(ChannelRequest), Request(ChannelRequest),
Data { data: Vec<u8> }, Data { data: Vec<u8> },
@ -80,7 +80,7 @@ pub enum ChannelUpdateKind {
Closed, Closed,
} }
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
pub enum ChannelOpen { pub enum ChannelKind {
Session, Session,
} }
#[derive(Debug)] #[derive(Debug)]
@ -173,7 +173,7 @@ impl ChannelsState {
debug!(%channel_type, %sender_channel, "Receving channel open"); debug!(%channel_type, %sender_channel, "Receving channel open");
let update_message = match channel_type { let update_message = match channel_type {
"session" => ChannelOpen::Session, "session" => ChannelKind::Session,
_ => { _ => {
self.packets_to_send self.packets_to_send
.push_back(Packet::new_msg_channel_open_failure( .push_back(Packet::new_msg_channel_open_failure(
@ -512,7 +512,7 @@ impl ChannelsState {
} }
/// Create a new channel /// Create a new channel
pub fn create_channel(&mut self, kind: ChannelOpen) -> ChannelNumber { pub fn create_channel(&mut self, kind: ChannelKind) -> ChannelNumber {
let our_number = self.next_channel_id; let our_number = self.next_channel_id;
self.next_channel_id = ChannelNumber( self.next_channel_id = ChannelNumber(
self.next_channel_id self.next_channel_id
@ -521,7 +521,7 @@ impl ChannelsState {
.expect("created too many channels"), .expect("created too many channels"),
); );
assert_eq!(kind, ChannelOpen::Session, "TODO"); assert_eq!(kind, ChannelKind::Session, "TODO");
let our_window_size = 2097152; // same as OpenSSH let our_window_size = 2097152; // same as OpenSSH
let our_max_packet_size = 32768; // same as OpenSSH let our_max_packet_size = 32768; // same as OpenSSH

View file

@ -44,8 +44,9 @@ impl ServerConnection {
self.transport.send_plaintext_packet(to_send); self.transport.send_plaintext_packet(to_send);
} }
if auth.is_authenticated() { if auth.is_authenticated() {
self.state = self.state = ServerConnectionState::Open(
ServerConnectionState::Open(cluelessh_connection::ChannelsState::new(true)); cluelessh_connection::ChannelsState::new(true),
);
} }
} }
ServerConnectionState::Open(con) => { ServerConnectionState::Open(con) => {
@ -94,6 +95,20 @@ impl ServerConnection {
} }
} }
} }
pub fn channels(&mut self) -> Option<&mut cluelessh_connection::ChannelsState> {
match &mut self.state {
ServerConnectionState::Open(channels) => Some(channels),
_ => None,
}
}
pub fn auth(&mut self) -> Option<&mut auth::BadAuth> {
match &mut self.state {
ServerConnectionState::Auth(auth) => Some(auth),
_ => None,
}
}
} }
pub struct ClientConnection { pub struct ClientConnection {
@ -108,7 +123,10 @@ enum ClientConnectionState {
} }
impl ClientConnection { impl ClientConnection {
pub fn new(transport: cluelessh_transport::client::ClientConnection, auth: auth::ClientAuth) -> Self { pub fn new(
transport: cluelessh_transport::client::ClientConnection,
auth: auth::ClientAuth,
) -> Self {
Self { Self {
transport, transport,
state: ClientConnectionState::Setup(Some(auth)), state: ClientConnectionState::Setup(Some(auth)),
@ -139,8 +157,9 @@ impl ClientConnection {
self.transport.send_plaintext_packet(to_send); self.transport.send_plaintext_packet(to_send);
} }
if auth.is_authenticated() { if auth.is_authenticated() {
self.state = self.state = ClientConnectionState::Open(
ClientConnectionState::Open(cluelessh_connection::ChannelsState::new(false)); cluelessh_connection::ChannelsState::new(false),
);
} }
} }
ClientConnectionState::Open(con) => { ClientConnectionState::Open(con) => {
@ -227,6 +246,18 @@ pub mod auth {
is_authenticated: bool, is_authenticated: bool,
} }
pub enum ServerRequest {
VerifyPassword {
user: String,
password: String,
},
VerifyPubkey {
session_identifier: [u8; 32],
user: String,
pubkey: Vec<u8>,
},
}
impl BadAuth { impl BadAuth {
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
@ -320,6 +351,10 @@ pub mod auth {
self.is_authenticated self.is_authenticated
} }
pub fn server_requests(&mut self) -> impl Iterator<Item = ServerRequest> + '_ {
[].into_iter()
}
fn queue_packet(&mut self, packet: Packet) { fn queue_packet(&mut self, packet: Packet) {
self.packets_to_send.push_back(packet); self.packets_to_send.push_back(packet);
} }

View file

@ -1,13 +1,15 @@
use cluelessh_connection::{ChannelNumber, ChannelOpen, ChannelOperation, ChannelOperationKind}; use cluelessh_connection::{ChannelKind, ChannelNumber, ChannelOperation, ChannelOperationKind};
use std::{collections::HashMap, pin::Pin, sync::Arc}; use std::{collections::HashMap, pin::Pin, sync::Arc};
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
use cluelessh_protocol::{ChannelUpdateKind, SshStatus};
use eyre::{bail, ContextCompat, OptionExt, Result, WrapErr}; use eyre::{bail, ContextCompat, OptionExt, Result, WrapErr};
use futures::future::BoxFuture; use futures::future::BoxFuture;
use cluelessh_protocol::{ChannelUpdateKind, SshStatus};
use tokio::io::{AsyncRead, AsyncWrite}; use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{debug, info, warn}; use tracing::{debug, info, warn};
use crate::Channel;
pub struct ClientConnection<S> { pub struct ClientConnection<S> {
stream: Pin<Box<S>>, stream: Pin<Box<S>>,
buf: [u8; 1024], buf: [u8; 1024],
@ -55,11 +57,6 @@ pub struct PendingChannel {
ready_recv: tokio::sync::oneshot::Receiver<Result<(), String>>, ready_recv: tokio::sync::oneshot::Receiver<Result<(), String>>,
channel: Channel, channel: Channel,
} }
pub struct Channel {
number: ChannelNumber,
updates_recv: tokio::sync::mpsc::Receiver<ChannelUpdateKind>,
ops_send: tokio::sync::mpsc::Sender<ChannelOperation>,
}
impl<S: AsyncRead + AsyncWrite> ClientConnection<S> { impl<S: AsyncRead + AsyncWrite> ClientConnection<S> {
pub async fn connect(stream: S, auth: ClientAuth) -> Result<Self> { pub async fn connect(stream: S, auth: ClientAuth) -> Result<Self> {
@ -75,7 +72,9 @@ impl<S: AsyncRead + AsyncWrite> ClientConnection<S> {
channel_ops_recv, channel_ops_recv,
channels: HashMap::new(), channels: HashMap::new(),
proto: cluelessh_protocol::ClientConnection::new( proto: cluelessh_protocol::ClientConnection::new(
cluelessh_transport::client::ClientConnection::new(cluelessh_protocol::ThreadRngRand), cluelessh_transport::client::ClientConnection::new(
cluelessh_protocol::ThreadRngRand,
),
cluelessh_protocol::auth::ClientAuth::new(auth.username.as_bytes().to_vec()), cluelessh_protocol::auth::ClientAuth::new(auth.username.as_bytes().to_vec()),
), ),
auth, auth,
@ -245,14 +244,14 @@ impl<S: AsyncRead + AsyncWrite> ClientConnection<S> {
Ok(()) Ok(())
} }
pub fn open_channel(&mut self, kind: ChannelOpen) -> PendingChannel { pub fn open_channel(&mut self, kind: ChannelKind) -> PendingChannel {
let Some(channels) = self.proto.channels() else { let Some(channels) = self.proto.channels() else {
panic!("connection not ready yet") panic!("connection not ready yet")
}; };
let (updates_send, updates_recv) = tokio::sync::mpsc::channel(10); let (updates_send, updates_recv) = tokio::sync::mpsc::channel(10);
let (ready_send, ready_recv) = tokio::sync::oneshot::channel(); let (ready_send, ready_recv) = tokio::sync::oneshot::channel();
let number = channels.create_channel(kind); let number = channels.create_channel(kind.clone());
self.channels.insert( self.channels.insert(
number, number,
@ -268,6 +267,7 @@ impl<S: AsyncRead + AsyncWrite> ClientConnection<S> {
number, number,
updates_recv, updates_recv,
ops_send: self.channel_ops_send.clone(), ops_send: self.channel_ops_send.clone(),
kind,
}, },
} }
} }
@ -290,11 +290,4 @@ impl Channel {
.await .await
.map_err(Into::into) .map_err(Into::into)
} }
pub async fn next_update(&mut self) -> Result<ChannelUpdateKind> {
self.updates_recv
.recv()
.await
.ok_or_eyre("channel has been closed")
}
} }

View file

@ -1 +1,33 @@
pub mod client; pub mod client;
pub mod server;
use cluelessh_connection::{ChannelKind, ChannelNumber, ChannelOperation, ChannelOperationKind};
use cluelessh_protocol::ChannelUpdateKind;
use eyre::{OptionExt, Result};
pub struct Channel {
number: ChannelNumber,
updates_recv: tokio::sync::mpsc::Receiver<ChannelUpdateKind>,
ops_send: tokio::sync::mpsc::Sender<ChannelOperation>,
kind: ChannelKind,
}
impl Channel {
pub async fn send(&mut self, op: ChannelOperationKind) -> Result<()> {
self.ops_send
.send(self.number.construct_op(op))
.await
.map_err(Into::into)
}
pub async fn next_update(&mut self) -> Result<ChannelUpdateKind> {
self.updates_recv
.recv()
.await
.ok_or_eyre("channel has been closed")
}
pub fn kind(&self) -> &ChannelKind {
&self.kind
}
}

View file

@ -0,0 +1,327 @@
use cluelessh_connection::{ChannelKind, ChannelNumber, ChannelOperation};
use std::{
collections::{HashMap, VecDeque},
net::SocketAddr,
pin::Pin,
};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::{TcpListener, TcpStream},
};
use cluelessh_protocol::{ChannelUpdateKind, SshStatus};
use eyre::{eyre, ContextCompat, Result, WrapErr};
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::info;
use crate::Channel;
pub struct ServerListener {
listener: TcpListener,
// todo ratelimits etc
}
pub struct ServerConnection<S> {
stream: Pin<Box<S>>,
peer_addr: SocketAddr,
buf: [u8; 1024],
proto: cluelessh_protocol::ServerConnection,
operations_send: tokio::sync::mpsc::Sender<Operation>,
operations_recv: tokio::sync::mpsc::Receiver<Operation>,
/// Cloned and passed on to channels.
channel_ops_send: tokio::sync::mpsc::Sender<ChannelOperation>,
channel_ops_recv: tokio::sync::mpsc::Receiver<ChannelOperation>,
channels: HashMap<ChannelNumber, ChannelState>,
/// New channels opened by the peer.
new_channels: VecDeque<Channel>,
}
enum ChannelState {
Pending {
ready_send: tokio::sync::oneshot::Sender<Result<(), String>>,
updates_send: tokio::sync::mpsc::Sender<ChannelUpdateKind>,
},
Ready(tokio::sync::mpsc::Sender<ChannelUpdateKind>),
}
enum Operation {
VerifyPassword {
user: String,
password: String,
},
VerifyPubkey {
session_identifier: [u8; 32],
user: String,
pubkey: Vec<u8>,
},
}
pub struct SignatureResult {
pub key_alg_name: &'static str,
pub public_key: Vec<u8>,
pub signature: Vec<u8>,
}
pub struct PendingChannel {
ready_recv: tokio::sync::oneshot::Receiver<Result<(), String>>,
channel: Channel,
}
pub enum Error {
SshStatus(SshStatus),
ServerError(eyre::Report),
}
impl From<eyre::Report> for Error {
fn from(value: eyre::Report) -> Self {
Self::ServerError(value)
}
}
impl ServerListener {
pub fn new(listener: TcpListener) -> Self {
Self { listener }
}
pub async fn accept(&mut self) -> Result<ServerConnection<TcpStream>> {
let (conn, peer_addr) = self.listener.accept().await?;
Ok(ServerConnection::new(conn, peer_addr))
}
}
impl<S: AsyncRead + AsyncWrite> ServerConnection<S> {
pub fn new(stream: S, peer_addr: SocketAddr) -> Self {
let (operations_send, operations_recv) = tokio::sync::mpsc::channel(15);
let (channel_ops_send, channel_ops_recv) = tokio::sync::mpsc::channel(15);
Self {
stream: Box::pin(stream),
peer_addr,
buf: [0; 1024],
operations_send,
operations_recv,
channel_ops_send,
channel_ops_recv,
channels: HashMap::new(),
proto: cluelessh_protocol::ServerConnection::new(
cluelessh_transport::server::ServerConnection::new(
cluelessh_protocol::ThreadRngRand,
),
),
new_channels: VecDeque::new(),
}
}
pub fn peer_addr(&self) -> SocketAddr {
self.peer_addr
}
/// Executes one loop iteration of the main loop.
// IMPORTANT: no operations on this struct should ever block the main loop, except this one.
pub async fn progress(&mut self) -> Result<(), Error> {
if let Some(auth) = self.proto.auth() {
for req in auth.server_requests() {
match req {
cluelessh_protocol::auth::ServerRequest::VerifyPassword { user, password } => {
let send = self.operations_send.clone();
tokio::spawn(async move {
let _ = send
.send(Operation::VerifyPassword { user, password })
.await;
});
}
cluelessh_protocol::auth::ServerRequest::VerifyPubkey {
session_identifier,
pubkey,
user,
} => {
let send = self.operations_send.clone();
tokio::spawn(async move {
let _ = send
.send(Operation::VerifyPubkey {
session_identifier,
user,
pubkey,
})
.await;
});
}
}
}
}
if let Some(channels) = self.proto.channels() {
while let Some(update) = channels.next_channel_update() {
match &update.kind {
ChannelUpdateKind::Open(channel_kind) => {
let channel = self.channels.get_mut(&update.number);
match channel {
// We opened.
Some(ChannelState::Pending { updates_send, .. }) => {
let updates_send = updates_send.clone();
let old = self
.channels
.insert(update.number, ChannelState::Ready(updates_send));
match old.unwrap() {
ChannelState::Pending { ready_send, .. } => {
let _ = ready_send.send(Ok(()));
}
_ => unreachable!(),
}
}
Some(ChannelState::Ready(_)) => {
return Err(Error::ServerError(eyre!(
"attemping to open channel twice: {}",
update.number
)))
}
// They opened.
None => {
let (updates_send, updates_recv) = tokio::sync::mpsc::channel(10);
let number = update.number;
self.channels
.insert(number, ChannelState::Ready(updates_send));
let channel = Channel {
number,
updates_recv,
ops_send: self.channel_ops_send.clone(),
kind: channel_kind.clone(),
};
self.new_channels.push_back(channel);
}
}
}
ChannelUpdateKind::OpenFailed { message, .. } => {
let channel = self
.channels
.get_mut(&update.number)
.wrap_err("unknown channel")?;
match channel {
ChannelState::Pending { .. } => {
let old = self.channels.remove(&update.number);
match old.unwrap() {
ChannelState::Pending { ready_send, .. } => {
let _ = ready_send.send(Err(message.clone()));
}
_ => unreachable!(),
}
}
ChannelState::Ready(_) => {
return Err(Error::ServerError(eyre!(
"attemping to open channel twice: {}",
update.number
)))
}
}
}
_ => {
let channel = self
.channels
.get_mut(&update.number)
.wrap_err("unknown channel")?;
match channel {
ChannelState::Pending { .. } => {
return Err(Error::ServerError(eyre!("channel not ready yet")))
}
ChannelState::Ready(updates_send) => {
let _ = updates_send.send(update.kind).await;
}
}
}
}
}
}
// Make sure that we send all queued messages before going into the select, waiting for things to happen.
self.send_off_data().await?;
tokio::select! {
read = self.stream.read(&mut self.buf) => {
let read = read.wrap_err("reading from connection")?;
if read == 0 {
info!("Did not read any bytes from TCP stream, EOF");
return Ok(());
}
if let Err(err) = self.proto.recv_bytes(&self.buf[..read]) {
return Err(Error::SshStatus(err));
}
}
channel_op = self.channel_ops_recv.recv() => {
let channels = self.proto.channels().expect("connection not ready");
if let Some(channel_op) = channel_op {
channels.do_operation(channel_op);
}
}
op = self.operations_recv.recv() => {
match op {
Some(Operation::VerifyPubkey { .. }) => todo!(),
Some(Operation::VerifyPassword { .. }) => todo!(),
None => {}
}
self.send_off_data().await?;
}
}
Ok(())
}
async fn send_off_data(&mut self) -> Result<()> {
self.proto.progress();
while let Some(msg) = self.proto.next_msg_to_send() {
self.stream
.write_all(&msg.to_bytes())
.await
.wrap_err("writing response")?;
}
Ok(())
}
pub fn open_channel(&mut self, kind: ChannelKind) -> PendingChannel {
let Some(channels) = self.proto.channels() else {
panic!("connection not ready yet")
};
let (updates_send, updates_recv) = tokio::sync::mpsc::channel(10);
let (ready_send, ready_recv) = tokio::sync::oneshot::channel();
let number = channels.create_channel(kind.clone());
self.channels.insert(
number,
ChannelState::Pending {
ready_send,
updates_send,
},
);
PendingChannel {
ready_recv,
channel: Channel {
number,
updates_recv,
ops_send: self.channel_ops_send.clone(),
kind,
},
}
}
pub fn next_new_channel(&mut self) -> Option<Channel> {
self.new_channels.pop_front()
}
}
impl PendingChannel {
pub async fn wait_ready(self) -> Result<Channel, Option<String>> {
match self.ready_recv.await {
Ok(Ok(())) => Ok(self.channel),
Ok(Err(err)) => Err(Some(err)),
Err(_) => Err(None),
}
}
}