RPC improvements

This commit is contained in:
nora 2024-08-29 01:58:39 +02:00
parent ee00db37e1
commit 533b8cda1e
2 changed files with 107 additions and 140 deletions

View file

@ -98,7 +98,7 @@ async fn connection_inner(state: SerializedConnectionState) -> Result<()> {
let rpc_client = rpc_client2.clone();
Box::pin(async move {
rpc_client
.check_pubkey(
.check_public_key(
msg.user,
msg.session_identifier,
msg.pubkey_alg_name,
@ -475,20 +475,14 @@ impl SessionState {
height_px: u32,
term_modes: Vec<u8>,
) -> Result<()> {
let mut fd = self
let controller = self
.rpc_client
.pty_req(width_chars, height_rows, width_px, height_px, term_modes)
.await?;
ensure!(
fd.len() == 1,
"Incorrect amount of FDs received: {}",
fd.len()
);
self.pty_term = Some(term);
let controller = fd.remove(0);
self.writer = Some(Box::pin(File::from_std(std::fs::File::from(
controller.try_clone()?,
))));

View file

@ -39,10 +39,25 @@ use users::User;
#[derive(Debug, Serialize, Deserialize)]
enum Request {
// TODO: This is a bit... not good, it's not good.
// It can be used to sign any arbitrary message, or any arbitary exchange!
// I think we need to let the monitor do the DH Key Exchange.
// Basically, it should generate the private key for the exchange (and give that to the client)
// and then when signing, we compute the shared secret ourselves for the hash.
// This should ensure that the connection process cannot sign anything except an SSH kex has
// but only with our specific chosen shared secret, which should make it entirely useless for anything else.
Sign {
hash: [u8; 32],
public_key: PublicKey,
},
CheckPublicKey {
user: String,
session_identifier: [u8; 32],
pubkey_alg_name: String,
pubkey: Vec<u8>,
},
/// Verify that the public key signature for the user is okay.
/// If it is okay, store the user so we can later spawn a process as them.
VerifySignature {
user: String,
session_identifier: [u8; 32],
@ -50,18 +65,14 @@ enum Request {
pubkey: Vec<u8>,
signature: Vec<u8>,
},
CheckPubkey {
user: String,
session_identifier: [u8; 32],
pubkey_alg_name: String,
pubkey: Vec<u8>,
},
/// Request a PTY. We create a new PTY and give the client an FD to the controller.
PtyReq(PtyRequest),
/// Executes a command on the host.
/// IMPORTANT: This is the critical operation, and we must ensure that it is secure.
/// To ensure that even a compromised auth process cannot escalate privileges via this RPC,
/// the RPC server keeps track of the authenciated user
Shell(ShellRequest),
/// Wait for the currently running command to finish.
Wait,
}
@ -88,34 +99,14 @@ struct ShellRequestPty {
term: String,
}
#[derive(Debug, Serialize, Deserialize)]
struct SignResponse {
signature: Result<Signature, String>,
}
type SignResponse = Signature;
type VerifySignatureResponse = bool;
type CheckPublicKeyResponse = bool;
type ShellResponse = ();
type PtyReqResponse = ();
type WaitResponse = Option<i32>;
#[derive(Debug, Serialize, Deserialize)]
struct VerifySignatureResponse {
is_ok: Result<bool, String>,
}
#[derive(Debug, Serialize, Deserialize)]
struct CheckPubkeyResponse {
is_ok: Result<bool, String>,
}
#[derive(Debug, Serialize, Deserialize)]
struct ShellResponse {
result: Result<(), String>,
}
#[derive(Debug, Serialize, Deserialize)]
struct PtyResponse {
result: Result<(), String>,
}
#[derive(Debug, Serialize, Deserialize)]
struct WaitResponse {
result: Result<Option<i32>, String>,
}
type ResponseResult<T> = Result<T, String>;
pub struct Client {
socket: UnixDatagram,
@ -151,13 +142,14 @@ impl Server {
pub async fn process(&mut self) -> Result<()> {
loop {
let recv = receive_with_fds::<Request>(&self.server).await?;
self.receive_message(recv.0, recv.1).await?;
let (recv, fds) = receive_with_fds::<Request>(&self.server).await?;
ensure!(fds.is_empty(), "Client sent FDs in request");
self.receive_message(recv).await?;
}
}
async fn receive_message(&mut self, req: Request, fds: Vec<OwnedFd>) -> Result<()> {
trace!(?req, ?fds, "Received RPC message");
async fn receive_message(&mut self, req: Request) -> Result<()> {
trace!(?req, "Received RPC message");
match req {
Request::Sign { hash, public_key } => {
@ -166,20 +158,31 @@ impl Server {
.iter()
.find(|privkey| privkey.private_key.public_key() == public_key)
else {
self.respond(SignResponse {
signature: Err("missing private key".to_owned()),
})
.await?;
self.respond_err("missing private key".to_owned()).await?;
return Ok(());
};
let signature = private.private_key.sign(&hash);
self.respond(SignResponse {
signature: Ok(signature),
self.respond::<SignResponse>(Ok(signature)).await?;
}
Request::CheckPublicKey {
user,
session_identifier,
pubkey_alg_name,
pubkey,
} => {
let is_ok = crate::auth::check_pubkey(CheckPubkey {
user,
session_identifier,
pubkey_alg_name,
pubkey,
})
.await?;
.await
.map_err(|err| err.to_string());
self.respond::<CheckPublicKeyResponse>(is_ok).await?;
}
Request::VerifySignature {
user,
@ -189,10 +192,8 @@ impl Server {
signature,
} => {
if self.authenticated_user.is_some() {
self.respond(VerifySignatureResponse {
is_ok: Err("user already authenticated".to_owned()),
})
.await?;
self.respond_err("user already authenticated".to_owned())
.await?;
}
let is_ok = crate::auth::verify_signature(VerifySignature {
user,
@ -211,31 +212,11 @@ impl Server {
None => false,
});
self.respond(VerifySignatureResponse { is_ok }).await?;
}
Request::CheckPubkey {
user,
session_identifier,
pubkey_alg_name,
pubkey,
} => {
let is_ok = crate::auth::check_pubkey(CheckPubkey {
user,
session_identifier,
pubkey_alg_name,
pubkey,
})
.await
.map_err(|err| err.to_string());
self.respond(CheckPubkeyResponse { is_ok }).await?;
self.respond::<VerifySignatureResponse>(is_ok).await?;
}
Request::PtyReq(req) => {
if self.pty_user.is_some() {
self.respond(ShellResponse {
result: Err("already requests pty".to_owned()),
})
.await?;
self.respond_err("already requests pty".to_owned()).await?;
return Ok(());
}
@ -256,10 +237,8 @@ impl Server {
Err(err) => (vec![], Err(err)),
};
self.respond_ancillary(
ShellResponse {
result: user.as_ref().map(drop).map_err(ToString::to_string),
},
self.respond_ancillary::<PtyReqResponse>(
user.as_ref().map(drop).map_err(ToString::to_string),
&controller,
)
.await?;
@ -268,29 +247,22 @@ impl Server {
}
Request::Shell(req) => {
if self.shell_process.is_some() {
self.respond(ShellResponse {
result: Err("process already running".to_owned()),
})
.await?;
self.respond_err("process already running".to_owned())
.await?;
return Ok(());
}
let Some(user) = self.authenticated_user.clone() else {
self.respond(ShellResponse {
result: Err("unauthenticated".to_owned()),
})
.await?;
self.respond_err("unauthenticated".to_owned()).await?;
return Ok(());
};
let result = self.shell(&user, req).await.map_err(|err| err.to_string());
self.respond_ancillary(
ShellResponse {
result: result.as_ref().map(drop).map_err(Clone::clone),
},
self.respond_ancillary::<ShellResponse>(
result.as_ref().map(drop).map_err(Clone::clone),
&result
.unwrap_or_default()
.iter()
@ -301,19 +273,16 @@ impl Server {
}
Request::Wait => match &mut self.shell_process {
None => {
self.respond(WaitResponse {
result: Err("no child running".to_owned()),
})
.await?;
self.respond_err("no child running".to_owned()).await?;
}
Some(child) => {
let result = child.wait().await;
self.respond(WaitResponse {
result: result
self.respond::<WaitResponse>(
result
.map(|status| status.code())
.map_err(|err| err.to_string()),
})
)
.await?;
// implicitly drop stdio
@ -385,11 +354,19 @@ impl Server {
Ok(fds1)
}
async fn respond(&self, resp: impl Serialize) -> Result<()> {
async fn respond_err(&self, resp: String) -> Result<()> {
self.respond::<()>(Err(resp)).await
}
async fn respond<T: Serialize>(&self, resp: ResponseResult<T>) -> Result<()> {
self.respond_ancillary(resp, &[]).await
}
async fn respond_ancillary(&self, resp: impl Serialize, fds: &[BorrowedFd<'_>]) -> Result<()> {
async fn respond_ancillary<T: Serialize>(
&self,
resp: ResponseResult<T>,
fds: &[BorrowedFd<'_>],
) -> Result<()> {
send_with_fds(&self.server, &postcard::to_allocvec(&resp)?, fds).await?;
Ok(())
@ -403,30 +380,24 @@ impl Client {
}
pub async fn sign(&self, hash: [u8; 32], public_key: PublicKey) -> Result<Signature> {
let resp = self
.request_response::<SignResponse>(&Request::Sign { hash, public_key })
.await?;
resp.signature.map_err(|err| eyre!(err))
self.request_response::<SignResponse>(&Request::Sign { hash, public_key })
.await
}
pub async fn check_pubkey(
pub async fn check_public_key(
&self,
user: String,
session_identifier: [u8; 32],
pubkey_alg_name: String,
pubkey: Vec<u8>,
) -> Result<bool> {
let resp = self
.request_response::<CheckPubkeyResponse>(&Request::CheckPubkey {
user,
session_identifier,
pubkey_alg_name,
pubkey,
})
.await?;
resp.is_ok.map_err(|err| eyre!(err))
self.request_response::<CheckPublicKeyResponse>(&Request::CheckPublicKey {
user,
session_identifier,
pubkey_alg_name,
pubkey,
})
.await
}
pub async fn verify_signature(
@ -437,17 +408,14 @@ impl Client {
pubkey: Vec<u8>,
signature: Vec<u8>,
) -> Result<bool> {
let resp = self
.request_response::<VerifySignatureResponse>(&Request::VerifySignature {
user,
session_identifier,
pubkey_alg_name,
pubkey,
signature,
})
.await?;
resp.is_ok.map_err(|err| eyre!(err))
self.request_response::<VerifySignatureResponse>(&Request::VerifySignature {
user,
session_identifier,
pubkey_alg_name,
pubkey,
signature,
})
.await
}
pub async fn pty_req(
@ -457,7 +425,7 @@ impl Client {
width_px: u32,
height_px: u32,
term_modes: Vec<u8>,
) -> Result<Vec<OwnedFd>> {
) -> Result<OwnedFd> {
self.send_request(&Request::PtyReq(PtyRequest {
height_rows,
width_chars,
@ -467,10 +435,16 @@ impl Client {
}))
.await?;
let (resp, fds) = self.recv_response_ancillary::<PtyResponse>().await?;
resp.result.map_err(|err| eyre!(err))?;
let (_, mut fds) = self.recv_response_ancillary::<PtyReqResponse>().await?;
ensure!(
fds.len() == 1,
"Incorrect amount of FDs received: {}",
fds.len()
);
Ok(fds)
let controller = fds.remove(0);
Ok(controller)
}
pub async fn shell(
@ -486,16 +460,13 @@ impl Client {
}))
.await?;
let (resp, fds) = self.recv_response_ancillary::<ShellResponse>().await?;
resp.result.map_err(|err| eyre!(err))?;
let (_, fds) = self.recv_response_ancillary::<ShellResponse>().await?;
Ok(fds)
}
pub async fn wait(&self) -> Result<Option<i32>> {
self.request_response::<WaitResponse>(&Request::Wait)
.await
.and_then(|resp| resp.result.map_err(|err| eyre!(err)))
self.request_response::<WaitResponse>(&Request::Wait).await
}
async fn request_response<R: DeserializeOwned + Debug + Send + 'static>(
@ -516,12 +487,14 @@ impl Client {
async fn recv_response_ancillary<R: DeserializeOwned + Debug + Send + 'static>(
&self,
) -> Result<(R, Vec<OwnedFd>)> {
let (resp, fds) = receive_with_fds(&self.socket)
let (resp, fds) = receive_with_fds::<ResponseResult<R>>(&self.socket)
.await
.wrap_err("failed to recv")?;
trace!(?resp, ?fds, "Received RPC response");
let resp = resp.map_err(|err| eyre!(err))?;
Ok((resp, fds))
}
}