This commit is contained in:
nora 2022-04-02 17:34:23 +02:00
parent 19e43ea5a2
commit 97fadb835e
6 changed files with 235 additions and 176 deletions

View file

@ -1,94 +0,0 @@
/*
#define MAX_EVENTS 10
struct epoll_event ev, events[MAX_EVENTS];
int listen_sock, conn_sock, nfds, epollfd;
/* Code to set up listening socket, 'listen_sock',
(socket(), bind(), listen()) omitted */
epollfd = epoll_create1(0);
if (epollfd == -1) {
perror("epoll_create1");
exit(EXIT_FAILURE);
}
ev.events = EPOLLIN;
ev.data.fd = listen_sock;
if (epoll_ctl(epollfd, EPOLL_CTL_ADD, listen_sock, &ev) == -1) {
perror("epoll_ctl: listen_sock");
exit(EXIT_FAILURE);
}
for (;;) {
nfds = epoll_wait(epollfd, events, MAX_EVENTS, -1);
if (nfds == -1) {
perror("epoll_wait");
exit(EXIT_FAILURE);
}
for (n = 0; n < nfds; ++n) {
if (events[n].data.fd == listen_sock) {
conn_sock = accept(listen_sock,
(struct sockaddr *) &addr, &addrlen);
if (conn_sock == -1) {
perror("accept");
exit(EXIT_FAILURE);
}
setnonblocking(conn_sock);
ev.events = EPOLLIN | EPOLLET;
ev.data.fd = conn_sock;
if (epoll_ctl(epollfd, EPOLL_CTL_ADD, conn_sock,
&ev) == -1) {
perror("epoll_ctl: conn_sock");
exit(EXIT_FAILURE);
}
} else {
do_use_fd(events[n].data.fd);
}
}
}
*/
const MAX_EVENTS: usize = 10;
pub fn example_from_man_page() -> Result<(), &'static str> {
// SAFETY: I trust man pages (maybe a mistake)
/*unsafe {
let listen_sock: i32 = todo!();
let mut events = [libc::epoll_event { events: 0, u64: 0 }; MAX_EVENTS]; // empty value
let epollfd = libc::epoll_create1(1);
if epollfd == -1 {
return Err("Failed to crate epoll instance");
}
let ev = libc::epoll_event {
events: libc::EPOLLIN.try_into().unwrap(),
u64: listen_sock.try_into().unwrap(),
};
loop {
let nfds = libc::epoll_wait(
epollfd,
events.as_mut_ptr(),
MAX_EVENTS.try_into().unwrap(),
-1,
);
if nfds == -1 {
return Err("Failed to wait for next event");
}
for i in 0usize..nfds.try_into().unwrap() {
if events[i].u64 == listen_sock.try_into().unwrap() {
let conn_sock = libc::accept(listen_sock, todo!(), todo!());
if conn_sock != -1 {
return Err("Failed to accept");
}
todo!()
}
}
}
}*/
Ok(()) // this is safe.
}

84
src/epoll/mod.rs Normal file
View file

