mirror of
https://github.com/Noratrieb/cluelessh.git
synced 2026-01-14 16:35:06 +01:00
move orchestration logic into ssh-tokio
This commit is contained in:
parent
9532065b16
commit
ea28daca0c
11 changed files with 477 additions and 209 deletions
15
Cargo.lock
generated
15
Cargo.lock
generated
|
|
@ -1210,6 +1210,7 @@ dependencies = [
|
||||||
"rpassword",
|
"rpassword",
|
||||||
"ssh-agent-client",
|
"ssh-agent-client",
|
||||||
"ssh-protocol",
|
"ssh-protocol",
|
||||||
|
"ssh-tokio",
|
||||||
"ssh-transport",
|
"ssh-transport",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tracing",
|
"tracing",
|
||||||
|
|
@ -1282,11 +1283,25 @@ dependencies = [
|
||||||
name = "ssh-protocol"
|
name = "ssh-protocol"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"rand",
|
||||||
"ssh-connection",
|
"ssh-connection",
|
||||||
"ssh-transport",
|
"ssh-transport",
|
||||||
"tracing",
|
"tracing",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "ssh-tokio"
|
||||||
|
version = "0.1.0"
|
||||||
|
dependencies = [
|
||||||
|
"eyre",
|
||||||
|
"futures",
|
||||||
|
"ssh-connection",
|
||||||
|
"ssh-protocol",
|
||||||
|
"ssh-transport",
|
||||||
|
"tokio",
|
||||||
|
"tracing",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ssh-transport"
|
name = "ssh-transport"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
|
|
|
||||||
|
|
@ -180,6 +180,7 @@ async fn handle_connection(
|
||||||
ChannelRequest::Env { .. } => {}
|
ChannelRequest::Env { .. } => {}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
ChannelUpdateKind::OpenFailed { .. } => todo!(),
|
||||||
ChannelUpdateKind::Data { data } => {
|
ChannelUpdateKind::Data { data } => {
|
||||||
let is_eof = data.contains(&0x04 /*EOF, Ctrl-D*/);
|
let is_eof = data.contains(&0x04 /*EOF, Ctrl-D*/);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ edition = "2021"
|
||||||
ssh-protocol = { path = "../../lib/ssh-protocol" }
|
ssh-protocol = { path = "../../lib/ssh-protocol" }
|
||||||
ssh-transport = { path = "../../lib/ssh-transport" }
|
ssh-transport = { path = "../../lib/ssh-transport" }
|
||||||
ssh-agent-client = { path = "../../lib/ssh-agent-client" }
|
ssh-agent-client = { path = "../../lib/ssh-agent-client" }
|
||||||
|
ssh-tokio = { path = "../../lib/ssh-tokio" }
|
||||||
|
|
||||||
clap = { version = "4.5.15", features = ["derive"] }
|
clap = { version = "4.5.15", features = ["derive"] }
|
||||||
eyre = "0.6.12"
|
eyre = "0.6.12"
|
||||||
|
|
|
||||||
|
|
@ -1,32 +1,16 @@
|
||||||
use std::{collections::HashSet, io::Write};
|
use std::{collections::HashSet, sync::Arc};
|
||||||
|
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
|
|
||||||
use eyre::{bail, Context, ContextCompat, OptionExt};
|
use eyre::{bail, Context, ContextCompat, OptionExt, Result};
|
||||||
use rand::RngCore;
|
use ssh_tokio::client::{PendingChannel, SignatureResult};
|
||||||
use ssh_transport::{key::PublicKey, numbers, parse::Writer, peer_error};
|
use ssh_transport::{key::PublicKey, numbers, parse::Writer};
|
||||||
use tokio::{
|
use tokio::net::TcpStream;
|
||||||
io::{AsyncReadExt, AsyncWriteExt},
|
use tracing::{debug, error};
|
||||||
net::TcpStream,
|
|
||||||
};
|
|
||||||
use tracing::{debug, error, info};
|
|
||||||
|
|
||||||
use ssh_protocol::{
|
use ssh_protocol::connection::{ChannelOpen, ChannelOperationKind, ChannelRequest};
|
||||||
connection::{
|
|
||||||
ChannelNumber, ChannelOpen, ChannelOperation, ChannelOperationKind, ChannelRequest,
|
|
||||||
},
|
|
||||||
transport::{self},
|
|
||||||
ChannelUpdate, ChannelUpdateKind, SshStatus,
|
|
||||||
};
|
|
||||||
use tracing_subscriber::EnvFilter;
|
use tracing_subscriber::EnvFilter;
|
||||||
|
|
||||||
struct ThreadRngRand;
|
|
||||||
impl ssh_protocol::transport::SshRng for ThreadRngRand {
|
|
||||||
fn fill_bytes(&mut self, dest: &mut [u8]) {
|
|
||||||
rand::thread_rng().fill_bytes(dest);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(clap::Parser, Debug)]
|
#[derive(clap::Parser, Debug)]
|
||||||
struct Args {
|
struct Args {
|
||||||
#[arg(short = 'p', long, default_value_t = 22)]
|
#[arg(short = 'p', long, default_value_t = 22)]
|
||||||
|
|
@ -37,22 +21,6 @@ struct Args {
|
||||||
command: Vec<String>,
|
command: Vec<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
enum Operation {
|
|
||||||
PasswordEntered(std::io::Result<String>),
|
|
||||||
Signature {
|
|
||||||
key_alg_name: &'static str,
|
|
||||||
public_key: Vec<u8>,
|
|
||||||
signature: Vec<u8>,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: state machine everything including auth
|
|
||||||
enum ClientState {
|
|
||||||
Start,
|
|
||||||
WaitingForOpen(ChannelNumber),
|
|
||||||
WaitingForPty(ChannelNumber),
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> eyre::Result<()> {
|
async fn main() -> eyre::Result<()> {
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
|
|
@ -77,183 +45,112 @@ async fn main() -> eyre::Result<()> {
|
||||||
Some(user) => user,
|
Some(user) => user,
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut attempted_public_keys = HashSet::new();
|
let conn = TcpStream::connect(&format!("{}:{}", args.destination, args.port))
|
||||||
|
|
||||||
let mut conn = TcpStream::connect(&format!("{}:{}", args.destination, args.port))
|
|
||||||
.await
|
.await
|
||||||
.wrap_err("connecting")?;
|
.wrap_err("connecting")?;
|
||||||
|
|
||||||
let mut state = ssh_protocol::ClientConnection::new(
|
let username1 = username.clone();
|
||||||
transport::client::ClientConnection::new(ThreadRngRand),
|
let mut tokio_conn = ssh_tokio::client::ClientConnection::connect(
|
||||||
ssh_protocol::auth::ClientAuth::new(username.as_bytes().to_vec()),
|
conn,
|
||||||
);
|
ssh_tokio::client::ClientAuth {
|
||||||
|
username: username.clone(),
|
||||||
|
prompt_password: Arc::new(move || {
|
||||||
|
let username = username1.clone();
|
||||||
|
let destination = args.destination.clone();
|
||||||
|
Box::pin(async {
|
||||||
|
let result = tokio::task::spawn_blocking(move || {
|
||||||
|
let password = rpassword::prompt_password(format!(
|
||||||
|
"{}@{}'s password: ",
|
||||||
|
username, destination
|
||||||
|
));
|
||||||
|
password
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
result.wrap_err("failed to prompt password")
|
||||||
|
})
|
||||||
|
}),
|
||||||
|
sign_pubkey: Arc::new(move |session_identifier| {
|
||||||
|
let session_identifier = session_identifier.to_vec();
|
||||||
|
let mut attempted_public_keys = HashSet::new();
|
||||||
|
let username = username.clone();
|
||||||
|
Box::pin(async move {
|
||||||
|
// TODO: support agentless manual key opening
|
||||||
|
// TODO: move
|
||||||
|
let mut agent = ssh_agent_client::SocketAgentConnection::from_env()
|
||||||
|
.await
|
||||||
|
.wrap_err("failed to connect to SSH agent")?;
|
||||||
|
let identities = agent.list_identities().await?;
|
||||||
|
for identity in &identities {
|
||||||
|
let pubkey = PublicKey::from_wire_encoding(&identity.key_blob)
|
||||||
|
.wrap_err("received invalid public key from SSH agent")?;
|
||||||
|
debug!(comment = ?identity.comment, %pubkey, "Found identity");
|
||||||
|
}
|
||||||
|
if identities.len() > 1 {
|
||||||
|
todo!("try identities");
|
||||||
|
}
|
||||||
|
let identity = &identities[0];
|
||||||
|
if !attempted_public_keys.insert(identity.key_blob.clone()) {
|
||||||
|
bail!("authentication denied (publickey)");
|
||||||
|
}
|
||||||
|
let pubkey = PublicKey::from_wire_encoding(&identity.key_blob)?;
|
||||||
|
|
||||||
let mut client_state = ClientState::Start;
|
let mut sign_data = Writer::new();
|
||||||
|
sign_data.string(session_identifier);
|
||||||
|
sign_data.u8(numbers::SSH_MSG_USERAUTH_REQUEST);
|
||||||
|
sign_data.string(&username);
|
||||||
|
sign_data.string("ssh-connection");
|
||||||
|
sign_data.string("publickey");
|
||||||
|
sign_data.bool(true);
|
||||||
|
sign_data.string(pubkey.algorithm_name());
|
||||||
|
sign_data.string(&identity.key_blob);
|
||||||
|
|
||||||
let (send_op, mut recv_op) = tokio::sync::mpsc::channel::<Operation>(10);
|
let data = sign_data.finish();
|
||||||
|
let signature = agent
|
||||||
|
.sign(&identity.key_blob, &data, 0)
|
||||||
|
.await
|
||||||
|
.wrap_err("signing for authentication")?;
|
||||||
|
|
||||||
let mut buf = [0; 1024];
|
Ok(SignatureResult {
|
||||||
|
key_alg_name: pubkey.algorithm_name(),
|
||||||
|
public_key: identity.key_blob.clone(),
|
||||||
|
signature,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let session = tokio_conn.open_channel(ChannelOpen::Session);
|
||||||
|
|
||||||
|
tokio::spawn(async {
|
||||||
|
let result = main_channel(session).await;
|
||||||
|
if let Err(err) = result {
|
||||||
|
error!(?err);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
if let Some(auth) = state.auth() {
|
tokio_conn.progress().await?;
|
||||||
for req in auth.user_requests() {
|
|
||||||
match req {
|
|
||||||
ssh_protocol::auth::ClientUserRequest::Password => {
|
|
||||||
let username = username.clone();
|
|
||||||
let destination = args.destination.clone();
|
|
||||||
let send_op = send_op.clone();
|
|
||||||
std::thread::spawn(move || {
|
|
||||||
let password = rpassword::prompt_password(format!(
|
|
||||||
"{}@{}'s password: ",
|
|
||||||
username, destination
|
|
||||||
));
|
|
||||||
let _ = send_op.blocking_send(Operation::PasswordEntered(password));
|
|
||||||
});
|
|
||||||
}
|
|
||||||
ssh_protocol::auth::ClientUserRequest::PrivateKeySign {
|
|
||||||
session_identifier,
|
|
||||||
} => {
|
|
||||||
// TODO: support agentless manual key opening
|
|
||||||
// TODO: move
|
|
||||||
let mut agent = ssh_agent_client::SocketAgentConnection::from_env()
|
|
||||||
.await
|
|
||||||
.wrap_err("failed to connect to SSH agent")?;
|
|
||||||
let identities = agent.list_identities().await?;
|
|
||||||
for identity in &identities {
|
|
||||||
let pubkey = PublicKey::from_wire_encoding(&identity.key_blob)
|
|
||||||
.wrap_err("received invalid public key from SSH agent")?;
|
|
||||||
debug!(comment = ?identity.comment, %pubkey, "Found identity");
|
|
||||||
}
|
|
||||||
if identities.len() > 1 {
|
|
||||||
todo!("try identities");
|
|
||||||
}
|
|
||||||
let identity = &identities[0];
|
|
||||||
if !attempted_public_keys.insert(identity.key_blob.clone()) {
|
|
||||||
bail!("authentication denied (publickey)");
|
|
||||||
}
|
|
||||||
let pubkey = PublicKey::from_wire_encoding(&identity.key_blob)?;
|
|
||||||
|
|
||||||
let mut sign_data = Writer::new();
|
|
||||||
sign_data.string(session_identifier);
|
|
||||||
sign_data.u8(numbers::SSH_MSG_USERAUTH_REQUEST);
|
|
||||||
sign_data.string(&username);
|
|
||||||
sign_data.string("ssh-connection");
|
|
||||||
sign_data.string("publickey");
|
|
||||||
sign_data.bool(true);
|
|
||||||
sign_data.string(pubkey.algorithm_name());
|
|
||||||
sign_data.string(&identity.key_blob);
|
|
||||||
|
|
||||||
let data = sign_data.finish();
|
|
||||||
let signature = agent
|
|
||||||
.sign(&identity.key_blob, &data, 0)
|
|
||||||
.await
|
|
||||||
.wrap_err("signing for authentication")?;
|
|
||||||
|
|
||||||
send_op
|
|
||||||
.send(Operation::Signature {
|
|
||||||
key_alg_name: pubkey.algorithm_name(),
|
|
||||||
public_key: identity.key_blob.clone(),
|
|
||||||
signature,
|
|
||||||
})
|
|
||||||
.await?;
|
|
||||||
}
|
|
||||||
ssh_protocol::auth::ClientUserRequest::Banner(banner) => {
|
|
||||||
let banner = String::from_utf8_lossy(&banner);
|
|
||||||
std::io::stdout().write(&banner.as_bytes())?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(channels) = state.channels() {
|
|
||||||
if let ClientState::Start = client_state {
|
|
||||||
let number = channels.create_channel(ChannelOpen::Session);
|
|
||||||
client_state = ClientState::WaitingForOpen(number);
|
|
||||||
}
|
|
||||||
|
|
||||||
while let Some(update) = channels.next_channel_update() {
|
|
||||||
match &update.kind {
|
|
||||||
ChannelUpdateKind::Open(_) => match client_state {
|
|
||||||
ClientState::WaitingForOpen(number) => {
|
|
||||||
if number != update.number {
|
|
||||||
bail!("unexpected channel opened by server");
|
|
||||||
}
|
|
||||||
client_state = ClientState::WaitingForPty(update.number);
|
|
||||||
channels.do_operation(number.construct_op(
|
|
||||||
ChannelOperationKind::Request(ChannelRequest::PtyReq {
|
|
||||||
want_reply: true,
|
|
||||||
term: "xterm-256color".to_owned(),
|
|
||||||
width_chars: 70,
|
|
||||||
height_rows: 10,
|
|
||||||
width_px: 0,
|
|
||||||
height_px: 0,
|
|
||||||
term_modes: vec![],
|
|
||||||
}),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
_ => bail!("unexpected channel opened by server"),
|
|
||||||
},
|
|
||||||
ChannelUpdateKind::Success => {}
|
|
||||||
ChannelUpdateKind::Failure => bail!("operation failed"),
|
|
||||||
ChannelUpdateKind::Request(_) => todo!(),
|
|
||||||
ChannelUpdateKind::Data { .. } => todo!(),
|
|
||||||
ChannelUpdateKind::ExtendedData { .. } => todo!(),
|
|
||||||
ChannelUpdateKind::Eof => todo!(),
|
|
||||||
ChannelUpdateKind::Closed => todo!(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Make sure that we send all queues messages before going into the select, waiting for things to happen.
|
|
||||||
state.progress();
|
|
||||||
while let Some(msg) = state.next_msg_to_send() {
|
|
||||||
conn.write_all(&msg.to_bytes())
|
|
||||||
.await
|
|
||||||
.wrap_err("writing response")?;
|
|
||||||
}
|
|
||||||
|
|
||||||
tokio::select! {
|
|
||||||
read = conn.read(&mut buf) => {
|
|
||||||
let read = read.wrap_err("reading from connection")?;
|
|
||||||
if read == 0 {
|
|
||||||
info!("Did not read any bytes from TCP stream, EOF");
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
if let Err(err) = state.recv_bytes(&buf[..read]) {
|
|
||||||
match err {
|
|
||||||
SshStatus::PeerError(err) => {
|
|
||||||
error!(?err, "disconnecting client after invalid operation");
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
SshStatus::Disconnect => {
|
|
||||||
error!("Received disconnect from server");
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
op = recv_op.recv() => {
|
|
||||||
match op {
|
|
||||||
Some(Operation::PasswordEntered(password)) => {
|
|
||||||
if let Some(auth) = state.auth() {
|
|
||||||
auth.send_password(&password?);
|
|
||||||
} else {
|
|
||||||
debug!("Ignoring entered password as the state has moved on");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Some(Operation::Signature{
|
|
||||||
key_alg_name, public_key, signature,
|
|
||||||
}) => {
|
|
||||||
if let Some(auth) = state.auth() {
|
|
||||||
auth.send_signature(key_alg_name, &public_key, &signature);
|
|
||||||
} else {
|
|
||||||
debug!("Ignoring signature as the state has moved on");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
None => {}
|
|
||||||
}
|
|
||||||
state.progress();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn main_channel(channel: PendingChannel) -> Result<()> {
|
||||||
|
let Ok(mut channel) = channel.wait_ready().await else {
|
||||||
|
bail!("failed to create channel");
|
||||||
|
};
|
||||||
|
|
||||||
|
channel
|
||||||
|
.send_operation(ChannelOperationKind::Request(ChannelRequest::PtyReq {
|
||||||
|
want_reply: true,
|
||||||
|
term: "xterm-256color".to_owned(),
|
||||||
|
width_chars: 70,
|
||||||
|
height_rows: 10,
|
||||||
|
width_px: 0,
|
||||||
|
height_px: 0,
|
||||||
|
term_modes: vec![],
|
||||||
|
}))
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -72,6 +72,7 @@ pub enum ChannelUpdateKind {
|
||||||
Success,
|
Success,
|
||||||
Failure,
|
Failure,
|
||||||
Open(ChannelOpen),
|
Open(ChannelOpen),
|
||||||
|
OpenFailed { code: u32, message: String },
|
||||||
Request(ChannelRequest),
|
Request(ChannelRequest),
|
||||||
Data { data: Vec<u8> },
|
Data { data: Vec<u8> },
|
||||||
ExtendedData { code: u32, data: Vec<u8> },
|
ExtendedData { code: u32, data: Vec<u8> },
|
||||||
|
|
@ -259,6 +260,31 @@ impl ChannelsState {
|
||||||
|
|
||||||
debug!(channel_type = %"session", %our_number, "Successfully opened channel");
|
debug!(channel_type = %"session", %our_number, "Successfully opened channel");
|
||||||
}
|
}
|
||||||
|
numbers::SSH_MSG_CHANNEL_OPEN_FAILURE => {
|
||||||
|
let our_channel = p.u32()?;
|
||||||
|
let our_number = ChannelNumber(our_channel);
|
||||||
|
let Some(&ChannelState::AwaitingConfirmation { .. }) =
|
||||||
|
self.channels.get(&our_number)
|
||||||
|
else {
|
||||||
|
return Err(peer_error!("unknown channel: {our_channel}"));
|
||||||
|
};
|
||||||
|
|
||||||
|
let reason_code = p.u32()?;
|
||||||
|
let reason_msg = p.utf8_string()?;
|
||||||
|
let _language_tag = p.utf8_string()?;
|
||||||
|
|
||||||
|
debug!(%our_number, %reason_code, %reason_msg, "Failed to open channel");
|
||||||
|
|
||||||
|
self.channel_updates.push_back(ChannelUpdate {
|
||||||
|
number: our_number,
|
||||||
|
kind: ChannelUpdateKind::OpenFailed {
|
||||||
|
code: reason_code,
|
||||||
|
message: reason_msg.to_owned(),
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
self.channels.remove(&our_number);
|
||||||
|
}
|
||||||
numbers::SSH_MSG_CHANNEL_WINDOW_ADJUST => {
|
numbers::SSH_MSG_CHANNEL_WINDOW_ADJUST => {
|
||||||
let our_channel = p.u32()?;
|
let our_channel = p.u32()?;
|
||||||
let our_channel = self.validate_channel(our_channel)?;
|
let our_channel = self.validate_channel(our_channel)?;
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ version = "0.1.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
rand = "0.8.5"
|
||||||
ssh-connection = { path = "../ssh-connection" }
|
ssh-connection = { path = "../ssh-connection" }
|
||||||
ssh-transport = { path = "../ssh-transport" }
|
ssh-transport = { path = "../ssh-transport" }
|
||||||
tracing.workspace = true
|
tracing.workspace = true
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,14 @@ pub use ssh_transport as transport;
|
||||||
pub use ssh_transport::{Result, SshStatus};
|
pub use ssh_transport::{Result, SshStatus};
|
||||||
use tracing::debug;
|
use tracing::debug;
|
||||||
|
|
||||||
|
pub struct ThreadRngRand;
|
||||||
|
impl transport::SshRng for ThreadRngRand {
|
||||||
|
fn fill_bytes(&mut self, dest: &mut [u8]) {
|
||||||
|
use rand::RngCore;
|
||||||
|
rand::thread_rng().fill_bytes(dest);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub struct ServerConnection {
|
pub struct ServerConnection {
|
||||||
transport: ssh_transport::server::ServerConnection,
|
transport: ssh_transport::server::ServerConnection,
|
||||||
state: ServerConnectionState,
|
state: ServerConnectionState,
|
||||||
|
|
|
||||||
13
lib/ssh-tokio/Cargo.toml
Normal file
13
lib/ssh-tokio/Cargo.toml
Normal file
|
|
@ -0,0 +1,13 @@
|
||||||
|
[package]
|
||||||
|
name = "ssh-tokio"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
eyre = "0.6.12"
|
||||||
|
ssh-transport = { path = "../ssh-transport" }
|
||||||
|
ssh-connection = { path = "../ssh-connection" }
|
||||||
|
ssh-protocol = { path = "../ssh-protocol" }
|
||||||
|
tokio = { version = "1.39.3", features = ["net"] }
|
||||||
|
tracing.workspace = true
|
||||||
|
futures = "0.3.30"
|
||||||
5
lib/ssh-tokio/README.md
Normal file
5
lib/ssh-tokio/README.md
Normal file
|
|
@ -0,0 +1,5 @@
|
||||||
|
# ssh-tokio
|
||||||
|
|
||||||
|
Adapter layer for async Tokio programs.
|
||||||
|
|
||||||
|
Exposes channels as MPSC-like structs.
|
||||||
300
lib/ssh-tokio/src/client.rs
Normal file
300
lib/ssh-tokio/src/client.rs
Normal file
|
|
@ -0,0 +1,300 @@
|
||||||
|
use ssh_connection::{ChannelNumber, ChannelOpen, ChannelOperation, ChannelOperationKind};
|
||||||
|
use std::{collections::HashMap, pin::Pin, sync::Arc};
|
||||||
|
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||||
|
|
||||||
|
use eyre::{bail, ContextCompat, OptionExt, Result, WrapErr};
|
||||||
|
use futures::future::BoxFuture;
|
||||||
|
use ssh_protocol::{ChannelUpdateKind, SshStatus};
|
||||||
|
use tokio::io::{AsyncRead, AsyncWrite};
|
||||||
|
use tracing::{debug, info, warn};
|
||||||
|
|
||||||
|
pub struct ClientConnection<S> {
|
||||||
|
stream: Pin<Box<S>>,
|
||||||
|
buf: [u8; 1024],
|
||||||
|
|
||||||
|
proto: ssh_protocol::ClientConnection,
|
||||||
|
operations_send: tokio::sync::mpsc::Sender<Operation>,
|
||||||
|
operations_recv: tokio::sync::mpsc::Receiver<Operation>,
|
||||||
|
|
||||||
|
/// Cloned and passed on to channels.
|
||||||
|
channel_ops_send: tokio::sync::mpsc::Sender<ChannelOperation>,
|
||||||
|
channel_ops_recv: tokio::sync::mpsc::Receiver<ChannelOperation>,
|
||||||
|
|
||||||
|
channels: HashMap<ChannelNumber, ChannelState>,
|
||||||
|
|
||||||
|
auth: ClientAuth,
|
||||||
|
}
|
||||||
|
|
||||||
|
enum ChannelState {
|
||||||
|
Pending {
|
||||||
|
ready_send: tokio::sync::oneshot::Sender<Result<(), String>>,
|
||||||
|
updates_send: tokio::sync::mpsc::Sender<ChannelUpdateKind>,
|
||||||
|
},
|
||||||
|
Ready(tokio::sync::mpsc::Sender<ChannelUpdateKind>),
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ClientAuth {
|
||||||
|
pub username: String,
|
||||||
|
pub prompt_password: Arc<dyn Fn() -> BoxFuture<'static, Result<String>> + Send + Sync>,
|
||||||
|
pub sign_pubkey:
|
||||||
|
Arc<dyn Fn(&[u8]) -> BoxFuture<'static, Result<SignatureResult>> + Send + Sync>,
|
||||||
|
}
|
||||||
|
|
||||||
|
enum Operation {
|
||||||
|
PasswordEntered(Result<String>),
|
||||||
|
Signature(Result<SignatureResult>),
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct SignatureResult {
|
||||||
|
pub key_alg_name: &'static str,
|
||||||
|
pub public_key: Vec<u8>,
|
||||||
|
pub signature: Vec<u8>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct PendingChannel {
|
||||||
|
ready_recv: tokio::sync::oneshot::Receiver<Result<(), String>>,
|
||||||
|
channel: Channel,
|
||||||
|
}
|
||||||
|
pub struct Channel {
|
||||||
|
number: ChannelNumber,
|
||||||
|
updates_recv: tokio::sync::mpsc::Receiver<ChannelUpdateKind>,
|
||||||
|
ops_send: tokio::sync::mpsc::Sender<ChannelOperation>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S: AsyncRead + AsyncWrite> ClientConnection<S> {
|
||||||
|
pub async fn connect(stream: S, auth: ClientAuth) -> Result<Self> {
|
||||||
|
let (operations_send, operations_recv) = tokio::sync::mpsc::channel(15);
|
||||||
|
let (channel_ops_send, channel_ops_recv) = tokio::sync::mpsc::channel(15);
|
||||||
|
|
||||||
|
let mut this = Self {
|
||||||
|
stream: Box::pin(stream),
|
||||||
|
buf: [0; 1024],
|
||||||
|
operations_send,
|
||||||
|
operations_recv,
|
||||||
|
channel_ops_send,
|
||||||
|
channel_ops_recv,
|
||||||
|
channels: HashMap::new(),
|
||||||
|
proto: ssh_protocol::ClientConnection::new(
|
||||||
|
ssh_transport::client::ClientConnection::new(ssh_protocol::ThreadRngRand),
|
||||||
|
ssh_protocol::auth::ClientAuth::new(auth.username.as_bytes().to_vec()),
|
||||||
|
),
|
||||||
|
auth,
|
||||||
|
};
|
||||||
|
|
||||||
|
while !this.proto.is_open() {
|
||||||
|
this.progress().await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(this)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Executes one loop iteration of the main loop.
|
||||||
|
// IMPORTANT: no operations on this struct should ever block the main loop, except this one.
|
||||||
|
pub async fn progress(&mut self) -> Result<()> {
|
||||||
|
if let Some(auth) = self.proto.auth() {
|
||||||
|
for req in auth.user_requests() {
|
||||||
|
match req {
|
||||||
|
ssh_protocol::auth::ClientUserRequest::Password => {
|
||||||
|
let send = self.operations_send.clone();
|
||||||
|
let prompt_password = self.auth.prompt_password.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let password = prompt_password().await;
|
||||||
|
let _ = send.send(Operation::PasswordEntered(password)).await;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
ssh_protocol::auth::ClientUserRequest::PrivateKeySign {
|
||||||
|
session_identifier,
|
||||||
|
} => {
|
||||||
|
let send = self.operations_send.clone();
|
||||||
|
let sign_pubkey = self.auth.sign_pubkey.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let signature_result = sign_pubkey(&session_identifier).await;
|
||||||
|
let _ = send.send(Operation::Signature(signature_result)).await;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
ssh_protocol::auth::ClientUserRequest::Banner(_) => {
|
||||||
|
warn!("ignoring banner as it's not implemented...");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(channels) = self.proto.channels() {
|
||||||
|
while let Some(update) = channels.next_channel_update() {
|
||||||
|
match &update.kind {
|
||||||
|
ChannelUpdateKind::Open(_) => {
|
||||||
|
let channel = self
|
||||||
|
.channels
|
||||||
|
.get_mut(&update.number)
|
||||||
|
.wrap_err("unknown channel")?;
|
||||||
|
match channel {
|
||||||
|
ChannelState::Pending { updates_send, .. } => {
|
||||||
|
let updates_send = updates_send.clone();
|
||||||
|
let old = self
|
||||||
|
.channels
|
||||||
|
.insert(update.number, ChannelState::Ready(updates_send));
|
||||||
|
match old.unwrap() {
|
||||||
|
ChannelState::Pending { ready_send, .. } => {
|
||||||
|
let _ = ready_send.send(Ok(()));
|
||||||
|
}
|
||||||
|
_ => unreachable!(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ChannelState::Ready(_) => {
|
||||||
|
bail!("attemping to open channel twice: {}", update.number);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ChannelUpdateKind::OpenFailed { message, .. } => {
|
||||||
|
let channel = self
|
||||||
|
.channels
|
||||||
|
.get_mut(&update.number)
|
||||||
|
.wrap_err("unknown channel")?;
|
||||||
|
match channel {
|
||||||
|
ChannelState::Pending { .. } => {
|
||||||
|
let old = self.channels.remove(&update.number);
|
||||||
|
match old.unwrap() {
|
||||||
|
ChannelState::Pending { ready_send, .. } => {
|
||||||
|
let _ = ready_send.send(Err(message.clone()));
|
||||||
|
}
|
||||||
|
_ => unreachable!(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ChannelState::Ready(_) => {
|
||||||
|
bail!("attemping to open channel twice: {}", update.number);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
let channel = self
|
||||||
|
.channels
|
||||||
|
.get_mut(&update.number)
|
||||||
|
.wrap_err("unknown channel")?;
|
||||||
|
match channel {
|
||||||
|
ChannelState::Pending { .. } => bail!("channel not ready yet"),
|
||||||
|
ChannelState::Ready(updates_send) => {
|
||||||
|
let _ = updates_send.send(update.kind).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make sure that we send all queues messages before going into the select, waiting for things to happen.
|
||||||
|
self.send_off_data().await?;
|
||||||
|
|
||||||
|
tokio::select! {
|
||||||
|
read = self.stream.read(&mut self.buf) => {
|
||||||
|
let read = read.wrap_err("reading from connection")?;
|
||||||
|
if read == 0 {
|
||||||
|
info!("Did not read any bytes from TCP stream, EOF");
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
if let Err(err) = self.proto.recv_bytes(&self.buf[..read]) {
|
||||||
|
match err {
|
||||||
|
SshStatus::PeerError(err) => {
|
||||||
|
bail!("disconnecting client after invalid operation: {err}");
|
||||||
|
}
|
||||||
|
SshStatus::Disconnect => {
|
||||||
|
bail!("Received disconnect from server");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
channel_op = self.channel_ops_recv.recv() => {
|
||||||
|
let channels = self.proto.channels().expect("connection not ready");
|
||||||
|
if let Some(channel_op) = channel_op {
|
||||||
|
channels.do_operation(channel_op);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
op = self.operations_recv.recv() => {
|
||||||
|
match op {
|
||||||
|
Some(Operation::PasswordEntered(password)) => {
|
||||||
|
if let Some(auth) = self.proto.auth() {
|
||||||
|
auth.send_password(&password?);
|
||||||
|
} else {
|
||||||
|
debug!("Ignoring entered password as the state has moved on");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Some(Operation::Signature(result)) => {
|
||||||
|
let result = result?;
|
||||||
|
if let Some(auth) = self.proto.auth() {
|
||||||
|
auth.send_signature(result.key_alg_name, &result.public_key, &result.signature);
|
||||||
|
} else {
|
||||||
|
debug!("Ignoring signature as the state has moved on");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None => {}
|
||||||
|
}
|
||||||
|
self.send_off_data().await?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn send_off_data(&mut self) -> Result<()> {
|
||||||
|
self.proto.progress();
|
||||||
|
while let Some(msg) = self.proto.next_msg_to_send() {
|
||||||
|
self.stream
|
||||||
|
.write_all(&msg.to_bytes())
|
||||||
|
.await
|
||||||
|
.wrap_err("writing response")?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn open_channel(&mut self, kind: ChannelOpen) -> PendingChannel {
|
||||||
|
let Some(channels) = self.proto.channels() else {
|
||||||
|
panic!("connection not ready yet")
|
||||||
|
};
|
||||||
|
let (updates_send, updates_recv) = tokio::sync::mpsc::channel(10);
|
||||||
|
let (ready_send, ready_recv) = tokio::sync::oneshot::channel();
|
||||||
|
|
||||||
|
let number = channels.create_channel(kind);
|
||||||
|
|
||||||
|
self.channels.insert(
|
||||||
|
number,
|
||||||
|
ChannelState::Pending {
|
||||||
|
ready_send,
|
||||||
|
updates_send,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
PendingChannel {
|
||||||
|
ready_recv,
|
||||||
|
channel: Channel {
|
||||||
|
number,
|
||||||
|
updates_recv,
|
||||||
|
ops_send: self.channel_ops_send.clone(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PendingChannel {
|
||||||
|
pub async fn wait_ready(self) -> Result<Channel, Option<String>> {
|
||||||
|
match self.ready_recv.await {
|
||||||
|
Ok(Ok(())) => Ok(self.channel),
|
||||||
|
Ok(Err(err)) => Err(Some(err)),
|
||||||
|
Err(_) => Err(None),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Channel {
|
||||||
|
pub async fn send_operation(&mut self, op: ChannelOperationKind) -> Result<()> {
|
||||||
|
self.ops_send
|
||||||
|
.send(self.number.construct_op(op))
|
||||||
|
.await
|
||||||
|
.map_err(Into::into)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn next_update(&mut self) -> Result<ChannelUpdateKind> {
|
||||||
|
self.updates_recv
|
||||||
|
.recv()
|
||||||
|
.await
|
||||||
|
.ok_or_eyre("channel has been closed")
|
||||||
|
}
|
||||||
|
}
|
||||||
1
lib/ssh-tokio/src/lib.rs
Normal file
1
lib/ssh-tokio/src/lib.rs
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
pub mod client;
|
||||||
Loading…
Add table
Add a link
Reference in a new issue