diff --git a/src/epoll.rs b/src/epoll.rs deleted file mode 100644 index 21777d6..0000000 --- a/src/epoll.rs +++ /dev/null @@ -1,94 +0,0 @@ -/* -#define MAX_EVENTS 10 - struct epoll_event ev, events[MAX_EVENTS]; - int listen_sock, conn_sock, nfds, epollfd; - - /* Code to set up listening socket, 'listen_sock', - (socket(), bind(), listen()) omitted */ - - epollfd = epoll_create1(0); - if (epollfd == -1) { - perror("epoll_create1"); - exit(EXIT_FAILURE); - } - - ev.events = EPOLLIN; - ev.data.fd = listen_sock; - if (epoll_ctl(epollfd, EPOLL_CTL_ADD, listen_sock, &ev) == -1) { - perror("epoll_ctl: listen_sock"); - exit(EXIT_FAILURE); - } - - for (;;) { - nfds = epoll_wait(epollfd, events, MAX_EVENTS, -1); - if (nfds == -1) { - perror("epoll_wait"); - exit(EXIT_FAILURE); - } - - for (n = 0; n < nfds; ++n) { - if (events[n].data.fd == listen_sock) { - conn_sock = accept(listen_sock, - (struct sockaddr *) &addr, &addrlen); - if (conn_sock == -1) { - perror("accept"); - exit(EXIT_FAILURE); - } - setnonblocking(conn_sock); - ev.events = EPOLLIN | EPOLLET; - ev.data.fd = conn_sock; - if (epoll_ctl(epollfd, EPOLL_CTL_ADD, conn_sock, - &ev) == -1) { - perror("epoll_ctl: conn_sock"); - exit(EXIT_FAILURE); - } - } else { - do_use_fd(events[n].data.fd); - } - } - } - */ - -const MAX_EVENTS: usize = 10; - -pub fn example_from_man_page() -> Result<(), &'static str> { - // SAFETY: I trust man pages (maybe a mistake) - /*unsafe { - let listen_sock: i32 = todo!(); - let mut events = [libc::epoll_event { events: 0, u64: 0 }; MAX_EVENTS]; // empty value - - let epollfd = libc::epoll_create1(1); - if epollfd == -1 { - return Err("Failed to crate epoll instance"); - } - - let ev = libc::epoll_event { - events: libc::EPOLLIN.try_into().unwrap(), - u64: listen_sock.try_into().unwrap(), - }; - - loop { - let nfds = libc::epoll_wait( - epollfd, - events.as_mut_ptr(), - MAX_EVENTS.try_into().unwrap(), - -1, - ); - if nfds == -1 { - return Err("Failed to wait for next event"); - } - - for i in 0usize..nfds.try_into().unwrap() { - if events[i].u64 == listen_sock.try_into().unwrap() { - let conn_sock = libc::accept(listen_sock, todo!(), todo!()); - if conn_sock != -1 { - return Err("Failed to accept"); - } - todo!() - } - } - } - }*/ - - Ok(()) // this is safe. -} diff --git a/src/epoll/mod.rs b/src/epoll/mod.rs new file mode 100644 index 0000000..794ec5d --- /dev/null +++ b/src/epoll/mod.rs @@ -0,0 +1,84 @@ +use std::{ + io, + mem::MaybeUninit, + os::unix::io::{AsRawFd, RawFd}, + ptr::addr_of, +}; + +use crate::{check_non_neg1, epoll::tcp::AsyncTcpListener}; + +mod tcp; + +unsafe fn make_nonblock(fd: RawFd) -> io::Result<()> { + let status = check_non_neg1!(libc::fcntl(fd, libc::F_GETFL)); + check_non_neg1!(libc::fcntl(fd, libc::F_SETFL, status | libc::O_NONBLOCK)); + Ok(()) +} + +pub fn example_from_man_page() -> io::Result<()> { + let listener = AsyncTcpListener::bind_any(8888)?; + println!("Created listener {listener:?}"); + + unsafe { + let listen_sock = listener.as_raw_fd(); + + let epollfd = check_non_neg1!(libc::epoll_create1(0)); + + println!("Created epoll instance"); + + let mut ev = libc::epoll_event { + events: libc::EPOLLIN as _, + u64: listen_sock as _, + }; + + check_non_neg1!(libc::epoll_ctl( + epollfd, + libc::EPOLL_CTL_ADD, + listen_sock, + &mut ev + )); + + loop { + let mut events = [libc::epoll_event { events: 0, u64: 0 }; 16]; + + let nfds = check_non_neg1!(libc::epoll_wait( + epollfd, + events.as_mut_ptr(), + events.len() as _, + -1, + )); + + for event in &events[0..nfds as _] { + if event.u64 == listen_sock as _ { + // our TCP listener received a new connection + let mut peer_sockaddr = MaybeUninit::uninit(); + let mut sockaddr_size = 0; + + let conn_sock = check_non_neg1!(libc::accept( + listener.as_raw_fd(), + peer_sockaddr.as_mut_ptr(), + &mut sockaddr_size, + )); + + make_nonblock(conn_sock)?; + let mut ev = libc::epoll_event { + events: (libc::EPOLLIN | libc::EPOLLET) as _, + u64: conn_sock as _, + }; + check_non_neg1!(libc::epoll_ctl( + epollfd, + libc::EPOLL_CTL_ADD, + conn_sock, + &mut ev + )); + println!("Received new connection! (fd: {conn_sock})"); + } else { + println!( + "something else happened! {}", + addr_of!(event.u64).read_unaligned() + ); + } + } + } + } +} diff --git a/src/epoll/tcp.rs b/src/epoll/tcp.rs new file mode 100644 index 0000000..af0cc3f --- /dev/null +++ b/src/epoll/tcp.rs @@ -0,0 +1,62 @@ +use std::{ + fmt::{Debug, Formatter}, + io, + os::{unix, unix::io::RawFd}, +}; + +use crate::{check_is_zero, format_addr, SOCKADDR_IN_SIZE}; + +pub struct AsyncTcpListener { + fd: unix::io::RawFd, + addr: libc::sockaddr_in, +} + +impl AsyncTcpListener { + pub fn bind_any(port: u16) -> io::Result { + let socket = + unsafe { libc::socket(libc::AF_INET, libc::SOCK_STREAM | libc::SOCK_NONBLOCK, 0) }; + + if socket == -1 { + return Err(io::Error::last_os_error()); + } + + let addr = libc::sockaddr_in { + sin_family: libc::AF_INET.try_into().unwrap(), + sin_port: port.to_be(), + sin_addr: libc::in_addr { + s_addr: libc::INADDR_ANY, + }, + sin_zero: [0; 8], + }; + let addr_erased_ptr = &addr as *const libc::sockaddr_in as _; + + let result = unsafe { libc::bind(socket, addr_erased_ptr, SOCKADDR_IN_SIZE) }; + if result == -1 { + return Err(io::Error::last_os_error()); + } + check_is_zero!(unsafe { libc::listen(socket, 5) }); + + Ok(Self { fd: socket, addr }) + } +} + +impl unix::io::AsRawFd for AsyncTcpListener { + fn as_raw_fd(&self) -> RawFd { + self.fd + } +} + +impl Drop for AsyncTcpListener { + fn drop(&mut self) { + unsafe { libc::close(self.fd) }; + } +} + +impl Debug for AsyncTcpListener { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SyncTcpListener") + .field("fd", &self.fd) + .field("addr", &format_addr(self.addr)) + .finish() + } +} diff --git a/src/lib.rs b/src/lib.rs index a27a493..8fef512 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,8 +1,44 @@ #![warn(rust_2018_idioms)] -#![allow(dead_code)] +#![allow(clippy::single_component_path_imports)] // lol clippy pub mod epoll; pub mod sync_tcp; #[cfg(not(target_os = "linux"))] compile_error!("yeah not gonna compile that here, rip you"); + +const SOCKADDR_IN_SIZE: libc::socklen_t = std::mem::size_of::() as _; + +fn format_addr(addr: libc::sockaddr_in) -> String { + let bytes = addr.sin_addr.s_addr.to_be_bytes(); + format!( + "{}.{}.{}.{}:{}", + bytes[0], + bytes[1], + bytes[2], + bytes[3], + u16::from_be_bytes(addr.sin_port.to_ne_bytes()) + ) +} + +macro_rules! check_is_zero { + ($result:expr) => { + if $result != 0 { + return Err(io::Error::last_os_error()); + } + }; +} + +use check_is_zero; + +macro_rules! check_non_neg1 { + ($result:expr) => {{ + let result = $result; + if result == -1 { + return Err(io::Error::last_os_error()); + } + result + }}; +} + +use check_non_neg1; diff --git a/src/main.rs b/src/main.rs index 633649c..a617186 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,15 +1,12 @@ use std::{ - io, io::{Read, Write}, time::Duration, }; -use survey::sync_tcp::{SyncTcpListener, SyncTcpStream}; - -const PORT: u16 = 6547; +use survey::epoll::example_from_man_page; pub fn main() { - match listener() { + match example_from_man_page() { Ok(()) => {} Err(err) => { eprintln!("{err}"); @@ -18,43 +15,43 @@ pub fn main() { } } -pub fn listener() -> io::Result<()> { - let mut threads = Vec::new(); - - let listener = SyncTcpListener::bind_any(PORT)?; - - println!("Bound listener on port {PORT}"); - - for stream in listener.incoming() { - let stream = stream?; - let handle = std::thread::spawn(move || handler_thread(stream)); - threads.push(handle); - } - - for thread in threads { - thread.join().unwrap(); - } - - Ok(()) -} - -fn handler_thread(stream: SyncTcpStream) { - match handler(stream) { - Ok(()) => {} - Err(err) => eprintln!("An error occurred while processing connection: {err}"), - } -} - -fn handler(mut stream: SyncTcpStream) -> io::Result<()> { - println!("Received connection! {stream:?}"); - - stream.write_all(b"Hi! Write your favourite three characters: ")?; - let mut buf = [0u8; 3]; - stream.read_exact(&mut buf)?; - println!("Read data: {buf:?}"); - stream.write_all(b"\nAh, it's: '")?; - stream.write_all(&buf)?; - stream.write_all(b"'. I like them too owo")?; - std::thread::sleep(Duration::from_millis(100)); - Ok(()) -} +//pub fn listener() -> io::Result<()> { +// let mut threads = Vec::new(); +// +// let listener = SyncTcpListener::bind_any(PORT)?; +// +// println!("Bound listener on port {PORT}"); +// +// for stream in listener.incoming() { +// let stream = stream?; +// let handle = std::thread::spawn(move || handler_thread(stream)); +// threads.push(handle); +// } +// +// for thread in threads { +// thread.join().unwrap(); +// } +// +// Ok(()) +//} +// +//fn handler_thread(stream: SyncTcpStream) { +// match handler(stream) { +// Ok(()) => {} +// Err(err) => eprintln!("An error occurred while processing connection: {err}"), +// } +//} +// +//fn handler(mut stream: SyncTcpStream) -> io::Result<()> { +// println!("Received connection! {stream:?}"); +// +// stream.write_all(b"Hi! Write your favourite three characters: ")?; +// let mut buf = [0u8; 3]; +// stream.read_exact(&mut buf)?; +// println!("Read data: {buf:?}"); +// stream.write_all(b"\nAh, it's: '")?; +// stream.write_all(&buf)?; +// stream.write_all(b"'. I like them too owo")?; +// std::thread::sleep(Duration::from_millis(100)); +// Ok(()) +//} diff --git a/src/sync_tcp.rs b/src/sync_tcp.rs index 1fe53d5..1da0c9d 100644 --- a/src/sync_tcp.rs +++ b/src/sync_tcp.rs @@ -2,20 +2,11 @@ use std::{ fmt::{Debug, Formatter}, io, io::{Read, Write}, - mem, mem::MaybeUninit, os::{unix, unix::io::RawFd}, }; -const SOCKADDR_IN_SIZE: libc::socklen_t = mem::size_of::() as _; - -macro_rules! check_is_zero { - ($result:expr) => { - if $result != 0 { - return Err(io::Error::last_os_error()); - } - }; -} +use crate::{check_is_zero, check_non_neg1, format_addr, SOCKADDR_IN_SIZE}; pub struct SyncTcpListener { fd: unix::io::RawFd, @@ -24,11 +15,7 @@ pub struct SyncTcpListener { impl SyncTcpListener { pub fn bind_any(port: u16) -> io::Result { - let socket = unsafe { libc::socket(libc::AF_INET, libc::SOCK_STREAM, 0) }; - - if socket == -1 { - return Err(io::Error::last_os_error()); - } + let socket = check_non_neg1!(unsafe { libc::socket(libc::AF_INET, libc::SOCK_STREAM, 0) }); let addr = libc::sockaddr_in { sin_family: libc::AF_INET.try_into().unwrap(), @@ -40,10 +27,9 @@ impl SyncTcpListener { }; let addr_erased_ptr = &addr as *const libc::sockaddr_in as _; - let result = unsafe { libc::bind(socket, addr_erased_ptr, SOCKADDR_IN_SIZE) }; - if result == -1 { - return Err(io::Error::last_os_error()); - } + let result = + check_non_neg1!(unsafe { libc::bind(socket, addr_erased_ptr, SOCKADDR_IN_SIZE) }); + check_is_zero!(unsafe { libc::listen(socket, 5) }); Ok(Self { fd: socket, addr }) @@ -100,20 +86,16 @@ pub struct SyncTcpStream { impl Read for SyncTcpStream { fn read(&mut self, buf: &mut [u8]) -> io::Result { - let size = unsafe { libc::read(self.fd, buf.as_mut_ptr().cast(), buf.len()) }; - if size == -1 { - return Err(io::Error::last_os_error()); - } + let size = + check_non_neg1!(unsafe { libc::read(self.fd, buf.as_mut_ptr().cast(), buf.len()) }); Ok(size.try_into().unwrap()) } } impl Write for SyncTcpStream { fn write(&mut self, buf: &[u8]) -> io::Result { - let size = unsafe { libc::send(self.fd, buf.as_ptr().cast(), buf.len(), 0) }; - if size == -1 { - return Err(io::Error::last_os_error()); - } + let size = + check_non_neg1!(unsafe { libc::send(self.fd, buf.as_ptr().cast(), buf.len(), 0) }); Ok(size.try_into().unwrap()) } @@ -135,7 +117,7 @@ impl Debug for SyncTcpStream { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { f.debug_struct("SyncTcpStream") .field("fd", &self.fd) - .field("peer_addr", &format_addr(self.peer_sockaddr)) + .field("addr", &format_addr(self.peer_sockaddr)) .finish() } } @@ -145,11 +127,3 @@ impl unix::io::AsRawFd for SyncTcpStream { self.fd } } - -fn format_addr(addr: libc::sockaddr_in) -> String { - let bytes = addr.sin_addr.s_addr.to_be_bytes(); - format!( - "{}.{}.{}.{}:{}", - bytes[0], bytes[1], bytes[2], bytes[3], addr.sin_port - ) -}