client auth

This commit is contained in:
nora 2024-08-16 14:31:37 +02:00
parent b3081cfeb9
commit 85f89b6f84
7 changed files with 426 additions and 48 deletions

View file

@ -1,8 +1,11 @@
use std::mem;
pub use ssh_connection as connection;
use ssh_connection::ChannelOperation;
pub use ssh_connection::{ChannelUpdate, ChannelUpdateKind};
pub use ssh_transport as transport;
pub use ssh_transport::{Result, SshStatus};
use tracing::debug;
pub struct ServerConnection {
transport: ssh_transport::server::ServerConnection,
@ -11,7 +14,7 @@ pub struct ServerConnection {
enum ServerConnectionState {
Auth(auth::BadAuth),
Open(ssh_connection::ServerChannelsState),
Open(ssh_connection::ChannelsState),
}
impl ServerConnection {
@ -34,16 +37,15 @@ impl ServerConnection {
}
if auth.is_authenticated() {
self.state =
ServerConnectionState::Open(ssh_connection::ServerChannelsState::new());
ServerConnectionState::Open(ssh_connection::ChannelsState::new(true));
}
}
ServerConnectionState::Open(con) => {
con.recv_packet(packet)?;
for to_send in con.packets_to_send() {
self.transport.send_plaintext_packet(to_send);
}
}
}
self.progress();
}
Ok(())
@ -65,6 +67,125 @@ impl ServerConnection {
ServerConnectionState::Auth(_) => panic!("tried to get connection during auth"),
ServerConnectionState::Open(con) => {
con.do_operation(op);
self.progress();
}
}
}
pub fn progress(&mut self) {
match &mut self.state {
ServerConnectionState::Auth(auth) => {
for to_send in auth.packets_to_send() {
self.transport.send_plaintext_packet(to_send);
}
}
ServerConnectionState::Open(con) => {
for to_send in con.packets_to_send() {
self.transport.send_plaintext_packet(to_send);
}
}
}
}
}
pub struct ClientConnection {
transport: ssh_transport::client::ClientConnection,
state: ClientConnectionState,
}
enum ClientConnectionState {
Setup(Option<auth::ClientAuth>),
Auth(auth::ClientAuth),
Open(ssh_connection::ChannelsState),
}
impl ClientConnection {
pub fn new(transport: ssh_transport::client::ClientConnection, auth: auth::ClientAuth) -> Self {
Self {
transport,
state: ClientConnectionState::Setup(Some(auth)),
}
}
pub fn recv_bytes(&mut self, bytes: &[u8]) -> Result<()> {
self.transport.recv_bytes(bytes)?;
if let ClientConnectionState::Setup(auth) = &mut self.state {
if self.transport.is_open() {
let mut auth = mem::take(auth).unwrap();
for to_send in auth.packets_to_send() {
self.transport.send_plaintext_packet(to_send);
}
debug!("Connection has been opened");
self.state = ClientConnectionState::Auth(auth);
}
}
while let Some(packet) = self.transport.next_plaintext_packet() {
match &mut self.state {
ClientConnectionState::Setup(_) => unreachable!("handled above"),
ClientConnectionState::Auth(auth) => {
auth.recv_packet(packet)?;
for to_send in auth.packets_to_send() {
self.transport.send_plaintext_packet(to_send);
}
if auth.is_authenticated() {
self.state =
ClientConnectionState::Open(ssh_connection::ChannelsState::new(false));
}
}
ClientConnectionState::Open(con) => {
con.recv_packet(packet)?;
for to_send in con.packets_to_send() {
self.transport.send_plaintext_packet(to_send);
}
}
}
}
Ok(())
}
pub fn auth(&mut self) -> Option<&mut auth::ClientAuth> {
match &mut self.state {
ClientConnectionState::Auth(auth) => Some(auth),
_ => None,
}
}
pub fn next_msg_to_send(&mut self) -> Option<ssh_transport::Msg> {
self.transport.next_msg_to_send()
}
pub fn next_channel_update(&mut self) -> Option<ssh_connection::ChannelUpdate> {
match &mut self.state {
ClientConnectionState::Setup(_) => None,
ClientConnectionState::Auth(_) => None,
ClientConnectionState::Open(con) => con.next_channel_update(),
}
}
pub fn do_operation(&mut self, op: ChannelOperation) {
match &mut self.state {
ClientConnectionState::Setup(_) | ClientConnectionState::Auth(_) => {
panic!("tried to get connection during auth")
}
ClientConnectionState::Open(con) => {
con.do_operation(op);
self.progress();
}
}
}
pub fn progress(&mut self) {
match &mut self.state {
ClientConnectionState::Setup(_) => {}
ClientConnectionState::Auth(auth) => {
for to_send in auth.packets_to_send() {
self.transport.send_plaintext_packet(to_send);
}
}
ClientConnectionState::Open(con) => {
for to_send in con.packets_to_send() {
self.transport.send_plaintext_packet(to_send);
}
@ -77,7 +198,7 @@ impl ServerConnection {
pub mod auth {
use std::collections::VecDeque;
use ssh_transport::{peer_error, numbers, packet::Packet, parse::NameList, Result};
use ssh_transport::{numbers, packet::Packet, parse::NameList, peer_error, Result};
use tracing::info;
pub struct BadAuth {
@ -183,4 +304,95 @@ pub mod auth {
self.packets_to_send.push_back(packet);
}
}
pub struct ClientAuth {
username: Vec<u8>,
packets_to_send: VecDeque<Packet>,
user_requests: VecDeque<ClientUserRequest>,
is_authenticated: bool,
}
pub enum ClientUserRequest {
Password,
Banner(Vec<u8>),
}
impl ClientAuth {
pub fn new(username: Vec<u8>) -> Self {
let mut packets_to_send = VecDeque::new();
let initial_useruath_req =
Packet::new_msg_userauth_request_none(&username, b"ssh-connection", b"none");
packets_to_send.push_back(initial_useruath_req);
Self {
packets_to_send,
username,
user_requests: VecDeque::new(),
is_authenticated: false,
}
}
pub fn is_authenticated(&self) -> bool {
self.is_authenticated
}
pub fn packets_to_send(&mut self) -> impl Iterator<Item = Packet> + '_ {
self.packets_to_send.drain(..)
}
pub fn user_requests(&mut self) -> impl Iterator<Item = ClientUserRequest> + '_ {
self.user_requests.drain(..)
}
pub fn send_password(&mut self, password: &str) {
let packet = Packet::new_msg_userauth_request_password(
&self.username,
b"ssh-connection",
b"password",
false,
password.as_bytes(),
);
self.packets_to_send.push_back(packet);
}
pub fn recv_packet(&mut self, packet: Packet) -> Result<()> {
assert!(!self.is_authenticated, "Must not feed more packets to authentication after authentication is been completed, check with .is_authenticated()");
let mut p = packet.payload_parser();
let packet_type = p.u8()?;
match packet_type {
numbers::SSH_MSG_USERAUTH_BANNER => {
let banner = p.string()?;
let _lang = p.string()?;
self.user_requests
.push_back(ClientUserRequest::Banner(banner.to_vec()));
}
numbers::SSH_MSG_USERAUTH_FAILURE => {
let authentications = p.name_list()?;
let _partial_success = p.bool()?;
if authentications.iter().any(|item| item == "password") {
self.user_requests.push_back(ClientUserRequest::Password);
} else {
return Err(peer_error!(
"server does not support password authentication"
));
}
}
numbers::SSH_MSG_USERAUTH_SUCCESS => {
self.is_authenticated = true;
}
_ => {
return Err(peer_error!(
"unexpected packet: {}",
numbers::packet_type_to_string(packet_type)
))
}
}
Ok(())
}
}
}