connection working

This commit is contained in:
nora 2022-02-19 18:12:28 +01:00
parent ca1f372665
commit 13deef42fd
9 changed files with 217 additions and 82 deletions

View file

@ -1,4 +1,4 @@
use crate::error::{ConException, ProtocolError, TransError};
use crate::error::{ConException, TransError};
use std::collections::HashMap;
mod generated;
@ -41,15 +41,17 @@ pub fn parse_method(payload: &[u8]) -> Result<generated::Class, TransError> {
match nom_result {
Ok(([], class)) => Ok(class),
Ok((_, _)) => Err(ProtocolError::ConException(ConException::SyntaxError(vec![
"could not consume all input".to_string(),
]))
.into()),
Ok((_, _)) => {
Err(
ConException::SyntaxError(vec!["could not consume all input".to_string()])
.into_trans(),
)
}
Err(nom::Err::Incomplete(_)) => {
Err(ProtocolError::ConException(ConException::SyntaxError(vec![
"there was not enough data".to_string(),
]))
.into())
Err(
ConException::SyntaxError(vec!["there was not enough data".to_string()])
.into_trans(),
)
}
Err(nom::Err::Failure(err) | nom::Err::Error(err)) => Err(err),
}

View file

@ -17,7 +17,7 @@ use std::collections::HashMap;
impl<T> nom::error::ParseError<T> for TransError {
fn from_error_kind(_input: T, _kind: ErrorKind) -> Self {
ProtocolError::ConException(ConException::SyntaxError(vec![])).into()
ConException::SyntaxError(vec![]).into_trans()
}
fn append(_input: T, _kind: ErrorKind, other: Self) -> Self {
@ -47,13 +47,11 @@ pub fn err<S: Into<String>>(msg: S) -> impl FnOnce(Err<TransError>) -> Err<Trans
},
_ => vec![msg],
};
error_level(ProtocolError::ConException(ConException::SyntaxError(stack)).into())
error_level(ConException::SyntaxError(stack).into_trans())
}
}
pub fn err_other<E, S: Into<String>>(msg: S) -> impl FnOnce(E) -> Err<TransError> {
move |_| {
Err::Error(ProtocolError::ConException(ConException::SyntaxError(vec![msg.into()])).into())
}
move |_| Err::Error(ConException::SyntaxError(vec![msg.into()]).into_trans())
}
pub fn failure<E>(err: Err<E>) -> Err<E> {
@ -145,7 +143,7 @@ pub fn table(input: &[u8]) -> IResult<Table> {
let (input, values) = many0(table_value_pair)(table_input)?;
if input != &[] {
if !input.is_empty() {
fail!(format!(
"table longer than expected, expected = {size}, remaining = {}",
input.len()

View file

@ -1,48 +1,88 @@
use crate::classes::FieldValue;
use crate::error::{ConException, ProtocolError, Result};
use crate::frame::{Frame, FrameType};
use crate::{classes, frame};
use crate::{classes, frame, sasl};
use anyhow::Context;
use std::collections::HashMap;
use std::net::SocketAddr;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tracing::{debug, error};
use tracing::{debug, error, info};
use uuid::Uuid;
const MIN_MAX_FRAME_SIZE: usize = 4096;
fn ensure_conn(condition: bool) -> Result<()> {
if condition {
Ok(())
} else {
Err(ConException::Todo.into_trans())
}
}
const FRAME_SIZE_MIN_MAX: usize = 4096;
const CHANNEL_MAX: u16 = 0;
const FRAME_SIZE_MAX: u32 = 0;
const HEARTBEAT_DELAY: u16 = 0;
pub struct Connection {
stream: TcpStream,
max_frame_size: usize,
heartbeat_delay: u16,
channel_max: u16,
id: Uuid,
}
impl Connection {
pub fn new(stream: TcpStream) -> Self {
pub fn new(stream: TcpStream, id: Uuid) -> Self {
Self {
stream,
max_frame_size: MIN_MAX_FRAME_SIZE,
max_frame_size: FRAME_SIZE_MIN_MAX,
heartbeat_delay: HEARTBEAT_DELAY,
channel_max: CHANNEL_MAX,
id,
}
}
pub async fn open_connection(mut self) {
match self.run().await {
pub async fn start_connection_processing(mut self) {
match self.process_connection().await {
Ok(()) => {}
Err(err) => error!(%err, "Error during processing of connection"),
}
}
pub async fn run(&mut self) -> Result<()> {
pub async fn process_connection(&mut self) -> Result<()> {
self.negotiate_version().await?;
self.start().await?;
self.tune().await?;
self.open().await?;
info!("Connection is ready for usage!");
loop {
let frame = frame::read_frame(&mut self.stream, self.max_frame_size).await?;
debug!(?frame, "received frame");
if frame.kind == FrameType::Method {
let class = super::classes::parse_method(&frame.payload)?;
debug!(?class, "was method frame");
let method = self.recv_method().await?;
debug!(?method, "Received method");
}
}
async fn send_method(&mut self, channel: u16, method: classes::Class) -> Result<()> {
let mut payload = Vec::with_capacity(64);
classes::write::write_method(method, &mut payload)?;
frame::write_frame(
&Frame {
kind: FrameType::Method,
channel,
payload,
},
&mut self.stream,
)
.await
}
async fn recv_method(&mut self) -> Result<classes::Class> {
let start_ok_frame = frame::read_frame(&mut self.stream, self.max_frame_size).await?;
ensure_conn(start_ok_frame.kind == FrameType::Method)?;
let class = classes::parse_method(&start_ok_frame.payload)?;
Ok(class)
}
async fn start(&mut self) -> Result<()> {
@ -58,30 +98,72 @@ impl Connection {
locales: "en_US".into(),
});
debug!(?start_method, "Sending start method");
debug!(?start_method, "Sending Start method");
self.send_method(0, start_method).await?;
let mut payload = Vec::with_capacity(64);
classes::write::write_method(start_method, &mut payload)?;
frame::write_frame(
&Frame {
kind: FrameType::Method,
channel: 0,
payload,
},
&mut self.stream,
)
.await?;
let start_ok = self.recv_method().await?;
debug!(?start_ok, "Received Start-Ok");
let start_ok_frame = frame::read_frame(&mut self.stream, self.max_frame_size).await?;
debug!(?start_ok_frame, "Received Start-Ok frame");
if start_ok_frame.kind != FrameType::Method {
return Err(ProtocolError::ConException(ConException::Todo).into());
if let classes::Class::Connection(classes::Connection::StartOk {
mechanism,
locale,
response,
..
}) = start_ok
{
ensure_conn(mechanism == "PLAIN")?;
ensure_conn(locale == "en_US")?;
let plain_user = sasl::parse_sasl_plain_response(&response)?;
info!(username = %plain_user.authentication_identity, "SASL Authentication successful")
} else {
return Err(ConException::Todo.into_trans());
}
let class = classes::parse_method(&start_ok_frame.payload)?;
Ok(())
}
debug!(?class, "extracted method");
async fn tune(&mut self) -> Result<()> {
let tune_method = classes::Class::Connection(classes::Connection::Tune {
channel_max: CHANNEL_MAX,
frame_max: FRAME_SIZE_MAX,
heartbeat: HEARTBEAT_DELAY,
});
debug!("Sending Tune method");
self.send_method(0, tune_method).await?;
let tune_ok = self.recv_method().await?;
debug!(?tune_ok, "Received Tune-Ok method");
if let classes::Class::Connection(classes::Connection::TuneOk {
channel_max,
frame_max,
heartbeat,
}) = tune_ok
{
self.channel_max = channel_max;
self.max_frame_size = usize::try_from(frame_max).unwrap();
self.heartbeat_delay = heartbeat;
}
Ok(())
}
async fn open(&mut self) -> Result<()> {
let open = self.recv_method().await?;
debug!(?open, "Received Open method");
if let classes::Class::Connection(classes::Connection::Open { virtual_host, .. }) = open {
ensure_conn(virtual_host == "/")?;
}
self.send_method(
0,
classes::Class::Connection(classes::Connection::OpenOk {
reserved_1: "".to_string(),
}),
)
.await?;
Ok(())
}
@ -120,21 +202,18 @@ impl Connection {
}
fn server_properties(host: SocketAddr) -> classes::Table {
fn ss(str: &str) -> FieldValue {
FieldValue::LongString(str.into())
fn ls(str: &str) -> classes::FieldValue {
classes::FieldValue::LongString(str.into())
}
let host_str = host.ip().to_string();
HashMap::from([
("host".to_string(), ss(&host_str)),
(
"product".to_string(),
ss("no name yet"),
),
("version".to_string(), ss("0.1.0")),
("platform".to_string(), ss("microsoft linux")),
("copyright".to_string(), ss("MIT")),
("information".to_string(), ss("hello reader")),
("uwu".to_string(), ss("owo")),
("host".to_string(), ls(&host_str)),
("product".to_string(), ls("no name yet")),
("version".to_string(), ls("0.1.0")),
("platform".to_string(), ls("microsoft linux")),
("copyright".to_string(), ls("MIT")),
("information".to_string(), ls("hello reader")),
("uwu".to_string(), ls("owo")),
])
}

View file

@ -1,6 +1,8 @@
use std::io::Error;
pub type Result<T> = std::result::Result<T, TransError>;
pub type StdResult<T, E> = std::result::Result<T, E>;
pub type Result<T> = StdResult<T, TransError>;
#[derive(Debug, thiserror::Error)]
pub enum TransError {
@ -34,7 +36,7 @@ pub enum ConException {
FrameError,
#[error("503 Command invalid")]
CommandInvalid,
#[error("503 Syntax error")]
#[error("503 Syntax error | {0:?}")]
/// A method was received but there was a syntax error. The string stores where it occured.
SyntaxError(Vec<String>),
#[error("504 Channel error")]
@ -43,5 +45,11 @@ pub enum ConException {
Todo,
}
impl ConException {
pub fn into_trans(self) -> TransError {
TransError::Invalid(ProtocolError::ConException(self))
}
}
#[derive(Debug, thiserror::Error)]
pub enum ChannelException {}

View file

@ -1,7 +1,7 @@
use crate::error::{ConException, ProtocolError, Result};
use anyhow::Context;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tracing::debug;
use tracing::trace;
const REQUIRED_FRAME_END: u8 = 0xCE;
@ -30,11 +30,11 @@ pub enum FrameType {
Heartbeat = 8,
}
pub async fn write_frame<W>(frame: &Frame, mut w: W, ) -> Result<()>
pub async fn write_frame<W>(frame: &Frame, mut w: W) -> Result<()>
where
W: AsyncWriteExt + Unpin,
{
debug!(?frame, "sending frame");
trace!(?frame, "Sending frame");
w.write_u8(frame.kind as u8).await?;
w.write_u16(frame.channel).await?;
@ -63,17 +63,21 @@ where
return Err(ProtocolError::Fatal.into());
}
if payload.len() > max_frame_size {
return Err(ProtocolError::ConException(ConException::FrameError).into());
if max_frame_size != 0 && payload.len() > max_frame_size {
return Err(ConException::FrameError.into_trans());
}
let kind = parse_frame_type(kind, channel)?;
Ok(Frame {
let frame = Frame {
kind,
channel,
payload,
})
};
trace!(?frame, "Received frame");
Ok(frame)
}
fn parse_frame_type(kind: u8, channel: u16) -> Result<FrameType> {
@ -88,7 +92,7 @@ fn parse_frame_type(kind: u8, channel: u16) -> Result<FrameType> {
Ok(FrameType::Heartbeat)
}
}
_ => Err(ProtocolError::ConException(ConException::FrameError).into()),
_ => Err(ConException::FrameError.into_trans()),
}
}

View file

@ -6,13 +6,15 @@ mod classes;
mod connection;
mod error;
mod frame;
mod sasl;
#[cfg(test)]
mod tests;
use crate::connection::Connection;
use anyhow::Result;
use tokio::net;
use tracing::info;
use tracing::{info, info_span, Instrument};
use uuid::Uuid;
pub async fn do_thing_i_guess() -> Result<()> {
info!("Binding TCP listener...");
@ -22,10 +24,13 @@ pub async fn do_thing_i_guess() -> Result<()> {
loop {
let (stream, _) = listener.accept().await?;
info!(local_addr = ?stream.local_addr(), "Accepted new connection");
let id = Uuid::from_bytes(rand::random());
let connection = Connection::new(stream);
info!(local_addr = ?stream.local_addr(), %id, "Accepted new connection");
let span = info_span!("client-connection", %id);
tokio::spawn(connection.open_connection());
let connection = Connection::new(stream, id);
tokio::spawn(connection.start_connection_processing().instrument(span));
}
}

View file

@ -0,0 +1,33 @@
//! Partial implementation of the SASL Authentication (see [RFC 4422](https://datatracker.ietf.org/doc/html/rfc4422))
//!
//! Currently only supports PLAN (see [RFC 4616](https://datatracker.ietf.org/doc/html/rfc4616))
use crate::error::{ConException, Result};
pub struct PlainUser {
pub authorization_identity: String,
pub authentication_identity: String,
pub password: String,
}
pub fn parse_sasl_plain_response(response: &[u8]) -> Result<PlainUser> {
let mut parts = response
.split(|&n| n == 0)
.map(|bytes| String::from_utf8(bytes.into()).map_err(|_| ConException::Todo.into_trans()));
let authorization_identity = parts
.next()
.ok_or_else(|| ConException::Todo.into_trans())??;
let authentication_identity = parts
.next()
.ok_or_else(|| ConException::Todo.into_trans())??;
let password = parts
.next()
.ok_or_else(|| ConException::Todo.into_trans())??;
Ok(PlainUser {
authorization_identity,
authentication_identity,
password,
})
}

View file

@ -29,8 +29,6 @@ async fn write_start_ok_frame() {
frame::write_frame(&frame, &mut output).await.unwrap();
#[rustfmt::skip]
let expected = [
/* type, octet, method */
@ -76,8 +74,6 @@ async fn write_start_ok_frame() {
#[test]
fn read_start_ok_payload() {
#[rustfmt::skip]
let raw_data = [
/* Connection.Start-Ok */

View file

@ -1,18 +1,28 @@
use anyhow::Result;
use std::env;
use tracing::Level;
#[tokio::main]
async fn main() -> Result<()> {
setup_tracing();
let mut level = Level::DEBUG;
for arg in env::args().skip(1) {
match arg.as_str() {
"--trace" => level = Level::TRACE,
_ => {}
}
}
setup_tracing(level);
amqp_transport::do_thing_i_guess().await
}
fn setup_tracing() {
fn setup_tracing(level: Level) {
tracing_subscriber::fmt()
.with_level(true)
.with_timer(tracing_subscriber::fmt::time::time())
.with_ansi(true)
.with_thread_names(true)
.with_max_level(Level::DEBUG)
.with_max_level(level)
.init()
}