From ea28daca0c70f92ffab3393dfa7e0fec0eed425f Mon Sep 17 00:00:00 2001 From: Noratrieb <48135649+Noratrieb@users.noreply.github.com> Date: Fri, 23 Aug 2024 17:54:49 +0200 Subject: [PATCH] move orchestration logic into ssh-tokio --- Cargo.lock | 15 ++ bin/fakesshd/src/main.rs | 1 + bin/ssh/Cargo.toml | 1 + bin/ssh/src/main.rs | 315 ++++++++++++---------------------- lib/ssh-connection/src/lib.rs | 26 +++ lib/ssh-protocol/Cargo.toml | 1 + lib/ssh-protocol/src/lib.rs | 8 + lib/ssh-tokio/Cargo.toml | 13 ++ lib/ssh-tokio/README.md | 5 + lib/ssh-tokio/src/client.rs | 300 ++++++++++++++++++++++++++++++++ lib/ssh-tokio/src/lib.rs | 1 + 11 files changed, 477 insertions(+), 209 deletions(-) create mode 100644 lib/ssh-tokio/Cargo.toml create mode 100644 lib/ssh-tokio/README.md create mode 100644 lib/ssh-tokio/src/client.rs create mode 100644 lib/ssh-tokio/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index 6f3de3b..310ef7c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1210,6 +1210,7 @@ dependencies = [ "rpassword", "ssh-agent-client", "ssh-protocol", + "ssh-tokio", "ssh-transport", "tokio", "tracing", @@ -1282,11 +1283,25 @@ dependencies = [ name = "ssh-protocol" version = "0.1.0" dependencies = [ + "rand", "ssh-connection", "ssh-transport", "tracing", ] +[[package]] +name = "ssh-tokio" +version = "0.1.0" +dependencies = [ + "eyre", + "futures", + "ssh-connection", + "ssh-protocol", + "ssh-transport", + "tokio", + "tracing", +] + [[package]] name = "ssh-transport" version = "0.1.0" diff --git a/bin/fakesshd/src/main.rs b/bin/fakesshd/src/main.rs index d5cae60..703b4d6 100644 --- a/bin/fakesshd/src/main.rs +++ b/bin/fakesshd/src/main.rs @@ -180,6 +180,7 @@ async fn handle_connection( ChannelRequest::Env { .. } => {} }; } + ChannelUpdateKind::OpenFailed { .. } => todo!(), ChannelUpdateKind::Data { data } => { let is_eof = data.contains(&0x04 /*EOF, Ctrl-D*/); diff --git a/bin/ssh/Cargo.toml b/bin/ssh/Cargo.toml index d34fc78..f4f193f 100644 --- a/bin/ssh/Cargo.toml +++ b/bin/ssh/Cargo.toml @@ -7,6 +7,7 @@ edition = "2021" ssh-protocol = { path = "../../lib/ssh-protocol" } ssh-transport = { path = "../../lib/ssh-transport" } ssh-agent-client = { path = "../../lib/ssh-agent-client" } +ssh-tokio = { path = "../../lib/ssh-tokio" } clap = { version = "4.5.15", features = ["derive"] } eyre = "0.6.12" diff --git a/bin/ssh/src/main.rs b/bin/ssh/src/main.rs index 440b17c..c0d5474 100644 --- a/bin/ssh/src/main.rs +++ b/bin/ssh/src/main.rs @@ -1,32 +1,16 @@ -use std::{collections::HashSet, io::Write}; +use std::{collections::HashSet, sync::Arc}; use clap::Parser; -use eyre::{bail, Context, ContextCompat, OptionExt}; -use rand::RngCore; -use ssh_transport::{key::PublicKey, numbers, parse::Writer, peer_error}; -use tokio::{ - io::{AsyncReadExt, AsyncWriteExt}, - net::TcpStream, -}; -use tracing::{debug, error, info}; +use eyre::{bail, Context, ContextCompat, OptionExt, Result}; +use ssh_tokio::client::{PendingChannel, SignatureResult}; +use ssh_transport::{key::PublicKey, numbers, parse::Writer}; +use tokio::net::TcpStream; +use tracing::{debug, error}; -use ssh_protocol::{ - connection::{ - ChannelNumber, ChannelOpen, ChannelOperation, ChannelOperationKind, ChannelRequest, - }, - transport::{self}, - ChannelUpdate, ChannelUpdateKind, SshStatus, -}; +use ssh_protocol::connection::{ChannelOpen, ChannelOperationKind, ChannelRequest}; use tracing_subscriber::EnvFilter; -struct ThreadRngRand; -impl ssh_protocol::transport::SshRng for ThreadRngRand { - fn fill_bytes(&mut self, dest: &mut [u8]) { - rand::thread_rng().fill_bytes(dest); - } -} - #[derive(clap::Parser, Debug)] struct Args { #[arg(short = 'p', long, default_value_t = 22)] @@ -37,22 +21,6 @@ struct Args { command: Vec, } -enum Operation { - PasswordEntered(std::io::Result), - Signature { - key_alg_name: &'static str, - public_key: Vec, - signature: Vec, - }, -} - -// TODO: state machine everything including auth -enum ClientState { - Start, - WaitingForOpen(ChannelNumber), - WaitingForPty(ChannelNumber), -} - #[tokio::main] async fn main() -> eyre::Result<()> { let args = Args::parse(); @@ -77,183 +45,112 @@ async fn main() -> eyre::Result<()> { Some(user) => user, }; - let mut attempted_public_keys = HashSet::new(); - - let mut conn = TcpStream::connect(&format!("{}:{}", args.destination, args.port)) + let conn = TcpStream::connect(&format!("{}:{}", args.destination, args.port)) .await .wrap_err("connecting")?; - let mut state = ssh_protocol::ClientConnection::new( - transport::client::ClientConnection::new(ThreadRngRand), - ssh_protocol::auth::ClientAuth::new(username.as_bytes().to_vec()), - ); + let username1 = username.clone(); + let mut tokio_conn = ssh_tokio::client::ClientConnection::connect( + conn, + ssh_tokio::client::ClientAuth { + username: username.clone(), + prompt_password: Arc::new(move || { + let username = username1.clone(); + let destination = args.destination.clone(); + Box::pin(async { + let result = tokio::task::spawn_blocking(move || { + let password = rpassword::prompt_password(format!( + "{}@{}'s password: ", + username, destination + )); + password + }) + .await?; + result.wrap_err("failed to prompt password") + }) + }), + sign_pubkey: Arc::new(move |session_identifier| { + let session_identifier = session_identifier.to_vec(); + let mut attempted_public_keys = HashSet::new(); + let username = username.clone(); + Box::pin(async move { + // TODO: support agentless manual key opening + // TODO: move + let mut agent = ssh_agent_client::SocketAgentConnection::from_env() + .await + .wrap_err("failed to connect to SSH agent")?; + let identities = agent.list_identities().await?; + for identity in &identities { + let pubkey = PublicKey::from_wire_encoding(&identity.key_blob) + .wrap_err("received invalid public key from SSH agent")?; + debug!(comment = ?identity.comment, %pubkey, "Found identity"); + } + if identities.len() > 1 { + todo!("try identities"); + } + let identity = &identities[0]; + if !attempted_public_keys.insert(identity.key_blob.clone()) { + bail!("authentication denied (publickey)"); + } + let pubkey = PublicKey::from_wire_encoding(&identity.key_blob)?; - let mut client_state = ClientState::Start; + let mut sign_data = Writer::new(); + sign_data.string(session_identifier); + sign_data.u8(numbers::SSH_MSG_USERAUTH_REQUEST); + sign_data.string(&username); + sign_data.string("ssh-connection"); + sign_data.string("publickey"); + sign_data.bool(true); + sign_data.string(pubkey.algorithm_name()); + sign_data.string(&identity.key_blob); - let (send_op, mut recv_op) = tokio::sync::mpsc::channel::(10); + let data = sign_data.finish(); + let signature = agent + .sign(&identity.key_blob, &data, 0) + .await + .wrap_err("signing for authentication")?; - let mut buf = [0; 1024]; + Ok(SignatureResult { + key_alg_name: pubkey.algorithm_name(), + public_key: identity.key_blob.clone(), + signature, + }) + }) + }), + }, + ) + .await?; + + let session = tokio_conn.open_channel(ChannelOpen::Session); + + tokio::spawn(async { + let result = main_channel(session).await; + if let Err(err) = result { + error!(?err); + } + }); loop { - if let Some(auth) = state.auth() { - for req in auth.user_requests() { - match req { - ssh_protocol::auth::ClientUserRequest::Password => { - let username = username.clone(); - let destination = args.destination.clone(); - let send_op = send_op.clone(); - std::thread::spawn(move || { - let password = rpassword::prompt_password(format!( - "{}@{}'s password: ", - username, destination - )); - let _ = send_op.blocking_send(Operation::PasswordEntered(password)); - }); - } - ssh_protocol::auth::ClientUserRequest::PrivateKeySign { - session_identifier, - } => { - // TODO: support agentless manual key opening - // TODO: move - let mut agent = ssh_agent_client::SocketAgentConnection::from_env() - .await - .wrap_err("failed to connect to SSH agent")?; - let identities = agent.list_identities().await?; - for identity in &identities { - let pubkey = PublicKey::from_wire_encoding(&identity.key_blob) - .wrap_err("received invalid public key from SSH agent")?; - debug!(comment = ?identity.comment, %pubkey, "Found identity"); - } - if identities.len() > 1 { - todo!("try identities"); - } - let identity = &identities[0]; - if !attempted_public_keys.insert(identity.key_blob.clone()) { - bail!("authentication denied (publickey)"); - } - let pubkey = PublicKey::from_wire_encoding(&identity.key_blob)?; - - let mut sign_data = Writer::new(); - sign_data.string(session_identifier); - sign_data.u8(numbers::SSH_MSG_USERAUTH_REQUEST); - sign_data.string(&username); - sign_data.string("ssh-connection"); - sign_data.string("publickey"); - sign_data.bool(true); - sign_data.string(pubkey.algorithm_name()); - sign_data.string(&identity.key_blob); - - let data = sign_data.finish(); - let signature = agent - .sign(&identity.key_blob, &data, 0) - .await - .wrap_err("signing for authentication")?; - - send_op - .send(Operation::Signature { - key_alg_name: pubkey.algorithm_name(), - public_key: identity.key_blob.clone(), - signature, - }) - .await?; - } - ssh_protocol::auth::ClientUserRequest::Banner(banner) => { - let banner = String::from_utf8_lossy(&banner); - std::io::stdout().write(&banner.as_bytes())?; - } - } - } - } - - if let Some(channels) = state.channels() { - if let ClientState::Start = client_state { - let number = channels.create_channel(ChannelOpen::Session); - client_state = ClientState::WaitingForOpen(number); - } - - while let Some(update) = channels.next_channel_update() { - match &update.kind { - ChannelUpdateKind::Open(_) => match client_state { - ClientState::WaitingForOpen(number) => { - if number != update.number { - bail!("unexpected channel opened by server"); - } - client_state = ClientState::WaitingForPty(update.number); - channels.do_operation(number.construct_op( - ChannelOperationKind::Request(ChannelRequest::PtyReq { - want_reply: true, - term: "xterm-256color".to_owned(), - width_chars: 70, - height_rows: 10, - width_px: 0, - height_px: 0, - term_modes: vec![], - }), - )); - } - _ => bail!("unexpected channel opened by server"), - }, - ChannelUpdateKind::Success => {} - ChannelUpdateKind::Failure => bail!("operation failed"), - ChannelUpdateKind::Request(_) => todo!(), - ChannelUpdateKind::Data { .. } => todo!(), - ChannelUpdateKind::ExtendedData { .. } => todo!(), - ChannelUpdateKind::Eof => todo!(), - ChannelUpdateKind::Closed => todo!(), - } - } - } - - // Make sure that we send all queues messages before going into the select, waiting for things to happen. - state.progress(); - while let Some(msg) = state.next_msg_to_send() { - conn.write_all(&msg.to_bytes()) - .await - .wrap_err("writing response")?; - } - - tokio::select! { - read = conn.read(&mut buf) => { - let read = read.wrap_err("reading from connection")?; - if read == 0 { - info!("Did not read any bytes from TCP stream, EOF"); - return Ok(()); - } - if let Err(err) = state.recv_bytes(&buf[..read]) { - match err { - SshStatus::PeerError(err) => { - error!(?err, "disconnecting client after invalid operation"); - return Ok(()); - } - SshStatus::Disconnect => { - error!("Received disconnect from server"); - return Ok(()); - } - } - } - } - op = recv_op.recv() => { - match op { - Some(Operation::PasswordEntered(password)) => { - if let Some(auth) = state.auth() { - auth.send_password(&password?); - } else { - debug!("Ignoring entered password as the state has moved on"); - } - } - Some(Operation::Signature{ - key_alg_name, public_key, signature, - }) => { - if let Some(auth) = state.auth() { - auth.send_signature(key_alg_name, &public_key, &signature); - } else { - debug!("Ignoring signature as the state has moved on"); - } - } - None => {} - } - state.progress(); - } - } + tokio_conn.progress().await?; } } + +async fn main_channel(channel: PendingChannel) -> Result<()> { + let Ok(mut channel) = channel.wait_ready().await else { + bail!("failed to create channel"); + }; + + channel + .send_operation(ChannelOperationKind::Request(ChannelRequest::PtyReq { + want_reply: true, + term: "xterm-256color".to_owned(), + width_chars: 70, + height_rows: 10, + width_px: 0, + height_px: 0, + term_modes: vec![], + })) + .await?; + + Ok(()) +} diff --git a/lib/ssh-connection/src/lib.rs b/lib/ssh-connection/src/lib.rs index 85e5d1b..125fcfc 100644 --- a/lib/ssh-connection/src/lib.rs +++ b/lib/ssh-connection/src/lib.rs @@ -72,6 +72,7 @@ pub enum ChannelUpdateKind { Success, Failure, Open(ChannelOpen), + OpenFailed { code: u32, message: String }, Request(ChannelRequest), Data { data: Vec }, ExtendedData { code: u32, data: Vec }, @@ -259,6 +260,31 @@ impl ChannelsState { debug!(channel_type = %"session", %our_number, "Successfully opened channel"); } + numbers::SSH_MSG_CHANNEL_OPEN_FAILURE => { + let our_channel = p.u32()?; + let our_number = ChannelNumber(our_channel); + let Some(&ChannelState::AwaitingConfirmation { .. }) = + self.channels.get(&our_number) + else { + return Err(peer_error!("unknown channel: {our_channel}")); + }; + + let reason_code = p.u32()?; + let reason_msg = p.utf8_string()?; + let _language_tag = p.utf8_string()?; + + debug!(%our_number, %reason_code, %reason_msg, "Failed to open channel"); + + self.channel_updates.push_back(ChannelUpdate { + number: our_number, + kind: ChannelUpdateKind::OpenFailed { + code: reason_code, + message: reason_msg.to_owned(), + }, + }); + + self.channels.remove(&our_number); + } numbers::SSH_MSG_CHANNEL_WINDOW_ADJUST => { let our_channel = p.u32()?; let our_channel = self.validate_channel(our_channel)?; diff --git a/lib/ssh-protocol/Cargo.toml b/lib/ssh-protocol/Cargo.toml index bbef8b6..d3e9ea9 100644 --- a/lib/ssh-protocol/Cargo.toml +++ b/lib/ssh-protocol/Cargo.toml @@ -4,6 +4,7 @@ version = "0.1.0" edition = "2021" [dependencies] +rand = "0.8.5" ssh-connection = { path = "../ssh-connection" } ssh-transport = { path = "../ssh-transport" } tracing.workspace = true diff --git a/lib/ssh-protocol/src/lib.rs b/lib/ssh-protocol/src/lib.rs index 0a4a952..dc581a9 100644 --- a/lib/ssh-protocol/src/lib.rs +++ b/lib/ssh-protocol/src/lib.rs @@ -7,6 +7,14 @@ pub use ssh_transport as transport; pub use ssh_transport::{Result, SshStatus}; use tracing::debug; +pub struct ThreadRngRand; +impl transport::SshRng for ThreadRngRand { + fn fill_bytes(&mut self, dest: &mut [u8]) { + use rand::RngCore; + rand::thread_rng().fill_bytes(dest); + } +} + pub struct ServerConnection { transport: ssh_transport::server::ServerConnection, state: ServerConnectionState, diff --git a/lib/ssh-tokio/Cargo.toml b/lib/ssh-tokio/Cargo.toml new file mode 100644 index 0000000..a3b7d79 --- /dev/null +++ b/lib/ssh-tokio/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "ssh-tokio" +version = "0.1.0" +edition = "2021" + +[dependencies] +eyre = "0.6.12" +ssh-transport = { path = "../ssh-transport" } +ssh-connection = { path = "../ssh-connection" } +ssh-protocol = { path = "../ssh-protocol" } +tokio = { version = "1.39.3", features = ["net"] } +tracing.workspace = true +futures = "0.3.30" diff --git a/lib/ssh-tokio/README.md b/lib/ssh-tokio/README.md new file mode 100644 index 0000000..dbd7d7d --- /dev/null +++ b/lib/ssh-tokio/README.md @@ -0,0 +1,5 @@ +# ssh-tokio + +Adapter layer for async Tokio programs. + +Exposes channels as MPSC-like structs. diff --git a/lib/ssh-tokio/src/client.rs b/lib/ssh-tokio/src/client.rs new file mode 100644 index 0000000..8b06bbe --- /dev/null +++ b/lib/ssh-tokio/src/client.rs @@ -0,0 +1,300 @@ +use ssh_connection::{ChannelNumber, ChannelOpen, ChannelOperation, ChannelOperationKind}; +use std::{collections::HashMap, pin::Pin, sync::Arc}; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; + +use eyre::{bail, ContextCompat, OptionExt, Result, WrapErr}; +use futures::future::BoxFuture; +use ssh_protocol::{ChannelUpdateKind, SshStatus}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tracing::{debug, info, warn}; + +pub struct ClientConnection { + stream: Pin>, + buf: [u8; 1024], + + proto: ssh_protocol::ClientConnection, + operations_send: tokio::sync::mpsc::Sender, + operations_recv: tokio::sync::mpsc::Receiver, + + /// Cloned and passed on to channels. + channel_ops_send: tokio::sync::mpsc::Sender, + channel_ops_recv: tokio::sync::mpsc::Receiver, + + channels: HashMap, + + auth: ClientAuth, +} + +enum ChannelState { + Pending { + ready_send: tokio::sync::oneshot::Sender>, + updates_send: tokio::sync::mpsc::Sender, + }, + Ready(tokio::sync::mpsc::Sender), +} + +pub struct ClientAuth { + pub username: String, + pub prompt_password: Arc BoxFuture<'static, Result> + Send + Sync>, + pub sign_pubkey: + Arc BoxFuture<'static, Result> + Send + Sync>, +} + +enum Operation { + PasswordEntered(Result), + Signature(Result), +} + +pub struct SignatureResult { + pub key_alg_name: &'static str, + pub public_key: Vec, + pub signature: Vec, +} + +pub struct PendingChannel { + ready_recv: tokio::sync::oneshot::Receiver>, + channel: Channel, +} +pub struct Channel { + number: ChannelNumber, + updates_recv: tokio::sync::mpsc::Receiver, + ops_send: tokio::sync::mpsc::Sender, +} + +impl ClientConnection { + pub async fn connect(stream: S, auth: ClientAuth) -> Result { + let (operations_send, operations_recv) = tokio::sync::mpsc::channel(15); + let (channel_ops_send, channel_ops_recv) = tokio::sync::mpsc::channel(15); + + let mut this = Self { + stream: Box::pin(stream), + buf: [0; 1024], + operations_send, + operations_recv, + channel_ops_send, + channel_ops_recv, + channels: HashMap::new(), + proto: ssh_protocol::ClientConnection::new( + ssh_transport::client::ClientConnection::new(ssh_protocol::ThreadRngRand), + ssh_protocol::auth::ClientAuth::new(auth.username.as_bytes().to_vec()), + ), + auth, + }; + + while !this.proto.is_open() { + this.progress().await?; + } + + Ok(this) + } + + /// Executes one loop iteration of the main loop. + // IMPORTANT: no operations on this struct should ever block the main loop, except this one. + pub async fn progress(&mut self) -> Result<()> { + if let Some(auth) = self.proto.auth() { + for req in auth.user_requests() { + match req { + ssh_protocol::auth::ClientUserRequest::Password => { + let send = self.operations_send.clone(); + let prompt_password = self.auth.prompt_password.clone(); + tokio::spawn(async move { + let password = prompt_password().await; + let _ = send.send(Operation::PasswordEntered(password)).await; + }); + } + ssh_protocol::auth::ClientUserRequest::PrivateKeySign { + session_identifier, + } => { + let send = self.operations_send.clone(); + let sign_pubkey = self.auth.sign_pubkey.clone(); + tokio::spawn(async move { + let signature_result = sign_pubkey(&session_identifier).await; + let _ = send.send(Operation::Signature(signature_result)).await; + }); + } + ssh_protocol::auth::ClientUserRequest::Banner(_) => { + warn!("ignoring banner as it's not implemented..."); + } + } + } + } + + if let Some(channels) = self.proto.channels() { + while let Some(update) = channels.next_channel_update() { + match &update.kind { + ChannelUpdateKind::Open(_) => { + let channel = self + .channels + .get_mut(&update.number) + .wrap_err("unknown channel")?; + match channel { + ChannelState::Pending { updates_send, .. } => { + let updates_send = updates_send.clone(); + let old = self + .channels + .insert(update.number, ChannelState::Ready(updates_send)); + match old.unwrap() { + ChannelState::Pending { ready_send, .. } => { + let _ = ready_send.send(Ok(())); + } + _ => unreachable!(), + } + } + ChannelState::Ready(_) => { + bail!("attemping to open channel twice: {}", update.number); + } + } + } + ChannelUpdateKind::OpenFailed { message, .. } => { + let channel = self + .channels + .get_mut(&update.number) + .wrap_err("unknown channel")?; + match channel { + ChannelState::Pending { .. } => { + let old = self.channels.remove(&update.number); + match old.unwrap() { + ChannelState::Pending { ready_send, .. } => { + let _ = ready_send.send(Err(message.clone())); + } + _ => unreachable!(), + } + } + ChannelState::Ready(_) => { + bail!("attemping to open channel twice: {}", update.number); + } + } + } + _ => { + let channel = self + .channels + .get_mut(&update.number) + .wrap_err("unknown channel")?; + match channel { + ChannelState::Pending { .. } => bail!("channel not ready yet"), + ChannelState::Ready(updates_send) => { + let _ = updates_send.send(update.kind).await; + } + } + } + } + } + } + + // Make sure that we send all queues messages before going into the select, waiting for things to happen. + self.send_off_data().await?; + + tokio::select! { + read = self.stream.read(&mut self.buf) => { + let read = read.wrap_err("reading from connection")?; + if read == 0 { + info!("Did not read any bytes from TCP stream, EOF"); + return Ok(()); + } + if let Err(err) = self.proto.recv_bytes(&self.buf[..read]) { + match err { + SshStatus::PeerError(err) => { + bail!("disconnecting client after invalid operation: {err}"); + } + SshStatus::Disconnect => { + bail!("Received disconnect from server"); + } + } + } + } + channel_op = self.channel_ops_recv.recv() => { + let channels = self.proto.channels().expect("connection not ready"); + if let Some(channel_op) = channel_op { + channels.do_operation(channel_op); + } + } + op = self.operations_recv.recv() => { + match op { + Some(Operation::PasswordEntered(password)) => { + if let Some(auth) = self.proto.auth() { + auth.send_password(&password?); + } else { + debug!("Ignoring entered password as the state has moved on"); + } + } + Some(Operation::Signature(result)) => { + let result = result?; + if let Some(auth) = self.proto.auth() { + auth.send_signature(result.key_alg_name, &result.public_key, &result.signature); + } else { + debug!("Ignoring signature as the state has moved on"); + } + } + None => {} + } + self.send_off_data().await?; + } + } + + Ok(()) + } + + async fn send_off_data(&mut self) -> Result<()> { + self.proto.progress(); + while let Some(msg) = self.proto.next_msg_to_send() { + self.stream + .write_all(&msg.to_bytes()) + .await + .wrap_err("writing response")?; + } + Ok(()) + } + + pub fn open_channel(&mut self, kind: ChannelOpen) -> PendingChannel { + let Some(channels) = self.proto.channels() else { + panic!("connection not ready yet") + }; + let (updates_send, updates_recv) = tokio::sync::mpsc::channel(10); + let (ready_send, ready_recv) = tokio::sync::oneshot::channel(); + + let number = channels.create_channel(kind); + + self.channels.insert( + number, + ChannelState::Pending { + ready_send, + updates_send, + }, + ); + + PendingChannel { + ready_recv, + channel: Channel { + number, + updates_recv, + ops_send: self.channel_ops_send.clone(), + }, + } + } +} + +impl PendingChannel { + pub async fn wait_ready(self) -> Result> { + match self.ready_recv.await { + Ok(Ok(())) => Ok(self.channel), + Ok(Err(err)) => Err(Some(err)), + Err(_) => Err(None), + } + } +} + +impl Channel { + pub async fn send_operation(&mut self, op: ChannelOperationKind) -> Result<()> { + self.ops_send + .send(self.number.construct_op(op)) + .await + .map_err(Into::into) + } + + pub async fn next_update(&mut self) -> Result { + self.updates_recv + .recv() + .await + .ok_or_eyre("channel has been closed") + } +} diff --git a/lib/ssh-tokio/src/lib.rs b/lib/ssh-tokio/src/lib.rs new file mode 100644 index 0000000..b9babe5 --- /dev/null +++ b/lib/ssh-tokio/src/lib.rs @@ -0,0 +1 @@ +pub mod client;