cleanup RPC

This commit is contained in:
nora 2024-08-29 01:27:06 +02:00
parent 4e9eb447db
commit ee00db37e1

View file

@ -1,12 +1,12 @@
//! [`postcard`]-based RPC between the different processes.
use std::fmt::Debug;
use std::io;
use std::io::IoSlice;
use std::io::IoSliceMut;
use std::os::fd::AsFd;
use std::os::fd::BorrowedFd;
use std::os::fd::OwnedFd;
use std::os::unix::net::UnixDatagram;
use std::process::Stdio;
use cluelessh_keys::private::PlaintextPrivateKey;
@ -18,7 +18,6 @@ use eyre::bail;
use eyre::ensure;
use eyre::eyre;
use eyre::Context;
use eyre::OptionExt;
use eyre::Result;
use rustix::net::RecvAncillaryBuffer;
use rustix::net::RecvAncillaryMessage;
@ -29,11 +28,11 @@ use rustix::net::SendFlags;
use rustix::termios::Winsize;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use tokio::io::Interest;
use tokio::net::UnixDatagram;
use tokio::process::Child;
use tokio::process::Command;
use tokio::sync::mpsc;
use tracing::debug;
use tracing::error;
use tracing::trace;
use users::os::unix::UserExt;
use users::User;
@ -125,7 +124,6 @@ pub struct Client {
pub struct Server {
server: UnixDatagram,
client: UnixDatagram,
server_recv_recv: mpsc::Receiver<(Request, Vec<OwnedFd>)>,
host_keys: Vec<PlaintextPrivateKey>,
authenticated_user: Option<users::User>,
@ -133,38 +131,14 @@ pub struct Server {
shell_process: Option<Child>,
}
fn server_thread(
socket: OwnedFd,
server_recv_send: mpsc::Sender<(Request, Vec<OwnedFd>)>,
) -> Result<()> {
let socket = std::os::unix::net::UnixDatagram::from(socket);
socket.set_nonblocking(false)?;
loop {
let (req, fds) = blocking_receive_with_fds::<Request>(socket.as_fd())?;
server_recv_send.blocking_send((req, fds))?;
}
}
impl Server {
pub fn new(host_keys: Vec<PlaintextPrivateKey>) -> Result<Self> {
let (server, client) = UnixDatagram::pair().wrap_err("creating socketpair")?;
let (server_recv_send, server_recv_recv) = mpsc::channel(3);
let server_for_thread = server.as_fd().try_clone_to_owned()?;
std::thread::spawn(move || {
if let Err(err) = server_thread(server_for_thread, server_recv_send) {
error!(?err, "Server RPC recv thread error");
}
});
Ok(Self {
server,
client,
host_keys,
server_recv_recv,
authenticated_user: None,
pty_user: None,
shell_process: None,
@ -177,11 +151,7 @@ impl Server {
pub async fn process(&mut self) -> Result<()> {
loop {
let recv = self
.server_recv_recv
.recv()
.await
.ok_or_eyre("RPC thread error")?;
let recv = receive_with_fds::<Request>(&self.server).await?;
self.receive_message(recv.0, recv.1).await?;
}
}
@ -281,8 +251,8 @@ impl Server {
)
.await;
let (controller, user) = match result {
Ok(pty) => (vec![pty.controller], Ok(pty.user_pty)),
let (controller, user) = match &result {
Ok(pty) => (vec![pty.controller.as_fd()], Ok(pty.user_pty.try_clone()?)),
Err(err) => (vec![], Err(err)),
};
@ -290,7 +260,7 @@ impl Server {
ShellResponse {
result: user.as_ref().map(drop).map_err(ToString::to_string),
},
controller,
&controller,
)
.await?;
@ -321,7 +291,11 @@ impl Server {
ShellResponse {
result: result.as_ref().map(drop).map_err(Clone::clone),
},
result.unwrap_or_default(),
&result
.unwrap_or_default()
.iter()
.map(|fd| fd.as_fd())
.collect::<Vec<_>>(),
)
.await?;
}
@ -412,16 +386,11 @@ impl Server {
}
async fn respond(&self, resp: impl Serialize) -> Result<()> {
self.respond_ancillary(resp, vec![]).await
self.respond_ancillary(resp, &[]).await
}
async fn respond_ancillary(&self, resp: impl Serialize, fds: Vec<OwnedFd>) -> Result<()> {
send_with_fds(
self.server.as_fd().try_clone_to_owned()?,
postcard::to_allocvec(&resp)?,
fds,
)
.await?;
async fn respond_ancillary(&self, resp: impl Serialize, fds: &[BorrowedFd<'_>]) -> Result<()> {
send_with_fds(&self.server, &postcard::to_allocvec(&resp)?, fds).await?;
Ok(())
}
@ -429,13 +398,13 @@ impl Server {
impl Client {
pub fn from_fd(fd: OwnedFd) -> Result<Self> {
let socket = std::os::unix::net::UnixDatagram::from(fd);
let socket = UnixDatagram::from_std(std::os::unix::net::UnixDatagram::from(fd))?;
Ok(Self { socket })
}
pub async fn sign(&self, hash: [u8; 32], public_key: PublicKey) -> Result<Signature> {
let resp = self
.request_response::<SignResponse>(&Request::Sign { hash, public_key }, vec![])
.request_response::<SignResponse>(&Request::Sign { hash, public_key })
.await?;
resp.signature.map_err(|err| eyre!(err))
@ -449,15 +418,12 @@ impl Client {
pubkey: Vec<u8>,
) -> Result<bool> {
let resp = self
.request_response::<CheckPubkeyResponse>(
&Request::CheckPubkey {
.request_response::<CheckPubkeyResponse>(&Request::CheckPubkey {
user,
session_identifier,
pubkey_alg_name,
pubkey,
},
vec![],
)
})
.await?;
resp.is_ok.map_err(|err| eyre!(err))
@ -472,16 +438,13 @@ impl Client {
signature: Vec<u8>,
) -> Result<bool> {
let resp = self
.request_response::<VerifySignatureResponse>(
&Request::VerifySignature {
.request_response::<VerifySignatureResponse>(&Request::VerifySignature {
user,
session_identifier,
pubkey_alg_name,
pubkey,
signature,
},
vec![],
)
})
.await?;
resp.is_ok.map_err(|err| eyre!(err))
@ -495,16 +458,13 @@ impl Client {
height_px: u32,
term_modes: Vec<u8>,
) -> Result<Vec<OwnedFd>> {
self.send_request(
&Request::PtyReq(PtyRequest {
self.send_request(&Request::PtyReq(PtyRequest {
height_rows,
width_chars,
width_px,
height_px,
term_modes,
}),
vec![],
)
}))
.await?;
let (resp, fds) = self.recv_response_ancillary::<PtyResponse>().await?;
@ -519,14 +479,11 @@ impl Client {
pty_term: Option<String>,
env: Vec<(String, String)>,
) -> Result<Vec<OwnedFd>> {
self.send_request(
&Request::Shell(ShellRequest {
self.send_request(&Request::Shell(ShellRequest {
pty_term,
command,
env,
}),
vec![],
)
}))
.await?;
let (resp, fds) = self.recv_response_ancillary::<ShellResponse>().await?;
@ -536,7 +493,7 @@ impl Client {
}
pub async fn wait(&self) -> Result<Option<i32>> {
self.request_response::<WaitResponse>(&Request::Wait, vec![])
self.request_response::<WaitResponse>(&Request::Wait)
.await
.and_then(|resp| resp.result.map_err(|err| eyre!(err)))
}
@ -544,31 +501,23 @@ impl Client {
async fn request_response<R: DeserializeOwned + Debug + Send + 'static>(
&self,
req: &Request,
fds: Vec<OwnedFd>,
) -> Result<R> {
self.send_request(req, fds).await?;
self.send_request(req).await?;
Ok(self.recv_response_ancillary::<R>().await?.0)
}
async fn send_request(&self, req: &Request, fds: Vec<OwnedFd>) -> Result<()> {
// TODO: remove support for ancillary?
async fn send_request(&self, req: &Request) -> Result<()> {
let data = postcard::to_allocvec(&req)?;
let socket = self.socket.as_fd().try_clone_to_owned()?;
send_with_fds(socket, data, fds).await?;
send_with_fds(&self.socket, &data, &[]).await?;
Ok(())
}
async fn recv_response_ancillary<R: DeserializeOwned + Debug + Send + 'static>(
&self,
) -> Result<(R, Vec<OwnedFd>)> {
let socket =
std::os::unix::net::UnixDatagram::from(self.socket.as_fd().try_clone_to_owned()?);
let (resp, fds) =
tokio::task::spawn_blocking(move || blocking_receive_with_fds(socket.as_fd()))
.await?
let (resp, fds) = receive_with_fds(&self.socket)
.await
.wrap_err("failed to recv")?;
trace!(?resp, ?fds, "Received RPC response");
@ -577,40 +526,45 @@ impl Client {
}
}
async fn send_with_fds(socket: OwnedFd, data: Vec<u8>, fds: Vec<OwnedFd>) -> Result<()> {
tokio::task::spawn_blocking(move || {
async fn send_with_fds(socket: &UnixDatagram, data: &[u8], fds: &[BorrowedFd<'_>]) -> Result<()> {
socket
.async_io(Interest::WRITABLE, || {
let mut space = [0; rustix::cmsg_space!(ScmRights(3))]; //we send up to 3 fds at once
let mut ancillary = SendAncillaryBuffer::new(&mut space);
let fds = fds.iter().map(|fd| fd.as_fd()).collect::<Vec<_>>();
ancillary.push(SendAncillaryMessage::ScmRights(fds.as_slice()));
ancillary.push(SendAncillaryMessage::ScmRights(fds));
rustix::net::sendmsg(
socket,
&[IoSlice::new(&data)],
&[IoSlice::new(data)],
&mut ancillary,
SendFlags::empty(),
)
.map_err(|errno| io::Error::from(errno))?;
Ok(())
})
.await?
.wrap_err("failed to send")
.map(drop)
.await
.wrap_err("failed to write to socket")
}
fn blocking_receive_with_fds<R: DeserializeOwned>(
blocking_socket: BorrowedFd<'_>,
) -> Result<(R, Vec<OwnedFd>)> {
async fn receive_with_fds<R: DeserializeOwned>(socket: &UnixDatagram) -> Result<(R, Vec<OwnedFd>)> {
let mut data = [0; 1024];
let mut space = [0; rustix::cmsg_space!(ScmRights(3))]; // maximum size
let mut cmesg_buf = RecvAncillaryBuffer::new(&mut space);
let mut fds = Vec::new();
let read = rustix::net::recvmsg(
blocking_socket,
let read = socket
.async_io(Interest::READABLE, || {
rustix::net::recvmsg(
socket,
&mut [IoSliceMut::new(&mut data)],
&mut cmesg_buf,
RecvFlags::empty(),
)?;
)
.map_err(|errno| io::Error::from(errno))
})
.await?;
let mut fds = Vec::new();
let data = postcard::from_bytes::<R>(&data[..read.bytes]).wrap_err("invalid request")?;
for msg in cmesg_buf.drain() {