improvements to privsep

This commit is contained in:
nora 2024-08-28 22:29:31 +02:00
parent 85c4480938
commit cbf00dc6ff
10 changed files with 689 additions and 282 deletions

5
Cargo.lock generated
View file

@ -471,6 +471,7 @@ dependencies = [
"cluelessh-transport", "cluelessh-transport",
"eyre", "eyre",
"futures", "futures",
"libc",
"postcard", "postcard",
"rustix", "rustix",
"serde", "serde",
@ -965,9 +966,9 @@ checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
[[package]] [[package]]
name = "libc" name = "libc"
version = "0.2.155" version = "0.2.158"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c" checksum = "d8adc4bb1803a324070e64a98ae98f38934d91957a99cfb3a43dcbc01bc56439"
[[package]] [[package]]
name = "linux-raw-sys" name = "linux-raw-sys"

View file

@ -12,7 +12,7 @@ tokio = { version = "1.39.2", features = ["full"] }
tracing.workspace = true tracing.workspace = true
eyre.workspace = true eyre.workspace = true
tracing-subscriber = { version = "0.3.18", features = ["env-filter", "json"] } tracing-subscriber = { version = "0.3.18", features = ["env-filter", "json"] }
rustix = { version = "0.38.34", features = ["pty", "termios", "procfs", "process", "stdio"] } rustix = { version = "0.38.34", features = ["pty", "termios", "procfs", "process", "stdio", "net"] }
users = "0.11.0" users = "0.11.0"
futures = "0.3.30" futures = "0.3.30"
thiserror = "1.0.63" thiserror = "1.0.63"
@ -21,6 +21,7 @@ serde = { version = "1.0.209", features = ["derive"] }
toml = "0.8.19" toml = "0.8.19"
clap = { version = "4.5.16", features = ["derive"] } clap = { version = "4.5.16", features = ["derive"] }
postcard = { version = "1.0.10", features = ["alloc"] } postcard = { version = "1.0.10", features = ["alloc"] }
libc = "0.2.158"
[lints] [lints]
workspace = true workspace = true

View file

@ -11,3 +11,8 @@ host_keys = [
] ]
password_login = false password_login = false
banner = "welcome to my server!!!\r\ni hope you enjoy your stay.\r\n" banner = "welcome to my server!!!\r\ni hope you enjoy your stay.\r\n"
[security]
unprivileged_uid = 355353
unprivileged_gid = 355353
#unprivileged_user = "sshd"

View file

@ -9,10 +9,13 @@ use cluelessh_keys::{
use cluelessh_protocol::auth::{CheckPubkey, VerifySignature}; use cluelessh_protocol::auth::{CheckPubkey, VerifySignature};
use eyre::eyre; use eyre::eyre;
use tracing::debug; use tracing::debug;
use users::os::unix::UserExt; use users::{os::unix::UserExt, User};
/// A known-authorized public key for a user. /// A known-authorized public key for a user.
pub struct UserPublicKey(PublicKeyWithComment); pub struct UserPublicKey {
key: PublicKeyWithComment,
user: User,
}
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
pub enum AuthError { pub enum AuthError {
@ -46,27 +49,29 @@ impl UserPublicKey {
let authorized_keys = AuthorizedKeys::parse(&file)?; let authorized_keys = AuthorizedKeys::parse(&file)?;
if let Some(key) = authorized_keys.contains(provided_key) { if let Some(key) = authorized_keys.contains(provided_key) {
Ok(Self(key.clone())) Ok(Self {
key: key.clone(),
user,
})
} else { } else {
Err(AuthError::UnauthorizedPublicKey) Err(AuthError::UnauthorizedPublicKey)
} }
} }
pub fn verify_signature(&self, data: &[u8], signature: &[u8]) -> bool { pub fn verify_signature(&self, data: &[u8], signature: &[u8]) -> bool {
self.0.key.verify_signature(data, signature) self.key.key.verify_signature(data, signature)
} }
} }
pub async fn verify_signature(auth: VerifySignature) -> eyre::Result<bool> { pub async fn verify_signature(auth: VerifySignature) -> eyre::Result<Option<User>> {
let Ok(public_key) = PublicKey::from_wire_encoding(&auth.pubkey) else { let Ok(public_key) = PublicKey::from_wire_encoding(&auth.pubkey) else {
return Ok(false); return Ok(None);
}; };
if auth.pubkey_alg_name != public_key.algorithm_name() { if auth.pubkey_alg_name != public_key.algorithm_name() {
return Ok(false); return Ok(None);
} }
let result: std::result::Result<UserPublicKey, AuthError> = let result = UserPublicKey::for_user_and_key(auth.user.clone(), &public_key).await;
UserPublicKey::for_user_and_key(auth.user.clone(), &public_key).await;
debug!(user = %auth.user, err = ?result.as_ref().err(), "Attempting publickey signature"); debug!(user = %auth.user, err = ?result.as_ref().err(), "Attempting publickey signature");
@ -81,16 +86,16 @@ pub async fn verify_signature(auth: VerifySignature) -> eyre::Result<bool> {
); );
if user_key.verify_signature(&sign_data, &auth.signature) { if user_key.verify_signature(&sign_data, &auth.signature) {
Ok(true) Ok(Some(user_key.user))
} else { } else {
Ok(false) Ok(None)
} }
} }
Err( Err(
AuthError::UnknownUser AuthError::UnknownUser
| AuthError::UnauthorizedPublicKey | AuthError::UnauthorizedPublicKey
| AuthError::NoAuthorizedKeys(_), | AuthError::NoAuthorizedKeys(_),
) => Ok(false), ) => Ok(None),
Err(AuthError::InvalidAuthorizedKeys(err)) => Err(eyre!(err)), Err(AuthError::InvalidAuthorizedKeys(err)) => Err(eyre!(err)),
} }
} }

View file