@ -0,0 +1,84 @@
use std::{
io,
mem::MaybeUninit,
os::unix::io::{AsRawFd, RawFd},
ptr::addr_of,
};
use crate::{check_non_neg1, epoll::tcp::AsyncTcpListener};
mod tcp;
unsafe fn make_nonblock(fd: RawFd) -> io::Result<()> {
let status = check_non_neg1!(libc::fcntl(fd, libc::F_GETFL));
check_non_neg1!(libc::fcntl(fd, libc::F_SETFL, status | libc::O_NONBLOCK));
Ok(())
}
pub fn example_from_man_page() -> io::Result<()> {
let listener = AsyncTcpListener::bind_any(8888)?;
println!("Created listener {listener:?}");
unsafe {
let listen_sock = listener.as_raw_fd();
let epollfd = check_non_neg1!(libc::epoll_create1(0));
println!("Created epoll instance");
let mut ev = libc::epoll_event {
events: libc::EPOLLIN as _,
u64: listen_sock as _,
};
check_non_neg1!(libc::epoll_ctl(
epollfd,
libc::EPOLL_CTL_ADD,
listen_sock,
&mut ev
));
loop {
let mut events = [libc::epoll_event { events: 0, u64: 0 }; 16];
let nfds = check_non_neg1!(libc::epoll_wait(
epollfd,
events.as_mut_ptr(),
events.len() as _,
-1,
));
for event in &events[0..nfds as _] {
if event.u64 == listen_sock as _ {
// our TCP listener received a new connection
let mut peer_sockaddr = MaybeUninit::uninit();
let mut sockaddr_size = 0;
let conn_sock = check_non_neg1!(libc::accept(
listener.as_raw_fd(),
peer_sockaddr.as_mut_ptr(),
&mut sockaddr_size,
));
make_nonblock(conn_sock)?;
let mut ev = libc::epoll_event {
events: (libc::EPOLLIN | libc::EPOLLET) as _,
u64: conn_sock as _,
};
check_non_neg1!(libc::epoll_ctl(
epollfd,
libc::EPOLL_CTL_ADD,
conn_sock,
&mut ev
));
println!("Received new connection! (fd: {conn_sock})");
} else {
println!(
"something else happened! {}",
addr_of!(event.u64).read_unaligned()
);
}
}
}
}
}

62
src/epoll/tcp.rs Normal file
View file

@ -0,0 +1,62 @@
use std::{
fmt::{Debug, Formatter},
io,
os::{unix, unix::io::RawFd},
};
use crate::{check_is_zero, format_addr, SOCKADDR_IN_SIZE};
pub struct AsyncTcpListener {
fd: unix::io::RawFd,
addr: libc::sockaddr_in,
}
impl AsyncTcpListener {
pub fn bind_any(port: u16) -> io::Result<Self> {
let socket =
unsafe { libc::socket(libc::AF_INET, libc::SOCK_STREAM | libc::SOCK_NONBLOCK, 0) };
if socket == -1 {
return Err(io::Error::last_os_error());
}
let addr = libc::sockaddr_in {
sin_family: libc::AF_INET.try_into().unwrap(),
sin_port: port.to_be(),
sin_addr: libc::in_addr {
s_addr: libc::INADDR_ANY,
},
sin_zero: [0; 8],
};
let addr_erased_ptr = &addr as *const libc::sockaddr_in as _;
let result = unsafe { libc::bind(socket, addr_erased_ptr, SOCKADDR_IN_SIZE) };
if result == -1 {
return Err(io::Error::last_os_error());
}
check_is_zero!(unsafe { libc::listen(socket, 5) });
Ok(Self { fd: socket, addr })
}
}
impl unix::io::AsRawFd for AsyncTcpListener {
fn as_raw_fd(&self) -> RawFd {
self.fd
}
}
impl Drop for AsyncTcpListener {
fn drop(&mut self) {
unsafe { libc::close(self.fd) };
}
}
impl Debug for AsyncTcpListener {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SyncTcpListener")
.field("fd", &self.fd)
.field("addr", &format_addr(self.addr))
.finish()
}
}

View file

