safe custom tcp listener!!!

This commit is contained in:
nora 2022-04-01 22:06:40 +02:00
parent 43df60f510
commit 69f9a54164
5 changed files with 181 additions and 63 deletions

3
.rustfmt.toml Normal file
View file

@ -0,0 +1,3 @@
imports_granularity = "Crate"
newline_style = "Unix"
group_imports = "StdExternalCrate"

View file

@ -2,7 +2,7 @@
#![allow(dead_code)]
pub mod epoll;
pub mod listener;
pub mod sync_tcp;
#[cfg(not(target_os = "linux"))]
compile_error!("yeah not gonna compile that here, rip you");

View file

@ -1,61 +0,0 @@
use std::mem::MaybeUninit;
use std::{io, mem};
const PORT: libc::in_port_t = 1112;
const SOCKADDR_IN_SIZE: libc::socklen_t = mem::size_of::<libc::sockaddr_in>() as _;
macro_rules! check_zero {
($result:expr) => {
if $result != 0 {
return Err(io::Error::last_os_error());
}
};
}
pub fn listener() -> io::Result<()> {
unsafe {
let socket = libc::socket(libc::AF_INET, libc::SOCK_STREAM, 0);
if socket == -1 {
return Err(io::Error::last_os_error());
}
println!("Created socket ({})", socket);
let addr = libc::sockaddr_in {
sin_family: libc::AF_INET.try_into().unwrap(),
sin_port: PORT.to_be(),
sin_addr: libc::in_addr {
s_addr: libc::INADDR_ANY,
},
sin_zero: [0; 8],
};
let addr_erased_ptr = &addr as *const libc::sockaddr_in as _;
let result = libc::bind(socket, addr_erased_ptr, SOCKADDR_IN_SIZE);
if result == -1 {
return Err(io::Error::last_os_error());
}
println!("Bound socket ({socket}) on port {PORT}");
check_zero!(libc::listen(socket, 5));
println!("Listening on socket ({socket})");
let mut peer_sockaddr = MaybeUninit::uninit();
let mut sockaddr_size = 0;
let connection = libc::accept(socket, peer_sockaddr.as_mut_ptr(), &mut sockaddr_size);
if connection == -1 {
return Err(io::Error::last_os_error());
}
println!("Received connection! (connfd={connection})");
check_zero!(libc::close(connection));
check_zero!(libc::close(socket));
}
Ok(())
}

View file

@ -1,5 +1,14 @@
use std::{
io,
io::{Read, Write},
};
use survey::sync_tcp::{SyncTcpListener, SyncTcpStream};
const PORT: u16 = 6547;
pub fn main() {
match survey::listener::listener() {
match listener() {
Ok(()) => {}
Err(err) => {
eprintln!("{err}");
@ -7,3 +16,48 @@ pub fn main() {
}
}
}
pub fn listener() -> io::Result<()> {
let mut threads = Vec::new();
let mut listener = SyncTcpListener::bind_any(PORT)?;
println!("Bound listener on port {PORT}");
for stream in listener.accept() {
let handle = std::thread::spawn(move || handler_thread(stream));
threads.push(handle);
}
for thread in threads {
thread.join().unwrap();
}
Ok(())
}
fn handler_thread(stream: SyncTcpStream) {
match handler(stream) {
Ok(()) => {}
Err(err) => eprintln!("An error occurred while processing connection: {err}"),
}
}
fn handler(mut stream: SyncTcpStream) -> io::Result<()> {
println!("Received connection! {stream:?}");
stream.write_all(b"Hi! Write your favourite three characters: ")?;
let mut buf = [0u8; 3];
stream.read_exact(&mut buf)?;
println!("Read data: {buf:?}");
stream.write_all(b"\nAh, it's: '")?;
stream.write_all(&buf)?;
stream.write_all(b"'. I like them too owo")?;
println!("written stuff");
Ok(())
}
fn format_addr(addr: libc::in_addr) -> String {
let bytes = addr.s_addr.to_be_bytes();
format!("{}.{}.{}.{}", bytes[0], bytes[1], bytes[2], bytes[3])
}

122
src/sync_tcp.rs Normal file
View file

@ -0,0 +1,122 @@
use std::{
fmt::{Debug, Formatter},
io,
io::{Read, Write},
mem,
mem::MaybeUninit,
os::unix,
};
const SOCKADDR_IN_SIZE: libc::socklen_t = mem::size_of::<libc::sockaddr_in>() as _;
macro_rules! check_is_zero {
($result:expr) => {
if $result != 0 {
return Err(io::Error::last_os_error());
}
};
}
pub struct SyncTcpListener {
fd: unix::io::RawFd,
addr: libc::sockaddr_in,
}
impl SyncTcpListener {
pub fn bind_any(port: u16) -> io::Result<Self> {
let socket = unsafe { libc::socket(libc::AF_INET, libc::SOCK_STREAM, 0) };
if socket == -1 {
return Err(io::Error::last_os_error());
}
let addr = libc::sockaddr_in {
sin_family: libc::AF_INET.try_into().unwrap(),
sin_port: port.to_be(),
sin_addr: libc::in_addr {
s_addr: libc::INADDR_ANY,
},
sin_zero: [0; 8],
};
let addr_erased_ptr = &addr as *const libc::sockaddr_in as _;
let result = unsafe { libc::bind(socket, addr_erased_ptr, SOCKADDR_IN_SIZE) };
if result == -1 {
return Err(io::Error::last_os_error());
}
check_is_zero!(unsafe { libc::listen(socket, 5) });
Ok(Self { fd: socket, addr })
}
pub fn accept(&mut self) -> impl Iterator<Item = SyncTcpStream> + '_ {
std::iter::from_fn(|| {
let mut peer_sockaddr = MaybeUninit::uninit();
let mut sockaddr_size = 0;
let fd =
unsafe { libc::accept(self.fd, peer_sockaddr.as_mut_ptr(), &mut sockaddr_size) };
if fd == -1 {
return None;
}
let peer_sockaddr = unsafe {
peer_sockaddr
.as_mut_ptr()
.cast::<libc::sockaddr_in>()
.read()
};
Some(SyncTcpStream { fd, peer_sockaddr })
})
}
}
impl Drop for SyncTcpListener {
fn drop(&mut self) {
unsafe { libc::close(self.fd) };
}
}
pub struct SyncTcpStream {
fd: unix::io::RawFd,
peer_sockaddr: libc::sockaddr_in,
}
impl Read for SyncTcpStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let size = unsafe { libc::read(self.fd, buf.as_mut_ptr().cast(), buf.len()) };
if size == -1 {
return Err(io::Error::last_os_error());
}
Ok(size.try_into().unwrap())
}
}
impl Write for SyncTcpStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let size = unsafe { libc::write(self.fd, buf.as_ptr().cast(), buf.len()) };
if size == -1 {
return Err(io::Error::last_os_error());
}
Ok(size.try_into().unwrap())
}
fn flush(&mut self) -> io::Result<()> {
todo!()
}
}
impl Drop for SyncTcpStream {
fn drop(&mut self) {
unsafe { libc::close(self.fd) };
}
}
impl Debug for SyncTcpStream {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SyncTcpStream")
.field("fd", &self.fd)
.field("peer_addr", &format_addr(self.peer_sockaddr.sin_addr))
.finish()
}
}