diff --git a/bin/cluelesshd/src/rpc.rs b/bin/cluelesshd/src/rpc.rs index 21d09c2..770edcf 100644 --- a/bin/cluelesshd/src/rpc.rs +++ b/bin/cluelesshd/src/rpc.rs @@ -1,12 +1,12 @@ //! [`postcard`]-based RPC between the different processes. use std::fmt::Debug; +use std::io; use std::io::IoSlice; use std::io::IoSliceMut; use std::os::fd::AsFd; use std::os::fd::BorrowedFd; use std::os::fd::OwnedFd; -use std::os::unix::net::UnixDatagram; use std::process::Stdio; use cluelessh_keys::private::PlaintextPrivateKey; @@ -18,7 +18,6 @@ use eyre::bail; use eyre::ensure; use eyre::eyre; use eyre::Context; -use eyre::OptionExt; use eyre::Result; use rustix::net::RecvAncillaryBuffer; use rustix::net::RecvAncillaryMessage; @@ -29,11 +28,11 @@ use rustix::net::SendFlags; use rustix::termios::Winsize; use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; +use tokio::io::Interest; +use tokio::net::UnixDatagram; use tokio::process::Child; use tokio::process::Command; -use tokio::sync::mpsc; use tracing::debug; -use tracing::error; use tracing::trace; use users::os::unix::UserExt; use users::User; @@ -125,7 +124,6 @@ pub struct Client { pub struct Server { server: UnixDatagram, client: UnixDatagram, - server_recv_recv: mpsc::Receiver<(Request, Vec)>, host_keys: Vec, authenticated_user: Option, @@ -133,38 +131,14 @@ pub struct Server { shell_process: Option, } -fn server_thread( - socket: OwnedFd, - server_recv_send: mpsc::Sender<(Request, Vec)>, -) -> Result<()> { - let socket = std::os::unix::net::UnixDatagram::from(socket); - socket.set_nonblocking(false)?; - - loop { - let (req, fds) = blocking_receive_with_fds::(socket.as_fd())?; - server_recv_send.blocking_send((req, fds))?; - } -} - impl Server { pub fn new(host_keys: Vec) -> Result { let (server, client) = UnixDatagram::pair().wrap_err("creating socketpair")?; - let (server_recv_send, server_recv_recv) = mpsc::channel(3); - - let server_for_thread = server.as_fd().try_clone_to_owned()?; - - std::thread::spawn(move || { - if let Err(err) = server_thread(server_for_thread, server_recv_send) { - error!(?err, "Server RPC recv thread error"); - } - }); - Ok(Self { server, client, host_keys, - server_recv_recv, authenticated_user: None, pty_user: None, shell_process: None, @@ -177,11 +151,7 @@ impl Server { pub async fn process(&mut self) -> Result<()> { loop { - let recv = self - .server_recv_recv - .recv() - .await - .ok_or_eyre("RPC thread error")?; + let recv = receive_with_fds::(&self.server).await?; self.receive_message(recv.0, recv.1).await?; } } @@ -281,8 +251,8 @@ impl Server { ) .await; - let (controller, user) = match result { - Ok(pty) => (vec![pty.controller], Ok(pty.user_pty)), + let (controller, user) = match &result { + Ok(pty) => (vec![pty.controller.as_fd()], Ok(pty.user_pty.try_clone()?)), Err(err) => (vec![], Err(err)), }; @@ -290,7 +260,7 @@ impl Server { ShellResponse { result: user.as_ref().map(drop).map_err(ToString::to_string), }, - controller, + &controller, ) .await?; @@ -321,7 +291,11 @@ impl Server { ShellResponse { result: result.as_ref().map(drop).map_err(Clone::clone), }, - result.unwrap_or_default(), + &result + .unwrap_or_default() + .iter() + .map(|fd| fd.as_fd()) + .collect::>(), ) .await?; } @@ -412,16 +386,11 @@ impl Server { } async fn respond(&self, resp: impl Serialize) -> Result<()> { - self.respond_ancillary(resp, vec![]).await + self.respond_ancillary(resp, &[]).await } - async fn respond_ancillary(&self, resp: impl Serialize, fds: Vec) -> Result<()> { - send_with_fds( - self.server.as_fd().try_clone_to_owned()?, - postcard::to_allocvec(&resp)?, - fds, - ) - .await?; + async fn respond_ancillary(&self, resp: impl Serialize, fds: &[BorrowedFd<'_>]) -> Result<()> { + send_with_fds(&self.server, &postcard::to_allocvec(&resp)?, fds).await?; Ok(()) } @@ -429,13 +398,13 @@ impl Server { impl Client { pub fn from_fd(fd: OwnedFd) -> Result { - let socket = std::os::unix::net::UnixDatagram::from(fd); + let socket = UnixDatagram::from_std(std::os::unix::net::UnixDatagram::from(fd))?; Ok(Self { socket }) } pub async fn sign(&self, hash: [u8; 32], public_key: PublicKey) -> Result { let resp = self - .request_response::(&Request::Sign { hash, public_key }, vec![]) + .request_response::(&Request::Sign { hash, public_key }) .await?; resp.signature.map_err(|err| eyre!(err)) @@ -449,15 +418,12 @@ impl Client { pubkey: Vec, ) -> Result { let resp = self - .request_response::( - &Request::CheckPubkey { - user, - session_identifier, - pubkey_alg_name, - pubkey, - }, - vec![], - ) + .request_response::(&Request::CheckPubkey { + user, + session_identifier, + pubkey_alg_name, + pubkey, + }) .await?; resp.is_ok.map_err(|err| eyre!(err)) @@ -472,16 +438,13 @@ impl Client { signature: Vec, ) -> Result { let resp = self - .request_response::( - &Request::VerifySignature { - user, - session_identifier, - pubkey_alg_name, - pubkey, - signature, - }, - vec![], - ) + .request_response::(&Request::VerifySignature { + user, + session_identifier, + pubkey_alg_name, + pubkey, + signature, + }) .await?; resp.is_ok.map_err(|err| eyre!(err)) @@ -495,16 +458,13 @@ impl Client { height_px: u32, term_modes: Vec, ) -> Result> { - self.send_request( - &Request::PtyReq(PtyRequest { - height_rows, - width_chars, - width_px, - height_px, - term_modes, - }), - vec![], - ) + self.send_request(&Request::PtyReq(PtyRequest { + height_rows, + width_chars, + width_px, + height_px, + term_modes, + })) .await?; let (resp, fds) = self.recv_response_ancillary::().await?; @@ -519,14 +479,11 @@ impl Client { pty_term: Option, env: Vec<(String, String)>, ) -> Result> { - self.send_request( - &Request::Shell(ShellRequest { - pty_term, - command, - env, - }), - vec![], - ) + self.send_request(&Request::Shell(ShellRequest { + pty_term, + command, + env, + })) .await?; let (resp, fds) = self.recv_response_ancillary::().await?; @@ -536,7 +493,7 @@ impl Client { } pub async fn wait(&self) -> Result> { - self.request_response::(&Request::Wait, vec![]) + self.request_response::(&Request::Wait) .await .and_then(|resp| resp.result.map_err(|err| eyre!(err))) } @@ -544,32 +501,24 @@ impl Client { async fn request_response( &self, req: &Request, - fds: Vec, ) -> Result { - self.send_request(req, fds).await?; + self.send_request(req).await?; Ok(self.recv_response_ancillary::().await?.0) } - async fn send_request(&self, req: &Request, fds: Vec) -> Result<()> { - // TODO: remove support for ancillary? + async fn send_request(&self, req: &Request) -> Result<()> { let data = postcard::to_allocvec(&req)?; - let socket = self.socket.as_fd().try_clone_to_owned()?; - - send_with_fds(socket, data, fds).await?; + send_with_fds(&self.socket, &data, &[]).await?; Ok(()) } async fn recv_response_ancillary( &self, ) -> Result<(R, Vec)> { - let socket = - std::os::unix::net::UnixDatagram::from(self.socket.as_fd().try_clone_to_owned()?); - - let (resp, fds) = - tokio::task::spawn_blocking(move || blocking_receive_with_fds(socket.as_fd())) - .await? - .wrap_err("failed to recv")?; + let (resp, fds) = receive_with_fds(&self.socket) + .await + .wrap_err("failed to recv")?; trace!(?resp, ?fds, "Received RPC response"); @@ -577,40 +526,45 @@ impl Client { } } -async fn send_with_fds(socket: OwnedFd, data: Vec, fds: Vec) -> Result<()> { - tokio::task::spawn_blocking(move || { - let mut space = [0; rustix::cmsg_space!(ScmRights(3))]; //we send up to 3 fds at once - let mut ancillary = SendAncillaryBuffer::new(&mut space); - let fds = fds.iter().map(|fd| fd.as_fd()).collect::>(); +async fn send_with_fds(socket: &UnixDatagram, data: &[u8], fds: &[BorrowedFd<'_>]) -> Result<()> { + socket + .async_io(Interest::WRITABLE, || { + let mut space = [0; rustix::cmsg_space!(ScmRights(3))]; //we send up to 3 fds at once + let mut ancillary = SendAncillaryBuffer::new(&mut space); - ancillary.push(SendAncillaryMessage::ScmRights(fds.as_slice())); - rustix::net::sendmsg( - socket, - &[IoSlice::new(&data)], - &mut ancillary, - SendFlags::empty(), - ) - }) - .await? - .wrap_err("failed to send") - .map(drop) + ancillary.push(SendAncillaryMessage::ScmRights(fds)); + rustix::net::sendmsg( + socket, + &[IoSlice::new(data)], + &mut ancillary, + SendFlags::empty(), + ) + .map_err(|errno| io::Error::from(errno))?; + Ok(()) + }) + .await + .wrap_err("failed to write to socket") } -fn blocking_receive_with_fds( - blocking_socket: BorrowedFd<'_>, -) -> Result<(R, Vec)> { +async fn receive_with_fds(socket: &UnixDatagram) -> Result<(R, Vec)> { let mut data = [0; 1024]; let mut space = [0; rustix::cmsg_space!(ScmRights(3))]; // maximum size let mut cmesg_buf = RecvAncillaryBuffer::new(&mut space); + let read = socket + .async_io(Interest::READABLE, || { + rustix::net::recvmsg( + socket, + &mut [IoSliceMut::new(&mut data)], + &mut cmesg_buf, + RecvFlags::empty(), + ) + .map_err(|errno| io::Error::from(errno)) + }) + .await?; + let mut fds = Vec::new(); - let read = rustix::net::recvmsg( - blocking_socket, - &mut [IoSliceMut::new(&mut data)], - &mut cmesg_buf, - RecvFlags::empty(), - )?; let data = postcard::from_bytes::(&data[..read.bytes]).wrap_err("invalid request")?; for msg in cmesg_buf.drain() {