RPC improvements

This commit is contained in:
nora 2024-08-29 01:58:39 +02:00
parent ee00db37e1
commit 533b8cda1e
2 changed files with 107 additions and 140 deletions

View file

@ -98,7 +98,7 @@ async fn connection_inner(state: SerializedConnectionState) -> Result<()> {
let rpc_client = rpc_client2.clone(); let rpc_client = rpc_client2.clone();
Box::pin(async move { Box::pin(async move {
rpc_client rpc_client
.check_pubkey( .check_public_key(
msg.user, msg.user,
msg.session_identifier, msg.session_identifier,
msg.pubkey_alg_name, msg.pubkey_alg_name,
@ -475,20 +475,14 @@ impl SessionState {
height_px: u32, height_px: u32,
term_modes: Vec<u8>, term_modes: Vec<u8>,
) -> Result<()> { ) -> Result<()> {
let mut fd = self let controller = self
.rpc_client .rpc_client
.pty_req(width_chars, height_rows, width_px, height_px, term_modes) .pty_req(width_chars, height_rows, width_px, height_px, term_modes)
.await?; .await?;
ensure!(
fd.len() == 1,
"Incorrect amount of FDs received: {}",
fd.len()
);
self.pty_term = Some(term); self.pty_term = Some(term);
let controller = fd.remove(0);
self.writer = Some(Box::pin(File::from_std(std::fs::File::from( self.writer = Some(Box::pin(File::from_std(std::fs::File::from(
controller.try_clone()?, controller.try_clone()?,
)))); ))));

View file

@ -39,10 +39,25 @@ use users::User;
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
enum Request { enum Request {
// TODO: This is a bit... not good, it's not good.
// It can be used to sign any arbitrary message, or any arbitary exchange!
// I think we need to let the monitor do the DH Key Exchange.
// Basically, it should generate the private key for the exchange (and give that to the client)
// and then when signing, we compute the shared secret ourselves for the hash.
// This should ensure that the connection process cannot sign anything except an SSH kex has
// but only with our specific chosen shared secret, which should make it entirely useless for anything else.
Sign { Sign {
hash: [u8; 32], hash: [u8; 32],
public_key: PublicKey, public_key: PublicKey,
}, },
CheckPublicKey {
user: String,
session_identifier: [u8; 32],
pubkey_alg_name: String,
pubkey: Vec<u8>,
},
/// Verify that the public key signature for the user is okay.
/// If it is okay, store the user so we can later spawn a process as them.
VerifySignature { VerifySignature {
user: String, user: String,
session_identifier: [u8; 32], session_identifier: [u8; 32],
@ -50,18 +65,14 @@ enum Request {
pubkey: Vec<u8>, pubkey: Vec<u8>,
signature: Vec<u8>, signature: Vec<u8>,
}, },
CheckPubkey { /// Request a PTY. We create a new PTY and give the client an FD to the controller.
user: String,
session_identifier: [u8; 32],
pubkey_alg_name: String,
pubkey: Vec<u8>,
},
PtyReq(PtyRequest), PtyReq(PtyRequest),
/// Executes a command on the host. /// Executes a command on the host.
/// IMPORTANT: This is the critical operation, and we must ensure that it is secure. /// IMPORTANT: This is the critical operation, and we must ensure that it is secure.
/// To ensure that even a compromised auth process cannot escalate privileges via this RPC, /// To ensure that even a compromised auth process cannot escalate privileges via this RPC,
/// the RPC server keeps track of the authenciated user /// the RPC server keeps track of the authenciated user
Shell(ShellRequest), Shell(ShellRequest),
/// Wait for the currently running command to finish.
Wait, Wait,
} }
@ -88,34 +99,14 @@ struct ShellRequestPty {
term: String, term: String,
} }
#[derive(Debug, Serialize, Deserialize)] type SignResponse = Signature;
struct SignResponse { type VerifySignatureResponse = bool;
signature: Result<Signature, String>, type CheckPublicKeyResponse = bool;
} type ShellResponse = ();
type PtyReqResponse = ();
type WaitResponse = Option<i32>;
#[derive(Debug, Serialize, Deserialize)] type ResponseResult<T> = Result<T, String>;
struct VerifySignatureResponse {
is_ok: Result<bool, String>,
}
#[derive(Debug, Serialize, Deserialize)]
struct CheckPubkeyResponse {
is_ok: Result<bool, String>,
}
#[derive(Debug, Serialize, Deserialize)]
struct ShellResponse {
result: Result<(), String>,
}
#[derive(Debug, Serialize, Deserialize)]
struct PtyResponse {
result: Result<(), String>,
}
#[derive(Debug, Serialize, Deserialize)]
struct WaitResponse {
result: Result<Option<i32>, String>,
}
pub struct Client { pub struct Client {
socket: UnixDatagram, socket: UnixDatagram,
@ -151,13 +142,14 @@ impl Server {
pub async fn process(&mut self) -> Result<()> { pub async fn process(&mut self) -> Result<()> {
loop { loop {
let recv = receive_with_fds::<Request>(&self.server).await?; let (recv, fds) = receive_with_fds::<Request>(&self.server).await?;
self.receive_message(recv.0, recv.1).await?; ensure!(fds.is_empty(), "Client sent FDs in request");
self.receive_message(recv).await?;
} }
} }
async fn receive_message(&mut self, req: Request, fds: Vec<OwnedFd>) -> Result<()> { async fn receive_message(&mut self, req: Request) -> Result<()> {
trace!(?req, ?fds, "Received RPC message"); trace!(?req, "Received RPC message");
match req { match req {
Request::Sign { hash, public_key } => { Request::Sign { hash, public_key } => {
@ -166,20 +158,31 @@ impl Server {
.iter() .iter()
.find(|privkey| privkey.private_key.public_key() == public_key) .find(|privkey| privkey.private_key.public_key() == public_key)
else { else {
self.respond(SignResponse { self.respond_err("missing private key".to_owned()).await?;
signature: Err("missing private key".to_owned()),
})
.await?;
return Ok(()); return Ok(());
}; };
let signature = private.private_key.sign(&hash); let signature = private.private_key.sign(&hash);
self.respond(SignResponse { self.respond::<SignResponse>(Ok(signature)).await?;
signature: Ok(signature), }
Request::CheckPublicKey {
user,
session_identifier,
pubkey_alg_name,
pubkey,
} => {
let is_ok = crate::auth::check_pubkey(CheckPubkey {
user,
session_identifier,
pubkey_alg_name,
pubkey,
}) })
.await?; .await
.map_err(|err| err.to_string());
self.respond::<CheckPublicKeyResponse>(is_ok).await?;
} }
Request::VerifySignature { Request::VerifySignature {
user, user,
@ -189,10 +192,8 @@ impl Server {
signature, signature,
} => { } => {
if self.authenticated_user.is_some() { if self.authenticated_user.is_some() {
self.respond(VerifySignatureResponse { self.respond_err("user already authenticated".to_owned())
is_ok: Err("user already authenticated".to_owned()), .await?;
})
.await?;
} }
let is_ok = crate::auth::verify_signature(VerifySignature { let is_ok = crate::auth::verify_signature(VerifySignature {
user, user,
@ -211,31 +212,11 @@ impl Server {
None => false, None => false,
}); });
self.respond(VerifySignatureResponse { is_ok }).await?; self.respond::<VerifySignatureResponse>(is_ok).await?;
}
Request::CheckPubkey {
user,
session_identifier,
pubkey_alg_name,
pubkey,
} => {
let is_ok = crate::auth::check_pubkey(CheckPubkey {
user,
session_identifier,
pubkey_alg_name,
pubkey,
})
.await
.map_err(|err| err.to_string());
self.respond(CheckPubkeyResponse { is_ok }).await?;
} }
Request::PtyReq(req) => { Request::PtyReq(req) => {
if self.pty_user.is_some() { if self.pty_user.is_some() {
self.respond(ShellResponse { self.respond_err("already requests pty".to_owned()).await?;
result: Err("already requests pty".to_owned()),
})
.await?;
return Ok(()); return Ok(());
} }
@ -256,10 +237,8 @@ impl Server {
Err(err) => (vec![], Err(err)), Err(err) => (vec![], Err(err)),
}; };
self.respond_ancillary( self.respond_ancillary::<PtyReqResponse>(
ShellResponse { user.as_ref().map(drop).map_err(ToString::to_string),
result: user.as_ref().map(drop).map_err(ToString::to_string),
},
&controller, &controller,
) )
.await?; .await?;
@ -268,29 +247,22 @@ impl Server {
} }
Request::Shell(req) => { Request::Shell(req) => {
if self.shell_process.is_some() { if self.shell_process.is_some() {
self.respond(ShellResponse { self.respond_err("process already running".to_owned())
result: Err("process already running".to_owned()), .await?;
})
.await?;
return Ok(()); return Ok(());
} }
let Some(user) = self.authenticated_user.clone() else { let Some(user) = self.authenticated_user.clone() else {
self.respond(ShellResponse { self.respond_err("unauthenticated".to_owned()).await?;
result: Err("unauthenticated".to_owned()),
})
.await?;
return Ok(()); return Ok(());
}; };
let result = self.shell(&user, req).await.map_err(|err| err.to_string()); let result = self.shell(&user, req).await.map_err(|err| err.to_string());
self.respond_ancillary( self.respond_ancillary::<ShellResponse>(
ShellResponse { result.as_ref().map(drop).map_err(Clone::clone),
result: result.as_ref().map(drop).map_err(Clone::clone),
},
&result &result
.unwrap_or_default() .unwrap_or_default()
.iter() .iter()
@ -301,19 +273,16 @@ impl Server {
} }
Request::Wait => match &mut self.shell_process { Request::Wait => match &mut self.shell_process {
None => { None => {
self.respond(WaitResponse { self.respond_err("no child running".to_owned()).await?;
result: Err("no child running".to_owned()),
})
.await?;
} }
Some(child) => { Some(child) => {
let result = child.wait().await; let result = child.wait().await;
self.respond(WaitResponse { self.respond::<WaitResponse>(
result: result result
.map(|status| status.code()) .map(|status| status.code())
.map_err(|err| err.to_string()), .map_err(|err| err.to_string()),
}) )
.await?; .await?;
// implicitly drop stdio // implicitly drop stdio
@ -385,11 +354,19 @@ impl Server {
Ok(fds1) Ok(fds1)
} }
async fn respond(&self, resp: impl Serialize) -> Result<()> { async fn respond_err(&self, resp: String) -> Result<()> {
self.respond::<()>(Err(resp)).await
}
async fn respond<T: Serialize>(&self, resp: ResponseResult<T>) -> Result<()> {
self.respond_ancillary(resp, &[]).await self.respond_ancillary(resp, &[]).await
} }
async fn respond_ancillary(&self, resp: impl Serialize, fds: &[BorrowedFd<'_>]) -> Result<()> { async fn respond_ancillary<T: Serialize>(
&self,
resp: ResponseResult<T>,
fds: &[BorrowedFd<'_>],
) -> Result<()> {
send_with_fds(&self.server, &postcard::to_allocvec(&resp)?, fds).await?; send_with_fds(&self.server, &postcard::to_allocvec(&resp)?, fds).await?;
Ok(()) Ok(())
@ -403,30 +380,24 @@ impl Client {
} }
pub async fn sign(&self, hash: [u8; 32], public_key: PublicKey) -> Result<Signature> { pub async fn sign(&self, hash: [u8; 32], public_key: PublicKey) -> Result<Signature> {
let resp = self self.request_response::<SignResponse>(&Request::Sign { hash, public_key })
.request_response::<SignResponse>(&Request::Sign { hash, public_key }) .await
.await?;
resp.signature.map_err(|err| eyre!(err))
} }
pub async fn check_pubkey( pub async fn check_public_key(
&self, &self,
user: String, user: String,
session_identifier: [u8; 32], session_identifier: [u8; 32],
pubkey_alg_name: String, pubkey_alg_name: String,
pubkey: Vec<u8>, pubkey: Vec<u8>,
) -> Result<bool> { ) -> Result<bool> {
let resp = self self.request_response::<CheckPublicKeyResponse>(&Request::CheckPublicKey {
.request_response::<CheckPubkeyResponse>(&Request::CheckPubkey { user,
user, session_identifier,
session_identifier, pubkey_alg_name,
pubkey_alg_name, pubkey,
pubkey, })
}) .await
.await?;
resp.is_ok.map_err(|err| eyre!(err))
} }
pub async fn verify_signature( pub async fn verify_signature(
@ -437,17 +408,14 @@ impl Client {
pubkey: Vec<u8>, pubkey: Vec<u8>,
signature: Vec<u8>, signature: Vec<u8>,
) -> Result<bool> { ) -> Result<bool> {
let resp = self self.request_response::<VerifySignatureResponse>(&Request::VerifySignature {
.request_response::<VerifySignatureResponse>(&Request::VerifySignature { user,
user, session_identifier,
session_identifier, pubkey_alg_name,
pubkey_alg_name, pubkey,
pubkey, signature,
signature, })
}) .await
.await?;
resp.is_ok.map_err(|err| eyre!(err))
} }
pub async fn pty_req( pub async fn pty_req(
@ -457,7 +425,7 @@ impl Client {
width_px: u32, width_px: u32,
height_px: u32, height_px: u32,
term_modes: Vec<u8>, term_modes: Vec<u8>,
) -> Result<Vec<OwnedFd>> { ) -> Result<OwnedFd> {
self.send_request(&Request::PtyReq(PtyRequest { self.send_request(&Request::PtyReq(PtyRequest {
height_rows, height_rows,
width_chars, width_chars,
@ -467,10 +435,16 @@ impl Client {
})) }))
.await?; .await?;
let (resp, fds) = self.recv_response_ancillary::<PtyResponse>().await?; let (_, mut fds) = self.recv_response_ancillary::<PtyReqResponse>().await?;
resp.result.map_err(|err| eyre!(err))?; ensure!(
fds.len() == 1,
"Incorrect amount of FDs received: {}",
fds.len()
);
Ok(fds) let controller = fds.remove(0);
Ok(controller)
} }
pub async fn shell( pub async fn shell(
@ -486,16 +460,13 @@ impl Client {
})) }))
.await?; .await?;
let (resp, fds) = self.recv_response_ancillary::<ShellResponse>().await?; let (_, fds) = self.recv_response_ancillary::<ShellResponse>().await?;
resp.result.map_err(|err| eyre!(err))?;
Ok(fds) Ok(fds)
} }
pub async fn wait(&self) -> Result<Option<i32>> { pub async fn wait(&self) -> Result<Option<i32>> {
self.request_response::<WaitResponse>(&Request::Wait) self.request_response::<WaitResponse>(&Request::Wait).await
.await
.and_then(|resp| resp.result.map_err(|err| eyre!(err)))
} }
async fn request_response<R: DeserializeOwned + Debug + Send + 'static>( async fn request_response<R: DeserializeOwned + Debug + Send + 'static>(
@ -516,12 +487,14 @@ impl Client {
async fn recv_response_ancillary<R: DeserializeOwned + Debug + Send + 'static>( async fn recv_response_ancillary<R: DeserializeOwned + Debug + Send + 'static>(
&self, &self,
) -> Result<(R, Vec<OwnedFd>)> { ) -> Result<(R, Vec<OwnedFd>)> {
let (resp, fds) = receive_with_fds(&self.socket) let (resp, fds) = receive_with_fds::<ResponseResult<R>>(&self.socket)
.await .await
.wrap_err("failed to recv")?; .wrap_err("failed to recv")?;
trace!(?resp, ?fds, "Received RPC response"); trace!(?resp, ?fds, "Received RPC response");
let resp = resp.map_err(|err| eyre!(err))?;
Ok((resp, fds)) Ok((resp, fds))
} }
} }