@ -1,8 +1,44 @@
#![warn(rust_2018_idioms)] #![warn(rust_2018_idioms)]
#![allow(dead_code)] #![allow(clippy::single_component_path_imports)] // lol clippy
pub mod epoll; pub mod epoll;
pub mod sync_tcp; pub mod sync_tcp;
#[cfg(not(target_os = "linux"))] #[cfg(not(target_os = "linux"))]
compile_error!("yeah not gonna compile that here, rip you"); compile_error!("yeah not gonna compile that here, rip you");
const SOCKADDR_IN_SIZE: libc::socklen_t = std::mem::size_of::<libc::sockaddr_in>() as _;
fn format_addr(addr: libc::sockaddr_in) -> String {
let bytes = addr.sin_addr.s_addr.to_be_bytes();
format!(
"{}.{}.{}.{}:{}",
bytes[0],
bytes[1],
bytes[2],
bytes[3],
u16::from_be_bytes(addr.sin_port.to_ne_bytes())
)
}
macro_rules! check_is_zero {
($result:expr) => {
if $result != 0 {
return Err(io::Error::last_os_error());
}
};
}
use check_is_zero;
macro_rules! check_non_neg1 {
($result:expr) => {{
let result = $result;
if result == -1 {
return Err(io::Error::last_os_error());
}
result
}};
}
use check_non_neg1;

View file

