This commit is contained in:
nora 2022-04-02 17:34:23 +02:00
parent 19e43ea5a2
commit 97fadb835e
6 changed files with 235 additions and 176 deletions

View file

@ -1,94 +0,0 @@
/*
#define MAX_EVENTS 10
struct epoll_event ev, events[MAX_EVENTS];
int listen_sock, conn_sock, nfds, epollfd;
/* Code to set up listening socket, 'listen_sock',
(socket(), bind(), listen()) omitted */
epollfd = epoll_create1(0);
if (epollfd == -1) {
perror("epoll_create1");
exit(EXIT_FAILURE);
}
ev.events = EPOLLIN;
ev.data.fd = listen_sock;
if (epoll_ctl(epollfd, EPOLL_CTL_ADD, listen_sock, &ev) == -1) {
perror("epoll_ctl: listen_sock");
exit(EXIT_FAILURE);
}
for (;;) {
nfds = epoll_wait(epollfd, events, MAX_EVENTS, -1);
if (nfds == -1) {
perror("epoll_wait");
exit(EXIT_FAILURE);
}
for (n = 0; n < nfds; ++n) {
if (events[n].data.fd == listen_sock) {
conn_sock = accept(listen_sock,
(struct sockaddr *) &addr, &addrlen);
if (conn_sock == -1) {
perror("accept");
exit(EXIT_FAILURE);
}
setnonblocking(conn_sock);
ev.events = EPOLLIN | EPOLLET;
ev.data.fd = conn_sock;
if (epoll_ctl(epollfd, EPOLL_CTL_ADD, conn_sock,
&ev) == -1) {
perror("epoll_ctl: conn_sock");
exit(EXIT_FAILURE);
}
} else {
do_use_fd(events[n].data.fd);
}
}
}
*/
const MAX_EVENTS: usize = 10;
pub fn example_from_man_page() -> Result<(), &'static str> {
// SAFETY: I trust man pages (maybe a mistake)
/*unsafe {
let listen_sock: i32 = todo!();
let mut events = [libc::epoll_event { events: 0, u64: 0 }; MAX_EVENTS]; // empty value
let epollfd = libc::epoll_create1(1);
if epollfd == -1 {
return Err("Failed to crate epoll instance");
}
let ev = libc::epoll_event {
events: libc::EPOLLIN.try_into().unwrap(),
u64: listen_sock.try_into().unwrap(),
};
loop {
let nfds = libc::epoll_wait(
epollfd,
events.as_mut_ptr(),
MAX_EVENTS.try_into().unwrap(),
-1,
);
if nfds == -1 {
return Err("Failed to wait for next event");
}
for i in 0usize..nfds.try_into().unwrap() {
if events[i].u64 == listen_sock.try_into().unwrap() {
let conn_sock = libc::accept(listen_sock, todo!(), todo!());
if conn_sock != -1 {
return Err("Failed to accept");
}
todo!()
}
}
}
}*/
Ok(()) // this is safe.
}

84
src/epoll/mod.rs Normal file
View file

@ -0,0 +1,84 @@
use std::{
io,
mem::MaybeUninit,
os::unix::io::{AsRawFd, RawFd},
ptr::addr_of,
};
use crate::{check_non_neg1, epoll::tcp::AsyncTcpListener};
mod tcp;
unsafe fn make_nonblock(fd: RawFd) -> io::Result<()> {
let status = check_non_neg1!(libc::fcntl(fd, libc::F_GETFL));
check_non_neg1!(libc::fcntl(fd, libc::F_SETFL, status | libc::O_NONBLOCK));
Ok(())
}
pub fn example_from_man_page() -> io::Result<()> {
let listener = AsyncTcpListener::bind_any(8888)?;
println!("Created listener {listener:?}");
unsafe {
let listen_sock = listener.as_raw_fd();
let epollfd = check_non_neg1!(libc::epoll_create1(0));
println!("Created epoll instance");
let mut ev = libc::epoll_event {
events: libc::EPOLLIN as _,
u64: listen_sock as _,
};
check_non_neg1!(libc::epoll_ctl(
epollfd,
libc::EPOLL_CTL_ADD,
listen_sock,
&mut ev
));
loop {
let mut events = [libc::epoll_event { events: 0, u64: 0 }; 16];
let nfds = check_non_neg1!(libc::epoll_wait(
epollfd,
events.as_mut_ptr(),
events.len() as _,
-1,
));
for event in &events[0..nfds as _] {
if event.u64 == listen_sock as _ {
// our TCP listener received a new connection
let mut peer_sockaddr = MaybeUninit::uninit();
let mut sockaddr_size = 0;
let conn_sock = check_non_neg1!(libc::accept(
listener.as_raw_fd(),
peer_sockaddr.as_mut_ptr(),
&mut sockaddr_size,
));
make_nonblock(conn_sock)?;
let mut ev = libc::epoll_event {
events: (libc::EPOLLIN | libc::EPOLLET) as _,
u64: conn_sock as _,
};
check_non_neg1!(libc::epoll_ctl(
epollfd,
libc::EPOLL_CTL_ADD,
conn_sock,
&mut ev
));
println!("Received new connection! (fd: {conn_sock})");
} else {
println!(
"something else happened! {}",
addr_of!(event.u64).read_unaligned()
);
}
}
}
}
}

