client auth

This commit is contained in:
nora 2024-08-16 14:31:37 +02:00
parent b3081cfeb9
commit 85f89b6f84
7 changed files with 426 additions and 48 deletions

116
Cargo.lock generated
View file

@ -97,7 +97,7 @@ version = "1.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6d36fc52c7f6c869915e99412912f22093507da8d9e942ceaf66fe4b7c14422a"
dependencies = [
"windows-sys",
"windows-sys 0.52.0",
]
[[package]]
@ -107,7 +107,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5bf74e1b6e971609db8ca7a9ce79fd5768ab6ae46441c572e46cf596f59e57f8"
dependencies = [
"anstyle",
"windows-sys",
"windows-sys 0.52.0",
]
[[package]]
@ -620,7 +620,7 @@ dependencies = [
"hermit-abi",
"libc",
"wasi",
"windows-sys",
"windows-sys 0.52.0",
]
[[package]]
@ -692,7 +692,7 @@ dependencies = [
"libc",
"redox_syscall",
"smallvec",
"windows-targets",
"windows-targets 0.52.6",
]
[[package]]
@ -872,6 +872,27 @@ dependencies = [
"subtle",
]
[[package]]
name = "rpassword"
version = "7.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "80472be3c897911d0137b2d2b9055faf6eeac5b14e324073d83bc17b191d7e3f"
dependencies = [
"libc",
"rtoolbox",
"windows-sys 0.48.0",
]
[[package]]
name = "rtoolbox"
version = "0.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c247d24e63230cdb56463ae328478bd5eac8b8faa8c69461a77e8e323afac90e"
dependencies = [
"libc",
"windows-sys 0.48.0",
]
[[package]]
name = "rustc-demangle"
version = "0.1.24"
@ -1009,7 +1030,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ce305eb0b4296696835b71df73eb912e0f1ffd2556a501fcede6e0c50349191c"
dependencies = [
"libc",
"windows-sys",
"windows-sys 0.52.0",
]
[[package]]
@ -1029,6 +1050,7 @@ dependencies = [
"clap",
"eyre",
"rand",
"rpassword",
"ssh-protocol",
"ssh-transport",
"tokio",
@ -1122,7 +1144,7 @@ dependencies = [
"signal-hook-registry",
"socket2",
"tokio-macros",
"windows-sys",
"windows-sys 0.52.0",
]
[[package]]
@ -1278,13 +1300,37 @@ version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
[[package]]
name = "windows-sys"
version = "0.48.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9"
dependencies = [
"windows-targets 0.48.5",
]
[[package]]
name = "windows-sys"
version = "0.52.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d"
dependencies = [
"windows-targets",
"windows-targets 0.52.6",
]
[[package]]
name = "windows-targets"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c"
dependencies = [
"windows_aarch64_gnullvm 0.48.5",
"windows_aarch64_msvc 0.48.5",
"windows_i686_gnu 0.48.5",
"windows_i686_msvc 0.48.5",
"windows_x86_64_gnu 0.48.5",
"windows_x86_64_gnullvm 0.48.5",
"windows_x86_64_msvc 0.48.5",
]
[[package]]
@ -1293,28 +1339,46 @@ version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973"
dependencies = [
"windows_aarch64_gnullvm",
"windows_aarch64_msvc",
"windows_i686_gnu",
"windows_aarch64_gnullvm 0.52.6",
"windows_aarch64_msvc 0.52.6",
"windows_i686_gnu 0.52.6",
"windows_i686_gnullvm",
"windows_i686_msvc",
"windows_x86_64_gnu",
"windows_x86_64_gnullvm",
"windows_x86_64_msvc",
"windows_i686_msvc 0.52.6",
"windows_x86_64_gnu 0.52.6",
"windows_x86_64_gnullvm 0.52.6",
"windows_x86_64_msvc 0.52.6",
]
[[package]]
name = "windows_aarch64_gnullvm"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8"
[[package]]
name = "windows_aarch64_gnullvm"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3"
[[package]]
name = "windows_aarch64_msvc"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc"
[[package]]
name = "windows_aarch64_msvc"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469"
[[package]]
name = "windows_i686_gnu"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e"
[[package]]
name = "windows_i686_gnu"
version = "0.52.6"
@ -1327,24 +1391,48 @@ version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66"
[[package]]
name = "windows_i686_msvc"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406"
[[package]]
name = "windows_i686_msvc"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66"
[[package]]
name = "windows_x86_64_gnu"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e"
[[package]]
name = "windows_x86_64_gnu"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78"
[[package]]
name = "windows_x86_64_gnullvm"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc"
[[package]]
name = "windows_x86_64_gnullvm"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d"
[[package]]
name = "windows_x86_64_msvc"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538"
[[package]]
name = "windows_x86_64_msvc"
version = "0.52.6"

View file

@ -16,12 +16,14 @@ impl std::fmt::Display for ChannelNumber {
}
}
pub struct ServerChannelsState {
pub struct ChannelsState {
packets_to_send: VecDeque<Packet>,
channel_updates: VecDeque<ChannelUpdate>,
channels: HashMap<ChannelNumber, Channel>,
next_channel_id: ChannelNumber,
is_server: bool,
}
struct Channel {
@ -121,13 +123,15 @@ pub enum ChannelOperationKind {
Close,
}
impl ServerChannelsState {
pub fn new() -> Self {
ServerChannelsState {
impl ChannelsState {
pub fn new(is_server: bool) -> Self {
ChannelsState {
packets_to_send: VecDeque::new(),
channels: HashMap::new(),
channel_updates: VecDeque::new(),
next_channel_id: ChannelNumber(0),
is_server,
}
}
@ -549,7 +553,7 @@ impl ChannelOperation {
mod tests {
use ssh_transport::{numbers, packet::Packet};
use crate::{ChannelNumber, ChannelOperation, ChannelOperationKind, ServerChannelsState};
use crate::{ChannelNumber, ChannelOperation, ChannelOperationKind, ChannelsState};
/// If a test fails, add this to the test to get logs.
#[allow(dead_code)]
@ -560,7 +564,7 @@ mod tests {
}
#[track_caller]
fn assert_response_types(state: &mut ServerChannelsState, types: &[u8]) {
fn assert_response_types(state: &mut ChannelsState, types: &[u8]) {
let response = state
.packets_to_send()
.map(|p| numbers::packet_type_to_string(p.packet_type()))
@ -573,7 +577,7 @@ mod tests {
assert_eq!(expected, response);
}
fn open_session_channel(state: &mut ServerChannelsState) {
fn open_session_channel(state: &mut ChannelsState) {
state
.recv_packet(Packet::new_msg_channel_open_session(
b"session", 0, 2048, 1024,
@ -584,7 +588,7 @@ mod tests {
#[test]
fn interactive_pty() {
let state = &mut ServerChannelsState::new();
let state = &mut ChannelsState::new(true);
open_session_channel(state);
state
@ -615,7 +619,7 @@ mod tests {
#[test]
fn only_single_close_for_double_close_operation() {
let state = &mut ServerChannelsState::new();
let state = &mut ChannelsState::new(true);
open_session_channel(state);
state.do_operation(ChannelOperation {
number: ChannelNumber(0),
@ -630,7 +634,7 @@ mod tests {
#[test]
fn ignore_operation_after_close() {
let mut state = &mut ServerChannelsState::new();
let mut state = &mut ChannelsState::new(true);
open_session_channel(state);
state.recv_packet(Packet::new_msg_channel_close(0)).unwrap();
assert_response_types(&mut state, &[numbers::SSH_MSG_CHANNEL_CLOSE]);
@ -643,7 +647,7 @@ mod tests {
#[test]
fn respect_peer_windowing() {
let state = &mut ServerChannelsState::new();
let state = &mut ChannelsState::new(true);
state
.recv_packet(Packet::new_msg_channel_open_session(b"session", 0, 10, 50))
.unwrap();
@ -684,7 +688,7 @@ mod tests {
#[test]
fn send_windowing_adjustments() {
let state = &mut ServerChannelsState::new();
let state = &mut ChannelsState::new(true);
state
.recv_packet(Packet::new_msg_channel_open_session(
b"session", 0, 2000, 2000,

View file

@ -1,8 +1,11 @@
use std::mem;
pub use ssh_connection as connection;
use ssh_connection::ChannelOperation;
pub use ssh_connection::{ChannelUpdate, ChannelUpdateKind};
pub use ssh_transport as transport;
pub use ssh_transport::{Result, SshStatus};
use tracing::debug;
pub struct ServerConnection {
transport: ssh_transport::server::ServerConnection,
@ -11,7 +14,7 @@ pub struct ServerConnection {
enum ServerConnectionState {
Auth(auth::BadAuth),
Open(ssh_connection::ServerChannelsState),
Open(ssh_connection::ChannelsState),
}
impl ServerConnection {
@ -34,16 +37,15 @@ impl ServerConnection {
}
if auth.is_authenticated() {
self.state =
ServerConnectionState::Open(ssh_connection::ServerChannelsState::new());
ServerConnectionState::Open(ssh_connection::ChannelsState::new(true));
}
}
ServerConnectionState::Open(con) => {
con.recv_packet(packet)?;
for to_send in con.packets_to_send() {
self.transport.send_plaintext_packet(to_send);
}
}
}
self.progress();
}
Ok(())
@ -65,6 +67,125 @@ impl ServerConnection {
ServerConnectionState::Auth(_) => panic!("tried to get connection during auth"),
ServerConnectionState::Open(con) => {
con.do_operation(op);
self.progress();
}
}
}
pub fn progress(&mut self) {
match &mut self.state {
ServerConnectionState::Auth(auth) => {
for to_send in auth.packets_to_send() {
self.transport.send_plaintext_packet(to_send);
}
}
ServerConnectionState::Open(con) => {
for to_send in con.packets_to_send() {
self.transport.send_plaintext_packet(to_send);
}
}
}
}
}
pub struct ClientConnection {
transport: ssh_transport::client::ClientConnection,
state: ClientConnectionState,
}
enum ClientConnectionState {
Setup(Option<auth::ClientAuth>),
Auth(auth::ClientAuth),
Open(ssh_connection::ChannelsState),
}
impl ClientConnection {
pub fn new(transport: ssh_transport::client::ClientConnection, auth: auth::ClientAuth) -> Self {
Self {
transport,
state: ClientConnectionState::Setup(Some(auth)),
}
}
pub fn recv_bytes(&mut self, bytes: &[u8]) -> Result<()> {
self.transport.recv_bytes(bytes)?;
if let ClientConnectionState::Setup(auth) = &mut self.state {
if self.transport.is_open() {
let mut auth = mem::take(auth).unwrap();
for to_send in auth.packets_to_send() {
self.transport.send_plaintext_packet(to_send);
}
debug!("Connection has been opened");
self.state = ClientConnectionState::Auth(auth);
}
}
while let Some(packet) = self.transport.next_plaintext_packet() {
match &mut self.state {
ClientConnectionState::Setup(_) => unreachable!("handled above"),
ClientConnectionState::Auth(auth) => {
auth.recv_packet(packet)?;
for to_send in auth.packets_to_send() {
self.transport.send_plaintext_packet(to_send);
}
if auth.is_authenticated() {
self.state =
ClientConnectionState::Open(ssh_connection::ChannelsState::new(false));
}
}
ClientConnectionState::Open(con) => {
con.recv_packet(packet)?;
for to_send in con.packets_to_send() {
self.transport.send_plaintext_packet(to_send);
}
}
}
}
Ok(())
}
pub fn auth(&mut self) -> Option<&mut auth::ClientAuth> {
match &mut self.state {
ClientConnectionState::Auth(auth) => Some(auth),
_ => None,
}
}
pub fn next_msg_to_send(&mut self) -> Option<ssh_transport::Msg> {
self.transport.next_msg_to_send()
}
pub fn next_channel_update(&mut self) -> Option<ssh_connection::ChannelUpdate> {
match &mut self.state {
ClientConnectionState::Setup(_) => None,
ClientConnectionState::Auth(_) => None,
ClientConnectionState::Open(con) => con.next_channel_update(),
}
}
pub fn do_operation(&mut self, op: ChannelOperation) {
match &mut self.state {
ClientConnectionState::Setup(_) | ClientConnectionState::Auth(_) => {
panic!("tried to get connection during auth")
}
ClientConnectionState::Open(con) => {
con.do_operation(op);
self.progress();
}
}
}
pub fn progress(&mut self) {
match &mut self.state {
ClientConnectionState::Setup(_) => {}
ClientConnectionState::Auth(auth) => {
for to_send in auth.packets_to_send() {
self.transport.send_plaintext_packet(to_send);
}
}
ClientConnectionState::Open(con) => {
for to_send in con.packets_to_send() {
self.transport.send_plaintext_packet(to_send);
}
@ -77,7 +198,7 @@ impl ServerConnection {
pub mod auth {
use std::collections::VecDeque;
use ssh_transport::{peer_error, numbers, packet::Packet, parse::NameList, Result};
use ssh_transport::{numbers, packet::Packet, parse::NameList, peer_error, Result};
use tracing::info;
pub struct BadAuth {
@ -183,4 +304,95 @@ pub mod auth {
self.packets_to_send.push_back(packet);
}
}
pub struct ClientAuth {
username: Vec<u8>,
packets_to_send: VecDeque<Packet>,
user_requests: VecDeque<ClientUserRequest>,
is_authenticated: bool,
}
pub enum ClientUserRequest {
Password,
Banner(Vec<u8>),
}
impl ClientAuth {
pub fn new(username: Vec<u8>) -> Self {
let mut packets_to_send = VecDeque::new();
let initial_useruath_req =
Packet::new_msg_userauth_request_none(&username, b"ssh-connection", b"none");
packets_to_send.push_back(initial_useruath_req);
Self {
packets_to_send,
username,
user_requests: VecDeque::new(),
is_authenticated: false,
}
}
pub fn is_authenticated(&self) -> bool {
self.is_authenticated
}
pub fn packets_to_send(&mut self) -> impl Iterator<Item = Packet> + '_ {
self.packets_to_send.drain(..)
}
pub fn user_requests(&mut self) -> impl Iterator<Item = ClientUserRequest> + '_ {
self.user_requests.drain(..)
}
pub fn send_password(&mut self, password: &str) {
let packet = Packet::new_msg_userauth_request_password(
&self.username,
b"ssh-connection",
b"password",
false,
password.as_bytes(),
);
self.packets_to_send.push_back(packet);
}
pub fn recv_packet(&mut self, packet: Packet) -> Result<()> {
assert!(!self.is_authenticated, "Must not feed more packets to authentication after authentication is been completed, check with .is_authenticated()");
let mut p = packet.payload_parser();
let packet_type = p.u8()?;
match packet_type {
numbers::SSH_MSG_USERAUTH_BANNER => {
let banner = p.string()?;
let _lang = p.string()?;
self.user_requests
.push_back(ClientUserRequest::Banner(banner.to_vec()));
}
numbers::SSH_MSG_USERAUTH_FAILURE => {
let authentications = p.name_list()?;
let _partial_success = p.bool()?;
if authentications.iter().any(|item| item == "password") {
self.user_requests.push_back(ClientUserRequest::Password);
} else {
return Err(peer_error!(
"server does not support password authentication"
));
}
}
numbers::SSH_MSG_USERAUTH_SUCCESS => {
self.is_authenticated = true;
}
_ => {
return Err(peer_error!(
"unexpected packet: {}",
numbers::packet_type_to_string(packet_type)
))
}
}
Ok(())
}
}
}

View file

@ -304,6 +304,18 @@ impl ClientConnection {
self.packet_transport.next_msg_to_send()
}
pub fn next_plaintext_packet(&mut self) -> Option<Packet> {
self.plaintext_packets.pop_front()
}
pub fn send_plaintext_packet(&mut self, packet: Packet) {
self.packet_transport.queue_packet(packet);
}
pub fn is_open(&self) -> bool {
matches!(self.state, ClientState::Open)
}
fn send_kexinit(&mut self, client_ident: Vec<u8>, server_ident: Vec<u8>) {
let mut cookie = [0; 16];
self.rng.fill_bytes(&mut cookie);

View file

@ -63,6 +63,18 @@ ctors! {
// User authentication protocol:
// 50 to 59 User authentication generic
fn new_msg_userauth_request_none(SSH_MSG_USERAUTH_REQUEST;
username: string,
service_name: string,
method_name_none: string,
);
fn new_msg_userauth_request_password(SSH_MSG_USERAUTH_REQUEST;
username: string,
service_name: string,
method_name_password: string,
false_: bool,
password: string,
);
fn new_msg_userauth_failure(SSH_MSG_USERAUTH_FAILURE;
auth_options: name_list,
partial_success: bool,

View file

@ -13,3 +13,4 @@ tokio = { version = "1.39.2", features = ["full"] }
tracing-subscriber = { version = "0.3.18", features = ["json", "env-filter"] }
tracing.workspace = true
rpassword = "7.3.1"

View file

@ -1,3 +1,5 @@
use std::io::Write;
use clap::Parser;
use eyre::Context;
@ -6,7 +8,7 @@ use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::TcpStream,
};
use tracing::info;
use tracing::{debug, error, info};
use ssh_protocol::{
transport::{self},
@ -29,6 +31,10 @@ struct Args {
command: Vec<String>,
}
enum Operation {
PasswordEntered(std::io::Result<String>),
}
#[tokio::main]
async fn main() -> eyre::Result<()> {
let args = Args::parse();
@ -40,7 +46,14 @@ async fn main() -> eyre::Result<()> {
.await
.wrap_err("connecting")?;
let mut state = transport::client::ClientConnection::new(ThreadRngRand);
let username = "hans-peter";
let mut state = ssh_protocol::ClientConnection::new(
transport::client::ClientConnection::new(ThreadRngRand),
ssh_protocol::auth::ClientAuth::new(username.as_bytes().to_vec()),
);
let (send_op, mut recv_op) = tokio::sync::mpsc::channel::<Operation>(10);
let mut buf = [0; 1024];
@ -51,26 +64,62 @@ async fn main() -> eyre::Result<()> {
.wrap_err("writing response")?;
}
let read = conn
.read(&mut buf)
.await
.wrap_err("reading from connection")?;
if read == 0 {
info!("Did not read any bytes from TCP stream, EOF");
return Ok(());
if let Some(auth) = state.auth() {
for req in auth.user_requests() {
match req {
ssh_protocol::auth::ClientUserRequest::Password => {
let username = username.to_owned();
let destination = args.destination.clone();
let send_op = send_op.clone();
std::thread::spawn(move || {
let password = rpassword::prompt_password(format!(
"{}@{}'s password: ",
username, destination
));
let _ = send_op.blocking_send(Operation::PasswordEntered(password));
});
}
ssh_protocol::auth::ClientUserRequest::Banner(banner) => {
let banner = String::from_utf8_lossy(&banner);
std::io::stdout().write(&banner.as_bytes())?;
}
}
}
}
if let Err(err) = state.recv_bytes(&buf[..read]) {
match err {
SshStatus::PeerError(err) => {
info!(?err, "disconnecting client after invalid operation");
tokio::select! {
read = conn.read(&mut buf) => {
let read = read.wrap_err("reading from connection")?;
if read == 0 {
info!("Did not read any bytes from TCP stream, EOF");
return Ok(());
}
SshStatus::Disconnect => {
info!("Received disconnect from client");
return Ok(());
if let Err(err) = state.recv_bytes(&buf[..read]) {
match err {
SshStatus::PeerError(err) => {
error!(?err, "disconnecting client after invalid operation");
return Ok(());
}
SshStatus::Disconnect => {
error!("Received disconnect from server");
return Ok(());
}
}
}
}
op = recv_op.recv() => {
match op {
Some(Operation::PasswordEntered(password)) => {
if let Some(auth) = state.auth() {
auth.send_password(&password?);
} else {
debug!("Ignoring entered password as the state has moved on");
}
}
None => {}
}
state.progress();
}
}
}
}