@ -14,6 +14,7 @@ pub struct Config {
pub log_level: String, pub log_level: String,
pub net: NetConfig, pub net: NetConfig,
pub auth: AuthConfig, pub auth: AuthConfig,
pub security: SecurityConfig,
} }
#[derive(Clone, Serialize, Deserialize)] #[derive(Clone, Serialize, Deserialize)]
@ -34,6 +35,19 @@ pub struct AuthConfig {
pub banner: Option<String>, pub banner: Option<String>,
} }
#[derive(Clone, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct SecurityConfig {
/// A hardcoded uid for an unprivileged user.
/// Mostly useful for testing.
pub unprivileged_uid: Option<u32>,
/// A hardcoded gid for an unprivileged user.
/// Mostly useful for testing.
pub unprivileged_gid: Option<u32>,
/// The username of an unprivileged user.
pub unprivileged_user: Option<String>,
}
impl Config { impl Config {
pub fn find(args: &Args) -> Result<Self> { pub fn find(args: &Args) -> Result<Self> {
let path = std::env::var("CLUELESSHD_CONFIG") let path = std::env::var("CLUELESSHD_CONFIG")

View file

@ -1,14 +1,13 @@
use std::{ use std::{
io,
os::fd::{BorrowedFd, FromRawFd, OwnedFd}, os::fd::{BorrowedFd, FromRawFd, OwnedFd},
pin::Pin, pin::Pin,
process::{ExitStatus, Stdio},
sync::Arc, sync::Arc,
}; };
use crate::{ use crate::{
pty::{self, Pty}, pty::{self, Pty},
rpc, MemFd, SerializedConnectionState, PRIVSEP_CONNECTION_STATE_FD, rpc, MemFd, SerializedConnectionState, PRIVSEP_CONNECTION_RPC_CLIENT_FD,
PRIVSEP_CONNECTION_STATE_FD, PRIVSEP_CONNECTION_STREAM_FD,
}; };
use cluelessh_protocol::{ use cluelessh_protocol::{
connection::{ChannelKind, ChannelOperationKind, ChannelRequest}, connection::{ChannelKind, ChannelOperationKind, ChannelRequest},
@ -18,17 +17,15 @@ use cluelessh_tokio::{
server::{ServerAuth, ServerConnection}, server::{ServerAuth, ServerConnection},
Channel, Channel,
}; };
use eyre::{bail, OptionExt, Result, WrapErr}; use eyre::{bail, ensure, Result, WrapErr};
use rustix::termios::Winsize; use rustix::termios::Winsize;
use tokio::{ use tokio::{
fs::File, fs::File,
io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}, io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
net::TcpStream, net::TcpStream,
process::Command,
sync::mpsc, sync::mpsc,
}; };
use tracing::{debug, error, info, info_span, warn, Instrument}; use tracing::{debug, error, info, info_span, warn, Instrument};
use users::os::unix::UserExt as _;
pub async fn connection() -> Result<()> { pub async fn connection() -> Result<()> {
rustix::fs::fcntl_getfd(unsafe { BorrowedFd::borrow_raw(PRIVSEP_CONNECTION_STATE_FD) }) rustix::fs::fcntl_getfd(unsafe { BorrowedFd::borrow_raw(PRIVSEP_CONNECTION_STATE_FD) })
@ -38,22 +35,42 @@ pub async fn connection() -> Result<()> {
.wrap_err("failed to open memfd")?; .wrap_err("failed to open memfd")?;
let state = memfd.read().wrap_err("failed to read state")?; let state = memfd.read().wrap_err("failed to read state")?;
let config = state.config; crate::setup_tracing(&state.config);
crate::setup_tracing(&config);
let span = info_span!("connection", addr = %state.peer_addr); let span = info_span!("connection", addr = %state.peer_addr);
let stream = unsafe { std::net::TcpStream::from_raw_fd(state.stream_fd) }; connection_inner(state).instrument(span).await
}
async fn connection_inner(state: SerializedConnectionState) -> Result<()> {
let config = state.config;
if let Some(uid) = state.setgid {
debug!(?uid, "Setting GID to drop privileges");
let result = unsafe { libc::setgid(uid) };
if result == -1 {
return Err(std::io::Error::last_os_error()).wrap_err("failed to setgid");
}
}
if let Some(uid) = state.setuid {
debug!(?uid, "Setting UID to drop privileges");
let result = unsafe { libc::setuid(uid) };
if result == -1 {
return Err(std::io::Error::last_os_error()).wrap_err("failed to setuid");
}
}
let stream = unsafe { std::net::TcpStream::from_raw_fd(PRIVSEP_CONNECTION_STREAM_FD) };
let stream = TcpStream::from_std(stream)?; let stream = TcpStream::from_std(stream)?;
let host_keys = state.pub_host_keys; let host_keys = state.pub_host_keys;
let transport_config = cluelessh_transport::server::ServerConfig { host_keys }; let transport_config = cluelessh_transport::server::ServerConfig { host_keys };
let rpc_client = unsafe { OwnedFd::from_raw_fd(state.rpc_client_fd) }; let rpc_client = unsafe { OwnedFd::from_raw_fd(PRIVSEP_CONNECTION_RPC_CLIENT_FD) };
let rpc_client1 = Arc::new(rpc::Client::from_fd(rpc_client)?); let rpc_client1 = Arc::new(rpc::Client::from_fd(rpc_client)?);
let rpc_client2 = rpc_client1.clone(); let rpc_client2 = rpc_client1.clone();
let rpc_client3 = rpc_client1.clone(); let rpc_client3 = rpc_client1.clone();
let rpc_client4 = rpc_client1.clone();
let auth_verify = ServerAuth { let auth_verify = ServerAuth {
verify_password: config.auth.password_login.then(|| todo!("password login")), verify_password: config.auth.password_login.then(|| todo!("password login")),
@ -93,26 +110,23 @@ pub async fn connection() -> Result<()> {
let server_conn = ServerConnection::new(stream, state.peer_addr, auth_verify, transport_config); let server_conn = ServerConnection::new(stream, state.peer_addr, auth_verify, transport_config);
connection_inner(server_conn).instrument(span).await; if let Err(err) = handle_connection(server_conn, rpc_client4).await {
Ok(())
}
async fn connection_inner(server_conn: ServerConnection<TcpStream>) {
if let Err(err) = handle_connection(server_conn).await {
if let Some(err) = err.downcast_ref::<std::io::Error>() { if let Some(err) = err.downcast_ref::<std::io::Error>() {
if err.kind() == std::io::ErrorKind::ConnectionReset { if err.kind() == std::io::ErrorKind::ConnectionReset {
return; return Ok(());
} }
} }
error!(?err, "error handling connection"); error!(?err, "error handling connection");
} }
info!("Finished connection"); info!("Finished connection");
Ok(())
} }
async fn handle_connection( async fn handle_connection(
mut conn: cluelessh_tokio::server::ServerConnection<TcpStream>, mut conn: cluelessh_tokio::server::ServerConnection<TcpStream>,
rpc_client: Arc<rpc::Client>,
) -> Result<()> { ) -> Result<()> {
info!(addr = %conn.peer_addr(), "Received a new connection"); info!(addr = %conn.peer_addr(), "Received a new connection");
@ -145,9 +159,10 @@ async fn handle_connection(
} }
while let Some(channel) = conn.next_new_channel() { while let Some(channel) = conn.next_new_channel() {
let user = conn.inner().authenticated_user().unwrap().to_owned(); let _user = conn.inner().authenticated_user().unwrap().to_owned();
if *channel.kind() == ChannelKind::Session { if *channel.kind() == ChannelKind::Session {
let channel_task = tokio::spawn(handle_session_channel(user, channel)); let channel_task =
tokio::spawn(handle_session_channel(channel, rpc_client.clone()));
channel_tasks.push(Box::pin(async { channel_tasks.push(Box::pin(async {
let result = channel_task.await; let result = channel_task.await;
result.wrap_err("task panicked").and_then(|result| result) result.wrap_err("task panicked").and_then(|result| result)
@ -160,14 +175,15 @@ async fn handle_connection(
} }
struct SessionState { struct SessionState {
user: String,
pty: Option<Pty>, pty: Option<Pty>,
channel: Channel, channel: Channel,
process_exit_send: mpsc::Sender<Result<ExitStatus, io::Error>>, process_exit_send: mpsc::Sender<Result<Option<i32>>>,
process_exit_recv: mpsc::Receiver<Result<ExitStatus, io::Error>>, process_exit_recv: mpsc::Receiver<Result<Option<i32>>>,
envs: Vec<(String, String)>, envs: Vec<(String, String)>,
rpc_client: Arc<rpc::Client>,
//// stdin //// stdin
writer: Option<Pin<Box<dyn AsyncWrite + Send + Sync>>>, writer: Option<Pin<Box<dyn AsyncWrite + Send + Sync>>>,
/// stdout /// stdout
@ -176,16 +192,18 @@ struct SessionState {
reader_ext: Option<Pin<Box<dyn AsyncRead + Send + Sync>>>, reader_ext: Option<Pin<Box<dyn AsyncRead + Send + Sync>>>,
} }
async fn handle_session_channel(user: String, channel: Channel) -> Result<()> { async fn handle_session_channel(channel: Channel, rpc_client: Arc<rpc::Client>) -> Result<()> {
let (process_exit_send, process_exit_recv) = tokio::sync::mpsc::channel(1); let (process_exit_send, process_exit_recv) = tokio::sync::mpsc::channel(1);
let mut state = SessionState { let mut state = SessionState {
user,
pty: None, pty: None,
channel, channel,
process_exit_send, process_exit_send,
process_exit_recv, process_exit_recv,
envs: Vec::new(), envs: Vec::new(),
rpc_client,
writer: None, writer: None,
reader: None, reader: None,
reader_ext: None, reader_ext: None,
@ -227,7 +245,7 @@ async fn handle_session_channel(user: String, channel: Channel) -> Result<()> {
// TODO: also handle exit-signal // TODO: also handle exit-signal
state.channel state.channel
.send(ChannelOperationKind::Request(ChannelRequest::ExitStatus { .send(ChannelOperationKind::Request(ChannelRequest::ExitStatus {
status: exit.code().unwrap_or(0) as u32, status: exit.unwrap_or(1) as u32,
})) }))
.await?; .await?;
state.channel.send(ChannelOperationKind::Close).await?; state.channel.send(ChannelOperationKind::Close).await?;
@ -393,46 +411,36 @@ impl SessionState {
} }
async fn shell(&mut self, shell_command: Option<&str>) -> Result<()> { async fn shell(&mut self, shell_command: Option<&str>) -> Result<()> {
let user = self.user.clone(); let pty = match &self.pty {
let user = tokio::task::spawn_blocking(move || users::get_user_by_name(&user)) Some(pty) => Some(pty.user_fd()?),
.await? None => None,
.ok_or_eyre("failed to find user")?; };
let shell = user.shell(); let mut fds = self
.rpc_client
.exec(
shell_command.map(ToOwned::to_owned),
pty,
self.pty.as_ref().map(|pty| pty.term()).unwrap_or_default(),
self.envs.clone(),
)
.await?;
let mut cmd = Command::new(shell); if self.pty.is_some() {
if let Some(shell_command) = shell_command { ensure!(
cmd.arg("-c"); fds.len() == 0,
cmd.arg(shell_command); "RPC Server sent back FDs despite being in PTY mode"
} );
cmd.env_clear();
if let Some(pty) = &self.pty {
pty.start_session_for_command(&mut cmd)?;
} else { } else {
cmd.stdin(Stdio::piped()); ensure!(
cmd.stdout(Stdio::piped()); fds.len() == 3,
cmd.stderr(Stdio::piped()); "RPC Server sent back the wrong amount of FDs: {}",
} fds.len()
);
// TODO: **user** home directory let stdin = File::from_std(std::fs::File::from(fds.remove(0)));
cmd.current_dir(user.home_dir()); let stdout = File::from_std(std::fs::File::from(fds.remove(0)));
cmd.env("USER", user.name()); let stderr = File::from_std(std::fs::File::from(fds.remove(0)));
cmd.uid(user.uid());
cmd.gid(user.primary_group_id());
for (k, v) in &self.envs {
cmd.env(k, v);
}
debug!(cmd = %shell.display(), uid = %user.uid(), gid = %user.primary_group_id(), "Executing process");
let mut shell = cmd.spawn()?;
if self.pty.is_none() {
let stdin = shell.stdin.take().unwrap();
let stdout = shell.stdout.take().unwrap();
let stderr = shell.stderr.take().unwrap();
self.writer = Some(Box::pin(stdin)); self.writer = Some(Box::pin(stdin));
self.reader = Some(Box::pin(stdout)); self.reader = Some(Box::pin(stdout));
@ -440,8 +448,9 @@ impl SessionState {
} }
let process_exit_send = self.process_exit_send.clone(); let process_exit_send = self.process_exit_send.clone();
let client = self.rpc_client.clone();
tokio::spawn(async move { tokio::spawn(async move {
let result = shell.wait().await; let result = client.wait().await;
let _ = process_exit_send.send(result).await; let _ = process_exit_send.send(result).await;
}); });
debug!("Successfully spawned shell"); debug!("Successfully spawned shell");

View file

@ -8,17 +8,19 @@ use std::{
io::{Read, Seek, SeekFrom}, io::{Read, Seek, SeekFrom},
marker::PhantomData, marker::PhantomData,
net::SocketAddr, net::SocketAddr,
os::fd::{AsFd, AsRawFd, BorrowedFd, FromRawFd, OwnedFd, RawFd}, os::fd::{AsRawFd, BorrowedFd, FromRawFd, OwnedFd, RawFd},
path::PathBuf, path::PathBuf,
process::Stdio, process::Stdio,
sync::Arc,
}; };
use clap::Parser; use clap::Parser;
use cluelessh_keys::{host_keys::HostKeySet, private::EncryptedPrivateKeys, public::PublicKey}; use cluelessh_keys::{
use cluelessh_tokio::server::{ServerAuth, SignWithHostKey}; host_keys::HostKeySet,
private::{EncryptedPrivateKeys, PlaintextPrivateKey},
public::PublicKey,
};
use config::Config; use config::Config;
use eyre::{bail, Context, OptionExt, Result}; use eyre::{bail, eyre, Context, Result};
use rustix::fs::MemfdFlags; use rustix::fs::MemfdFlags;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tokio::net::{TcpListener, TcpStream}; use tokio::net::{TcpListener, TcpStream};
@ -33,6 +35,33 @@ struct Args {
config: Option<PathBuf>, config: Option<PathBuf>,
} }
#[tokio::main(flavor = "current_thread")]
async fn main() -> eyre::Result<()> {
match std::env::var("CLUELESSH_PRIVSEP_PROCESS") {
Ok(privsep_process) => match privsep_process.as_str() {
"connection" => connection::connection().await,
_ => bail!("unknown CLUELESSH_PRIVSEP_PROCESS: {privsep_process}"),
},
Err(_) => {
// Initial setup
let args = Args::parse();
let config = config::Config::find(&args)?;
setup_tracing(&config);
let addr: SocketAddr = SocketAddr::new(config.net.ip, config.net.port);
info!(%addr, "Starting server");
let listener = TcpListener::bind(addr)
.await
.wrap_err_with(|| format!("trying to listen on {addr}"))?;
main_process(config, listener).await
}
}
}
struct MemFd<T> { struct MemFd<T> {
fd: std::fs::File, fd: std::fs::File,
_data: PhantomData<T>, _data: PhantomData<T>,
@ -68,46 +97,48 @@ impl<T: serde::Serialize + serde::de::DeserializeOwned> MemFd<T> {
} }
} }
#[tokio::main(flavor = "current_thread")]
async fn main() -> eyre::Result<()> {
match std::env::var("CLUELESSH_PRIVSEP_PROCESS") {
Ok(privsep_process) => match privsep_process.as_str() {
"connection" => connection::connection().await,
_ => bail!("unknown CLUELESSH_PRIVSEP_PROCESS: {privsep_process}"),
},
Err(_) => {
// Initial setup
let args = Args::parse();
let config = config::Config::find(&args)?;
setup_tracing(&config);
let addr: SocketAddr = SocketAddr::new(config.net.ip, config.net.port);
info!(%addr, "Starting server");
let listener = TcpListener::bind(addr)
.await
.wrap_err_with(|| format!("trying to listen on {addr}"))?;
main_process(config, listener).await
}
}
}
const PRIVSEP_CONNECTION_STATE_FD: RawFd = 3; const PRIVSEP_CONNECTION_STATE_FD: RawFd = 3;
const PRIVSEP_CONNECTION_STREAM_FD: RawFd = 4;
const PRIVSEP_CONNECTION_RPC_CLIENT_FD: RawFd = 5;
/// The connection state passed to the child in the STATE_FD /// The connection state passed to the child in the STATE_FD
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
struct SerializedConnectionState { struct SerializedConnectionState {
stream_fd: RawFd,
peer_addr: SocketAddr, peer_addr: SocketAddr,
pub_host_keys: Vec<PublicKey>, pub_host_keys: Vec<PublicKey>,
config: Config, config: Config,
rpc_client_fd: RawFd,
setuid: Option<u32>,
setgid: Option<u32>,
} }
async fn main_process(config: Config, listener: TcpListener) -> Result<()> { async fn main_process(config: Config, listener: TcpListener) -> Result<()> {
let user = match &config.security.unprivileged_user {
Some(user) => Some(
users::get_user_by_name(user).ok_or_else(|| eyre!("unprivileged {user} not found"))?,
),
None => None,
};
let is_root = rustix::process::getuid().is_root();
if !is_root {
info!("Not running as root, disabling unprivileged setuid");
}
let setuid = match (is_root, &config.security.unprivileged_uid, &user) {
(false, _, _) => None,
(true, Some(uid), _) => Some(*uid),
(true, None, Some(user)) => Some(user.uid()),
(true, None, None) => None,
};
let setgid = match (is_root, &config.security.unprivileged_gid, &user) {
(false, _, _) => None,
(true, Some(uid), _) => Some(*uid),
(true, None, Some(user)) => Some(user.primary_group_id()),
(true, None, None) => None,
};
let host_keys = load_host_keys(&config.auth.host_keys).await?.into_keys(); let host_keys = load_host_keys(&config.auth.host_keys).await?.into_keys();
if host_keys.is_empty() { if host_keys.is_empty() {
@ -119,41 +150,21 @@ async fn main_process(config: Config, listener: TcpListener) -> Result<()> {
.map(|key| key.private_key.public_key()) .map(|key| key.private_key.public_key())
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let auth_operations = ServerAuth {
verify_password: config
.auth
.clone()
.password_login
.then(|| todo!("password login")),
verify_signature: Some(Arc::new(|auth| Box::pin(auth::verify_signature(auth)))),
check_pubkey: Some(Arc::new(|auth| Box::pin(auth::check_pubkey(auth)))),
auth_banner: config.auth.clone().banner,
sign_with_hostkey: Arc::new(move |msg: SignWithHostKey| {
let host_keys = host_keys.clone();
Box::pin(async move {
let private = host_keys
.iter()
.find(|privkey| privkey.private_key.public_key() == msg.public_key)
.ok_or_eyre("missing private key")?;
Ok(private.private_key.sign(&msg.hash))
})
}),
};
loop { loop {
let (next_stream, peer_addr) = listener.accept().await?; let (next_stream, peer_addr) = listener.accept().await?;
let config = config.clone(); let config = config.clone();
let pub_host_keys = pub_host_keys.clone(); let pub_host_keys = pub_host_keys.clone();
let auth_operations = auth_operations.clone(); let host_keys = host_keys.clone();
tokio::spawn(async move { tokio::spawn(async move {
let err = spawn_connection_child( let err = spawn_connection_child(
next_stream, next_stream,
peer_addr, peer_addr,
pub_host_keys, pub_host_keys,
config, config,
auth_operations, host_keys,
setuid,
setgid,
) )
.await; .await;
if let Err(err) = err { if let Err(err) = err {
@ -168,23 +179,22 @@ async fn spawn_connection_child(
peer_addr: SocketAddr, peer_addr: SocketAddr,
pub_host_keys: Vec<PublicKey>, pub_host_keys: Vec<PublicKey>,
config: Config, config: Config,
auth_operations: ServerAuth, host_keys: Vec<PlaintextPrivateKey>,
setuid: Option<u32>,
setgid: Option<u32>,
) -> Result<()> { ) -> Result<()> {
let stream_fd = stream.as_fd(); let stream_fd = stream.as_raw_fd();
let rpc_server = rpc::Server::new(auth_operations).wrap_err("creating RPC server")?; let mut rpc_server = rpc::Server::new(host_keys).wrap_err("creating RPC server")?;
// dup to avoid cloexec let rpc_client_fd = rpc_server.client_fd().as_raw_fd();
// TODO: we should probably do this in the child? not that it matters that much.
let stream_fd = rustix::io::dup(stream_fd).wrap_err("duping tcp stream")?;
let rpc_client_fd = rustix::io::dup(rpc_server.client_fd()).wrap_err("duping tcp stream")?;
let config_fd = MemFd::new(&SerializedConnectionState { let state_fd = MemFd::new(&SerializedConnectionState {
stream_fd: stream_fd.as_raw_fd(),
peer_addr, peer_addr,
pub_host_keys, pub_host_keys,
config, config,
rpc_client_fd: rpc_client_fd.as_raw_fd(), setuid,
setgid,
})?; })?;
let exe = std::env::current_exe().wrap_err("failed to get current executable path")?; let exe = std::env::current_exe().wrap_err("failed to get current executable path")?;
@ -195,30 +205,59 @@ async fn spawn_connection_child(
.stderr(Stdio::inherit()); .stderr(Stdio::inherit());
unsafe { unsafe {
let fd = config_fd.fd.as_raw_fd(); let state_fd = state_fd.fd.as_raw_fd();
cmd.pre_exec(move || { cmd.pre_exec(move || {
let mut state_fd = OwnedFd::from_raw_fd(PRIVSEP_CONNECTION_STATE_FD); let mut new_state_fd = OwnedFd::from_raw_fd(PRIVSEP_CONNECTION_STATE_FD);
rustix::io::dup2(BorrowedFd::borrow_raw(fd), &mut state_fd)?; let mut new_stream_fd = OwnedFd::from_raw_fd(PRIVSEP_CONNECTION_STREAM_FD);
// Ensure that it stays open in the child. let mut new_rpc_client_fd = OwnedFd::from_raw_fd(PRIVSEP_CONNECTION_RPC_CLIENT_FD);
std::mem::forget(state_fd);
rustix::io::dup2(BorrowedFd::borrow_raw(state_fd), &mut new_state_fd)?;
rustix::io::dup2(BorrowedFd::borrow_raw(stream_fd), &mut new_stream_fd)?;
rustix::io::dup2(
BorrowedFd::borrow_raw(rpc_client_fd),
&mut new_rpc_client_fd,
)?;
// Ensure that all FDs are closed except stdout (for logging), and the 3 arguments.
drop(rustix::stdio::take_stdin());
drop(rustix::stdio::take_stderr());
let result = libc::close_range(
(PRIVSEP_CONNECTION_RPC_CLIENT_FD as u32) + 1,
std::ffi::c_uint::MAX,
0,
);
if result == -1 {
return Err(std::io::Error::last_os_error());
}
// Ensure our new FDs stay open, as they will be acquired in the new process.
std::mem::forget((new_state_fd, new_stream_fd, new_rpc_client_fd));
Ok(()) Ok(())
}); });
} }
let mut listen_child = cmd.spawn().wrap_err("failed to spawn listener process")?; let mut listen_child = cmd.spawn().wrap_err("failed to spawn listener process")?;
loop { let mut exited = false;
tokio::select! {
server_err = rpc_server.process() => { tokio::select! {
error!(err = ?server_err, "RPC server error"); server_err = rpc_server.process() => {
} error!(err = ?server_err, "RPC server error");
status = listen_child.wait() => { }
let status = status?; status = listen_child.wait() => {
if !status.success() { let status = status?;
bail!("connection child process failed: {}", status); if !status.success() {
} bail!("connection child process failed: {}", status);
break;
} }
exited = true;
}
}
if !exited {
let status = listen_child.wait().await?;
if !status.success() {
bail!("connection child process failed: {}", status);
} }
} }
@ -259,6 +298,7 @@ async fn load_host_key(key_path: &PathBuf, host_keys: &mut HostKeySet) -> Result
} }
fn setup_tracing(config: &Config) { fn setup_tracing(config: &Config) {
// Log to stdout
let env_filter = let env_filter =
EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(&config.log_level)); EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new(&config.log_level));

View file

@ -16,13 +16,13 @@ pub struct Pty {
controller: OwnedFd, controller: OwnedFd,
user_pty: OwnedFd, user_pty: OwnedFd,
user_pty_name: String,
} }
impl Pty { impl Pty {
pub async fn new(term: String, winsize: Winsize, modes: Vec<u8>) -> Result<Self> { pub async fn new(term: String, winsize: Winsize, modes: Vec<u8>) -> Result<Self> {
tokio::task::spawn_blocking(move || Self::new_blocking(term, winsize, modes)).await? tokio::task::spawn_blocking(move || Self::new_blocking(term, winsize, modes)).await?
} }
pub fn new_blocking(term: String, winsize: Winsize, modes: Vec<u8>) -> Result<Self> { pub fn new_blocking(term: String, winsize: Winsize, modes: Vec<u8>) -> Result<Self> {
// Create new PTY: // Create new PTY:
let controller = rustix::pty::openpt(OpenptFlags::RDWR | OpenptFlags::NOCTTY) let controller = rustix::pty::openpt(OpenptFlags::RDWR | OpenptFlags::NOCTTY)
@ -50,34 +50,45 @@ impl Pty {
term, term,
controller, controller,
user_pty, user_pty,
user_pty_name,
}) })
} }
pub fn term(&self) -> String {
self.term.clone()
}
pub fn user_fd(&self) -> Result<OwnedFd> {
self.user_pty.try_clone().wrap_err("cloning PTY user")
}
pub fn controller(&self) -> BorrowedFd<'_> { pub fn controller(&self) -> BorrowedFd<'_> {
self.controller.as_fd() self.controller.as_fd()
} }
}
pub fn start_session_for_command(&self, cmd: &mut Command) -> Result<()> {
let user_pty = self.user_pty.try_clone()?; pub fn start_session_for_command(user_pty: OwnedFd, term: String, cmd: &mut Command) -> Result<()> {
unsafe { let ttyname = rustix::termios::ttyname(&user_pty, Vec::new())?;
cmd.pre_exec(move || { let tty_name = std::str::from_utf8(ttyname.as_bytes())
rustix::pty::grantpt(&user_pty)?; .wrap_err("pty name is invalid UTF-8")?
let pid = rustix::process::setsid()?; .to_owned();
rustix::process::ioctl_tiocsctty(&user_pty)?; // Set as the current controlling tty
rustix::termios::tcsetpgrp(&user_pty, pid)?; // Set current process as tty controller unsafe {
cmd.pre_exec(move || {
// Setup stdio with PTY. rustix::pty::grantpt(&user_pty)?;
rustix::stdio::dup2_stdin(&user_pty)?; let pid = rustix::process::setsid()?;
rustix::stdio::dup2_stdout(&user_pty)?; rustix::process::ioctl_tiocsctty(&user_pty)?; // Set as the current controlling tty
rustix::stdio::dup2_stderr(&user_pty)?; rustix::termios::tcsetpgrp(&user_pty, pid)?; // Set current process as tty controller
Ok(()) // Setup stdio with PTY.
}); rustix::stdio::dup2_stdin(&user_pty)?;
cmd.env("TERM", &self.term); rustix::stdio::dup2_stdout(&user_pty)?;
cmd.env("SSH_TTY", &self.user_pty_name); rustix::stdio::dup2_stderr(&user_pty)?;
}
Ok(())
Ok(()) });
} cmd.env("TERM", term);
cmd.env("SSH_TTY", tty_name);
}
Ok(())
} }

View file

@ -1,23 +1,42 @@
//! [`postcard`]-based RPC between the different processes. //! [`postcard`]-based RPC between the different processes.
use std::fmt::Debug;
use std::io::IoSlice;
use std::io::IoSliceMut;
use std::os::fd::AsFd; use std::os::fd::AsFd;
use std::os::fd::BorrowedFd; use std::os::fd::BorrowedFd;
use std::os::fd::OwnedFd; use std::os::fd::OwnedFd;
use std::os::unix::net::UnixDatagram;
use std::process::Stdio;
use cluelessh_keys::private::PlaintextPrivateKey;
use cluelessh_keys::public::PublicKey; use cluelessh_keys::public::PublicKey;
use cluelessh_keys::signature::Signature; use cluelessh_keys::signature::Signature;
use cluelessh_protocol::auth::CheckPubkey; use cluelessh_protocol::auth::CheckPubkey;
use cluelessh_protocol::auth::VerifySignature; use cluelessh_protocol::auth::VerifySignature;
use cluelessh_tokio::server::ServerAuth; use eyre::bail;
use cluelessh_tokio::server::SignWithHostKey;
use eyre::eyre; use eyre::eyre;
use eyre::Context; use eyre::Context;
use eyre::OptionExt;
use eyre::Result; use eyre::Result;
use rustix::net::RecvAncillaryBuffer;
use rustix::net::RecvAncillaryMessage;
use rustix::net::RecvFlags;
use rustix::net::SendAncillaryBuffer;
use rustix::net::SendAncillaryMessage;
use rustix::net::SendFlags;
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tokio::net::UnixDatagram; use tokio::process::Child;
use tokio::process::Command;
use tokio::sync::mpsc;
use tracing::debug;
use tracing::error;
use tracing::trace;
use users::os::unix::UserExt;
use users::User;
#[derive(Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
enum Request { enum Request {
Sign { Sign {
hash: [u8; 32], hash: [u8; 32],
@ -36,22 +55,54 @@ enum Request {
pubkey_alg_name: String, pubkey_alg_name: String,
pubkey: Vec<u8>, pubkey: Vec<u8>,
}, },
/// Executes a command on the host.
/// IMPORTANT: This is the critical operation, and we must ensure that it is secure.
/// To ensure that even a compromised auth process cannot escalate privileges via this RPC,
/// the RPC server keeps track of the authenciated user
Shell(ShellRequest),
Wait,
} }
#[derive(Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
struct ShellRequest {
/// Whether a PTY is used.
/// If true, the PTY fd is passed as ancillary data.
/// If false, the response will contain the 3 stdio fds
/// as ancillary data.
pty: Option<ShellRequestPty>,
command: Option<String>,
env: Vec<(String, String)>,
}
#[derive(Debug, Serialize, Deserialize)]
struct ShellRequestPty {
term: String,
}
#[derive(Debug, Serialize, Deserialize)]
struct SignResponse { struct SignResponse {
signature: Result<Signature, String>, signature: Result<Signature, String>,
} }
#[derive(Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
struct VerifySignatureResponse { struct VerifySignatureResponse {
is_ok: Result<bool, String>, is_ok: Result<bool, String>,
} }
#[derive(Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
struct CheckPubkeyResponse { struct CheckPubkeyResponse {
is_ok: Result<bool, String>, is_ok: Result<bool, String>,
} }
#[derive(Debug, Serialize, Deserialize)]
struct ShellResponse {
result: Result<(), String>,
}
#[derive(Debug, Serialize, Deserialize)]
struct WaitResponse {
result: Result<Option<i32>, String>,
}
pub struct Client { pub struct Client {
socket: UnixDatagram, socket: UnixDatagram,
@ -60,17 +111,48 @@ pub struct Client {
pub struct Server { pub struct Server {
server: UnixDatagram, server: UnixDatagram,
client: UnixDatagram, client: UnixDatagram,
auth_operations: ServerAuth, server_recv_recv: mpsc::Receiver<(Request, Vec<OwnedFd>)>,
host_keys: Vec<PlaintextPrivateKey>,
authenticated_user: Option<users::User>,
/// We keep the owned FDs here around to avoid a race condition where the child would
/// think stdout is closed before the client process opens it.
shell_process: Option<(Child, Vec<OwnedFd>)>,
}
fn server_thread(
socket: OwnedFd,
server_recv_send: mpsc::Sender<(Request, Vec<OwnedFd>)>,
) -> Result<()> {
let socket = std::os::unix::net::UnixDatagram::from(socket);
socket.set_nonblocking(false)?;
loop {
let (req, fds) = blocking_receive_with_fds::<Request>(socket.as_fd())?;
server_recv_send.blocking_send((req, fds))?;
}
} }
impl Server { impl Server {
pub fn new(auth_operations: ServerAuth) -> Result<Self> { pub fn new(host_keys: Vec<PlaintextPrivateKey>) -> Result<Self> {
let (server, client) = UnixDatagram::pair().wrap_err("creating socketpair")?; let (server, client) = UnixDatagram::pair().wrap_err("creating socketpair")?;
let (server_recv_send, server_recv_recv) = mpsc::channel(3);
let server_for_thread = server.as_fd().try_clone_to_owned()?;
std::thread::spawn(move || {
if let Err(err) = server_thread(server_for_thread, server_recv_send) {
error!(?err, "Server RPC recv thread error");
}
});
Ok(Self { Ok(Self {
server, server,
client, client,
auth_operations, host_keys,
server_recv_recv,
authenticated_user: None,
shell_process: None,
}) })
} }
@ -78,101 +160,237 @@ impl Server {
self.client.as_fd() self.client.as_fd()
} }
pub async fn process(&self) -> Result<()> { pub async fn process(&mut self) -> Result<()> {
let mut req = [0; 1024];
loop { loop {
let read = self let recv = self
.server .server_recv_recv
.recv(&mut req) .recv()
.await .await
.wrap_err("receiving response")?; .ok_or_eyre("RPC thread error")?;
self.receive_message(recv.0, recv.1).await?;
}
}
let req = postcard::from_bytes::<Request>(&req[..read]).wrap_err("invalid request")?; async fn receive_message(&mut self, req: Request, mut fds: Vec<OwnedFd>) -> Result<()> {
trace!(?req, ?fds, "Received RPC message");
match req { match req {
Request::Sign { hash, public_key } => { Request::Sign { hash, public_key } => {
let signature = (self.auth_operations.sign_with_hostkey)(SignWithHostKey { let Some(private) = self
hash, .host_keys
public_key, .iter()
.find(|privkey| privkey.private_key.public_key() == public_key)
else {
self.respond(SignResponse {
signature: Err("missing private key".to_owned()),
}) })
.await .await?;
.map_err(|err| err.to_string());
self.respond(SignResponse { signature }).await?; return Ok(());
};
let signature = private.private_key.sign(&hash);
self.respond(SignResponse {
signature: Ok(signature),
})
.await?;
}
Request::VerifySignature {
user,
session_identifier,
pubkey_alg_name,
pubkey,
signature,
} => {
if self.authenticated_user.is_some() {
self.respond(VerifySignatureResponse {
is_ok: Err("user already authenticated".to_owned()),
})
.await?;
} }
Request::VerifySignature { let is_ok = crate::auth::verify_signature(VerifySignature {
user, user,
session_identifier, session_identifier,
pubkey_alg_name, pubkey_alg_name,
pubkey, pubkey,
signature, signature,
} => { })
let Some(verify_signature) = &self.auth_operations.verify_signature else { .await
self.respond(VerifySignatureResponse { .map_err(|err| err.to_string())
is_ok: Err("public key login not supported".into()), .map(|user| match user {
}) Some(user) => {
.await?; self.authenticated_user = Some(user);
continue; true
}; }
let is_ok = verify_signature(VerifySignature { None => false,
user, });
session_identifier,
pubkey_alg_name,
pubkey,
signature,
})
.await
.map_err(|err| err.to_string());
self.respond(VerifySignatureResponse { is_ok }).await?; self.respond(VerifySignatureResponse { is_ok }).await?;
} }
Request::CheckPubkey { Request::CheckPubkey {
user,
session_identifier,
pubkey_alg_name,
pubkey,
} => {
let is_ok = crate::auth::check_pubkey(CheckPubkey {
user, user,
session_identifier, session_identifier,
pubkey_alg_name, pubkey_alg_name,
pubkey, pubkey,
} => { })
let Some(check_pubkey) = &self.auth_operations.check_pubkey else { .await
self.respond(VerifySignatureResponse { .map_err(|err| err.to_string());
is_ok: Err("public key login not supported".into()),
}) self.respond(CheckPubkeyResponse { is_ok }).await?;
.await?; }
continue; Request::Shell(req) => {
}; if self.shell_process.is_some() {
let is_ok = check_pubkey(CheckPubkey { self.respond(ShellResponse {
user, result: Err("process already running".to_owned()),
session_identifier,
pubkey_alg_name,
pubkey,
}) })
.await?;
return Ok(());
}
let Some(user) = self.authenticated_user.clone() else {
self.respond(ShellResponse {
result: Err("unauthenticated".to_owned()),
})
.await?;
return Ok(());
};
let result = self
.shell(&mut fds, &user, req)
.await .await
.map_err(|err| err.to_string()); .map_err(|err| err.to_string());
self.respond(CheckPubkeyResponse { is_ok }).await?; self.respond_ancillary(
} ShellResponse {
result: result.as_ref().map(drop).map_err(Clone::clone),
},
result.unwrap_or_default(),
)
.await?;
} }
Request::Wait => match &mut self.shell_process {
None => {
self.respond(WaitResponse {
result: Err("no child running".to_owned()),
})
.await?;
}
Some(child) => {
let result = child.0.wait().await;
self.respond(WaitResponse {
result: result
.map(|status| status.code())
.map_err(|err| err.to_string()),
})
.await?;
// implicitly drop stdio
self.shell_process = None;
}
},
} }
Ok(())
}
async fn shell(
&mut self,
fds: &mut Vec<OwnedFd>,
user: &User,
req: ShellRequest,
) -> Result<Vec<OwnedFd>> {
let shell = user.shell();
let mut cmd = Command::new(shell);
if let Some(shell_command) = req.command {
cmd.arg("-c");
cmd.arg(shell_command);
}
cmd.env_clear();
let has_pty = req.pty.is_some();
if let Some(pty) = req.pty {
if fds.len() != 1 {
bail!("invalid request: shell with PTY must send one FD");
}
let user_pty = fds.remove(0);
crate::pty::start_session_for_command(user_pty, pty.term, &mut cmd)?;
} else {
cmd.stdin(Stdio::piped());
cmd.stdout(Stdio::piped());
cmd.stderr(Stdio::piped());
}
cmd.current_dir(user.home_dir());
cmd.env("USER", user.name());
cmd.uid(user.uid());
cmd.gid(user.primary_group_id());
for (k, v) in req.env {
cmd.env(k, v);
}
debug!(cmd = %shell.display(), uid = %user.uid(), gid = %user.primary_group_id(), "Executing process");
let mut shell = cmd.spawn()?;
// See Server::shell_process
let mut fds1 = Vec::new();
let mut fds2 = Vec::new();
if !has_pty {
let stdin = shell.stdin.take().unwrap().into_owned_fd()?;
let stdout = shell.stdout.take().unwrap().into_owned_fd()?;
let stderr = shell.stderr.take().unwrap().into_owned_fd()?;
fds1.push(stdin.try_clone()?);
fds2.push(stdin);
fds1.push(stdout.try_clone()?);
fds2.push(stdout);
fds1.push(stderr.try_clone()?);
fds2.push(stderr);
}
self.shell_process = Some((shell, vec![]));
Ok(fds1)
} }
async fn respond(&self, resp: impl Serialize) -> Result<()> { async fn respond(&self, resp: impl Serialize) -> Result<()> {
self.server self.respond_ancillary(resp, vec![]).await
.send(&postcard::to_allocvec(&resp)?) }
.await
.wrap_err("sending response")?; async fn respond_ancillary(&self, resp: impl Serialize, fds: Vec<OwnedFd>) -> Result<()> {
send_with_fds(
self.server.as_fd().try_clone_to_owned()?,
postcard::to_allocvec(&resp)?,
fds,
)
.await?;
Ok(()) Ok(())
} }
} }
impl Client { impl Client {
pub fn from_fd(fd: OwnedFd) -> Result<Self> { pub fn from_fd(fd: OwnedFd) -> Result<Self> {
let socket = UnixDatagram::from_std(std::os::unix::net::UnixDatagram::from(fd))?; let socket = std::os::unix::net::UnixDatagram::from(fd);
Ok(Self { socket }) Ok(Self { socket })
} }
pub async fn sign(&self, hash: [u8; 32], public_key: PublicKey) -> Result<Signature> { pub async fn sign(&self, hash: [u8; 32], public_key: PublicKey) -> Result<Signature> {
let resp = self let resp = self
.request_response::<SignResponse>(&Request::Sign { hash, public_key }) .request_response::<SignResponse>(&Request::Sign { hash, public_key }, vec![])
.await?; .await?;
resp.signature.map_err(|err| eyre!(err)) resp.signature.map_err(|err| eyre!(err))
@ -186,12 +404,15 @@ impl Client {
pubkey: Vec<u8>, pubkey: Vec<u8>,
) -> Result<bool> { ) -> Result<bool> {
let resp = self let resp = self
.request_response::<CheckPubkeyResponse>(&Request::CheckPubkey { .request_response::<CheckPubkeyResponse>(
user, &Request::CheckPubkey {
session_identifier, user,
pubkey_alg_name, session_identifier,
pubkey, pubkey_alg_name,
}) pubkey,
},
vec![],
)
.await?; .await?;
resp.is_ok.map_err(|err| eyre!(err)) resp.is_ok.map_err(|err| eyre!(err))
@ -206,34 +427,133 @@ impl Client {
signature: Vec<u8>, signature: Vec<u8>,
) -> Result<bool> { ) -> Result<bool> {
let resp = self let resp = self
.request_response::<VerifySignatureResponse>(&Request::VerifySignature { .request_response::<VerifySignatureResponse>(
user, &Request::VerifySignature {
session_identifier, user,
pubkey_alg_name, session_identifier,
pubkey, pubkey_alg_name,
signature, pubkey,
}) signature,
},
vec![],
)
.await?; .await?;
resp.is_ok.map_err(|err| eyre!(err)) resp.is_ok.map_err(|err| eyre!(err))
} }
async fn request_response<Resp: DeserializeOwned>(&self, req: &Request) -> Result<Resp> { pub async fn exec(
self.socket &self,
.send(&postcard::to_allocvec(&req)?) command: Option<String>,
pty: Option<OwnedFd>,
term: String,
env: Vec<(String, String)>,
) -> Result<Vec<OwnedFd>> {
let has_pty = pty.is_some();
let fds = match pty {
Some(fd) => vec![fd],
None => vec![],
};
self.send_request(
&Request::Shell(ShellRequest {
pty: has_pty.then_some(ShellRequestPty { term }),
command,
env,
}),
fds,
)
.await?;
let (resp, fds) = self.recv_response_ancillary::<ShellResponse>().await?;
resp.result.map_err(|err| eyre!(err))?;
Ok(fds)
}
pub async fn wait(&self) -> Result<Option<i32>> {
self.request_response::<WaitResponse>(&Request::Wait, vec![])
.await .await
.wrap_err("sending request")?; .and_then(|resp| resp.result.map_err(|err| eyre!(err)))
}
let mut resp = [0; 1024]; async fn request_response<R: DeserializeOwned + Debug + Send + 'static>(
let read = self &self,
.socket req: &Request,
.recv(&mut resp) fds: Vec<OwnedFd>,
.await ) -> Result<R> {
.wrap_err("receiving response")?; self.send_request(req, fds).await?;
Ok(self.recv_response_ancillary::<R>().await?.0)
}
let resp = async fn send_request(&self, req: &Request, fds: Vec<OwnedFd>) -> Result<()> {
postcard::from_bytes::<Resp>(&resp[..read]).wrap_err("invalid signature response")?; let data = postcard::to_allocvec(&req)?;
Ok(resp) let socket = self.socket.as_fd().try_clone_to_owned()?;
send_with_fds(socket, data, fds).await?;
Ok(())
}
async fn recv_response_ancillary<R: DeserializeOwned + Debug + Send + 'static>(
&self,
) -> Result<(R, Vec<OwnedFd>)> {
let socket =
std::os::unix::net::UnixDatagram::from(self.socket.as_fd().try_clone_to_owned()?);
let (resp, fds) =
tokio::task::spawn_blocking(move || blocking_receive_with_fds(socket.as_fd()))
.await?
.wrap_err("failed to recv")?;
trace!(?resp, ?fds, "Received RPC response");
Ok((resp, fds))
} }
} }
async fn send_with_fds(socket: OwnedFd, data: Vec<u8>, fds: Vec<OwnedFd>) -> Result<()> {
tokio::task::spawn_blocking(move || {
let mut space = [0; rustix::cmsg_space!(ScmRights(3))]; //we send up to 3 fds at once
let mut ancillary = SendAncillaryBuffer::new(&mut space);
let fds = fds.iter().map(|fd| fd.as_fd()).collect::<Vec<_>>();
ancillary.push(SendAncillaryMessage::ScmRights(fds.as_slice()));
rustix::net::sendmsg(
socket,
&[IoSlice::new(&data)],
&mut ancillary,
SendFlags::empty(),
)
})
.await?
.wrap_err("failed to send")
.map(drop)
}
fn blocking_receive_with_fds<R: DeserializeOwned>(
blocking_socket: BorrowedFd<'_>,
) -> Result<(R, Vec<OwnedFd>)> {
let mut data = [0; 1024];
let mut space = [0; rustix::cmsg_space!(ScmRights(3))]; // maximum size
let mut cmesg_buf = RecvAncillaryBuffer::new(&mut space);
let mut fds = Vec::new();
let read = rustix::net::recvmsg(
blocking_socket,
&mut [IoSliceMut::new(&mut data)],
&mut cmesg_buf,
RecvFlags::empty(),
)?;
let data = postcard::from_bytes::<R>(&data[..read.bytes]).wrap_err("invalid request")?;
for msg in cmesg_buf.drain() {
match msg {
RecvAncillaryMessage::ScmRights(fd) => fds.extend(fd),
_ => bail!("unexpected ancillery msg"),
}
}
Ok((data, fds))
}

View file

@ -18,6 +18,7 @@ pub fn signature_data(session_id: [u8; 32], username: &str, pubkey: &PublicKey)
s.finish() s.finish()
} }
#[derive(Debug)]
pub enum Signature { pub enum Signature {
Ed25519 { signature: ed25519_dalek::Signature }, Ed25519 { signature: ed25519_dalek::Signature },
EcdsaSha2NistP256 { signature: p256::ecdsa::Signature }, EcdsaSha2NistP256 { signature: p256::ecdsa::Signature },