62
src/epoll/tcp.rs Normal file
View file

@ -0,0 +1,62 @@
use std::{
fmt::{Debug, Formatter},
io,
os::{unix, unix::io::RawFd},
};
use crate::{check_is_zero, format_addr, SOCKADDR_IN_SIZE};
pub struct AsyncTcpListener {
fd: unix::io::RawFd,
addr: libc::sockaddr_in,
}
impl AsyncTcpListener {
pub fn bind_any(port: u16) -> io::Result<Self> {
let socket =
unsafe { libc::socket(libc::AF_INET, libc::SOCK_STREAM | libc::SOCK_NONBLOCK, 0) };
if socket == -1 {
return Err(io::Error::last_os_error());
}
let addr = libc::sockaddr_in {
sin_family: libc::AF_INET.try_into().unwrap(),
sin_port: port.to_be(),
sin_addr: libc::in_addr {
s_addr: libc::INADDR_ANY,
},
sin_zero: [0; 8],
};
let addr_erased_ptr = &addr as *const libc::sockaddr_in as _;
let result = unsafe { libc::bind(socket, addr_erased_ptr, SOCKADDR_IN_SIZE) };
if result == -1 {
return Err(io::Error::last_os_error());
}
check_is_zero!(unsafe { libc::listen(socket, 5) });
Ok(Self { fd: socket, addr })
}
}
impl unix::io::AsRawFd for AsyncTcpListener {
fn as_raw_fd(&self) -> RawFd {
self.fd
}
}
impl Drop for AsyncTcpListener {
fn drop(&mut self) {
unsafe { libc::close(self.fd) };
}
}
impl Debug for AsyncTcpListener {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SyncTcpListener")
.field("fd", &self.fd)
.field("addr", &format_addr(self.addr))
.finish()
}
}

View file

@ -1,8 +1,44 @@
#![warn(rust_2018_idioms)]
#![allow(dead_code)]
#![allow(clippy::single_component_path_imports)] // lol clippy
pub mod epoll;
pub mod sync_tcp;
#[cfg(not(target_os = "linux"))]
compile_error!("yeah not gonna compile that here, rip you");
const SOCKADDR_IN_SIZE: libc::socklen_t = std::mem::size_of::<libc::sockaddr_in>() as _;
fn format_addr(addr: libc::sockaddr_in) -> String {
let bytes = addr.sin_addr.s_addr.to_be_bytes();
format!(
"{}.{}.{}.{}:{}",
bytes[0],
bytes[1],
bytes[2],
bytes[3],
u16::from_be_bytes(addr.sin_port.to_ne_bytes())
)
}
macro_rules! check_is_zero {
($result:expr) => {
if $result != 0 {
return Err(io::Error::last_os_error());
}
};
}
use check_is_zero;
macro_rules! check_non_neg1 {
($result:expr) => {{
let result = $result;
if result == -1 {
return Err(io::Error::last_os_error());
}
result
}};
}
use check_non_neg1;

View file

