factor out auth

This commit is contained in:
nora 2024-08-25 16:14:55 +02:00
parent b0acf03502
commit 1c346659f6
7 changed files with 267 additions and 156 deletions

View file

@ -1,5 +1,8 @@
use core::panic;
use std::collections::HashSet;
use std::mem;
use auth::AuthOption;
pub use cluelessh_connection as connection;
use cluelessh_connection::ChannelOperation;
pub use cluelessh_connection::{ChannelUpdate, ChannelUpdateKind};
@ -21,33 +24,42 @@ pub struct ServerConnection {
}
enum ServerConnectionState {
Auth(auth::BadAuth),
Setup(HashSet<AuthOption>),
Auth(auth::ServerAuth),
Open(cluelessh_connection::ChannelsState),
}
impl ServerConnection {
pub fn new(transport: cluelessh_transport::server::ServerConnection) -> Self {
pub fn new(
transport: cluelessh_transport::server::ServerConnection,
auth_options: HashSet<AuthOption>,
) -> Self {
Self {
transport,
state: ServerConnectionState::Auth(auth::BadAuth::new()),
state: ServerConnectionState::Setup(auth_options),
}
}
pub fn recv_bytes(&mut self, bytes: &[u8]) -> Result<()> {
self.transport.recv_bytes(bytes)?;
if let ServerConnectionState::Setup(options) = &mut self.state {
if let Some(session_ident) = self.transport.is_open() {
self.state = ServerConnectionState::Auth(auth::ServerAuth::new(
mem::take(options),
session_ident,
));
}
}
while let Some(packet) = self.transport.next_plaintext_packet() {
match &mut self.state {
ServerConnectionState::Setup(_) => unreachable!(),
ServerConnectionState::Auth(auth) => {
auth.recv_packet(packet)?;
for to_send in auth.packets_to_send() {
self.transport.send_plaintext_packet(to_send);
}
if auth.is_authenticated() {
self.state = ServerConnectionState::Open(
cluelessh_connection::ChannelsState::new(true),
);
}
}
ServerConnectionState::Open(con) => {
con.recv_packet(packet)?;
@ -66,14 +78,16 @@ impl ServerConnection {
pub fn next_channel_update(&mut self) -> Option<cluelessh_connection::ChannelUpdate> {
match &mut self.state {
ServerConnectionState::Auth(_) => None,
ServerConnectionState::Setup(_) | ServerConnectionState::Auth(_) => None,
ServerConnectionState::Open(con) => con.next_channel_update(),
}
}
pub fn do_operation(&mut self, op: ChannelOperation) {
match &mut self.state {
ServerConnectionState::Auth(_) => panic!("tried to get connection during auth"),
ServerConnectionState::Setup(_) | ServerConnectionState::Auth(_) => {
panic!("tried to get connection before it is ready")
}
ServerConnectionState::Open(con) => {
con.do_operation(op);
self.progress();
@ -83,10 +97,15 @@ impl ServerConnection {
pub fn progress(&mut self) {
match &mut self.state {
ServerConnectionState::Setup(_) => {}
ServerConnectionState::Auth(auth) => {
for to_send in auth.packets_to_send() {
self.transport.send_plaintext_packet(to_send);
}
if auth.is_authenticated() {
self.state =
ServerConnectionState::Open(cluelessh_connection::ChannelsState::new(true));
}
}
ServerConnectionState::Open(con) => {
for to_send in con.packets_to_send() {
@ -103,7 +122,7 @@ impl ServerConnection {
}
}
pub fn auth(&mut self) -> Option<&mut auth::BadAuth> {
pub fn auth(&mut self) -> Option<&mut auth::ServerAuth> {
match &mut self.state {
ServerConnectionState::Auth(auth) => Some(auth),
_ => None,
@ -140,11 +159,10 @@ impl ClientConnection {
if let Some(session_ident) = self.transport.is_open() {
let mut auth = mem::take(auth).unwrap();
auth.set_session_identifier(session_ident);
for to_send in auth.packets_to_send() {
self.transport.send_plaintext_packet(to_send);
}
debug!("Connection has been opened");
self.state = ClientConnectionState::Auth(auth);
self.progress();
}
}
@ -235,35 +253,53 @@ impl ClientConnection {
/// <https://datatracker.ietf.org/doc/html/rfc4252>
pub mod auth {
use std::collections::VecDeque;
use std::collections::{HashSet, VecDeque};
use cluelessh_transport::{numbers, packet::Packet, parse::NameList, peer_error, Result};
use tracing::{debug, info};
pub struct BadAuth {
pub struct ServerAuth {
has_failed: bool,
packets_to_send: VecDeque<Packet>,
is_authenticated: bool,
options: HashSet<AuthOption>,
server_requests: VecDeque<ServerRequest>,
session_ident: [u8; 32],
}
pub enum ServerRequest {
VerifyPassword {
user: String,
password: String,
},
VerifyPubkey {
session_identifier: [u8; 32],
user: String,
pubkey: Vec<u8>,
},
VerifyPassword(VerifyPassword),
VerifyPubkey(VerifyPubkey),
}
impl BadAuth {
pub fn new() -> Self {
pub struct VerifyPassword {
pub user: String,
pub password: String,
}
pub struct VerifyPubkey {
pub user: String,
pub session_identifier: [u8; 32],
pub pubkey_alg_name: Vec<u8>,
pub pubkey: Vec<u8>,
pub signature: Vec<u8>,
}
#[derive(Debug, PartialEq, Eq, Hash)]
pub enum AuthOption {
Password,
PublicKey,
}
impl ServerAuth {
pub fn new(options: HashSet<AuthOption>, session_ident: [u8; 32]) -> Self {
Self {
has_failed: false,
packets_to_send: VecDeque::new(),
options,
is_authenticated: false,
session_ident,
server_requests: VecDeque::new(),
}
}
@ -274,14 +310,14 @@ pub mod auth {
// We ask for a public key, and always let that one pass.
// The reason for this is that this makes it a lot easier to test locally.
// It's not very good, but it's good enough for now.
let mut auth_req = packet.payload_parser();
let mut p = packet.payload_parser();
if auth_req.u8()? != numbers::SSH_MSG_USERAUTH_REQUEST {
if p.u8()? != numbers::SSH_MSG_USERAUTH_REQUEST {
return Err(peer_error!("did not send SSH_MSG_SERVICE_REQUEST"));
}
let username = auth_req.utf8_string()?;
let service_name = auth_req.utf8_string()?;
let method_name = auth_req.utf8_string()?;
let username = p.utf8_string()?;
let service_name = p.utf8_string()?;
let method_name = p.utf8_string()?;
if method_name != "none" {
info!(
@ -300,22 +336,47 @@ pub mod auth {
match method_name {
"password" => {
let change_password = auth_req.bool()?;
if !self.options.contains(&AuthOption::Password) {
self.has_failed = true;
self.send_failure();
}
let change_password = p.bool()?;
if change_password {
return Err(peer_error!("client tried to change password unprompted"));
}
let password = auth_req.utf8_string()?;
let password = p.utf8_string()?;
info!(%password, "Got password");
// Don't worry queen, your password is correct!
self.queue_packet(Packet::new_msg_userauth_success());
self.is_authenticated = true;
self.server_requests
.push_back(ServerRequest::VerifyPassword(VerifyPassword {
user: username.to_owned(),
password: password.to_owned(),
}));
}
"publickey" => {
info!("Got public key");
// Don't worry queen, your key is correct!
self.queue_packet(Packet::new_msg_userauth_success());
self.is_authenticated = true;
if !self.options.contains(&AuthOption::PublicKey) {
self.has_failed = true;
self.send_failure();
}
// Whether the client is just checking whether the public key is allowed.
let is_check = p.bool()?;
if is_check {
todo!();
}
let pubkey_alg_name = p.string()?;
let public_key_blob = p.string()?;
let signature = p.string()?;
self.server_requests
.push_back(ServerRequest::VerifyPubkey(VerifyPubkey {
user: username.to_owned(),
session_identifier: self.session_ident,
pubkey_alg_name: pubkey_alg_name.to_vec(),
pubkey: public_key_blob.to_vec(),
signature: signature.to_vec(),
}));
}
_ if self.has_failed => {
return Err(peer_error!(
@ -323,8 +384,7 @@ pub mod auth {
));
}
_ => {
// Initial.
// Initial:
self.queue_packet(Packet::new_msg_userauth_banner(
b"!! this system ONLY allows catgirls to enter !!\r\n\
!! all other attempts WILL be prosecuted to the full extent of the rawr !!\r\n\
@ -333,16 +393,23 @@ pub mod auth {
b"",
));
self.queue_packet(Packet::new_msg_userauth_failure(
NameList::one("password"),
false,
));
self.send_failure();
// Stay in the same state
}
}
Ok(())
}
pub fn verification_result(&mut self, is_ok: bool) {
if is_ok {
self.queue_packet(Packet::new_msg_userauth_success());
self.is_authenticated = true;
} else {
self.send_failure();
self.has_failed = true;
}
}
pub fn packets_to_send(&mut self) -> impl Iterator<Item = Packet> + '_ {
self.packets_to_send.drain(..)
}
@ -352,7 +419,25 @@ pub mod auth {
}
pub fn server_requests(&mut self) -> impl Iterator<Item = ServerRequest> + '_ {
[].into_iter()
self.server_requests.drain(..)
}
fn send_failure(&mut self) {
self.queue_packet(Packet::new_msg_userauth_failure(
NameList(&self.option_list()),
false,
));
}
fn option_list(&self) -> String {
self.options
.iter()
.map(|op| match op {
AuthOption::Password => "password",
AuthOption::PublicKey => "publickey",
})
.collect::<Vec<&str>>()
.join(",")
}
fn queue_packet(&mut self, packet: Packet) {

View file

@ -1,4 +1,4 @@
use cluelessh_connection::{ChannelKind, ChannelNumber, ChannelOperation, ChannelOperationKind};
use cluelessh_connection::{ChannelKind, ChannelNumber, ChannelOperation};
use std::{collections::HashMap, pin::Pin, sync::Arc};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
@ -8,7 +8,7 @@ use futures::future::BoxFuture;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{debug, info, warn};
use crate::Channel;
use crate::{Channel, ChannelState, PendingChannel};
pub struct ClientConnection<S> {
stream: Pin<Box<S>>,
@ -27,14 +27,6 @@ pub struct ClientConnection<S> {
auth: ClientAuth,
}
enum ChannelState {
Pending {
ready_send: tokio::sync::oneshot::Sender<Result<(), String>>,
updates_send: tokio::sync::mpsc::Sender<ChannelUpdateKind>,
},
Ready(tokio::sync::mpsc::Sender<ChannelUpdateKind>),
}
pub struct ClientAuth {
pub username: String,
pub prompt_password: Arc<dyn Fn() -> BoxFuture<'static, Result<String>> + Send + Sync>,
@ -53,11 +45,6 @@ pub struct SignatureResult {
pub signature: Vec<u8>,
}
pub struct PendingChannel {
ready_recv: tokio::sync::oneshot::Receiver<Result<(), String>>,
channel: Channel,
}
impl<S: AsyncRead + AsyncWrite> ClientConnection<S> {
pub async fn connect(stream: S, auth: ClientAuth) -> Result<Self> {
let (operations_send, operations_recv) = tokio::sync::mpsc::channel(15);
@ -272,22 +259,3 @@ impl<S: AsyncRead + AsyncWrite> ClientConnection<S> {
}
}
}
impl PendingChannel {
pub async fn wait_ready(self) -> Result<Channel, Option<String>> {
match self.ready_recv.await {
Ok(Ok(())) => Ok(self.channel),
Ok(Err(err)) => Err(Some(err)),
Err(_) => Err(None),
}
}
}
impl Channel {
pub async fn send_operation(&mut self, op: ChannelOperationKind) -> Result<()> {
self.ops_send
.send(self.number.construct_op(op))
.await
.map_err(Into::into)
}
}

View file

@ -31,3 +31,25 @@ impl Channel {
&self.kind
}
}
enum ChannelState {
Pending {
ready_send: tokio::sync::oneshot::Sender<Result<(), String>>,
updates_send: tokio::sync::mpsc::Sender<ChannelUpdateKind>,
},
Ready(tokio::sync::mpsc::Sender<ChannelUpdateKind>),
}
pub struct PendingChannel {
ready_recv: tokio::sync::oneshot::Receiver<Result<(), String>>,
channel: Channel,
}
impl PendingChannel {
pub async fn wait_ready(self) -> Result<Channel, Option<String>> {
match self.ready_recv.await {
Ok(Ok(())) => Ok(self.channel),
Ok(Err(err)) => Err(Some(err)),
Err(_) => Err(None),
}
}
}

View file

@ -1,24 +1,30 @@
use cluelessh_connection::{ChannelKind, ChannelNumber, ChannelOperation};
use futures::future::BoxFuture;
use std::{
collections::{HashMap, VecDeque},
collections::{HashMap, HashSet, VecDeque},
net::SocketAddr,
pin::Pin,
sync::Arc,
};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::{TcpListener, TcpStream},
};
use cluelessh_protocol::{ChannelUpdateKind, SshStatus};
use eyre::{eyre, ContextCompat, Result, WrapErr};
use cluelessh_protocol::{
auth::{AuthOption, VerifyPassword, VerifyPubkey},
ChannelUpdateKind, SshStatus,
};
use eyre::{eyre, ContextCompat, OptionExt, Result, WrapErr};
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::info;
use crate::Channel;
use crate::{Channel, ChannelState, PendingChannel};
pub struct ServerListener {
listener: TcpListener,
// todo ratelimits etc
auth_verify: ServerAuthVerify,
// TODO ratelimits etc
}
pub struct ServerConnection<S> {
@ -38,38 +44,27 @@ pub struct ServerConnection<S> {
/// New channels opened by the peer.
new_channels: VecDeque<Channel>,
}
enum ChannelState {
Pending {
ready_send: tokio::sync::oneshot::Sender<Result<(), String>>,
updates_send: tokio::sync::mpsc::Sender<ChannelUpdateKind>,
},
Ready(tokio::sync::mpsc::Sender<ChannelUpdateKind>),
auth_verify: ServerAuthVerify,
}
enum Operation {
VerifyPassword {
user: String,
password: String,
},
VerifyPubkey {
session_identifier: [u8; 32],
user: String,
pubkey: Vec<u8>,
},
VerifyPassword(Result<()>),
VerifyPubkey(Result<()>),
}
pub struct SignatureResult {
pub key_alg_name: &'static str,
pub public_key: Vec<u8>,
pub signature: Vec<u8>,
#[derive(Clone)]
pub struct ServerAuthVerify {
pub verify_password:
Option<Arc<dyn Fn(VerifyPassword) -> BoxFuture<'static, Result<()>> + Send + Sync>>,
pub verify_pubkey:
Option<Arc<dyn Fn(VerifyPubkey) -> BoxFuture<'static, Result<()>> + Send + Sync>>,
}
fn _assert_send_sync() {
fn send<T: Send + Sync>() {}
send::<ServerAuthVerify>();
}
pub struct PendingChannel {
ready_recv: tokio::sync::oneshot::Receiver<Result<(), String>>,
channel: Channel,
}
pub enum Error {
SshStatus(SshStatus),
ServerError(eyre::Report),
@ -81,22 +76,41 @@ impl From<eyre::Report> for Error {
}
impl ServerListener {
pub fn new(listener: TcpListener) -> Self {
Self { listener }
pub fn new(listener: TcpListener, auth_verify: ServerAuthVerify) -> Self {
Self {
listener,
auth_verify,
}
}
pub async fn accept(&mut self) -> Result<ServerConnection<TcpStream>> {
let (conn, peer_addr) = self.listener.accept().await?;
Ok(ServerConnection::new(conn, peer_addr))
Ok(ServerConnection::new(
conn,
peer_addr,
self.auth_verify.clone(),
))
}
}
impl<S: AsyncRead + AsyncWrite> ServerConnection<S> {
pub fn new(stream: S, peer_addr: SocketAddr) -> Self {
pub fn new(stream: S, peer_addr: SocketAddr, auth_verify: ServerAuthVerify) -> Self {
let (operations_send, operations_recv) = tokio::sync::mpsc::channel(15);
let (channel_ops_send, channel_ops_recv) = tokio::sync::mpsc::channel(15);
let mut options = HashSet::new();
if auth_verify.verify_password.is_some() {
options.insert(AuthOption::Password);
}
if auth_verify.verify_pubkey.is_some() {
options.insert(AuthOption::PublicKey);
}
if options.is_empty() {
panic!("no auth options provided");
}
Self {
stream: Box::pin(stream),
peer_addr,
@ -110,8 +124,10 @@ impl<S: AsyncRead + AsyncWrite> ServerConnection<S> {
cluelessh_transport::server::ServerConnection::new(
cluelessh_protocol::ThreadRngRand,
),
options,
),
new_channels: VecDeque::new(),
auth_verify,
}
}
@ -125,28 +141,28 @@ impl<S: AsyncRead + AsyncWrite> ServerConnection<S> {
if let Some(auth) = self.proto.auth() {
for req in auth.server_requests() {
match req {
cluelessh_protocol::auth::ServerRequest::VerifyPassword { user, password } => {
cluelessh_protocol::auth::ServerRequest::VerifyPassword(password_verify) => {
let send = self.operations_send.clone();
let verify = self
.auth_verify
.verify_password
.clone()
.ok_or_eyre("password auth not supported")?;
tokio::spawn(async move {
let _ = send
.send(Operation::VerifyPassword { user, password })
.await;
let result = verify(password_verify).await;
let _ = send.send(Operation::VerifyPassword(result)).await;
});
}
cluelessh_protocol::auth::ServerRequest::VerifyPubkey {
session_identifier,
pubkey,
user,
} => {
cluelessh_protocol::auth::ServerRequest::VerifyPubkey(pubkey_verify) => {
let send = self.operations_send.clone();
let verify = self
.auth_verify
.verify_pubkey
.clone()
.ok_or_eyre("pubkey auth not supported")?;
tokio::spawn(async move {
let _ = send
.send(Operation::VerifyPubkey {
session_identifier,
user,
pubkey,
})
.await;
let result = verify(pubkey_verify).await;
let _ = send.send(Operation::VerifyPubkey(result)).await;
});
}
}
@ -247,7 +263,7 @@ impl<S: AsyncRead + AsyncWrite> ServerConnection<S> {
let read = read.wrap_err("reading from connection")?;
if read == 0 {
info!("Did not read any bytes from TCP stream, EOF");
return Ok(());
return Err(Error::SshStatus(SshStatus::Disconnect));
}
if let Err(err) = self.proto.recv_bytes(&self.buf[..read]) {
return Err(Error::SshStatus(err));
@ -261,8 +277,12 @@ impl<S: AsyncRead + AsyncWrite> ServerConnection<S> {
}
op = self.operations_recv.recv() => {
match op {
Some(Operation::VerifyPubkey { .. }) => todo!(),
Some(Operation::VerifyPassword { .. }) => todo!(),
Some(Operation::VerifyPubkey(result)) => if let Some(auth) = self.proto.auth() {
auth.verification_result(result.is_ok());
},
Some(Operation::VerifyPassword(result)) => if let Some(auth) = self.proto.auth() {
auth.verification_result(result.is_ok());
},
None => {}
}
self.send_off_data().await?;
@ -315,13 +335,3 @@ impl<S: AsyncRead + AsyncWrite> ServerConnection<S> {
self.new_channels.pop_front()
}
}
impl PendingChannel {
pub async fn wait_ready(self) -> Result<Channel, Option<String>> {
match self.ready_recv.await {
Ok(Ok(())) => Ok(self.channel),
Ok(Err(err)) => Err(Some(err)),
Err(_) => Err(None),
}
}
}

View file

@ -44,8 +44,12 @@ enum ServerState {
encryption_client_to_server: EncryptionAlgorithm,
encryption_server_to_client: EncryptionAlgorithm,
},
ServiceRequest,
Open,
ServiceRequest {
session_ident: [u8; 32],
},
Open {
session_ident: [u8; 32],
},
}
impl ServerConnection {
@ -289,9 +293,9 @@ impl ServerConnection {
*encryption_server_to_client,
true,
);
self.state = ServerState::ServiceRequest {};
self.state = ServerState::ServiceRequest { session_ident: *h };
}
ServerState::ServiceRequest => {
ServerState::ServiceRequest { session_ident } => {
// TODO: this should probably move out of here? unsure.
if packet.payload.first() != Some(&numbers::SSH_MSG_SERVICE_REQUEST) {
return Err(peer_error!("did not send SSH_MSG_SERVICE_REQUEST"));
@ -312,9 +316,11 @@ impl ServerConnection {
writer.finish()
},
});
self.state = ServerState::Open;
self.state = ServerState::Open {
session_ident: *session_ident,
};
}
ServerState::Open => {
ServerState::Open { .. } => {
self.plaintext_packets.push_back(packet);
}
}
@ -322,6 +328,13 @@ impl ServerConnection {
Ok(())
}
pub fn is_open(&self) -> Option<[u8; 32]> {
match self.state {
ServerState::Open { session_ident } => Some(session_ident),
_ => None,
}
}
pub fn next_msg_to_send(&mut self) -> Option<Msg> {
self.packet_transport.next_msg_to_send()
}