cleanup RPC

This commit is contained in:
nora 2024-08-29 01:27:06 +02:00
parent 4e9eb447db
commit ee00db37e1

View file

@ -1,12 +1,12 @@
//! [`postcard`]-based RPC between the different processes. //! [`postcard`]-based RPC between the different processes.
use std::fmt::Debug; use std::fmt::Debug;
use std::io;
use std::io::IoSlice; use std::io::IoSlice;
use std::io::IoSliceMut; use std::io::IoSliceMut;
use std::os::fd::AsFd; use std::os::fd::AsFd;
use std::os::fd::BorrowedFd; use std::os::fd::BorrowedFd;
use std::os::fd::OwnedFd; use std::os::fd::OwnedFd;
use std::os::unix::net::UnixDatagram;
use std::process::Stdio; use std::process::Stdio;
use cluelessh_keys::private::PlaintextPrivateKey; use cluelessh_keys::private::PlaintextPrivateKey;
@ -18,7 +18,6 @@ use eyre::bail;
use eyre::ensure; use eyre::ensure;
use eyre::eyre; use eyre::eyre;
use eyre::Context; use eyre::Context;
use eyre::OptionExt;
use eyre::Result; use eyre::Result;
use rustix::net::RecvAncillaryBuffer; use rustix::net::RecvAncillaryBuffer;
use rustix::net::RecvAncillaryMessage; use rustix::net::RecvAncillaryMessage;
@ -29,11 +28,11 @@ use rustix::net::SendFlags;
use rustix::termios::Winsize; use rustix::termios::Winsize;
use serde::de::DeserializeOwned; use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tokio::io::Interest;
use tokio::net::UnixDatagram;
use tokio::process::Child; use tokio::process::Child;
use tokio::process::Command; use tokio::process::Command;
use tokio::sync::mpsc;
use tracing::debug; use tracing::debug;
use tracing::error;
use tracing::trace; use tracing::trace;
use users::os::unix::UserExt; use users::os::unix::UserExt;
use users::User; use users::User;
@ -125,7 +124,6 @@ pub struct Client {
pub struct Server { pub struct Server {
server: UnixDatagram, server: UnixDatagram,
client: UnixDatagram, client: UnixDatagram,
server_recv_recv: mpsc::Receiver<(Request, Vec<OwnedFd>)>,
host_keys: Vec<PlaintextPrivateKey>, host_keys: Vec<PlaintextPrivateKey>,
authenticated_user: Option<users::User>, authenticated_user: Option<users::User>,
@ -133,38 +131,14 @@ pub struct Server {
shell_process: Option<Child>, shell_process: Option<Child>,
} }
fn server_thread(
socket: OwnedFd,
server_recv_send: mpsc::Sender<(Request, Vec<OwnedFd>)>,
) -> Result<()> {
let socket = std::os::unix::net::UnixDatagram::from(socket);
socket.set_nonblocking(false)?;
loop {
let (req, fds) = blocking_receive_with_fds::<Request>(socket.as_fd())?;
server_recv_send.blocking_send((req, fds))?;
}
}
impl Server { impl Server {
pub fn new(host_keys: Vec<PlaintextPrivateKey>) -> Result<Self> { pub fn new(host_keys: Vec<PlaintextPrivateKey>) -> Result<Self> {
let (server, client) = UnixDatagram::pair().wrap_err("creating socketpair")?; let (server, client) = UnixDatagram::pair().wrap_err("creating socketpair")?;
let (server_recv_send, server_recv_recv) = mpsc::channel(3);
let server_for_thread = server.as_fd().try_clone_to_owned()?;
std::thread::spawn(move || {
if let Err(err) = server_thread(server_for_thread, server_recv_send) {
error!(?err, "Server RPC recv thread error");
}
});
Ok(Self { Ok(Self {
server, server,
client, client,
host_keys, host_keys,
server_recv_recv,
authenticated_user: None, authenticated_user: None,
pty_user: None, pty_user: None,
shell_process: None, shell_process: None,
@ -177,11 +151,7 @@ impl Server {
pub async fn process(&mut self) -> Result<()> { pub async fn process(&mut self) -> Result<()> {
loop { loop {
let recv = self let recv = receive_with_fds::<Request>(&self.server).await?;
.server_recv_recv
.recv()
.await
.ok_or_eyre("RPC thread error")?;
self.receive_message(recv.0, recv.1).await?; self.receive_message(recv.0, recv.1).await?;
} }
} }
@ -281,8 +251,8 @@ impl Server {
) )
.await; .await;
let (controller, user) = match result { let (controller, user) = match &result {
Ok(pty) => (vec![pty.controller], Ok(pty.user_pty)), Ok(pty) => (vec![pty.controller.as_fd()], Ok(pty.user_pty.try_clone()?)),
Err(err) => (vec![], Err(err)), Err(err) => (vec![], Err(err)),
}; };
@ -290,7 +260,7 @@ impl Server {
ShellResponse { ShellResponse {
result: user.as_ref().map(drop).map_err(ToString::to_string), result: user.as_ref().map(drop).map_err(ToString::to_string),
}, },
controller, &controller,
) )
.await?; .await?;
@ -321,7 +291,11 @@ impl Server {
ShellResponse { ShellResponse {
result: result.as_ref().map(drop).map_err(Clone::clone), result: result.as_ref().map(drop).map_err(Clone::clone),
}, },
result.unwrap_or_default(), &result
.unwrap_or_default()
.iter()
.map(|fd| fd.as_fd())
.collect::<Vec<_>>(),
) )
.await?; .await?;
} }
@ -412,16 +386,11 @@ impl Server {
} }
async fn respond(&self, resp: impl Serialize) -> Result<()> { async fn respond(&self, resp: impl Serialize) -> Result<()> {
self.respond_ancillary(resp, vec![]).await self.respond_ancillary(resp, &[]).await
} }
async fn respond_ancillary(&self, resp: impl Serialize, fds: Vec<OwnedFd>) -> Result<()> { async fn respond_ancillary(&self, resp: impl Serialize, fds: &[BorrowedFd<'_>]) -> Result<()> {
send_with_fds( send_with_fds(&self.server, &postcard::to_allocvec(&resp)?, fds).await?;
self.server.as_fd().try_clone_to_owned()?,
postcard::to_allocvec(&resp)?,
fds,
)
.await?;
Ok(()) Ok(())
} }
@ -429,13 +398,13 @@ impl Server {
impl Client { impl Client {
pub fn from_fd(fd: OwnedFd) -> Result<Self> { pub fn from_fd(fd: OwnedFd) -> Result<Self> {
let socket = std::os::unix::net::UnixDatagram::from(fd); let socket = UnixDatagram::from_std(std::os::unix::net::UnixDatagram::from(fd))?;
Ok(Self { socket }) Ok(Self { socket })
} }
pub async fn sign(&self, hash: [u8; 32], public_key: PublicKey) -> Result<Signature> { pub async fn sign(&self, hash: [u8; 32], public_key: PublicKey) -> Result<Signature> {
let resp = self let resp = self
.request_response::<SignResponse>(&Request::Sign { hash, public_key }, vec![]) .request_response::<SignResponse>(&Request::Sign { hash, public_key })
.await?; .await?;
resp.signature.map_err(|err| eyre!(err)) resp.signature.map_err(|err| eyre!(err))
@ -449,15 +418,12 @@ impl Client {
pubkey: Vec<u8>, pubkey: Vec<u8>,
) -> Result<bool> { ) -> Result<bool> {
let resp = self let resp = self
.request_response::<CheckPubkeyResponse>( .request_response::<CheckPubkeyResponse>(&Request::CheckPubkey {
&Request::CheckPubkey {
user, user,
session_identifier, session_identifier,
pubkey_alg_name, pubkey_alg_name,
pubkey, pubkey,
}, })
vec![],
)
.await?; .await?;
resp.is_ok.map_err(|err| eyre!(err)) resp.is_ok.map_err(|err| eyre!(err))
@ -472,16 +438,13 @@ impl Client {
signature: Vec<u8>, signature: Vec<u8>,
) -> Result<bool> { ) -> Result<bool> {
let resp = self let resp = self
.request_response::<VerifySignatureResponse>( .request_response::<VerifySignatureResponse>(&Request::VerifySignature {
&Request::VerifySignature {
user, user,
session_identifier, session_identifier,
pubkey_alg_name, pubkey_alg_name,
pubkey, pubkey,
signature, signature,
}, })
vec![],
)
.await?; .await?;
resp.is_ok.map_err(|err| eyre!(err)) resp.is_ok.map_err(|err| eyre!(err))
@ -495,16 +458,13 @@ impl Client {
height_px: u32, height_px: u32,
term_modes: Vec<u8>, term_modes: Vec<u8>,
) -> Result<Vec<OwnedFd>> { ) -> Result<Vec<OwnedFd>> {
self.send_request( self.send_request(&Request::PtyReq(PtyRequest {
&Request::PtyReq(PtyRequest {
height_rows, height_rows,
width_chars, width_chars,
width_px, width_px,
height_px, height_px,
term_modes, term_modes,
}), }))
vec![],
)
.await?; .await?;
let (resp, fds) = self.recv_response_ancillary::<PtyResponse>().await?; let (resp, fds) = self.recv_response_ancillary::<PtyResponse>().await?;
@ -519,14 +479,11 @@ impl Client {
pty_term: Option<String>, pty_term: Option<String>,
env: Vec<(String, String)>, env: Vec<(String, String)>,
) -> Result<Vec<OwnedFd>> { ) -> Result<Vec<OwnedFd>> {
self.send_request( self.send_request(&Request::Shell(ShellRequest {
&Request::Shell(ShellRequest {
pty_term, pty_term,
command, command,
env, env,
}), }))
vec![],
)
.await?; .await?;
let (resp, fds) = self.recv_response_ancillary::<ShellResponse>().await?; let (resp, fds) = self.recv_response_ancillary::<ShellResponse>().await?;
@ -536,7 +493,7 @@ impl Client {
} }
pub async fn wait(&self) -> Result<Option<i32>> { pub async fn wait(&self) -> Result<Option<i32>> {
self.request_response::<WaitResponse>(&Request::Wait, vec![]) self.request_response::<WaitResponse>(&Request::Wait)
.await .await
.and_then(|resp| resp.result.map_err(|err| eyre!(err))) .and_then(|resp| resp.result.map_err(|err| eyre!(err)))
} }
@ -544,31 +501,23 @@ impl Client {
async fn request_response<R: DeserializeOwned + Debug + Send + 'static>( async fn request_response<R: DeserializeOwned + Debug + Send + 'static>(
&self, &self,
req: &Request, req: &Request,
fds: Vec<OwnedFd>,
) -> Result<R> { ) -> Result<R> {
self.send_request(req, fds).await?; self.send_request(req).await?;
Ok(self.recv_response_ancillary::<R>().await?.0) Ok(self.recv_response_ancillary::<R>().await?.0)
} }
async fn send_request(&self, req: &Request, fds: Vec<OwnedFd>) -> Result<()> { async fn send_request(&self, req: &Request) -> Result<()> {
// TODO: remove support for ancillary?
let data = postcard::to_allocvec(&req)?; let data = postcard::to_allocvec(&req)?;
let socket = self.socket.as_fd().try_clone_to_owned()?; send_with_fds(&self.socket, &data, &[]).await?;
send_with_fds(socket, data, fds).await?;
Ok(()) Ok(())
} }
async fn recv_response_ancillary<R: DeserializeOwned + Debug + Send + 'static>( async fn recv_response_ancillary<R: DeserializeOwned + Debug + Send + 'static>(
&self, &self,
) -> Result<(R, Vec<OwnedFd>)> { ) -> Result<(R, Vec<OwnedFd>)> {
let socket = let (resp, fds) = receive_with_fds(&self.socket)
std::os::unix::net::UnixDatagram::from(self.socket.as_fd().try_clone_to_owned()?); .await
let (resp, fds) =
tokio::task::spawn_blocking(move || blocking_receive_with_fds(socket.as_fd()))
.await?
.wrap_err("failed to recv")?; .wrap_err("failed to recv")?;
trace!(?resp, ?fds, "Received RPC response"); trace!(?resp, ?fds, "Received RPC response");
@ -577,40 +526,45 @@ impl Client {
} }
} }
async fn send_with_fds(socket: OwnedFd, data: Vec<u8>, fds: Vec<OwnedFd>) -> Result<()> { async fn send_with_fds(socket: &UnixDatagram, data: &[u8], fds: &[BorrowedFd<'_>]) -> Result<()> {
tokio::task::spawn_blocking(move || { socket
.async_io(Interest::WRITABLE, || {
let mut space = [0; rustix::cmsg_space!(ScmRights(3))]; //we send up to 3 fds at once let mut space = [0; rustix::cmsg_space!(ScmRights(3))]; //we send up to 3 fds at once
let mut ancillary = SendAncillaryBuffer::new(&mut space); let mut ancillary = SendAncillaryBuffer::new(&mut space);
let fds = fds.iter().map(|fd| fd.as_fd()).collect::<Vec<_>>();
ancillary.push(SendAncillaryMessage::ScmRights(fds.as_slice())); ancillary.push(SendAncillaryMessage::ScmRights(fds));
rustix::net::sendmsg( rustix::net::sendmsg(
socket, socket,
&[IoSlice::new(&data)], &[IoSlice::new(data)],
&mut ancillary, &mut ancillary,
SendFlags::empty(), SendFlags::empty(),
) )
.map_err(|errno| io::Error::from(errno))?;
Ok(())
}) })
.await? .await
.wrap_err("failed to send") .wrap_err("failed to write to socket")
.map(drop)
} }
fn blocking_receive_with_fds<R: DeserializeOwned>( async fn receive_with_fds<R: DeserializeOwned>(socket: &UnixDatagram) -> Result<(R, Vec<OwnedFd>)> {
blocking_socket: BorrowedFd<'_>,
) -> Result<(R, Vec<OwnedFd>)> {
let mut data = [0; 1024]; let mut data = [0; 1024];
let mut space = [0; rustix::cmsg_space!(ScmRights(3))]; // maximum size let mut space = [0; rustix::cmsg_space!(ScmRights(3))]; // maximum size
let mut cmesg_buf = RecvAncillaryBuffer::new(&mut space); let mut cmesg_buf = RecvAncillaryBuffer::new(&mut space);
let mut fds = Vec::new(); let read = socket
.async_io(Interest::READABLE, || {
let read = rustix::net::recvmsg( rustix::net::recvmsg(
blocking_socket, socket,
&mut [IoSliceMut::new(&mut data)], &mut [IoSliceMut::new(&mut data)],
&mut cmesg_buf, &mut cmesg_buf,
RecvFlags::empty(), RecvFlags::empty(),
)?; )
.map_err(|errno| io::Error::from(errno))
})
.await?;
let mut fds = Vec::new();
let data = postcard::from_bytes::<R>(&data[..read.bytes]).wrap_err("invalid request")?; let data = postcard::from_bytes::<R>(&data[..read.bytes]).wrap_err("invalid request")?;
for msg in cmesg_buf.drain() { for msg in cmesg_buf.drain() {