diff --git a/.rustfmt.toml b/.rustfmt.toml new file mode 100644 index 0000000..4d7dd9e --- /dev/null +++ b/.rustfmt.toml @@ -0,0 +1,3 @@ +imports_granularity = "Crate" +newline_style = "Unix" +group_imports = "StdExternalCrate" \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index df92cec..a27a493 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,7 +2,7 @@ #![allow(dead_code)] pub mod epoll; -pub mod listener; +pub mod sync_tcp; #[cfg(not(target_os = "linux"))] compile_error!("yeah not gonna compile that here, rip you"); diff --git a/src/listener.rs b/src/listener.rs deleted file mode 100644 index a7d47b8..0000000 --- a/src/listener.rs +++ /dev/null @@ -1,61 +0,0 @@ -use std::mem::MaybeUninit; -use std::{io, mem}; - -const PORT: libc::in_port_t = 1112; - -const SOCKADDR_IN_SIZE: libc::socklen_t = mem::size_of::() as _; - -macro_rules! check_zero { - ($result:expr) => { - if $result != 0 { - return Err(io::Error::last_os_error()); - } - }; -} - -pub fn listener() -> io::Result<()> { - unsafe { - let socket = libc::socket(libc::AF_INET, libc::SOCK_STREAM, 0); - - if socket == -1 { - return Err(io::Error::last_os_error()); - } - - println!("Created socket ({})", socket); - - let addr = libc::sockaddr_in { - sin_family: libc::AF_INET.try_into().unwrap(), - sin_port: PORT.to_be(), - sin_addr: libc::in_addr { - s_addr: libc::INADDR_ANY, - }, - sin_zero: [0; 8], - }; - let addr_erased_ptr = &addr as *const libc::sockaddr_in as _; - - let result = libc::bind(socket, addr_erased_ptr, SOCKADDR_IN_SIZE); - if result == -1 { - return Err(io::Error::last_os_error()); - } - - println!("Bound socket ({socket}) on port {PORT}"); - - check_zero!(libc::listen(socket, 5)); - - println!("Listening on socket ({socket})"); - - let mut peer_sockaddr = MaybeUninit::uninit(); - let mut sockaddr_size = 0; - let connection = libc::accept(socket, peer_sockaddr.as_mut_ptr(), &mut sockaddr_size); - if connection == -1 { - return Err(io::Error::last_os_error()); - } - - println!("Received connection! (connfd={connection})"); - - check_zero!(libc::close(connection)); - check_zero!(libc::close(socket)); - } - - Ok(()) -} diff --git a/src/main.rs b/src/main.rs index 573f69a..e231d14 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,5 +1,14 @@ +use std::{ + io, + io::{Read, Write}, +}; + +use survey::sync_tcp::{SyncTcpListener, SyncTcpStream}; + +const PORT: u16 = 6547; + pub fn main() { - match survey::listener::listener() { + match listener() { Ok(()) => {} Err(err) => { eprintln!("{err}"); @@ -7,3 +16,48 @@ pub fn main() { } } } + +pub fn listener() -> io::Result<()> { + let mut threads = Vec::new(); + + let mut listener = SyncTcpListener::bind_any(PORT)?; + + println!("Bound listener on port {PORT}"); + + for stream in listener.accept() { + let handle = std::thread::spawn(move || handler_thread(stream)); + threads.push(handle); + } + + for thread in threads { + thread.join().unwrap(); + } + + Ok(()) +} + +fn handler_thread(stream: SyncTcpStream) { + match handler(stream) { + Ok(()) => {} + Err(err) => eprintln!("An error occurred while processing connection: {err}"), + } +} + +fn handler(mut stream: SyncTcpStream) -> io::Result<()> { + println!("Received connection! {stream:?}"); + + stream.write_all(b"Hi! Write your favourite three characters: ")?; + let mut buf = [0u8; 3]; + stream.read_exact(&mut buf)?; + println!("Read data: {buf:?}"); + stream.write_all(b"\nAh, it's: '")?; + stream.write_all(&buf)?; + stream.write_all(b"'. I like them too owo")?; + println!("written stuff"); + Ok(()) +} + +fn format_addr(addr: libc::in_addr) -> String { + let bytes = addr.s_addr.to_be_bytes(); + format!("{}.{}.{}.{}", bytes[0], bytes[1], bytes[2], bytes[3]) +} diff --git a/src/sync_tcp.rs b/src/sync_tcp.rs new file mode 100644 index 0000000..f7a72a3 --- /dev/null +++ b/src/sync_tcp.rs @@ -0,0 +1,122 @@ +use std::{ + fmt::{Debug, Formatter}, + io, + io::{Read, Write}, + mem, + mem::MaybeUninit, + os::unix, +}; + +const SOCKADDR_IN_SIZE: libc::socklen_t = mem::size_of::() as _; + +macro_rules! check_is_zero { + ($result:expr) => { + if $result != 0 { + return Err(io::Error::last_os_error()); + } + }; +} + +pub struct SyncTcpListener { + fd: unix::io::RawFd, + addr: libc::sockaddr_in, +} + +impl SyncTcpListener { + pub fn bind_any(port: u16) -> io::Result { + let socket = unsafe { libc::socket(libc::AF_INET, libc::SOCK_STREAM, 0) }; + + if socket == -1 { + return Err(io::Error::last_os_error()); + } + + let addr = libc::sockaddr_in { + sin_family: libc::AF_INET.try_into().unwrap(), + sin_port: port.to_be(), + sin_addr: libc::in_addr { + s_addr: libc::INADDR_ANY, + }, + sin_zero: [0; 8], + }; + let addr_erased_ptr = &addr as *const libc::sockaddr_in as _; + + let result = unsafe { libc::bind(socket, addr_erased_ptr, SOCKADDR_IN_SIZE) }; + if result == -1 { + return Err(io::Error::last_os_error()); + } + check_is_zero!(unsafe { libc::listen(socket, 5) }); + + Ok(Self { fd: socket, addr }) + } + + pub fn accept(&mut self) -> impl Iterator + '_ { + std::iter::from_fn(|| { + let mut peer_sockaddr = MaybeUninit::uninit(); + let mut sockaddr_size = 0; + let fd = + unsafe { libc::accept(self.fd, peer_sockaddr.as_mut_ptr(), &mut sockaddr_size) }; + if fd == -1 { + return None; + } + + let peer_sockaddr = unsafe { + peer_sockaddr + .as_mut_ptr() + .cast::() + .read() + }; + + Some(SyncTcpStream { fd, peer_sockaddr }) + }) + } +} + +impl Drop for SyncTcpListener { + fn drop(&mut self) { + unsafe { libc::close(self.fd) }; + } +} + +pub struct SyncTcpStream { + fd: unix::io::RawFd, + peer_sockaddr: libc::sockaddr_in, +} + +impl Read for SyncTcpStream { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + let size = unsafe { libc::read(self.fd, buf.as_mut_ptr().cast(), buf.len()) }; + if size == -1 { + return Err(io::Error::last_os_error()); + } + Ok(size.try_into().unwrap()) + } +} + +impl Write for SyncTcpStream { + fn write(&mut self, buf: &[u8]) -> io::Result { + let size = unsafe { libc::write(self.fd, buf.as_ptr().cast(), buf.len()) }; + if size == -1 { + return Err(io::Error::last_os_error()); + } + Ok(size.try_into().unwrap()) + } + + fn flush(&mut self) -> io::Result<()> { + todo!() + } +} + +impl Drop for SyncTcpStream { + fn drop(&mut self) { + unsafe { libc::close(self.fd) }; + } +} + +impl Debug for SyncTcpStream { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SyncTcpStream") + .field("fd", &self.fd) + .field("peer_addr", &format_addr(self.peer_sockaddr.sin_addr)) + .finish() + } +}