diff --git a/Cargo.lock b/Cargo.lock index 224796c..3ad70b7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -110,6 +110,15 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "atomic-polyfill" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8cf2bce30dfe09ef0bfaef228b9d414faaf7e563035494d7fe092dba54b300f4" +dependencies = [ + "critical-section", +] + [[package]] name = "autocfg" version = "1.3.0" @@ -396,6 +405,7 @@ dependencies = [ "p256", "pem", "rand", + "serde", "thiserror", "tracing", ] @@ -406,6 +416,7 @@ version = "0.1.0" dependencies = [ "cluelessh-connection", "cluelessh-format", + "cluelessh-keys", "cluelessh-transport", "rand", "tracing", @@ -416,6 +427,7 @@ name = "cluelessh-tokio" version = "0.1.0" dependencies = [ "cluelessh-connection", + "cluelessh-keys", "cluelessh-protocol", "cluelessh-transport", "eyre", @@ -451,6 +463,7 @@ dependencies = [ name = "cluelesshd" version = "0.1.0" dependencies = [ + "clap", "cluelessh-format", "cluelessh-keys", "cluelessh-protocol", @@ -458,6 +471,7 @@ dependencies = [ "cluelessh-transport", "eyre", "futures", + "postcard", "rustix", "serde", "thiserror", @@ -468,6 +482,12 @@ dependencies = [ "users", ] +[[package]] +name = "cobs" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67ba02a97a2bd10f4b59b25c7973101c79642302776489e030cd13cdab09ed15" + [[package]] name = "colorchoice" version = "1.0.2" @@ -489,6 +509,12 @@ dependencies = [ "libc", ] +[[package]] +name = "critical-section" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f64009896348fc5af4222e9cf7d7d82a95a256c634ebcf61c53e4ea461422242" + [[package]] name = "crypto-bigint" version = "0.5.5" @@ -631,6 +657,18 @@ dependencies = [ "zeroize", ] +[[package]] +name = "embedded-io" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef1a6892d9eef45c8fa6b9e0086428a2cca8491aca8f787c534a3d6d0bcb3ced" + +[[package]] +name = "embedded-io" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "edd0f118536f44f5ccd48bcb8b111bdc3de888b58c74639dfb034a357d0f206d" + [[package]] name = "equivalent" version = "1.0.1" @@ -811,12 +849,35 @@ dependencies = [ "subtle", ] +[[package]] +name = "hash32" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0c35f58762feb77d74ebe43bdbc3210f09be9fe6742234d573bacc26ed92b67" +dependencies = [ + "byteorder", +] + [[package]] name = "hashbrown" version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" +[[package]] +name = "heapless" +version = "0.7.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cdc6457c0eb62c71aac4bc17216026d8410337c4126773b9c5daba343f17964f" +dependencies = [ + "atomic-polyfill", + "hash32", + "rustc_version", + "serde", + "spin", + "stable_deref_trait", +] + [[package]] name = "heck" version = "0.5.0" @@ -1111,6 +1172,19 @@ dependencies = [ "universal-hash", ] +[[package]] +name = "postcard" +version = "1.0.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f7f0a8d620d71c457dd1d47df76bb18960378da56af4527aaa10f515eee732e" +dependencies = [ + "cobs", + "embedded-io 0.4.0", + "embedded-io 0.6.1", + "heapless", + "serde", +] + [[package]] name = "ppv-lite86" version = "0.2.20" @@ -1434,6 +1508,15 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "spin" +version = "0.9.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" +dependencies = [ + "lock_api", +] + [[package]] name = "spki" version = "0.7.3" @@ -1444,6 +1527,12 @@ dependencies = [ "der", ] +[[package]] +name = "stable_deref_trait" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" + [[package]] name = "strsim" version = "0.11.1" diff --git a/bin/cluelessh-faked/src/main.rs b/bin/cluelessh-faked/src/main.rs index 02598df..5dfec82 100644 --- a/bin/cluelessh-faked/src/main.rs +++ b/bin/cluelessh-faked/src/main.rs @@ -3,8 +3,8 @@ mod readline; use std::{net::SocketAddr, sync::Arc}; use cluelessh_keys::private::EncryptedPrivateKeys; -use cluelessh_tokio::{server::ServerAuthVerify, Channel}; -use eyre::{Context, Result}; +use cluelessh_tokio::{server::ServerAuth, Channel}; +use eyre::{Context, OptionExt, Result}; use tokio::{ net::{TcpListener, TcpStream}, sync::Mutex, @@ -40,7 +40,25 @@ async fn main() -> eyre::Result<()> { let listener = TcpListener::bind(addr).await.wrap_err("binding listener")?; - let auth_verify = ServerAuthVerify { + let host_keys = vec![ + EncryptedPrivateKeys::parse(ED25519_PRIVKEY.as_bytes()) + .unwrap() + .decrypt(None) + .unwrap() + .remove(0), + EncryptedPrivateKeys::parse(ECDSA_PRIVKEY.as_bytes()) + .unwrap() + .decrypt(None) + .unwrap() + .remove(0), + ]; + + let pub_host_keys = host_keys + .iter() + .map(|key| key.private_key.public_key()) + .collect::>(); + + let auth_verify = ServerAuth { verify_password: Some(Arc::new(|auth| { Box::pin(async move { info!(password = %auth.password, "Got password"); @@ -59,21 +77,21 @@ async fn main() -> eyre::Result<()> { !! DO NOT ENTER PASSWORDS YOU DON'T WANT STOLEN !!\r\n" .to_owned(), ), + sign_with_hostkey: Arc::new(move |msg| { + let host_keys = host_keys.clone(); + Box::pin(async move { + let private = host_keys + .iter() + .find(|privkey| privkey.private_key.public_key() == msg.public_key) + .ok_or_eyre("missing private key")?; + + Ok(private.private_key.sign(&msg.hash)) + }) + }), }; let transport_config = cluelessh_protocol::transport::server::ServerConfig { - host_keys: vec![ - EncryptedPrivateKeys::parse(ED25519_PRIVKEY.as_bytes()) - .unwrap() - .decrypt(None) - .unwrap() - .remove(0), - EncryptedPrivateKeys::parse(ECDSA_PRIVKEY.as_bytes()) - .unwrap() - .decrypt(None) - .unwrap() - .remove(0), - ], + host_keys: pub_host_keys, }; let mut listener = @@ -187,9 +205,11 @@ async fn handle_session_channel( } let result = execute_command(&command); - channel - .send(ChannelOperationKind::Data(result.stdout)) - .await?; + if !result.stdout.is_empty() { + channel + .send(ChannelOperationKind::Data(result.stdout)) + .await?; + } channel .send(ChannelOperationKind::Request(ChannelRequest::ExitStatus { status: result.status, @@ -221,6 +241,7 @@ async fn handle_session_channel( readline.recv_bytes(&data); let to_write = readline.bytes_to_write(); if !to_write.is_empty() { + // TODO: introduce helper to Channel that allows writing 0 data channel.send(ChannelOperationKind::Data(to_write)).await?; } diff --git a/bin/cluelesshd/Cargo.toml b/bin/cluelesshd/Cargo.toml index 0175822..f9b0d47 100644 --- a/bin/cluelesshd/Cargo.toml +++ b/bin/cluelesshd/Cargo.toml @@ -19,6 +19,8 @@ thiserror = "1.0.63" cluelessh-keys = { version = "0.1.0", path = "../../lib/cluelessh-keys" } serde = { version = "1.0.209", features = ["derive"] } toml = "0.8.19" +clap = { version = "4.5.16", features = ["derive"] } +postcard = { version = "1.0.10", features = ["alloc"] } [lints] workspace = true diff --git a/bin/cluelesshd/src/config.rs b/bin/cluelesshd/src/config.rs index b0e8bd5..6c29670 100644 --- a/bin/cluelesshd/src/config.rs +++ b/bin/cluelesshd/src/config.rs @@ -1,11 +1,13 @@ use eyre::{Context, Result}; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use std::{ net::{IpAddr, Ipv4Addr}, path::PathBuf, }; -#[derive(Deserialize)] +use crate::Args; + +#[derive(Clone, Serialize, Deserialize)] #[serde(deny_unknown_fields)] pub struct Config { #[serde(default = "default_info")] @@ -14,7 +16,7 @@ pub struct Config { pub auth: AuthConfig, } -#[derive(Deserialize)] +#[derive(Clone, Serialize, Deserialize)] #[serde(deny_unknown_fields)] pub struct NetConfig { #[serde(default = "addr_default")] @@ -23,7 +25,7 @@ pub struct NetConfig { pub port: u16, } -#[derive(Deserialize)] +#[derive(Clone, Serialize, Deserialize)] #[serde(deny_unknown_fields)] pub struct AuthConfig { pub host_keys: Vec, @@ -33,15 +35,18 @@ pub struct AuthConfig { } impl Config { - pub fn find() -> Result { - let path = - std::env::var("CLUELESSHD_CONFIG").unwrap_or_else(|_| "cluelesshd.toml".to_owned()); + pub fn find(args: &Args) -> Result { + let path = std::env::var("CLUELESSHD_CONFIG") + .map(PathBuf::from) + .or(args.config.clone().ok_or(std::env::VarError::NotPresent)) + .unwrap_or_else(|_| PathBuf::from("cluelesshd.toml")); let content = std::fs::read_to_string(&path).wrap_err_with(|| { - format!("failed to open config file '{path}', refusing to start. you can change the config file path with the CLUELESSHD_CONFIG environment variable") + format!("failed to open config file '{}', refusing to start. you can change the config file path with the --config arg or the CLUELESSHD_CONFIG environment variable", path.display()) })?; - toml::from_str(&content).wrap_err_with(|| format!("invalid config file '{path}'")) + toml::from_str(&content) + .wrap_err_with(|| format!("invalid config file '{}'", path.display())) } } diff --git a/bin/cluelesshd/src/main.rs b/bin/cluelesshd/src/main.rs index 8d97106..bc609dd 100644 --- a/bin/cluelesshd/src/main.rs +++ b/bin/cluelesshd/src/main.rs @@ -1,22 +1,30 @@ mod auth; mod config; mod pty; +mod rpc; use std::{ - io, + io::{self, Read, Seek, SeekFrom}, + marker::PhantomData, net::SocketAddr, + os::fd::{AsFd, AsRawFd, BorrowedFd, FromRawFd, OwnedFd, RawFd}, path::PathBuf, pin::Pin, process::{ExitStatus, Stdio}, sync::Arc, }; -use cluelessh_keys::{host_keys::HostKeySet, private::EncryptedPrivateKeys}; -use cluelessh_tokio::{server::ServerAuthVerify, Channel}; -use cluelessh_transport::server::ServerConfig; +use clap::Parser; +use cluelessh_keys::{host_keys::HostKeySet, private::EncryptedPrivateKeys, public::PublicKey}; +use cluelessh_tokio::{ + server::{ServerAuth, ServerConnection, SignWithHostKey}, + Channel, +}; +use config::Config; use eyre::{bail, Context, OptionExt, Result}; use pty::Pty; -use rustix::termios::Winsize; +use rustix::{fs::MemfdFlags, termios::Winsize}; +use serde::{Deserialize, Serialize}; use tokio::{ fs::File, io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}, @@ -27,66 +35,324 @@ use tokio::{ use tracing::{debug, error, info, info_span, warn, Instrument}; use cluelessh_protocol::{ + auth::{CheckPubkey, VerifySignature}, connection::{ChannelKind, ChannelOperationKind, ChannelRequest}, ChannelUpdateKind, SshStatus, }; use tracing_subscriber::EnvFilter; use users::os::unix::UserExt; +#[derive(clap::Parser)] +struct Args { + /// The path to the config file + #[arg(long)] + config: Option, +} + +struct MemFd { + fd: std::fs::File, + _data: PhantomData, +} + +impl MemFd { + fn new(data: &T) -> Result { + let fd = rustix::fs::memfd_create("cluelesshd.toml", MemfdFlags::empty()) + .wrap_err("failed to memfd memfd")?; + let mut fd: std::fs::File = std::fs::File::from(fd); + std::io::Write::write_all(&mut fd, &postcard::to_allocvec(data)?) + .wrap_err("failed to write config")?; + + Ok(Self { + fd, + _data: PhantomData, + }) + } + + unsafe fn from_raw_fd(fd: RawFd) -> Result { + let fd = unsafe { std::fs::File::from_raw_fd(fd) }; + Ok(Self { + fd, + _data: PhantomData, + }) + } + + fn read(&mut self) -> Result { + self.fd.seek(SeekFrom::Start(0))?; + let mut data = Vec::new(); + self.fd.read_to_end(&mut data).wrap_err("reading data")?; + postcard::from_bytes(&data).wrap_err("failed to deserialize") + } +} + #[tokio::main(flavor = "current_thread")] async fn main() -> eyre::Result<()> { - let config = config::Config::find()?; + match std::env::var("CLUELESSH_PRIVSEP_PROCESS") { + Ok(privsep_process) => match privsep_process.as_str() { + "connection" => connnection().await, + _ => bail!("unknown CLUELESSH_PRIVSEP_PROCESS: {privsep_process}"), + }, + Err(_) => { + // Initial setup + let args = Args::parse(); + + let config = config::Config::find(&args)?; + + let env_filter = EnvFilter::try_from_default_env() + .unwrap_or_else(|_| EnvFilter::new(&config.log_level)); + + tracing_subscriber::fmt().with_env_filter(env_filter).init(); + let addr: SocketAddr = SocketAddr::new(config.net.ip, config.net.port); + info!(%addr, "Starting server"); + + let listener = TcpListener::bind(addr) + .await + .wrap_err_with(|| format!("trying to listen on {addr}"))?; + + main_process(config, listener).await + } + } +} + +const PRIVSEP_CONNECTION_STATE_FD: RawFd = 3; + +/// The connection state passed to the child in the STATE_FD +#[derive(Serialize, Deserialize)] +struct SerializedConnectionState { + stream_fd: RawFd, + peer_addr: SocketAddr, + pub_host_keys: Vec, + config: Config, + rpc_client_fd: RawFd, +} + +async fn connnection() -> Result<()> { + rustix::fs::fcntl_getfd(unsafe { BorrowedFd::borrow_raw(PRIVSEP_CONNECTION_STATE_FD) }) + .unwrap(); + let mut memfd = + unsafe { MemFd::::from_raw_fd(PRIVSEP_CONNECTION_STATE_FD) } + .wrap_err("failed to open memfd")?; + let state = memfd.read().wrap_err("failed to read state")?; + + let config = state.config; let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(&config.log_level)); - tracing_subscriber::fmt().with_env_filter(env_filter).init(); - let addr = SocketAddr::new(config.net.ip, config.net.port); - info!(%addr, "Starting server"); + let span = info_span!("connection", addr = %state.peer_addr); - let listener = TcpListener::bind(addr) - .await - .wrap_err_with(|| format!("trying to listen on {addr}"))?; + let stream = unsafe { std::net::TcpStream::from_raw_fd(state.stream_fd) }; + let stream = TcpStream::from_std(stream)?; - let auth_verify = ServerAuthVerify { + let host_keys = state.pub_host_keys; + let transport_config = cluelessh_transport::server::ServerConfig { host_keys }; + + let rpc_client = unsafe { OwnedFd::from_raw_fd(state.rpc_client_fd) }; + let rpc_client1 = Arc::new(rpc::Client::from_fd(rpc_client)?); + let rpc_client2 = rpc_client1.clone(); + let rpc_client3 = rpc_client1.clone(); + + let auth_verify = ServerAuth { verify_password: config.auth.password_login.then(|| todo!("password login")), - verify_signature: Some(Arc::new(|auth| Box::pin(auth::verify_signature(auth)))), - check_pubkey: Some(Arc::new(|auth| Box::pin(auth::check_pubkey(auth)))), + verify_signature: Some(Arc::new(move |msg: VerifySignature| { + let rpc_client = rpc_client1.clone(); + Box::pin(async move { + rpc_client + .verify_signature( + msg.user, + msg.session_identifier, + msg.pubkey_alg_name, + msg.pubkey, + msg.signature, + ) + .await + }) + })), + check_pubkey: Some(Arc::new(move |msg: CheckPubkey| { + let rpc_client = rpc_client2.clone(); + Box::pin(async move { + rpc_client + .check_pubkey( + msg.user, + msg.session_identifier, + msg.pubkey_alg_name, + msg.pubkey, + ) + .await + }) + })), auth_banner: config.auth.banner, + sign_with_hostkey: Arc::new(move |msg: SignWithHostKey| { + let rpc_client = rpc_client3.clone(); + Box::pin(async move { rpc_client.sign(msg.hash, msg.public_key).await }) + }), }; + let server_conn = ServerConnection::new(stream, state.peer_addr, auth_verify, transport_config); + + connection_inner(server_conn).instrument(span).await; + + Ok(()) +} + +async fn connection_inner(server_conn: ServerConnection) { + if let Err(err) = handle_connection(server_conn).await { + if let Some(err) = err.downcast_ref::() { + if err.kind() == std::io::ErrorKind::ConnectionReset { + return; + } + } + + error!(?err, "error handling connection"); + } + info!("Finished connection"); +} + +async fn main_process(config: Config, listener: TcpListener) -> Result<()> { let host_keys = load_host_keys(&config.auth.host_keys).await?.into_keys(); if host_keys.is_empty() { bail!("no host keys found"); } - let config = ServerConfig { host_keys }; + let pub_host_keys = host_keys + .iter() + .map(|key| key.private_key.public_key()) + .collect::>(); - let mut listener = cluelessh_tokio::server::ServerListener::new(listener, auth_verify, config); + let auth_operations = ServerAuth { + verify_password: config + .auth + .clone() + .password_login + .then(|| todo!("password login")), + verify_signature: Some(Arc::new(|auth| Box::pin(auth::verify_signature(auth)))), + check_pubkey: Some(Arc::new(|auth| Box::pin(auth::check_pubkey(auth)))), + auth_banner: config.auth.clone().banner, + sign_with_hostkey: Arc::new(move |msg: SignWithHostKey| { + let host_keys = host_keys.clone(); + Box::pin(async move { + let private = host_keys + .iter() + .find(|privkey| privkey.private_key.public_key() == msg.public_key) + .ok_or_eyre("missing private key")?; + + Ok(private.private_key.sign(&msg.hash)) + }) + }), + }; + + // let server_config = ServerConfig { + // host_keys: pub_host_keys, + // }; loop { - let next = listener.accept().await?; - let span = info_span!("connection", addr = %next.peer_addr()); - tokio::spawn( - async move { - if let Err(err) = handle_connection(next).await { - if let Some(err) = err.downcast_ref::() { - if err.kind() == std::io::ErrorKind::ConnectionReset { - return; - } - } + let (next_stream, peer_addr) = listener.accept().await?; - error!(?err, "error handling connection"); - } - info!("Finished connection"); + // let server_conn = cluelessh_tokio::server::ServerConnection::new( + // next_stream, + // peer_addr, + // auth_verify.clone(), + // server_config.clone(), + // ); + + let config = config.clone(); + let pub_host_keys = pub_host_keys.clone(); + let auth_operations = auth_operations.clone(); + tokio::spawn(async move { + let err = spawn_connection_child( + next_stream, + peer_addr, + pub_host_keys, + config, + auth_operations, + ) + .await; + if let Err(err) = err { + error!(?err, "child failed"); } - .instrument(span), - ); + }); + + //tokio::spawn( + // async move { + // if let Err(err) = handle_connection(server_conn).await { + // if let Some(err) = err.downcast_ref::() { + // if err.kind() == std::io::ErrorKind::ConnectionReset { + // return; + // } + // } + // + // error!(?err, "error handling connection"); + // } + // info!("Finished connection"); + // } + // .instrument(span), + //); } } +async fn spawn_connection_child( + stream: TcpStream, + peer_addr: SocketAddr, + pub_host_keys: Vec, + config: Config, + auth_operations: ServerAuth, +) -> Result<()> { + let stream_fd = stream.as_fd(); + + let rpc_server = rpc::Server::new(auth_operations).wrap_err("creating RPC server")?; + + // dup to avoid cloexec + // TODO: we should probably do this in the child? not that it matters that much. + let stream_fd = rustix::io::dup(stream_fd).wrap_err("duping tcp stream")?; + let rpc_client_fd = rustix::io::dup(rpc_server.client_fd()).wrap_err("duping tcp stream")?; + + let config_fd = MemFd::new(&SerializedConnectionState { + stream_fd: stream_fd.as_raw_fd(), + peer_addr, + pub_host_keys, + config, + rpc_client_fd: rpc_client_fd.as_raw_fd(), + })?; + + let exe = std::env::current_exe().wrap_err("failed to get current executable path")?; + let mut cmd = tokio::process::Command::new(exe); + cmd.env("CLUELESSH_PRIVSEP_PROCESS", "connection") + .stdin(Stdio::null()) + .stdout(Stdio::inherit()) + .stderr(Stdio::inherit()); + + unsafe { + let fd = config_fd.fd.as_raw_fd(); + cmd.pre_exec(move || { + let mut state_fd = OwnedFd::from_raw_fd(PRIVSEP_CONNECTION_STATE_FD); + rustix::io::dup2(BorrowedFd::borrow_raw(fd), &mut state_fd)?; + // Ensure that it stays open in the child. + std::mem::forget(state_fd); + Ok(()) + }); + } + + let mut listen_child = cmd.spawn().wrap_err("failed to spawn listener process")?; + + loop { + tokio::select! { + server_err = rpc_server.process() => { + error!(err = ?server_err, "RPC server error"); + } + status = listen_child.wait() => { + let status = status?; + if !status.success() { + bail!("connection child process failed: {}", status); + } + break; + } + } + } + + Ok(()) +} + async fn load_host_keys(keys: &[PathBuf]) -> Result { let mut host_keys = HostKeySet::new(); diff --git a/bin/cluelesshd/src/rpc.rs b/bin/cluelesshd/src/rpc.rs new file mode 100644 index 0000000..fb418c7 --- /dev/null +++ b/bin/cluelesshd/src/rpc.rs @@ -0,0 +1,239 @@ +//! [`postcard`]-based RPC between the different processes. + +use std::os::fd::AsFd; +use std::os::fd::BorrowedFd; +use std::os::fd::OwnedFd; + +use cluelessh_keys::public::PublicKey; +use cluelessh_keys::signature::Signature; +use cluelessh_protocol::auth::CheckPubkey; +use cluelessh_protocol::auth::VerifySignature; +use cluelessh_tokio::server::ServerAuth; +use cluelessh_tokio::server::SignWithHostKey; +use eyre::eyre; +use eyre::Context; +use eyre::Result; +use serde::de::DeserializeOwned; +use serde::{Deserialize, Serialize}; +use tokio::net::UnixDatagram; + +#[derive(Serialize, Deserialize)] +enum Request { + Sign { + hash: [u8; 32], + public_key: PublicKey, + }, + VerifySignature { + user: String, + session_identifier: [u8; 32], + pubkey_alg_name: String, + pubkey: Vec, + signature: Vec, + }, + CheckPubkey { + user: String, + session_identifier: [u8; 32], + pubkey_alg_name: String, + pubkey: Vec, + }, +} + +#[derive(Serialize, Deserialize)] +struct SignResponse { + signature: Result, +} + +#[derive(Serialize, Deserialize)] +struct VerifySignatureResponse { + is_ok: Result, +} + +#[derive(Serialize, Deserialize)] +struct CheckPubkeyResponse { + is_ok: Result, +} + +pub struct Client { + socket: UnixDatagram, +} + +pub struct Server { + server: UnixDatagram, + client: UnixDatagram, + auth_operations: ServerAuth, +} + +impl Server { + pub fn new(auth_operations: ServerAuth) -> Result { + let (server, client) = UnixDatagram::pair().wrap_err("creating socketpair")?; + + Ok(Self { + server, + client, + auth_operations, + }) + } + + pub fn client_fd(&self) -> BorrowedFd<'_> { + self.client.as_fd() + } + + pub async fn process(&self) -> Result<()> { + let mut req = [0; 1024]; + + loop { + let read = self + .server + .recv(&mut req) + .await + .wrap_err("receiving response")?; + + let req = postcard::from_bytes::(&req[..read]).wrap_err("invalid request")?; + + match req { + Request::Sign { hash, public_key } => { + let signature = (self.auth_operations.sign_with_hostkey)(SignWithHostKey { + hash, + public_key, + }) + .await + .map_err(|err| err.to_string()); + + self.respond(SignResponse { signature }).await?; + } + Request::VerifySignature { + user, + session_identifier, + pubkey_alg_name, + pubkey, + signature, + } => { + let Some(verify_signature) = &self.auth_operations.verify_signature else { + self.respond(VerifySignatureResponse { + is_ok: Err("public key login not supported".into()), + }) + .await?; + continue; + }; + let is_ok = verify_signature(VerifySignature { + user, + session_identifier, + pubkey_alg_name, + pubkey, + signature, + }) + .await + .map_err(|err| err.to_string()); + + self.respond(VerifySignatureResponse { is_ok }).await?; + } + Request::CheckPubkey { + user, + session_identifier, + pubkey_alg_name, + pubkey, + } => { + let Some(check_pubkey) = &self.auth_operations.check_pubkey else { + self.respond(VerifySignatureResponse { + is_ok: Err("public key login not supported".into()), + }) + .await?; + continue; + }; + let is_ok = check_pubkey(CheckPubkey { + user, + session_identifier, + pubkey_alg_name, + pubkey, + }) + .await + .map_err(|err| err.to_string()); + + self.respond(CheckPubkeyResponse { is_ok }).await?; + } + } + } + } + + async fn respond(&self, resp: impl Serialize) -> Result<()> { + self.server + .send(&postcard::to_allocvec(&resp)?) + .await + .wrap_err("sending response")?; + Ok(()) + } +} + +impl Client { + pub fn from_fd(fd: OwnedFd) -> Result { + let socket = UnixDatagram::from_std(std::os::unix::net::UnixDatagram::from(fd))?; + Ok(Self { socket }) + } + + pub async fn sign(&self, hash: [u8; 32], public_key: PublicKey) -> Result { + let resp = self + .request_response::(&Request::Sign { hash, public_key }) + .await?; + + resp.signature.map_err(|err| eyre!(err)) + } + + pub async fn check_pubkey( + &self, + user: String, + session_identifier: [u8; 32], + pubkey_alg_name: String, + pubkey: Vec, + ) -> Result { + let resp = self + .request_response::(&Request::CheckPubkey { + user, + session_identifier, + pubkey_alg_name, + pubkey, + }) + .await?; + + resp.is_ok.map_err(|err| eyre!(err)) + } + + pub async fn verify_signature( + &self, + user: String, + session_identifier: [u8; 32], + pubkey_alg_name: String, + pubkey: Vec, + signature: Vec, + ) -> Result { + let resp = self + .request_response::(&Request::VerifySignature { + user, + session_identifier, + pubkey_alg_name, + pubkey, + signature, + }) + .await?; + + resp.is_ok.map_err(|err| eyre!(err)) + } + + async fn request_response(&self, req: &Request) -> Result { + self.socket + .send(&postcard::to_allocvec(&req)?) + .await + .wrap_err("sending request")?; + + let mut resp = [0; 1024]; + let read = self + .socket + .recv(&mut resp) + .await + .wrap_err("receiving response")?; + + let resp = + postcard::from_bytes::(&resp[..read]).wrap_err("invalid signature response")?; + + Ok(resp) + } +} diff --git a/lib/cluelessh-keys/Cargo.toml b/lib/cluelessh-keys/Cargo.toml index 2f32c5c..e9c9064 100644 --- a/lib/cluelessh-keys/Cargo.toml +++ b/lib/cluelessh-keys/Cargo.toml @@ -15,6 +15,7 @@ base64 = "0.22.1" cluelessh-format = { version = "0.1.0", path = "../cluelessh-format" } tracing.workspace = true p256 = "0.13.2" +serde = "1.0.209" [lints] workspace = true diff --git a/lib/cluelessh-keys/src/public.rs b/lib/cluelessh-keys/src/public.rs index 7b408a5..be2bcec 100644 --- a/lib/cluelessh-keys/src/public.rs +++ b/lib/cluelessh-keys/src/public.rs @@ -134,6 +134,45 @@ fn b64encode(bytes: &[u8]) -> String { base64::prelude::BASE64_STANDARD.encode(bytes) } +impl serde::Serialize for PublicKey { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.serialize_bytes(&self.to_wire_encoding()) + } +} + +impl<'de> serde::Deserialize<'de> for PublicKey { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + use serde::de; + + struct Visitor; + impl<'de> de::Visitor<'de> for Visitor { + type Value = PublicKey; + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(formatter, "bytes encoded as an SSH public key") + } + + fn visit_bytes(self, bytes: &[u8]) -> Result + where + E: de::Error, + { + PublicKey::from_wire_encoding(bytes).map_err(|err| { + serde::de::Error::custom(format_args!( + "invalid value: {}: {err}", + de::Unexpected::Bytes(bytes), + )) + }) + } + } + deserializer.deserialize_bytes(Visitor) + } +} + #[cfg(test)] mod tests { use base64::Engine; diff --git a/lib/cluelessh-keys/src/signature.rs b/lib/cluelessh-keys/src/signature.rs index 39eabd8..37a8823 100644 --- a/lib/cluelessh-keys/src/signature.rs +++ b/lib/cluelessh-keys/src/signature.rs @@ -97,6 +97,45 @@ impl Signature { } } +impl serde::Serialize for Signature { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.serialize_bytes(&self.to_wire_encoding()) + } +} + +impl<'de> serde::Deserialize<'de> for Signature { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + use serde::de; + + struct Visitor; + impl<'de> de::Visitor<'de> for Visitor { + type Value = Signature; + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(formatter, "bytes encoded as an SSH signature") + } + + fn visit_bytes(self, bytes: &[u8]) -> Result + where + E: de::Error, + { + Signature::from_wire_encoding(bytes).map_err(|err| { + serde::de::Error::custom(format_args!( + "invalid value: {}: {err}", + de::Unexpected::Bytes(bytes), + )) + }) + } + } + deserializer.deserialize_bytes(Visitor) + } +} + impl PrivateKey { pub fn sign(&self, data: &[u8]) -> Signature { match self { diff --git a/lib/cluelessh-protocol/Cargo.toml b/lib/cluelessh-protocol/Cargo.toml index 5614c61..c2ded70 100644 --- a/lib/cluelessh-protocol/Cargo.toml +++ b/lib/cluelessh-protocol/Cargo.toml @@ -7,6 +7,7 @@ edition = "2021" rand = "0.8.5" cluelessh-connection = { path = "../cluelessh-connection" } cluelessh-transport = { path = "../cluelessh-transport" } +cluelessh-keys = { path = "../cluelessh-keys" } tracing.workspace = true cluelessh-format = { version = "0.1.0", path = "../cluelessh-format" } diff --git a/lib/cluelessh-protocol/src/lib.rs b/lib/cluelessh-protocol/src/lib.rs index 6fc35ba..c8f94e4 100644 --- a/lib/cluelessh-protocol/src/lib.rs +++ b/lib/cluelessh-protocol/src/lib.rs @@ -4,6 +4,8 @@ use std::mem; use auth::AuthOption; use cluelessh_connection::ChannelOperation; +use cluelessh_keys::public::PublicKey; +use cluelessh_keys::signature::Signature; use tracing::debug; // Re-exports @@ -76,6 +78,14 @@ impl ServerConnection { Ok(()) } + pub fn is_waiting_on_signature(&self) -> Option<(&PublicKey, [u8; 32])> { + self.transport.is_waiting_on_signature() + } + + pub fn do_signature(&mut self, signature: Signature) { + self.transport.do_signature(signature); + } + pub fn next_msg_to_send(&mut self) -> Option { self.transport.next_msg_to_send() } diff --git a/lib/cluelessh-tokio/Cargo.toml b/lib/cluelessh-tokio/Cargo.toml index b5a9f3c..386c849 100644 --- a/lib/cluelessh-tokio/Cargo.toml +++ b/lib/cluelessh-tokio/Cargo.toml @@ -8,6 +8,7 @@ eyre.workspace = true cluelessh-transport = { path = "../cluelessh-transport" } cluelessh-connection = { path = "../cluelessh-connection" } cluelessh-protocol = { path = "../cluelessh-protocol" } +cluelessh-keys = { path = "../cluelessh-keys" } tokio = { version = "1.39.3", features = ["net"] } tracing.workspace = true futures = "0.3.30" diff --git a/lib/cluelessh-tokio/src/server.rs b/lib/cluelessh-tokio/src/server.rs index 77affb7..2c66132 100644 --- a/lib/cluelessh-tokio/src/server.rs +++ b/lib/cluelessh-tokio/src/server.rs @@ -1,4 +1,5 @@ use cluelessh_connection::{ChannelKind, ChannelNumber, ChannelOperation}; +use cluelessh_keys::{public::PublicKey, signature::Signature}; use futures::future::BoxFuture; use std::{ collections::{HashMap, HashSet, VecDeque}, @@ -23,7 +24,7 @@ use crate::{Channel, ChannelState, PendingChannel}; pub struct ServerListener { listener: TcpListener, - auth_verify: ServerAuthVerify, + auth_verify: ServerAuth, transport_config: cluelessh_transport::server::ServerConfig, // TODO ratelimits etc } @@ -45,27 +46,35 @@ pub struct ServerConnection { /// New channels opened by the peer. new_channels: VecDeque, - auth_verify: ServerAuthVerify, + signature_in_progress: bool, + auth_verify: ServerAuth, } enum Operation { VerifyPassword(String, Result), CheckPubkey(Result, String, Vec), VerifySignature(String, Result), + SignatureReceived(Result), } pub type AuthFn = Arc BoxFuture<'static, R> + Send + Sync>; #[derive(Clone)] -pub struct ServerAuthVerify { +pub struct ServerAuth { pub verify_password: Option>>, pub verify_signature: Option>>, pub check_pubkey: Option>>, + pub sign_with_hostkey: AuthFn>, pub auth_banner: Option, } fn _assert_send_sync() { fn send() {} - send::(); + send::(); +} + +pub struct SignWithHostKey { + pub hash: [u8; 32], + pub public_key: PublicKey, } pub enum Error { @@ -81,7 +90,7 @@ impl From for Error { impl ServerListener { pub fn new( listener: TcpListener, - auth_verify: ServerAuthVerify, + auth_verify: ServerAuth, transport_config: cluelessh_transport::server::ServerConfig, ) -> Self { Self { @@ -107,7 +116,7 @@ impl ServerConnection { pub fn new( stream: S, peer_addr: SocketAddr, - auth_verify: ServerAuthVerify, + auth_verify: ServerAuth, transport_config: cluelessh_transport::server::ServerConfig, ) -> Self { let (operations_send, operations_recv) = tokio::sync::mpsc::channel(15); @@ -149,6 +158,7 @@ impl ServerConnection { ), new_channels: VecDeque::new(), auth_verify, + signature_in_progress: false, } } @@ -159,6 +169,20 @@ impl ServerConnection { /// Executes one loop iteration of the main loop. // IMPORTANT: no operations on this struct should ever block the main loop, except this one. pub async fn progress(&mut self) -> Result<(), Error> { + if let Some((public_key, hash)) = self.proto.is_waiting_on_signature() { + if !self.signature_in_progress { + self.signature_in_progress = true; + + let send = self.operations_send.clone(); + let public_key = public_key.clone(); + let sign_with_hostkey = self.auth_verify.sign_with_hostkey.clone(); + tokio::spawn(async move { + let result = sign_with_hostkey(SignWithHostKey { public_key, hash }).await; + let _ = send.send(Operation::SignatureReceived(result)).await; + }); + } + } + if let Some(auth) = self.proto.auth() { for req in auth.server_requests() { match req { @@ -329,6 +353,10 @@ impl ServerConnection { Some(Operation::VerifyPassword(user, result)) => if let Some(auth) = self.proto.auth() { auth.verification_result(result?, user); }, + Some(Operation::SignatureReceived(signature)) => { + let signature = signature?; + self.proto.do_signature(signature); + } None => {} } self.send_off_data().await?; diff --git a/lib/cluelessh-transport/src/crypto.rs b/lib/cluelessh-transport/src/crypto.rs index d9824b7..5635c20 100644 --- a/lib/cluelessh-transport/src/crypto.rs +++ b/lib/cluelessh-transport/src/crypto.rs @@ -1,10 +1,6 @@ pub mod encrypt; -use cluelessh_keys::{ - private::{PlaintextPrivateKey, PrivateKey}, - public::PublicKey, - signature::Signature, -}; +use cluelessh_keys::{public::PublicKey, signature::Signature}; use p256::ecdsa::signature::Verifier; use sha2::Digest; @@ -110,26 +106,21 @@ impl AlgorithmName for EncryptionAlgorithm { pub struct EncodedSshSignature(pub Vec); pub struct HostKeySigningAlgorithm { - private_key: Box, + public_key: PublicKey, } impl AlgorithmName for HostKeySigningAlgorithm { fn name(&self) -> &'static str { - self.private_key.algorithm_name() + self.public_key.algorithm_name() } } impl HostKeySigningAlgorithm { - pub fn new(private_key: PrivateKey) -> Self { - Self { - private_key: Box::new(private_key), - } - } - pub fn sign(&self, data: &[u8]) -> Signature { - self.private_key.sign(data) + pub fn new(public_key: PublicKey) -> Self { + Self { public_key } } pub fn public_key(&self) -> PublicKey { - self.private_key.public_key() + self.public_key.clone() } } @@ -253,10 +244,10 @@ pub struct SupportedAlgorithms { impl SupportedAlgorithms { /// A secure default using elliptic curves and AEAD. - pub fn secure(host_keys: &[PlaintextPrivateKey]) -> Self { + pub fn secure(host_keys: &[PublicKey]) -> Self { let supported_host_keys = host_keys .iter() - .map(|key| HostKeySigningAlgorithm::new(key.private_key.clone())) + .map(|key| HostKeySigningAlgorithm::new(key.clone())) .collect(); Self { diff --git a/lib/cluelessh-transport/src/server.rs b/lib/cluelessh-transport/src/server.rs index 3e0bc95..ff5f383 100644 --- a/lib/cluelessh-transport/src/server.rs +++ b/lib/cluelessh-transport/src/server.rs @@ -10,6 +10,8 @@ use crate::Result; use crate::{peer_error, Msg, SshRng, SshStatus}; use cluelessh_format::numbers; use cluelessh_format::{NameList, Reader, Writer}; +use cluelessh_keys::public::PublicKey; +use cluelessh_keys::signature::Signature; use tracing::{debug, info, trace}; // This is definitely who we are. @@ -28,7 +30,7 @@ pub struct ServerConnection { #[derive(Debug, Clone, Default)] pub struct ServerConfig { - pub host_keys: Vec, + pub host_keys: Vec, } enum ServerState { @@ -47,9 +49,21 @@ enum ServerState { encryption_client_to_server: EncryptionAlgorithm, encryption_server_to_client: EncryptionAlgorithm, }, + WaitingForSignature { + /// h + hash: [u8; 32], + pub_hostkey: PublicKey, + /// k + shared_secret: Vec, + server_ephemeral_public_key: Vec, + encryption_client_to_server: EncryptionAlgorithm, + encryption_server_to_client: EncryptionAlgorithm, + }, NewKeys { - h: [u8; 32], - k: Vec, + /// h + hash: [u8; 32], + /// k + shared_secret: Vec, encryption_client_to_server: EncryptionAlgorithm, encryption_server_to_client: EncryptionAlgorithm, }, @@ -242,11 +256,11 @@ impl ServerConnection { } => { let dh = KeyExchangeEcDhInitPacket::parse(&packet.payload)?; - let client_public_key = dh.qc; + let client_ephemeral_public_key = dh.qc; let server_secret = (kex_algorithm.generate_secret)(&mut *self.rng); - let server_public_key = server_secret.pubkey; - let shared_secret = (server_secret.exchange)(client_public_key)?; + let server_ephemeral_public_key = server_secret.pubkey; + let shared_secret = (server_secret.exchange)(client_ephemeral_public_key)?; let pub_hostkey = server_host_key_algorithm.public_key(); let hash = crypto::key_exchange_hash( @@ -255,35 +269,31 @@ impl ServerConnection { client_kexinit, server_kexinit, &pub_hostkey.to_wire_encoding(), - client_public_key, - &server_public_key, + client_ephemeral_public_key, + &server_ephemeral_public_key, &shared_secret, ); - let signature = server_host_key_algorithm.sign(&hash); + // eprintln!("client_ephemeral_public_key: {:x?}", client_ephemeral_public_key); + // eprintln!("server_ephemeral_public_key: {:x?}", server_ephemeral_public_key); + // eprintln!("shared_secret: {:x?}", shared_secret); + // eprintln!("hash: {:x?}", hash); - // eprintln!("client_public_key: {:x?}", client_public_key); - // eprintln!("server_public_key: {:x?}", server_public_key); - // eprintln!("shared_secret: {:x?}", shared_secret); - // eprintln!("hash: {:x?}", hash); - - let packet = Packet::new_msg_kex_ecdh_reply( - &pub_hostkey.to_wire_encoding(), - &server_public_key, - &signature.to_wire_encoding(), - ); - - self.packet_transport.queue_packet(packet); - self.state = ServerState::NewKeys { - h: hash, - k: shared_secret, + self.state = ServerState::WaitingForSignature { + hash, + pub_hostkey, + shared_secret, + server_ephemeral_public_key, encryption_client_to_server: *encryption_client_to_server, encryption_server_to_client: *encryption_server_to_client, }; } + ServerState::WaitingForSignature { .. } => { + return Err(peer_error!("unexpected packet")); + } ServerState::NewKeys { - h, - k, + hash: h, + shared_secret: k, encryption_client_to_server, encryption_server_to_client, } => { @@ -344,6 +354,43 @@ impl ServerConnection { } } + pub fn is_waiting_on_signature(&self) -> Option<(&PublicKey, [u8; 32])> { + match &self.state { + ServerState::WaitingForSignature { + pub_hostkey, hash, .. + } => Some((pub_hostkey, *hash)), + _ => None, + } + } + + pub fn do_signature(&mut self, signature: Signature) { + match &self.state { + ServerState::WaitingForSignature { + hash, + pub_hostkey, + shared_secret, + server_ephemeral_public_key, + encryption_client_to_server, + encryption_server_to_client, + } => { + let packet = Packet::new_msg_kex_ecdh_reply( + &pub_hostkey.to_wire_encoding(), + &server_ephemeral_public_key, + &signature.to_wire_encoding(), + ); + + self.packet_transport.queue_packet(packet); + self.state = ServerState::NewKeys { + hash: *hash, + shared_secret: shared_secret.clone(), + encryption_client_to_server: *encryption_client_to_server, + encryption_server_to_client: *encryption_server_to_client, + }; + } + _ => unreachable!("doing signature while not waiting for it"), + } + } + pub fn next_msg_to_send(&mut self) -> Option { self.packet_transport.next_msg_to_send() }