connection working

This commit is contained in:
nora 2022-02-19 18:12:28 +01:00
parent ca1f372665
commit 13deef42fd
9 changed files with 217 additions and 82 deletions

View file

@ -1,4 +1,4 @@
use crate::error::{ConException, ProtocolError, TransError}; use crate::error::{ConException, TransError};
use std::collections::HashMap; use std::collections::HashMap;
mod generated; mod generated;
@ -41,15 +41,17 @@ pub fn parse_method(payload: &[u8]) -> Result<generated::Class, TransError> {
match nom_result { match nom_result {
Ok(([], class)) => Ok(class), Ok(([], class)) => Ok(class),
Ok((_, _)) => Err(ProtocolError::ConException(ConException::SyntaxError(vec![ Ok((_, _)) => {
"could not consume all input".to_string(), Err(
])) ConException::SyntaxError(vec!["could not consume all input".to_string()])
.into()), .into_trans(),
)
}
Err(nom::Err::Incomplete(_)) => { Err(nom::Err::Incomplete(_)) => {
Err(ProtocolError::ConException(ConException::SyntaxError(vec![ Err(
"there was not enough data".to_string(), ConException::SyntaxError(vec!["there was not enough data".to_string()])
])) .into_trans(),
.into()) )
} }
Err(nom::Err::Failure(err) | nom::Err::Error(err)) => Err(err), Err(nom::Err::Failure(err) | nom::Err::Error(err)) => Err(err),
} }

View file

@ -17,7 +17,7 @@ use std::collections::HashMap;
impl<T> nom::error::ParseError<T> for TransError { impl<T> nom::error::ParseError<T> for TransError {
fn from_error_kind(_input: T, _kind: ErrorKind) -> Self { fn from_error_kind(_input: T, _kind: ErrorKind) -> Self {
ProtocolError::ConException(ConException::SyntaxError(vec![])).into() ConException::SyntaxError(vec![]).into_trans()
} }
fn append(_input: T, _kind: ErrorKind, other: Self) -> Self { fn append(_input: T, _kind: ErrorKind, other: Self) -> Self {
@ -47,13 +47,11 @@ pub fn err<S: Into<String>>(msg: S) -> impl FnOnce(Err<TransError>) -> Err<Trans
}, },
_ => vec![msg], _ => vec![msg],
}; };
error_level(ProtocolError::ConException(ConException::SyntaxError(stack)).into()) error_level(ConException::SyntaxError(stack).into_trans())
} }
} }
pub fn err_other<E, S: Into<String>>(msg: S) -> impl FnOnce(E) -> Err<TransError> { pub fn err_other<E, S: Into<String>>(msg: S) -> impl FnOnce(E) -> Err<TransError> {
move |_| { move |_| Err::Error(ConException::SyntaxError(vec![msg.into()]).into_trans())
Err::Error(ProtocolError::ConException(ConException::SyntaxError(vec![msg.into()])).into())
}
} }
pub fn failure<E>(err: Err<E>) -> Err<E> { pub fn failure<E>(err: Err<E>) -> Err<E> {
@ -145,7 +143,7 @@ pub fn table(input: &[u8]) -> IResult<Table> {
let (input, values) = many0(table_value_pair)(table_input)?; let (input, values) = many0(table_value_pair)(table_input)?;
if input != &[] { if !input.is_empty() {
fail!(format!( fail!(format!(
"table longer than expected, expected = {size}, remaining = {}", "table longer than expected, expected = {size}, remaining = {}",
input.len() input.len()

View file

@ -1,50 +1,90 @@
use crate::classes::FieldValue;
use crate::error::{ConException, ProtocolError, Result}; use crate::error::{ConException, ProtocolError, Result};
use crate::frame::{Frame, FrameType}; use crate::frame::{Frame, FrameType};
use crate::{classes, frame}; use crate::{classes, frame, sasl};
use anyhow::Context; use anyhow::Context;
use std::collections::HashMap; use std::collections::HashMap;
use std::net::SocketAddr; use std::net::SocketAddr;
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tracing::{debug, error}; use tracing::{debug, error, info};
use uuid::Uuid;
const MIN_MAX_FRAME_SIZE: usize = 4096; fn ensure_conn(condition: bool) -> Result<()> {
if condition {
Ok(())
} else {
Err(ConException::Todo.into_trans())
}
}
const FRAME_SIZE_MIN_MAX: usize = 4096;
const CHANNEL_MAX: u16 = 0;
const FRAME_SIZE_MAX: u32 = 0;
const HEARTBEAT_DELAY: u16 = 0;
pub struct Connection { pub struct Connection {
stream: TcpStream, stream: TcpStream,
max_frame_size: usize, max_frame_size: usize,
heartbeat_delay: u16,
channel_max: u16,
id: Uuid,
} }
impl Connection { impl Connection {
pub fn new(stream: TcpStream) -> Self { pub fn new(stream: TcpStream, id: Uuid) -> Self {
Self { Self {
stream, stream,
max_frame_size: MIN_MAX_FRAME_SIZE, max_frame_size: FRAME_SIZE_MIN_MAX,
heartbeat_delay: HEARTBEAT_DELAY,
channel_max: CHANNEL_MAX,
id,
} }
} }
pub async fn open_connection(mut self) { pub async fn start_connection_processing(mut self) {
match self.run().await { match self.process_connection().await {
Ok(()) => {} Ok(()) => {}
Err(err) => error!(%err, "Error during processing of connection"), Err(err) => error!(%err, "Error during processing of connection"),
} }
} }
pub async fn run(&mut self) -> Result<()> { pub async fn process_connection(&mut self) -> Result<()> {
self.negotiate_version().await?; self.negotiate_version().await?;
self.start().await?; self.start().await?;
self.tune().await?;
self.open().await?;
info!("Connection is ready for usage!");
loop { loop {
let frame = frame::read_frame(&mut self.stream, self.max_frame_size).await?; let method = self.recv_method().await?;
debug!(?frame, "received frame"); debug!(?method, "Received method");
if frame.kind == FrameType::Method {
let class = super::classes::parse_method(&frame.payload)?;
debug!(?class, "was method frame");
}
} }
} }
async fn send_method(&mut self, channel: u16, method: classes::Class) -> Result<()> {
let mut payload = Vec::with_capacity(64);
classes::write::write_method(method, &mut payload)?;
frame::write_frame(
&Frame {
kind: FrameType::Method,
channel,
payload,
},
&mut self.stream,
)
.await
}
async fn recv_method(&mut self) -> Result<classes::Class> {
let start_ok_frame = frame::read_frame(&mut self.stream, self.max_frame_size).await?;
ensure_conn(start_ok_frame.kind == FrameType::Method)?;
let class = classes::parse_method(&start_ok_frame.payload)?;
Ok(class)
}
async fn start(&mut self) -> Result<()> { async fn start(&mut self) -> Result<()> {
let start_method = classes::Class::Connection(classes::Connection::Start { let start_method = classes::Class::Connection(classes::Connection::Start {
version_major: 0, version_major: 0,
@ -58,30 +98,72 @@ impl Connection {
locales: "en_US".into(), locales: "en_US".into(),
}); });
debug!(?start_method, "Sending start method"); debug!(?start_method, "Sending Start method");
self.send_method(0, start_method).await?;
let mut payload = Vec::with_capacity(64); let start_ok = self.recv_method().await?;
classes::write::write_method(start_method, &mut payload)?; debug!(?start_ok, "Received Start-Ok");
frame::write_frame(
&Frame {
kind: FrameType::Method,
channel: 0,
payload,
},
&mut self.stream,
)
.await?;
let start_ok_frame = frame::read_frame(&mut self.stream, self.max_frame_size).await?; if let classes::Class::Connection(classes::Connection::StartOk {
debug!(?start_ok_frame, "Received Start-Ok frame"); mechanism,
locale,
if start_ok_frame.kind != FrameType::Method { response,
return Err(ProtocolError::ConException(ConException::Todo).into()); ..
}) = start_ok
{
ensure_conn(mechanism == "PLAIN")?;
ensure_conn(locale == "en_US")?;
let plain_user = sasl::parse_sasl_plain_response(&response)?;
info!(username = %plain_user.authentication_identity, "SASL Authentication successful")
} else {
return Err(ConException::Todo.into_trans());
} }
let class = classes::parse_method(&start_ok_frame.payload)?; Ok(())
}
debug!(?class, "extracted method"); async fn tune(&mut self) -> Result<()> {
let tune_method = classes::Class::Connection(classes::Connection::Tune {
channel_max: CHANNEL_MAX,
frame_max: FRAME_SIZE_MAX,
heartbeat: HEARTBEAT_DELAY,
});
debug!("Sending Tune method");
self.send_method(0, tune_method).await?;
let tune_ok = self.recv_method().await?;
debug!(?tune_ok, "Received Tune-Ok method");
if let classes::Class::Connection(classes::Connection::TuneOk {
channel_max,
frame_max,
heartbeat,
}) = tune_ok
{
self.channel_max = channel_max;
self.max_frame_size = usize::try_from(frame_max).unwrap();
self.heartbeat_delay = heartbeat;
}
Ok(())
}
async fn open(&mut self) -> Result<()> {
let open = self.recv_method().await?;
debug!(?open, "Received Open method");
if let classes::Class::Connection(classes::Connection::Open { virtual_host, .. }) = open {
ensure_conn(virtual_host == "/")?;
}
self.send_method(
0,
classes::Class::Connection(classes::Connection::OpenOk {
reserved_1: "".to_string(),
}),
)
.await?;
Ok(()) Ok(())
} }
@ -120,21 +202,18 @@ impl Connection {
} }
fn server_properties(host: SocketAddr) -> classes::Table { fn server_properties(host: SocketAddr) -> classes::Table {
fn ss(str: &str) -> FieldValue { fn ls(str: &str) -> classes::FieldValue {
FieldValue::LongString(str.into()) classes::FieldValue::LongString(str.into())
} }
let host_str = host.ip().to_string(); let host_str = host.ip().to_string();
HashMap::from([ HashMap::from([
("host".to_string(), ss(&host_str)), ("host".to_string(), ls(&host_str)),
( ("product".to_string(), ls("no name yet")),
"product".to_string(), ("version".to_string(), ls("0.1.0")),
ss("no name yet"), ("platform".to_string(), ls("microsoft linux")),
), ("copyright".to_string(), ls("MIT")),
("version".to_string(), ss("0.1.0")), ("information".to_string(), ls("hello reader")),
("platform".to_string(), ss("microsoft linux")), ("uwu".to_string(), ls("owo")),
("copyright".to_string(), ss("MIT")),
("information".to_string(), ss("hello reader")),
("uwu".to_string(), ss("owo")),
]) ])
} }

View file

@ -1,6 +1,8 @@
use std::io::Error; use std::io::Error;
pub type Result<T> = std::result::Result<T, TransError>; pub type StdResult<T, E> = std::result::Result<T, E>;
pub type Result<T> = StdResult<T, TransError>;
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
pub enum TransError { pub enum TransError {
@ -34,7 +36,7 @@ pub enum ConException {
FrameError, FrameError,
#[error("503 Command invalid")] #[error("503 Command invalid")]
CommandInvalid, CommandInvalid,
#[error("503 Syntax error")] #[error("503 Syntax error | {0:?}")]
/// A method was received but there was a syntax error. The string stores where it occured. /// A method was received but there was a syntax error. The string stores where it occured.
SyntaxError(Vec<String>), SyntaxError(Vec<String>),
#[error("504 Channel error")] #[error("504 Channel error")]
@ -43,5 +45,11 @@ pub enum ConException {
Todo, Todo,
} }
impl ConException {
pub fn into_trans(self) -> TransError {
TransError::Invalid(ProtocolError::ConException(self))
}
}
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
pub enum ChannelException {} pub enum ChannelException {}

View file

@ -1,7 +1,7 @@
use crate::error::{ConException, ProtocolError, Result}; use crate::error::{ConException, ProtocolError, Result};
use anyhow::Context; use anyhow::Context;
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tracing::debug; use tracing::trace;
const REQUIRED_FRAME_END: u8 = 0xCE; const REQUIRED_FRAME_END: u8 = 0xCE;
@ -30,11 +30,11 @@ pub enum FrameType {
Heartbeat = 8, Heartbeat = 8,
} }
pub async fn write_frame<W>(frame: &Frame, mut w: W, ) -> Result<()> pub async fn write_frame<W>(frame: &Frame, mut w: W) -> Result<()>
where where
W: AsyncWriteExt + Unpin, W: AsyncWriteExt + Unpin,
{ {
debug!(?frame, "sending frame"); trace!(?frame, "Sending frame");
w.write_u8(frame.kind as u8).await?; w.write_u8(frame.kind as u8).await?;
w.write_u16(frame.channel).await?; w.write_u16(frame.channel).await?;
@ -63,17 +63,21 @@ where
return Err(ProtocolError::Fatal.into()); return Err(ProtocolError::Fatal.into());
} }
if payload.len() > max_frame_size { if max_frame_size != 0 && payload.len() > max_frame_size {
return Err(ProtocolError::ConException(ConException::FrameError).into()); return Err(ConException::FrameError.into_trans());
} }
let kind = parse_frame_type(kind, channel)?; let kind = parse_frame_type(kind, channel)?;
Ok(Frame { let frame = Frame {
kind, kind,
channel, channel,
payload, payload,
}) };
trace!(?frame, "Received frame");
Ok(frame)
} }
fn parse_frame_type(kind: u8, channel: u16) -> Result<FrameType> { fn parse_frame_type(kind: u8, channel: u16) -> Result<FrameType> {
@ -88,7 +92,7 @@ fn parse_frame_type(kind: u8, channel: u16) -> Result<FrameType> {
Ok(FrameType::Heartbeat) Ok(FrameType::Heartbeat)
} }
} }
_ => Err(ProtocolError::ConException(ConException::FrameError).into()), _ => Err(ConException::FrameError.into_trans()),
} }
} }

View file

@ -6,13 +6,15 @@ mod classes;
mod connection; mod connection;
mod error; mod error;
mod frame; mod frame;
mod sasl;
#[cfg(test)] #[cfg(test)]
mod tests; mod tests;
use crate::connection::Connection; use crate::connection::Connection;
use anyhow::Result; use anyhow::Result;
use tokio::net; use tokio::net;
use tracing::info; use tracing::{info, info_span, Instrument};
use uuid::Uuid;
pub async fn do_thing_i_guess() -> Result<()> { pub async fn do_thing_i_guess() -> Result<()> {
info!("Binding TCP listener..."); info!("Binding TCP listener...");
@ -22,10 +24,13 @@ pub async fn do_thing_i_guess() -> Result<()> {
loop { loop {
let (stream, _) = listener.accept().await?; let (stream, _) = listener.accept().await?;
info!(local_addr = ?stream.local_addr(), "Accepted new connection"); let id = Uuid::from_bytes(rand::random());
let connection = Connection::new(stream); info!(local_addr = ?stream.local_addr(), %id, "Accepted new connection");
let span = info_span!("client-connection", %id);
tokio::spawn(connection.open_connection()); let connection = Connection::new(stream, id);
tokio::spawn(connection.start_connection_processing().instrument(span));
} }
} }

View file

@ -0,0 +1,33 @@
//! Partial implementation of the SASL Authentication (see [RFC 4422](https://datatracker.ietf.org/doc/html/rfc4422))
//!
//! Currently only supports PLAN (see [RFC 4616](https://datatracker.ietf.org/doc/html/rfc4616))
use crate::error::{ConException, Result};
pub struct PlainUser {
pub authorization_identity: String,
pub authentication_identity: String,
pub password: String,
}
pub fn parse_sasl_plain_response(response: &[u8]) -> Result<PlainUser> {
let mut parts = response
.split(|&n| n == 0)
.map(|bytes| String::from_utf8(bytes.into()).map_err(|_| ConException::Todo.into_trans()));
let authorization_identity = parts
.next()
.ok_or_else(|| ConException::Todo.into_trans())??;
let authentication_identity = parts
.next()
.ok_or_else(|| ConException::Todo.into_trans())??;
let password = parts
.next()
.ok_or_else(|| ConException::Todo.into_trans())??;
Ok(PlainUser {
authorization_identity,
authentication_identity,
password,
})
}

View file

@ -29,8 +29,6 @@ async fn write_start_ok_frame() {
frame::write_frame(&frame, &mut output).await.unwrap(); frame::write_frame(&frame, &mut output).await.unwrap();
#[rustfmt::skip] #[rustfmt::skip]
let expected = [ let expected = [
/* type, octet, method */ /* type, octet, method */
@ -76,8 +74,6 @@ async fn write_start_ok_frame() {
#[test] #[test]
fn read_start_ok_payload() { fn read_start_ok_payload() {
#[rustfmt::skip] #[rustfmt::skip]
let raw_data = [ let raw_data = [
/* Connection.Start-Ok */ /* Connection.Start-Ok */

View file

@ -1,18 +1,28 @@
use anyhow::Result; use anyhow::Result;
use std::env;
use tracing::Level; use tracing::Level;
#[tokio::main] #[tokio::main]
async fn main() -> Result<()> { async fn main() -> Result<()> {
setup_tracing(); let mut level = Level::DEBUG;
for arg in env::args().skip(1) {
match arg.as_str() {
"--trace" => level = Level::TRACE,
_ => {}
}
}
setup_tracing(level);
amqp_transport::do_thing_i_guess().await amqp_transport::do_thing_i_guess().await
} }
fn setup_tracing() { fn setup_tracing(level: Level) {
tracing_subscriber::fmt() tracing_subscriber::fmt()
.with_level(true) .with_level(true)
.with_timer(tracing_subscriber::fmt::time::time()) .with_timer(tracing_subscriber::fmt::time::time())
.with_ansi(true) .with_ansi(true)
.with_thread_names(true) .with_thread_names(true)
.with_max_level(Level::DEBUG) .with_max_level(level)
.init() .init()
} }