@ -1,15 +1,12 @@
use std::{ use std::{
io,
io::{Read, Write}, io::{Read, Write},
time::Duration, time::Duration,
}; };
use survey::sync_tcp::{SyncTcpListener, SyncTcpStream}; use survey::epoll::example_from_man_page;
const PORT: u16 = 6547;
pub fn main() { pub fn main() {
match listener() { match example_from_man_page() {
Ok(()) => {} Ok(()) => {}
Err(err) => { Err(err) => {
eprintln!("{err}"); eprintln!("{err}");
@ -18,43 +15,43 @@ pub fn main() {
} }
} }
pub fn listener() -> io::Result<()> { //pub fn listener() -> io::Result<()> {
let mut threads = Vec::new(); // let mut threads = Vec::new();
//
let listener = SyncTcpListener::bind_any(PORT)?; // let listener = SyncTcpListener::bind_any(PORT)?;
//
println!("Bound listener on port {PORT}"); // println!("Bound listener on port {PORT}");
//
for stream in listener.incoming() { // for stream in listener.incoming() {
let stream = stream?; // let stream = stream?;
let handle = std::thread::spawn(move || handler_thread(stream)); // let handle = std::thread::spawn(move || handler_thread(stream));
threads.push(handle); // threads.push(handle);
} // }
//
for thread in threads { // for thread in threads {
thread.join().unwrap(); // thread.join().unwrap();
} // }
//
Ok(()) // Ok(())
} //}
//
fn handler_thread(stream: SyncTcpStream) { //fn handler_thread(stream: SyncTcpStream) {
match handler(stream) { // match handler(stream) {
Ok(()) => {} // Ok(()) => {}
Err(err) => eprintln!("An error occurred while processing connection: {err}"), // Err(err) => eprintln!("An error occurred while processing connection: {err}"),
} // }
} //}
//
fn handler(mut stream: SyncTcpStream) -> io::Result<()> { //fn handler(mut stream: SyncTcpStream) -> io::Result<()> {
println!("Received connection! {stream:?}"); // println!("Received connection! {stream:?}");
//
stream.write_all(b"Hi! Write your favourite three characters: ")?; // stream.write_all(b"Hi! Write your favourite three characters: ")?;
let mut buf = [0u8; 3]; // let mut buf = [0u8; 3];
stream.read_exact(&mut buf)?; // stream.read_exact(&mut buf)?;
println!("Read data: {buf:?}"); // println!("Read data: {buf:?}");
stream.write_all(b"\nAh, it's: '")?; // stream.write_all(b"\nAh, it's: '")?;
stream.write_all(&buf)?; // stream.write_all(&buf)?;
stream.write_all(b"'. I like them too owo")?; // stream.write_all(b"'. I like them too owo")?;
std::thread::sleep(Duration::from_millis(100)); // std::thread::sleep(Duration::from_millis(100));
Ok(()) // Ok(())
} //}

View file

@ -2,20 +2,11 @@ use std::{
fmt::{Debug, Formatter}, fmt::{Debug, Formatter},
io, io,
io::{Read, Write}, io::{Read, Write},
mem,
mem::MaybeUninit, mem::MaybeUninit,
os::{unix, unix::io::RawFd}, os::{unix, unix::io::RawFd},
}; };
const SOCKADDR_IN_SIZE: libc::socklen_t = mem::size_of::<libc::sockaddr_in>() as _; use crate::{check_is_zero, check_non_neg1, format_addr, SOCKADDR_IN_SIZE};
macro_rules! check_is_zero {
($result:expr) => {
if $result != 0 {
return Err(io::Error::last_os_error());
}
};
}
pub struct SyncTcpListener { pub struct SyncTcpListener {
fd: unix::io::RawFd, fd: unix::io::RawFd,
@ -24,11 +15,7 @@ pub struct SyncTcpListener {
impl SyncTcpListener { impl SyncTcpListener {
pub fn bind_any(port: u16) -> io::Result<Self> { pub fn bind_any(port: u16) -> io::Result<Self> {
let socket = unsafe { libc::socket(libc::AF_INET, libc::SOCK_STREAM, 0) }; let socket = check_non_neg1!(unsafe { libc::socket(libc::AF_INET, libc::SOCK_STREAM, 0) });
if socket == -1 {
return Err(io::Error::last_os_error());
}
let addr = libc::sockaddr_in { let addr = libc::sockaddr_in {
sin_family: libc::AF_INET.try_into().unwrap(), sin_family: libc::AF_INET.try_into().unwrap(),
@ -40,10 +27,9 @@ impl SyncTcpListener {
}; };
let addr_erased_ptr = &addr as *const libc::sockaddr_in as _; let addr_erased_ptr = &addr as *const libc::sockaddr_in as _;
let result = unsafe { libc::bind(socket, addr_erased_ptr, SOCKADDR_IN_SIZE) }; let result =
if result == -1 { check_non_neg1!(unsafe { libc::bind(socket, addr_erased_ptr, SOCKADDR_IN_SIZE) });
return Err(io::Error::last_os_error());
}
check_is_zero!(unsafe { libc::listen(socket, 5) }); check_is_zero!(unsafe { libc::listen(socket, 5) });
Ok(Self { fd: socket, addr }) Ok(Self { fd: socket, addr })
@ -100,20 +86,16 @@ pub struct SyncTcpStream {
impl Read for SyncTcpStream { impl Read for SyncTcpStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let size = unsafe { libc::read(self.fd, buf.as_mut_ptr().cast(), buf.len()) }; let size =
if size == -1 { check_non_neg1!(unsafe { libc::read(self.fd, buf.as_mut_ptr().cast(), buf.len()) });
return Err(io::Error::last_os_error());
}
Ok(size.try_into().unwrap()) Ok(size.try_into().unwrap())
} }
} }
impl Write for SyncTcpStream { impl Write for SyncTcpStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> { fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let size = unsafe { libc::send(self.fd, buf.as_ptr().cast(), buf.len(), 0) }; let size =
if size == -1 { check_non_neg1!(unsafe { libc::send(self.fd, buf.as_ptr().cast(), buf.len(), 0) });
return Err(io::Error::last_os_error());
}
Ok(size.try_into().unwrap()) Ok(size.try_into().unwrap())
} }
@ -135,7 +117,7 @@ impl Debug for SyncTcpStream {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SyncTcpStream") f.debug_struct("SyncTcpStream")
.field("fd", &self.fd) .field("fd", &self.fd)
.field("peer_addr", &format_addr(self.peer_sockaddr)) .field("addr", &format_addr(self.peer_sockaddr))
.finish() .finish()
} }
} }
@ -145,11 +127,3 @@ impl unix::io::AsRawFd for SyncTcpStream {
self.fd self.fd
} }
} }
fn format_addr(addr: libc::sockaddr_in) -> String {
let bytes = addr.sin_addr.s_addr.to_be_bytes();
format!(
"{}.{}.{}.{}:{}",
bytes[0], bytes[1], bytes[2], bytes[3], addr.sin_port
)
}