From 800ff88a6d3c7aaf8eb3b7d96aba3aab524361fe Mon Sep 17 00:00:00 2001 From: Nilstrieb <48135649+Nilstrieb@users.noreply.github.com> Date: Tue, 26 Sep 2023 07:50:20 +0200 Subject: [PATCH] expect byte test --- src/lib.rs | 27 ++++++++++---------- src/main.rs | 5 +++- tests/expect.rs | 66 +++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 84 insertions(+), 14 deletions(-) create mode 100644 tests/expect.rs diff --git a/src/lib.rs b/src/lib.rs index 8a7378f..1926761 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,30 +2,31 @@ pub mod proto; use std::{ fmt::Debug, - io::{self, BufWriter, Read, Write}, - net::TcpStream, + io::{self, Read, Write}, }; use crate::proto::TLSPlaintext; type Result = std::result::Result; -pub struct ClientConnection {} +pub struct ClientConnection { + _w: W, +} -impl ClientConnection { - pub fn establish(host: &str, port: u16) -> Result { - let _setup = ClientSetupConnection::establish(host, port)?; +impl ClientConnection { + pub fn establish(w: W, host: &str) -> Result { + let _setup = ClientSetupConnection::establish(w, host)?; todo!() } } -struct ClientSetupConnection {} - -impl ClientSetupConnection { - fn establish(host: &str, port: u16) -> Result { - let mut stream = BufWriter::new(LoggingWriter(TcpStream::connect((host, port))?)); +struct ClientSetupConnection { + _w: W, +} +impl ClientSetupConnection { + fn establish(mut stream: W, host: &str) -> Result { let secret = x25519_dalek::EphemeralSecret::random_from_rng(rand::thread_rng()); let public = x25519_dalek::PublicKey::from(&secret); @@ -82,7 +83,7 @@ impl ClientSetupConnection { plaintext.write(&mut stream)?; stream.flush()?; - let out = proto::TLSPlaintext::read(stream.get_mut())?; + let out = proto::TLSPlaintext::read(&mut stream)?; dbg!(&out); if matches!(out, TLSPlaintext::Handshake { handshake } if handshake.is_hello_retry_request()) @@ -125,7 +126,7 @@ impl From for Error { } #[derive(Debug)] -struct LoggingWriter(W); +pub struct LoggingWriter(pub W); impl io::Write for LoggingWriter { fn write(&mut self, buf: &[u8]) -> io::Result { diff --git a/src/main.rs b/src/main.rs index 48c139d..a2f1644 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,7 @@ +use std::net::TcpStream; + // An example program that makes a shitty HTTP/1.1 request. fn main() { - tls::ClientConnection::establish("nilstrieb.dev", 443).unwrap(); + let conn = tls::LoggingWriter(TcpStream::connect(("nilstrieb.dev", 443)).unwrap()); + tls::ClientConnection::establish(conn, "nilstrieb.dev").unwrap(); } diff --git a/tests/expect.rs b/tests/expect.rs new file mode 100644 index 0000000..060651c --- /dev/null +++ b/tests/expect.rs @@ -0,0 +1,66 @@ +use std::io::{Read, Write}; + +struct ExpectServer { + expect: Vec, +} + +impl ExpectServer { + fn new(expect: Vec) -> Self { + ExpectServer { expect } + } +} + +impl Read for ExpectServer { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + let Some(Expect::Server(server)) = self.expect.first_mut() else { + panic!("Reading from server, but client input is expected"); + }; + + let len = std::cmp::min(buf.len(), server.len()); + buf[..len].copy_from_slice(&mut server[..len]); + server.rotate_left(len); + server.truncate(server.len() - len); + if server.is_empty() { + self.expect.remove(0); + } + Ok(len) + } +} + +impl Write for ExpectServer { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + let Some(Expect::Client(client)) = self.expect.first_mut() else { + panic!("Writing as client, but should read instead"); + }; + + let to_write = client + .get(..buf.len()) + .expect("writing more bytes than expected"); + assert_eq!(to_write, buf); + client.rotate_left(buf.len()); + client.truncate(client.len() - buf.len()); + if client.is_empty() { + self.expect.remove(0); + } + Ok(buf.len()) + } + + fn flush(&mut self) -> std::io::Result<()> { + Ok(()) + } +} + +enum Expect { + Server(Vec), + Client(Vec), +} + +#[test] +#[ignore] +fn connect() { + let mut expect = ExpectServer::new(vec![ + Expect::Client(vec![0]), // TODO: do this + ]); + + let conn = tls::ClientConnection::establish(&mut expect, "example.com").unwrap(); +}