@ -1,15 +1,12 @@
use std::{
io,
io::{Read, Write},
time::Duration,
};
use survey::sync_tcp::{SyncTcpListener, SyncTcpStream};
const PORT: u16 = 6547;
use survey::epoll::example_from_man_page;
pub fn main() {
match listener() {
match example_from_man_page() {
Ok(()) => {}
Err(err) => {
eprintln!("{err}");
@ -18,43 +15,43 @@ pub fn main() {
}
}
pub fn listener() -> io::Result<()> {
let mut threads = Vec::new();
let listener = SyncTcpListener::bind_any(PORT)?;
println!("Bound listener on port {PORT}");
for stream in listener.incoming() {
let stream = stream?;
let handle = std::thread::spawn(move || handler_thread(stream));
threads.push(handle);
}
for thread in threads {
thread.join().unwrap();
}
Ok(())
}
fn handler_thread(stream: SyncTcpStream) {
match handler(stream) {
Ok(()) => {}
Err(err) => eprintln!("An error occurred while processing connection: {err}"),
}
}
fn handler(mut stream: SyncTcpStream) -> io::Result<()> {
println!("Received connection! {stream:?}");
stream.write_all(b"Hi! Write your favourite three characters: ")?;
let mut buf = [0u8; 3];
stream.read_exact(&mut buf)?;
println!("Read data: {buf:?}");
stream.write_all(b"\nAh, it's: '")?;
stream.write_all(&buf)?;
stream.write_all(b"'. I like them too owo")?;
std::thread::sleep(Duration::from_millis(100));
Ok(())
}
//pub fn listener() -> io::Result<()> {
// let mut threads = Vec::new();
//
// let listener = SyncTcpListener::bind_any(PORT)?;
//
// println!("Bound listener on port {PORT}");
//
// for stream in listener.incoming() {
// let stream = stream?;
// let handle = std::thread::spawn(move || handler_thread(stream));
// threads.push(handle);
// }
//
// for thread in threads {
// thread.join().unwrap();
// }
//
// Ok(())
//}
//
//fn handler_thread(stream: SyncTcpStream) {
// match handler(stream) {
// Ok(()) => {}
// Err(err) => eprintln!("An error occurred while processing connection: {err}"),
// }
//}
//
//fn handler(mut stream: SyncTcpStream) -> io::Result<()> {
// println!("Received connection! {stream:?}");
//
// stream.write_all(b"Hi! Write your favourite three characters: ")?;
// let mut buf = [0u8; 3];
// stream.read_exact(&mut buf)?;
// println!("Read data: {buf:?}");
// stream.write_all(b"\nAh, it's: '")?;
// stream.write_all(&buf)?;
// stream.write_all(b"'. I like them too owo")?;
// std::thread::sleep(Duration::from_millis(100));
// Ok(())
//}

View file

@ -2,20 +2,11 @@ use std::{
fmt::{Debug, Formatter},
io,
io::{Read, Write},
mem,
mem::MaybeUninit,
os::{unix, unix::io::RawFd},
};
const SOCKADDR_IN_SIZE: libc::socklen_t = mem::size_of::<libc::sockaddr_in>() as _;
macro_rules! check_is_zero {
($result:expr) => {
if $result != 0 {
return Err(io::Error::last_os_error());
}
};
}
use crate::{check_is_zero, check_non_neg1, format_addr, SOCKADDR_IN_SIZE};
pub struct SyncTcpListener {
fd: unix::io::RawFd,
@ -24,11 +15,7 @@ pub struct SyncTcpListener {
impl SyncTcpListener {
pub fn bind_any(port: u16) -> io::Result<Self> {
let socket = unsafe { libc::socket(libc::AF_INET, libc::SOCK_STREAM, 0) };
if socket == -1 {
return Err(io::Error::last_os_error());
}
let socket = check_non_neg1!(unsafe { libc::socket(libc::AF_INET, libc::SOCK_STREAM, 0) });
let addr = libc::sockaddr_in {
sin_family: libc::AF_INET.try_into().unwrap(),
@ -40,10 +27,9 @@ impl SyncTcpListener {
};
let addr_erased_ptr = &addr as *const libc::sockaddr_in as _;
let result = unsafe { libc::bind(socket, addr_erased_ptr, SOCKADDR_IN_SIZE) };
if result == -1 {
return Err(io::Error::last_os_error());
}
let result =
check_non_neg1!(unsafe { libc::bind(socket, addr_erased_ptr, SOCKADDR_IN_SIZE) });
check_is_zero!(unsafe { libc::listen(socket, 5) });
Ok(Self { fd: socket, addr })
@ -100,20 +86,16 @@ pub struct SyncTcpStream {
impl Read for SyncTcpStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let size = unsafe { libc::read(self.fd, buf.as_mut_ptr().cast(), buf.len()) };
if size == -1 {
return Err(io::Error::last_os_error());
}
let size =
check_non_neg1!(unsafe { libc::read(self.fd, buf.as_mut_ptr().cast(), buf.len()) });
Ok(size.try_into().unwrap())
}
}
impl Write for SyncTcpStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let size = unsafe { libc::send(self.fd, buf.as_ptr().cast(), buf.len(), 0) };
if size == -1 {
return Err(io::Error::last_os_error());
}
let size =
check_non_neg1!(unsafe { libc::send(self.fd, buf.as_ptr().cast(), buf.len(), 0) });
Ok(size.try_into().unwrap())
}
@ -135,7 +117,7 @@ impl Debug for SyncTcpStream {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SyncTcpStream")
.field("fd", &self.fd)
.field("peer_addr", &format_addr(self.peer_sockaddr))
.field("addr", &format_addr(self.peer_sockaddr))
.finish()
}
}
@ -145,11 +127,3 @@ impl unix::io::AsRawFd for SyncTcpStream {
self.fd
}
}
fn format_addr(addr: libc::sockaddr_in) -> String {
let bytes = addr.sin_addr.s_addr.to_be_bytes();
format!(
"{}.{}.{}.{}:{}",
bytes[0], bytes[1], bytes[2], bytes[3], addr.sin_port
)
}