This commit is contained in:
nora 2023-03-07 14:00:23 +01:00
parent 25adea4103
commit 7af1274587
160 changed files with 38999 additions and 4 deletions

1425
hyper/src/proto/h1/conn.rs Normal file

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,731 @@
use std::error::Error as StdError;
use std::fmt;
use std::io;
use std::usize;
use bytes::Bytes;
use tracing::{debug, trace};
use crate::common::{task, Poll};
use super::io::MemRead;
use super::DecodedLength;
use self::Kind::{Chunked, Eof, Length};
/// Decoders to handle different Transfer-Encodings.
///
/// If a message body does not include a Transfer-Encoding, it *should*
/// include a Content-Length header.
#[derive(Clone, PartialEq)]
pub(crate) struct Decoder {
kind: Kind,
}
#[derive(Debug, Clone, Copy, PartialEq)]
enum Kind {
/// A Reader used when a Content-Length header is passed with a positive integer.
Length(u64),
/// A Reader used when Transfer-Encoding is `chunked`.
Chunked(ChunkedState, u64),
/// A Reader used for responses that don't indicate a length or chunked.
///
/// The bool tracks when EOF is seen on the transport.
///
/// Note: This should only used for `Response`s. It is illegal for a
/// `Request` to be made with both `Content-Length` and
/// `Transfer-Encoding: chunked` missing, as explained from the spec:
///
/// > If a Transfer-Encoding header field is present in a response and
/// > the chunked transfer coding is not the final encoding, the
/// > message body length is determined by reading the connection until
/// > it is closed by the server. If a Transfer-Encoding header field
/// > is present in a request and the chunked transfer coding is not
/// > the final encoding, the message body length cannot be determined
/// > reliably; the server MUST respond with the 400 (Bad Request)
/// > status code and then close the connection.
Eof(bool),
}
#[derive(Debug, PartialEq, Clone, Copy)]
enum ChunkedState {
Size,
SizeLws,
Extension,
SizeLf,
Body,
BodyCr,
BodyLf,
Trailer,
TrailerLf,
EndCr,
EndLf,
End,
}
impl Decoder {
// constructors
pub(crate) fn length(x: u64) -> Decoder {
Decoder {
kind: Kind::Length(x),
}
}
pub(crate) fn chunked() -> Decoder {
Decoder {
kind: Kind::Chunked(ChunkedState::Size, 0),
}
}
pub(crate) fn eof() -> Decoder {
Decoder {
kind: Kind::Eof(false),
}
}
pub(super) fn new(len: DecodedLength) -> Self {
match len {
DecodedLength::CHUNKED => Decoder::chunked(),
DecodedLength::CLOSE_DELIMITED => Decoder::eof(),
length => Decoder::length(length.danger_len()),
}
}
// methods
pub(crate) fn is_eof(&self) -> bool {
matches!(self.kind, Length(0) | Chunked(ChunkedState::End, _) | Eof(true))
}
pub(crate) fn decode<R: MemRead>(
&mut self,
cx: &mut task::Context<'_>,
body: &mut R,
) -> Poll<Result<Bytes, io::Error>> {
trace!("decode; state={:?}", self.kind);
match self.kind {
Length(ref mut remaining) => {
if *remaining == 0 {
Poll::Ready(Ok(Bytes::new()))
} else {
let to_read = *remaining as usize;
let buf = ready!(body.read_mem(cx, to_read))?;
let num = buf.as_ref().len() as u64;
if num > *remaining {
*remaining = 0;
} else if num == 0 {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
IncompleteBody,
)));
} else {
*remaining -= num;
}
Poll::Ready(Ok(buf))
}
}
Chunked(ref mut state, ref mut size) => {
loop {
let mut buf = None;
// advances the chunked state
*state = ready!(state.step(cx, body, size, &mut buf))?;
if *state == ChunkedState::End {
trace!("end of chunked");
return Poll::Ready(Ok(Bytes::new()));
}
if let Some(buf) = buf {
return Poll::Ready(Ok(buf));
}
}
}
Eof(ref mut is_eof) => {
if *is_eof {
Poll::Ready(Ok(Bytes::new()))
} else {
// 8192 chosen because its about 2 packets, there probably
// won't be that much available, so don't have MemReaders
// allocate buffers to big
body.read_mem(cx, 8192).map_ok(|slice| {
*is_eof = slice.is_empty();
slice
})
}
}
}
}
#[cfg(test)]
async fn decode_fut<R: MemRead>(&mut self, body: &mut R) -> Result<Bytes, io::Error> {
futures_util::future::poll_fn(move |cx| self.decode(cx, body)).await
}
}
impl fmt::Debug for Decoder {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Debug::fmt(&self.kind, f)
}
}
macro_rules! byte (
($rdr:ident, $cx:expr) => ({
let buf = ready!($rdr.read_mem($cx, 1))?;
if !buf.is_empty() {
buf[0]
} else {
return Poll::Ready(Err(io::Error::new(io::ErrorKind::UnexpectedEof,
"unexpected EOF during chunk size line")));
}
})
);
impl ChunkedState {
fn step<R: MemRead>(
&self,
cx: &mut task::Context<'_>,
body: &mut R,
size: &mut u64,
buf: &mut Option<Bytes>,
) -> Poll<Result<ChunkedState, io::Error>> {
use self::ChunkedState::*;
match *self {
Size => ChunkedState::read_size(cx, body, size),
SizeLws => ChunkedState::read_size_lws(cx, body),
Extension => ChunkedState::read_extension(cx, body),
SizeLf => ChunkedState::read_size_lf(cx, body, *size),
Body => ChunkedState::read_body(cx, body, size, buf),
BodyCr => ChunkedState::read_body_cr(cx, body),
BodyLf => ChunkedState::read_body_lf(cx, body),
Trailer => ChunkedState::read_trailer(cx, body),
TrailerLf => ChunkedState::read_trailer_lf(cx, body),
EndCr => ChunkedState::read_end_cr(cx, body),
EndLf => ChunkedState::read_end_lf(cx, body),
End => Poll::Ready(Ok(ChunkedState::End)),
}
}
fn read_size<R: MemRead>(
cx: &mut task::Context<'_>,
rdr: &mut R,
size: &mut u64,
) -> Poll<Result<ChunkedState, io::Error>> {
trace!("Read chunk hex size");
macro_rules! or_overflow {
($e:expr) => (
match $e {
Some(val) => val,
None => return Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidData,
"invalid chunk size: overflow",
))),
}
)
}
let radix = 16;
match byte!(rdr, cx) {
b @ b'0'..=b'9' => {
*size = or_overflow!(size.checked_mul(radix));
*size = or_overflow!(size.checked_add((b - b'0') as u64));
}
b @ b'a'..=b'f' => {
*size = or_overflow!(size.checked_mul(radix));
*size = or_overflow!(size.checked_add((b + 10 - b'a') as u64));
}
b @ b'A'..=b'F' => {
*size = or_overflow!(size.checked_mul(radix));
*size = or_overflow!(size.checked_add((b + 10 - b'A') as u64));
}
b'\t' | b' ' => return Poll::Ready(Ok(ChunkedState::SizeLws)),
b';' => return Poll::Ready(Ok(ChunkedState::Extension)),
b'\r' => return Poll::Ready(Ok(ChunkedState::SizeLf)),
_ => {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Invalid chunk size line: Invalid Size",
)));
}
}
Poll::Ready(Ok(ChunkedState::Size))
}
fn read_size_lws<R: MemRead>(
cx: &mut task::Context<'_>,
rdr: &mut R,
) -> Poll<Result<ChunkedState, io::Error>> {
trace!("read_size_lws");
match byte!(rdr, cx) {
// LWS can follow the chunk size, but no more digits can come
b'\t' | b' ' => Poll::Ready(Ok(ChunkedState::SizeLws)),
b';' => Poll::Ready(Ok(ChunkedState::Extension)),
b'\r' => Poll::Ready(Ok(ChunkedState::SizeLf)),
_ => Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Invalid chunk size linear white space",
))),
}
}
fn read_extension<R: MemRead>(
cx: &mut task::Context<'_>,
rdr: &mut R,
) -> Poll<Result<ChunkedState, io::Error>> {
trace!("read_extension");
// We don't care about extensions really at all. Just ignore them.
// They "end" at the next CRLF.
//
// However, some implementations may not check for the CR, so to save
// them from themselves, we reject extensions containing plain LF as
// well.
match byte!(rdr, cx) {
b'\r' => Poll::Ready(Ok(ChunkedState::SizeLf)),
b'\n' => Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidData,
"invalid chunk extension contains newline",
))),
_ => Poll::Ready(Ok(ChunkedState::Extension)), // no supported extensions
}
}
fn read_size_lf<R: MemRead>(
cx: &mut task::Context<'_>,
rdr: &mut R,
size: u64,
) -> Poll<Result<ChunkedState, io::Error>> {
trace!("Chunk size is {:?}", size);
match byte!(rdr, cx) {
b'\n' => {
if size == 0 {
Poll::Ready(Ok(ChunkedState::EndCr))
} else {
debug!("incoming chunked header: {0:#X} ({0} bytes)", size);
Poll::Ready(Ok(ChunkedState::Body))
}
}
_ => Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Invalid chunk size LF",
))),
}
}
fn read_body<R: MemRead>(
cx: &mut task::Context<'_>,
rdr: &mut R,
rem: &mut u64,
buf: &mut Option<Bytes>,
) -> Poll<Result<ChunkedState, io::Error>> {
trace!("Chunked read, remaining={:?}", rem);
// cap remaining bytes at the max capacity of usize
let rem_cap = match *rem {
r if r > usize::MAX as u64 => usize::MAX,
r => r as usize,
};
let to_read = rem_cap;
let slice = ready!(rdr.read_mem(cx, to_read))?;
let count = slice.len();
if count == 0 {
*rem = 0;
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
IncompleteBody,
)));
}
*buf = Some(slice);
*rem -= count as u64;
if *rem > 0 {
Poll::Ready(Ok(ChunkedState::Body))
} else {
Poll::Ready(Ok(ChunkedState::BodyCr))
}
}
fn read_body_cr<R: MemRead>(
cx: &mut task::Context<'_>,
rdr: &mut R,
) -> Poll<Result<ChunkedState, io::Error>> {
match byte!(rdr, cx) {
b'\r' => Poll::Ready(Ok(ChunkedState::BodyLf)),
_ => Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Invalid chunk body CR",
))),
}
}
fn read_body_lf<R: MemRead>(
cx: &mut task::Context<'_>,
rdr: &mut R,
) -> Poll<Result<ChunkedState, io::Error>> {
match byte!(rdr, cx) {
b'\n' => Poll::Ready(Ok(ChunkedState::Size)),
_ => Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Invalid chunk body LF",
))),
}
}
fn read_trailer<R: MemRead>(
cx: &mut task::Context<'_>,
rdr: &mut R,
) -> Poll<Result<ChunkedState, io::Error>> {
trace!("read_trailer");
match byte!(rdr, cx) {
b'\r' => Poll::Ready(Ok(ChunkedState::TrailerLf)),
_ => Poll::Ready(Ok(ChunkedState::Trailer)),
}
}
fn read_trailer_lf<R: MemRead>(
cx: &mut task::Context<'_>,
rdr: &mut R,
) -> Poll<Result<ChunkedState, io::Error>> {
match byte!(rdr, cx) {
b'\n' => Poll::Ready(Ok(ChunkedState::EndCr)),
_ => Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Invalid trailer end LF",
))),
}
}
fn read_end_cr<R: MemRead>(
cx: &mut task::Context<'_>,
rdr: &mut R,
) -> Poll<Result<ChunkedState, io::Error>> {
match byte!(rdr, cx) {
b'\r' => Poll::Ready(Ok(ChunkedState::EndLf)),
_ => Poll::Ready(Ok(ChunkedState::Trailer)),
}
}
fn read_end_lf<R: MemRead>(
cx: &mut task::Context<'_>,
rdr: &mut R,
) -> Poll<Result<ChunkedState, io::Error>> {
match byte!(rdr, cx) {
b'\n' => Poll::Ready(Ok(ChunkedState::End)),
_ => Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidInput,
"Invalid chunk end LF",
))),
}
}
}
#[derive(Debug)]
struct IncompleteBody;
impl fmt::Display for IncompleteBody {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "end of file before message length reached")
}
}
impl StdError for IncompleteBody {}
#[cfg(test)]
mod tests {
use super::*;
use std::pin::Pin;
use std::time::Duration;
use tokio::io::{AsyncRead, ReadBuf};
impl<'a> MemRead for &'a [u8] {
fn read_mem(&mut self, _: &mut task::Context<'_>, len: usize) -> Poll<io::Result<Bytes>> {
let n = std::cmp::min(len, self.len());
if n > 0 {
let (a, b) = self.split_at(n);
let buf = Bytes::copy_from_slice(a);
*self = b;
Poll::Ready(Ok(buf))
} else {
Poll::Ready(Ok(Bytes::new()))
}
}
}
impl<'a> MemRead for &'a mut (dyn AsyncRead + Unpin) {
fn read_mem(&mut self, cx: &mut task::Context<'_>, len: usize) -> Poll<io::Result<Bytes>> {
let mut v = vec![0; len];
let mut buf = ReadBuf::new(&mut v);
ready!(Pin::new(self).poll_read(cx, &mut buf)?);
Poll::Ready(Ok(Bytes::copy_from_slice(&buf.filled())))
}
}
#[cfg(feature = "nightly")]
impl MemRead for Bytes {
fn read_mem(&mut self, _: &mut task::Context<'_>, len: usize) -> Poll<io::Result<Bytes>> {
let n = std::cmp::min(len, self.len());
let ret = self.split_to(n);
Poll::Ready(Ok(ret))
}
}
/*
use std::io;
use std::io::Write;
use super::Decoder;
use super::ChunkedState;
use futures::{Async, Poll};
use bytes::{BytesMut, Bytes};
use crate::mock::AsyncIo;
*/
#[tokio::test]
async fn test_read_chunk_size() {
use std::io::ErrorKind::{InvalidData, InvalidInput, UnexpectedEof};
async fn read(s: &str) -> u64 {
let mut state = ChunkedState::Size;
let rdr = &mut s.as_bytes();
let mut size = 0;
loop {
let result =
futures_util::future::poll_fn(|cx| state.step(cx, rdr, &mut size, &mut None))
.await;
let desc = format!("read_size failed for {:?}", s);
state = result.expect(desc.as_str());
if state == ChunkedState::Body || state == ChunkedState::EndCr {
break;
}
}
size
}
async fn read_err(s: &str, expected_err: io::ErrorKind) {
let mut state = ChunkedState::Size;
let rdr = &mut s.as_bytes();
let mut size = 0;
loop {
let result =
futures_util::future::poll_fn(|cx| state.step(cx, rdr, &mut size, &mut None))
.await;
state = match result {
Ok(s) => s,
Err(e) => {
assert!(
expected_err == e.kind(),
"Reading {:?}, expected {:?}, but got {:?}",
s,
expected_err,
e.kind()
);
return;
}
};
if state == ChunkedState::Body || state == ChunkedState::End {
panic!("Was Ok. Expected Err for {:?}", s);
}
}
}
assert_eq!(1, read("1\r\n").await);
assert_eq!(1, read("01\r\n").await);
assert_eq!(0, read("0\r\n").await);
assert_eq!(0, read("00\r\n").await);
assert_eq!(10, read("A\r\n").await);
assert_eq!(10, read("a\r\n").await);
assert_eq!(255, read("Ff\r\n").await);
assert_eq!(255, read("Ff \r\n").await);
// Missing LF or CRLF
read_err("F\rF", InvalidInput).await;
read_err("F", UnexpectedEof).await;
// Invalid hex digit
read_err("X\r\n", InvalidInput).await;
read_err("1X\r\n", InvalidInput).await;
read_err("-\r\n", InvalidInput).await;
read_err("-1\r\n", InvalidInput).await;
// Acceptable (if not fully valid) extensions do not influence the size
assert_eq!(1, read("1;extension\r\n").await);
assert_eq!(10, read("a;ext name=value\r\n").await);
assert_eq!(1, read("1;extension;extension2\r\n").await);
assert_eq!(1, read("1;;; ;\r\n").await);
assert_eq!(2, read("2; extension...\r\n").await);
assert_eq!(3, read("3 ; extension=123\r\n").await);
assert_eq!(3, read("3 ;\r\n").await);
assert_eq!(3, read("3 ; \r\n").await);
// Invalid extensions cause an error
read_err("1 invalid extension\r\n", InvalidInput).await;
read_err("1 A\r\n", InvalidInput).await;
read_err("1;no CRLF", UnexpectedEof).await;
read_err("1;reject\nnewlines\r\n", InvalidData).await;
// Overflow
read_err("f0000000000000003\r\n", InvalidData).await;
}
#[tokio::test]
async fn test_read_sized_early_eof() {
let mut bytes = &b"foo bar"[..];
let mut decoder = Decoder::length(10);
assert_eq!(decoder.decode_fut(&mut bytes).await.unwrap().len(), 7);
let e = decoder.decode_fut(&mut bytes).await.unwrap_err();
assert_eq!(e.kind(), io::ErrorKind::UnexpectedEof);
}
#[tokio::test]
async fn test_read_chunked_early_eof() {
let mut bytes = &b"\
9\r\n\
foo bar\
"[..];
let mut decoder = Decoder::chunked();
assert_eq!(decoder.decode_fut(&mut bytes).await.unwrap().len(), 7);
let e = decoder.decode_fut(&mut bytes).await.unwrap_err();
assert_eq!(e.kind(), io::ErrorKind::UnexpectedEof);
}
#[tokio::test]
async fn test_read_chunked_single_read() {
let mut mock_buf = &b"10\r\n1234567890abcdef\r\n0\r\n"[..];
let buf = Decoder::chunked()
.decode_fut(&mut mock_buf)
.await
.expect("decode");
assert_eq!(16, buf.len());
let result = String::from_utf8(buf.as_ref().to_vec()).expect("decode String");
assert_eq!("1234567890abcdef", &result);
}
#[tokio::test]
async fn test_read_chunked_trailer_with_missing_lf() {
let mut mock_buf = &b"10\r\n1234567890abcdef\r\n0\r\nbad\r\r\n"[..];
let mut decoder = Decoder::chunked();
decoder.decode_fut(&mut mock_buf).await.expect("decode");
let e = decoder.decode_fut(&mut mock_buf).await.unwrap_err();
assert_eq!(e.kind(), io::ErrorKind::InvalidInput);
}
#[tokio::test]
async fn test_read_chunked_after_eof() {
let mut mock_buf = &b"10\r\n1234567890abcdef\r\n0\r\n\r\n"[..];
let mut decoder = Decoder::chunked();
// normal read
let buf = decoder.decode_fut(&mut mock_buf).await.unwrap();
assert_eq!(16, buf.len());
let result = String::from_utf8(buf.as_ref().to_vec()).expect("decode String");
assert_eq!("1234567890abcdef", &result);
// eof read
let buf = decoder.decode_fut(&mut mock_buf).await.expect("decode");
assert_eq!(0, buf.len());
// ensure read after eof also returns eof
let buf = decoder.decode_fut(&mut mock_buf).await.expect("decode");
assert_eq!(0, buf.len());
}
// perform an async read using a custom buffer size and causing a blocking
// read at the specified byte
async fn read_async(mut decoder: Decoder, content: &[u8], block_at: usize) -> String {
let mut outs = Vec::new();
let mut ins = if block_at == 0 {
tokio_test::io::Builder::new()
.wait(Duration::from_millis(10))
.read(content)
.build()
} else {
tokio_test::io::Builder::new()
.read(&content[..block_at])
.wait(Duration::from_millis(10))
.read(&content[block_at..])
.build()
};
let mut ins = &mut ins as &mut (dyn AsyncRead + Unpin);
loop {
let buf = decoder
.decode_fut(&mut ins)
.await
.expect("unexpected decode error");
if buf.is_empty() {
break; // eof
}
outs.extend(buf.as_ref());
}
String::from_utf8(outs).expect("decode String")
}
// iterate over the different ways that this async read could go.
// tests blocking a read at each byte along the content - The shotgun approach
async fn all_async_cases(content: &str, expected: &str, decoder: Decoder) {
let content_len = content.len();
for block_at in 0..content_len {
let actual = read_async(decoder.clone(), content.as_bytes(), block_at).await;
assert_eq!(expected, &actual) //, "Failed async. Blocking at {}", block_at);
}
}
#[tokio::test]
async fn test_read_length_async() {
let content = "foobar";
all_async_cases(content, content, Decoder::length(content.len() as u64)).await;
}
#[tokio::test]
async fn test_read_chunked_async() {
let content = "3\r\nfoo\r\n3\r\nbar\r\n0\r\n\r\n";
let expected = "foobar";
all_async_cases(content, expected, Decoder::chunked()).await;
}
#[tokio::test]
async fn test_read_eof_async() {
let content = "foobar";
all_async_cases(content, content, Decoder::eof()).await;
}
#[cfg(feature = "nightly")]
#[bench]
fn bench_decode_chunked_1kb(b: &mut test::Bencher) {
let rt = new_runtime();
const LEN: usize = 1024;
let mut vec = Vec::new();
vec.extend(format!("{:x}\r\n", LEN).as_bytes());
vec.extend(&[0; LEN][..]);
vec.extend(b"\r\n");
let content = Bytes::from(vec);
b.bytes = LEN as u64;
b.iter(|| {
let mut decoder = Decoder::chunked();
rt.block_on(async {
let mut raw = content.clone();
let chunk = decoder.decode_fut(&mut raw).await.unwrap();
assert_eq!(chunk.len(), LEN);
});
});
}
#[cfg(feature = "nightly")]
#[bench]
fn bench_decode_length_1kb(b: &mut test::Bencher) {
let rt = new_runtime();
const LEN: usize = 1024;
let content = Bytes::from(&[0; LEN][..]);
b.bytes = LEN as u64;
b.iter(|| {
let mut decoder = Decoder::length(LEN as u64);
rt.block_on(async {
let mut raw = content.clone();
let chunk = decoder.decode_fut(&mut raw).await.unwrap();
assert_eq!(chunk.len(), LEN);
});
});
}
#[cfg(feature = "nightly")]
fn new_runtime() -> tokio::runtime::Runtime {
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("rt build")
}
}

View file

@ -0,0 +1,750 @@
use std::error::Error as StdError;
use bytes::{Buf, Bytes};
use http::Request;
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{debug, trace};
use super::{Http1Transaction, Wants};
use crate::body::{Body, DecodedLength, HttpBody};
use crate::common::{task, Future, Pin, Poll, Unpin};
use crate::proto::{
BodyLength, Conn, Dispatched, MessageHead, RequestHead,
};
use crate::upgrade::OnUpgrade;
pub(crate) struct Dispatcher<D, Bs: HttpBody, I, T> {
conn: Conn<I, Bs::Data, T>,
dispatch: D,
body_tx: Option<crate::body::Sender>,
body_rx: Pin<Box<Option<Bs>>>,
is_closing: bool,
}
pub(crate) trait Dispatch {
type PollItem;
type PollBody;
type PollError;
type RecvItem;
fn poll_msg(
self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
) -> Poll<Option<Result<(Self::PollItem, Self::PollBody), Self::PollError>>>;
fn recv_msg(&mut self, msg: crate::Result<(Self::RecvItem, Body)>) -> crate::Result<()>;
fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), ()>>;
fn should_poll(&self) -> bool;
}
cfg_server! {
use crate::service::HttpService;
pub(crate) struct Server<S: HttpService<B>, B> {
in_flight: Pin<Box<Option<S::Future>>>,
pub(crate) service: S,
}
}
cfg_client! {
pin_project_lite::pin_project! {
pub(crate) struct Client<B> {
callback: Option<crate::client::dispatch::Callback<Request<B>, http::Response<Body>>>,
#[pin]
rx: ClientRx<B>,
rx_closed: bool,
}
}
type ClientRx<B> = crate::client::dispatch::Receiver<Request<B>, http::Response<Body>>;
}
impl<D, Bs, I, T> Dispatcher<D, Bs, I, T>
where
D: Dispatch<
PollItem = MessageHead<T::Outgoing>,
PollBody = Bs,
RecvItem = MessageHead<T::Incoming>,
> + Unpin,
D::PollError: Into<Box<dyn StdError + Send + Sync>>,
I: AsyncRead + AsyncWrite + Unpin,
T: Http1Transaction + Unpin,
Bs: HttpBody + 'static,
Bs::Error: Into<Box<dyn StdError + Send + Sync>>,
{
pub(crate) fn new(dispatch: D, conn: Conn<I, Bs::Data, T>) -> Self {
Dispatcher {
conn,
dispatch,
body_tx: None,
body_rx: Box::pin(None),
is_closing: false,
}
}
#[cfg(feature = "server")]
pub(crate) fn disable_keep_alive(&mut self) {
self.conn.disable_keep_alive();
if self.conn.is_write_closed() {
self.close();
}
}
pub(crate) fn into_inner(self) -> (I, Bytes, D) {
let (io, buf) = self.conn.into_inner();
(io, buf, self.dispatch)
}
/// Run this dispatcher until HTTP says this connection is done,
/// but don't call `AsyncWrite::shutdown` on the underlying IO.
///
/// This is useful for old-style HTTP upgrades, but ignores
/// newer-style upgrade API.
pub(crate) fn poll_without_shutdown(
&mut self,
cx: &mut task::Context<'_>,
) -> Poll<crate::Result<()>>
where
Self: Unpin,
{
Pin::new(self).poll_catch(cx, false).map_ok(|ds| {
if let Dispatched::Upgrade(pending) = ds {
pending.manual();
}
})
}
fn poll_catch(
&mut self,
cx: &mut task::Context<'_>,
should_shutdown: bool,
) -> Poll<crate::Result<Dispatched>> {
Poll::Ready(ready!(self.poll_inner(cx, should_shutdown)).or_else(|e| {
// An error means we're shutting down either way.
// We just try to give the error to the user,
// and close the connection with an Ok. If we
// cannot give it to the user, then return the Err.
self.dispatch.recv_msg(Err(e))?;
Ok(Dispatched::Shutdown)
}))
}
fn poll_inner(
&mut self,
cx: &mut task::Context<'_>,
should_shutdown: bool,
) -> Poll<crate::Result<Dispatched>> {
T::update_date();
ready!(self.poll_loop(cx))?;
if self.is_done() {
if let Some(pending) = self.conn.pending_upgrade() {
self.conn.take_error()?;
return Poll::Ready(Ok(Dispatched::Upgrade(pending)));
} else if should_shutdown {
ready!(self.conn.poll_shutdown(cx)).map_err(crate::Error::new_shutdown)?;
}
self.conn.take_error()?;
Poll::Ready(Ok(Dispatched::Shutdown))
} else {
Poll::Pending
}
}
fn poll_loop(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> {
// Limit the looping on this connection, in case it is ready far too
// often, so that other futures don't starve.
//
// 16 was chosen arbitrarily, as that is number of pipelined requests
// benchmarks often use. Perhaps it should be a config option instead.
for _ in 0..16 {
let _ = self.poll_read(cx)?;
let _ = self.poll_write(cx)?;
let _ = self.poll_flush(cx)?;
// This could happen if reading paused before blocking on IO,
// such as getting to the end of a framed message, but then
// writing/flushing set the state back to Init. In that case,
// if the read buffer still had bytes, we'd want to try poll_read
// again, or else we wouldn't ever be woken up again.
//
// Using this instead of task::current() and notify() inside
// the Conn is noticeably faster in pipelined benchmarks.
if !self.conn.wants_read_again() {
//break;
return Poll::Ready(Ok(()));
}
}
trace!("poll_loop yielding (self = {:p})", self);
task::yield_now(cx).map(|never| match never {})
}
fn poll_read(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> {
loop {
if self.is_closing {
return Poll::Ready(Ok(()));
} else if self.conn.can_read_head() {
ready!(self.poll_read_head(cx))?;
} else if let Some(mut body) = self.body_tx.take() {
if self.conn.can_read_body() {
match body.poll_ready(cx) {
Poll::Ready(Ok(())) => (),
Poll::Pending => {
self.body_tx = Some(body);
return Poll::Pending;
}
Poll::Ready(Err(_canceled)) => {
// user doesn't care about the body
// so we should stop reading
trace!("body receiver dropped before eof, draining or closing");
self.conn.poll_drain_or_close_read(cx);
continue;
}
}
match self.conn.poll_read_body(cx) {
Poll::Ready(Some(Ok(chunk))) => match body.try_send_data(chunk) {
Ok(()) => {
self.body_tx = Some(body);
}
Err(_canceled) => {
if self.conn.can_read_body() {
trace!("body receiver dropped before eof, closing");
self.conn.close_read();
}
}
},
Poll::Ready(None) => {
// just drop, the body will close automatically
}
Poll::Pending => {
self.body_tx = Some(body);
return Poll::Pending;
}
Poll::Ready(Some(Err(e))) => {
body.send_error(crate::Error::new_body(e));
}
}
} else {
// just drop, the body will close automatically
}
} else {
return self.conn.poll_read_keep_alive(cx);
}
}
}
fn poll_read_head(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> {
// can dispatch receive, or does it still care about, an incoming message?
match ready!(self.dispatch.poll_ready(cx)) {
Ok(()) => (),
Err(()) => {
trace!("dispatch no longer receiving messages");
self.close();
return Poll::Ready(Ok(()));
}
}
// dispatch is ready for a message, try to read one
match ready!(self.conn.poll_read_head(cx)) {
Some(Ok((mut head, body_len, wants))) => {
let body = match body_len {
DecodedLength::ZERO => Body::empty(),
other => {
let (tx, rx) = Body::new_channel(other, wants.contains(Wants::EXPECT));
self.body_tx = Some(tx);
rx
}
};
if wants.contains(Wants::UPGRADE) {
let upgrade = self.conn.on_upgrade();
debug_assert!(!upgrade.is_none(), "empty upgrade");
debug_assert!(head.extensions.get::<OnUpgrade>().is_none(), "OnUpgrade already set");
head.extensions.insert(upgrade);
}
self.dispatch.recv_msg(Ok((head, body)))?;
Poll::Ready(Ok(()))
}
Some(Err(err)) => {
debug!("read_head error: {}", err);
self.dispatch.recv_msg(Err(err))?;
// if here, the dispatcher gave the user the error
// somewhere else. we still need to shutdown, but
// not as a second error.
self.close();
Poll::Ready(Ok(()))
}
None => {
// read eof, the write side will have been closed too unless
// allow_read_close was set to true, in which case just do
// nothing...
debug_assert!(self.conn.is_read_closed());
if self.conn.is_write_closed() {
self.close();
}
Poll::Ready(Ok(()))
}
}
}
fn poll_write(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> {
loop {
if self.is_closing {
return Poll::Ready(Ok(()));
} else if self.body_rx.is_none()
&& self.conn.can_write_head()
&& self.dispatch.should_poll()
{
if let Some(msg) = ready!(Pin::new(&mut self.dispatch).poll_msg(cx)) {
let (head, mut body) = msg.map_err(crate::Error::new_user_service)?;
// Check if the body knows its full data immediately.
//
// If so, we can skip a bit of bookkeeping that streaming
// bodies need to do.
if let Some(full) = crate::body::take_full_data(&mut body) {
self.conn.write_full_msg(head, full);
return Poll::Ready(Ok(()));
}
let body_type = if body.is_end_stream() {
self.body_rx.set(None);
None
} else {
let btype = body
.size_hint()
.exact()
.map(BodyLength::Known)
.or_else(|| Some(BodyLength::Unknown));
self.body_rx.set(Some(body));
btype
};
self.conn.write_head(head, body_type);
} else {
self.close();
return Poll::Ready(Ok(()));
}
} else if !self.conn.can_buffer_body() {
ready!(self.poll_flush(cx))?;
} else {
// A new scope is needed :(
if let (Some(mut body), clear_body) =
OptGuard::new(self.body_rx.as_mut()).guard_mut()
{
debug_assert!(!*clear_body, "opt guard defaults to keeping body");
if !self.conn.can_write_body() {
trace!(
"no more write body allowed, user body is_end_stream = {}",
body.is_end_stream(),
);
*clear_body = true;
continue;
}
let item = ready!(body.as_mut().poll_data(cx));
if let Some(item) = item {
let chunk = item.map_err(|e| {
*clear_body = true;
crate::Error::new_user_body(e)
})?;
let eos = body.is_end_stream();
if eos {
*clear_body = true;
if chunk.remaining() == 0 {
trace!("discarding empty chunk");
self.conn.end_body()?;
} else {
self.conn.write_body_and_end(chunk);
}
} else {
if chunk.remaining() == 0 {
trace!("discarding empty chunk");
continue;
}
self.conn.write_body(chunk);
}
} else {
*clear_body = true;
self.conn.end_body()?;
}
} else {
return Poll::Pending;
}
}
}
}
fn poll_flush(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> {
self.conn.poll_flush(cx).map_err(|err| {
debug!("error writing: {}", err);
crate::Error::new_body_write(err)
})
}
fn close(&mut self) {
self.is_closing = true;
self.conn.close_read();
self.conn.close_write();
}
fn is_done(&self) -> bool {
if self.is_closing {
return true;
}
let read_done = self.conn.is_read_closed();
if !T::should_read_first() && read_done {
// a client that cannot read may was well be done.
true
} else {
let write_done = self.conn.is_write_closed()
|| (!self.dispatch.should_poll() && self.body_rx.is_none());
read_done && write_done
}
}
}
impl<D, Bs, I, T> Future for Dispatcher<D, Bs, I, T>
where
D: Dispatch<
PollItem = MessageHead<T::Outgoing>,
PollBody = Bs,
RecvItem = MessageHead<T::Incoming>,
> + Unpin,
D::PollError: Into<Box<dyn StdError + Send + Sync>>,
I: AsyncRead + AsyncWrite + Unpin,
T: Http1Transaction + Unpin,
Bs: HttpBody + 'static,
Bs::Error: Into<Box<dyn StdError + Send + Sync>>,
{
type Output = crate::Result<Dispatched>;
#[inline]
fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
self.poll_catch(cx, true)
}
}
// ===== impl OptGuard =====
/// A drop guard to allow a mutable borrow of an Option while being able to
/// set whether the `Option` should be cleared on drop.
struct OptGuard<'a, T>(Pin<&'a mut Option<T>>, bool);
impl<'a, T> OptGuard<'a, T> {
fn new(pin: Pin<&'a mut Option<T>>) -> Self {
OptGuard(pin, false)
}
fn guard_mut(&mut self) -> (Option<Pin<&mut T>>, &mut bool) {
(self.0.as_mut().as_pin_mut(), &mut self.1)
}
}
impl<'a, T> Drop for OptGuard<'a, T> {
fn drop(&mut self) {
if self.1 {
self.0.set(None);
}
}
}
// ===== impl Server =====
cfg_server! {
impl<S, B> Server<S, B>
where
S: HttpService<B>,
{
pub(crate) fn new(service: S) -> Server<S, B> {
Server {
in_flight: Box::pin(None),
service,
}
}
pub(crate) fn into_service(self) -> S {
self.service
}
}
// Service is never pinned
impl<S: HttpService<B>, B> Unpin for Server<S, B> {}
impl<S, Bs> Dispatch for Server<S, Body>
where
S: HttpService<Body, ResBody = Bs>,
S::Error: Into<Box<dyn StdError + Send + Sync>>,
Bs: HttpBody,
{
type PollItem = MessageHead<http::StatusCode>;
type PollBody = Bs;
type PollError = S::Error;
type RecvItem = RequestHead;
fn poll_msg(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
) -> Poll<Option<Result<(Self::PollItem, Self::PollBody), Self::PollError>>> {
let mut this = self.as_mut();
let ret = if let Some(ref mut fut) = this.in_flight.as_mut().as_pin_mut() {
let resp = ready!(fut.as_mut().poll(cx)?);
let (parts, body) = resp.into_parts();
let head = MessageHead {
version: parts.version,
subject: parts.status,
headers: parts.headers,
extensions: parts.extensions,
};
Poll::Ready(Some(Ok((head, body))))
} else {
unreachable!("poll_msg shouldn't be called if no inflight");
};
// Since in_flight finished, remove it
this.in_flight.set(None);
ret
}
fn recv_msg(&mut self, msg: crate::Result<(Self::RecvItem, Body)>) -> crate::Result<()> {
let (msg, body) = msg?;
let mut req = Request::new(body);
*req.method_mut() = msg.subject.0;
*req.uri_mut() = msg.subject.1;
*req.headers_mut() = msg.headers;
*req.version_mut() = msg.version;
*req.extensions_mut() = msg.extensions;
let fut = self.service.call(req);
self.in_flight.set(Some(fut));
Ok(())
}
fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), ()>> {
if self.in_flight.is_some() {
Poll::Pending
} else {
self.service.poll_ready(cx).map_err(|_e| {
// FIXME: return error value.
trace!("service closed");
})
}
}
fn should_poll(&self) -> bool {
self.in_flight.is_some()
}
}
}
// ===== impl Client =====
cfg_client! {
impl<B> Client<B> {
pub(crate) fn new(rx: ClientRx<B>) -> Client<B> {
Client {
callback: None,
rx,
rx_closed: false,
}
}
}
impl<B> Dispatch for Client<B>
where
B: HttpBody,
{
type PollItem = RequestHead;
type PollBody = B;
type PollError = crate::common::Never;
type RecvItem = crate::proto::ResponseHead;
fn poll_msg(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
) -> Poll<Option<Result<(Self::PollItem, Self::PollBody), crate::common::Never>>> {
let mut this = self.as_mut();
debug_assert!(!this.rx_closed);
match this.rx.poll_recv(cx) {
Poll::Ready(Some((req, mut cb))) => {
// check that future hasn't been canceled already
match cb.poll_canceled(cx) {
Poll::Ready(()) => {
trace!("request canceled");
Poll::Ready(None)
}
Poll::Pending => {
let (parts, body) = req.into_parts();
let head = RequestHead {
version: parts.version,
subject: crate::proto::RequestLine(parts.method, parts.uri),
headers: parts.headers,
extensions: parts.extensions,
};
this.callback = Some(cb);
Poll::Ready(Some(Ok((head, body))))
}
}
}
Poll::Ready(None) => {
// user has dropped sender handle
trace!("client tx closed");
this.rx_closed = true;
Poll::Ready(None)
}
Poll::Pending => Poll::Pending,
}
}
fn recv_msg(&mut self, msg: crate::Result<(Self::RecvItem, Body)>) -> crate::Result<()> {
match msg {
Ok((msg, body)) => {
if let Some(cb) = self.callback.take() {
let res = msg.into_response(body);
cb.send(Ok(res));
Ok(())
} else {
// Getting here is likely a bug! An error should have happened
// in Conn::require_empty_read() before ever parsing a
// full message!
Err(crate::Error::new_unexpected_message())
}
}
Err(err) => {
if let Some(cb) = self.callback.take() {
cb.send(Err((err, None)));
Ok(())
} else if !self.rx_closed {
self.rx.close();
if let Some((req, cb)) = self.rx.try_recv() {
trace!("canceling queued request with connection error: {}", err);
// in this case, the message was never even started, so it's safe to tell
// the user that the request was completely canceled
cb.send(Err((crate::Error::new_canceled().with(err), Some(req))));
Ok(())
} else {
Err(err)
}
} else {
Err(err)
}
}
}
}
fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), ()>> {
match self.callback {
Some(ref mut cb) => match cb.poll_canceled(cx) {
Poll::Ready(()) => {
trace!("callback receiver has dropped");
Poll::Ready(Err(()))
}
Poll::Pending => Poll::Ready(Ok(())),
},
None => Poll::Ready(Err(())),
}
}
fn should_poll(&self) -> bool {
self.callback.is_none()
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::proto::h1::ClientTransaction;
use std::time::Duration;
#[test]
fn client_read_bytes_before_writing_request() {
let _ = pretty_env_logger::try_init();
tokio_test::task::spawn(()).enter(|cx, _| {
let (io, mut handle) = tokio_test::io::Builder::new().build_with_handle();
// Block at 0 for now, but we will release this response before
// the request is ready to write later...
let (mut tx, rx) = crate::client::dispatch::channel();
let conn = Conn::<_, bytes::Bytes, ClientTransaction>::new(io);
let mut dispatcher = Dispatcher::new(Client::new(rx), conn);
// First poll is needed to allow tx to send...
assert!(Pin::new(&mut dispatcher).poll(cx).is_pending());
// Unblock our IO, which has a response before we've sent request!
//
handle.read(b"HTTP/1.1 200 OK\r\n\r\n");
let mut res_rx = tx
.try_send(crate::Request::new(crate::Body::empty()))
.unwrap();
tokio_test::assert_ready_ok!(Pin::new(&mut dispatcher).poll(cx));
let err = tokio_test::assert_ready_ok!(Pin::new(&mut res_rx).poll(cx))
.expect_err("callback should send error");
match (err.0.kind(), err.1) {
(&crate::error::Kind::Canceled, Some(_)) => (),
other => panic!("expected Canceled, got {:?}", other),
}
});
}
#[tokio::test]
async fn client_flushing_is_not_ready_for_next_request() {
let _ = pretty_env_logger::try_init();
let (io, _handle) = tokio_test::io::Builder::new()
.write(b"POST / HTTP/1.1\r\ncontent-length: 4\r\n\r\n")
.read(b"HTTP/1.1 200 OK\r\ncontent-length: 0\r\n\r\n")
.wait(std::time::Duration::from_secs(2))
.build_with_handle();
let (mut tx, rx) = crate::client::dispatch::channel();
let mut conn = Conn::<_, bytes::Bytes, ClientTransaction>::new(io);
conn.set_write_strategy_queue();
let dispatcher = Dispatcher::new(Client::new(rx), conn);
let _dispatcher = tokio::spawn(async move { dispatcher.await });
let req = crate::Request::builder()
.method("POST")
.body(crate::Body::from("reee"))
.unwrap();
let res = tx.try_send(req).unwrap().await.expect("response");
drop(res);
assert!(!tx.is_ready());
}
#[tokio::test]
async fn body_empty_chunks_ignored() {
let _ = pretty_env_logger::try_init();
let io = tokio_test::io::Builder::new()
// no reading or writing, just be blocked for the test...
.wait(Duration::from_secs(5))
.build();
let (mut tx, rx) = crate::client::dispatch::channel();
let conn = Conn::<_, bytes::Bytes, ClientTransaction>::new(io);
let mut dispatcher = tokio_test::task::spawn(Dispatcher::new(Client::new(rx), conn));
// First poll is needed to allow tx to send...
assert!(dispatcher.poll().is_pending());
let body = {
let (mut tx, body) = crate::Body::channel();
tx.try_send_data("".into()).unwrap();
body
};
let _res_rx = tx.try_send(crate::Request::new(body)).unwrap();
// Ensure conn.write_body wasn't called with the empty chunk.
// If it is, it will trigger an assertion.
assert!(dispatcher.poll().is_pending());
}
}

View file

@ -0,0 +1,439 @@
use std::fmt;
use std::io::IoSlice;
use bytes::buf::{Chain, Take};
use bytes::Buf;
use tracing::trace;
use super::io::WriteBuf;
type StaticBuf = &'static [u8];
/// Encoders to handle different Transfer-Encodings.
#[derive(Debug, Clone, PartialEq)]
pub(crate) struct Encoder {
kind: Kind,
is_last: bool,
}
#[derive(Debug)]
pub(crate) struct EncodedBuf<B> {
kind: BufKind<B>,
}
#[derive(Debug)]
pub(crate) struct NotEof(u64);
#[derive(Debug, PartialEq, Clone)]
enum Kind {
/// An Encoder for when Transfer-Encoding includes `chunked`.
Chunked,
/// An Encoder for when Content-Length is set.
///
/// Enforces that the body is not longer than the Content-Length header.
Length(u64),
/// An Encoder for when neither Content-Length nor Chunked encoding is set.
///
/// This is mostly only used with HTTP/1.0 with a length. This kind requires
/// the connection to be closed when the body is finished.
#[cfg(feature = "server")]
CloseDelimited,
}
#[derive(Debug)]
enum BufKind<B> {
Exact(B),
Limited(Take<B>),
Chunked(Chain<Chain<ChunkSize, B>, StaticBuf>),
ChunkedEnd(StaticBuf),
}
impl Encoder {
fn new(kind: Kind) -> Encoder {
Encoder {
kind,
is_last: false,
}
}
pub(crate) fn chunked() -> Encoder {
Encoder::new(Kind::Chunked)
}
pub(crate) fn length(len: u64) -> Encoder {
Encoder::new(Kind::Length(len))
}
#[cfg(feature = "server")]
pub(crate) fn close_delimited() -> Encoder {
Encoder::new(Kind::CloseDelimited)
}
pub(crate) fn is_eof(&self) -> bool {
matches!(self.kind, Kind::Length(0))
}
#[cfg(feature = "server")]
pub(crate) fn set_last(mut self, is_last: bool) -> Self {
self.is_last = is_last;
self
}
pub(crate) fn is_last(&self) -> bool {
self.is_last
}
pub(crate) fn is_close_delimited(&self) -> bool {
match self.kind {
#[cfg(feature = "server")]
Kind::CloseDelimited => true,
_ => false,
}
}
pub(crate) fn end<B>(&self) -> Result<Option<EncodedBuf<B>>, NotEof> {
match self.kind {
Kind::Length(0) => Ok(None),
Kind::Chunked => Ok(Some(EncodedBuf {
kind: BufKind::ChunkedEnd(b"0\r\n\r\n"),
})),
#[cfg(feature = "server")]
Kind::CloseDelimited => Ok(None),
Kind::Length(n) => Err(NotEof(n)),
}
}
pub(crate) fn encode<B>(&mut self, msg: B) -> EncodedBuf<B>
where
B: Buf,
{
let len = msg.remaining();
debug_assert!(len > 0, "encode() called with empty buf");
let kind = match self.kind {
Kind::Chunked => {
trace!("encoding chunked {}B", len);
let buf = ChunkSize::new(len)
.chain(msg)
.chain(b"\r\n" as &'static [u8]);
BufKind::Chunked(buf)
}
Kind::Length(ref mut remaining) => {
trace!("sized write, len = {}", len);
if len as u64 > *remaining {
let limit = *remaining as usize;
*remaining = 0;
BufKind::Limited(msg.take(limit))
} else {
*remaining -= len as u64;
BufKind::Exact(msg)
}
}
#[cfg(feature = "server")]
Kind::CloseDelimited => {
trace!("close delimited write {}B", len);
BufKind::Exact(msg)
}
};
EncodedBuf { kind }
}
pub(super) fn encode_and_end<B>(&self, msg: B, dst: &mut WriteBuf<EncodedBuf<B>>) -> bool
where
B: Buf,
{
let len = msg.remaining();
debug_assert!(len > 0, "encode() called with empty buf");
match self.kind {
Kind::Chunked => {
trace!("encoding chunked {}B", len);
let buf = ChunkSize::new(len)
.chain(msg)
.chain(b"\r\n0\r\n\r\n" as &'static [u8]);
dst.buffer(buf);
!self.is_last
}
Kind::Length(remaining) => {
use std::cmp::Ordering;
trace!("sized write, len = {}", len);
match (len as u64).cmp(&remaining) {
Ordering::Equal => {
dst.buffer(msg);
!self.is_last
}
Ordering::Greater => {
dst.buffer(msg.take(remaining as usize));
!self.is_last
}
Ordering::Less => {
dst.buffer(msg);
false
}
}
}
#[cfg(feature = "server")]
Kind::CloseDelimited => {
trace!("close delimited write {}B", len);
dst.buffer(msg);
false
}
}
}
/// Encodes the full body, without verifying the remaining length matches.
///
/// This is used in conjunction with HttpBody::__hyper_full_data(), which
/// means we can trust that the buf has the correct size (the buf itself
/// was checked to make the headers).
pub(super) fn danger_full_buf<B>(self, msg: B, dst: &mut WriteBuf<EncodedBuf<B>>)
where
B: Buf,
{
debug_assert!(msg.remaining() > 0, "encode() called with empty buf");
debug_assert!(
match self.kind {
Kind::Length(len) => len == msg.remaining() as u64,
_ => true,
},
"danger_full_buf length mismatches"
);
match self.kind {
Kind::Chunked => {
let len = msg.remaining();
trace!("encoding chunked {}B", len);
let buf = ChunkSize::new(len)
.chain(msg)
.chain(b"\r\n0\r\n\r\n" as &'static [u8]);
dst.buffer(buf);
}
_ => {
dst.buffer(msg);
}
}
}
}
impl<B> Buf for EncodedBuf<B>
where
B: Buf,
{
#[inline]
fn remaining(&self) -> usize {
match self.kind {
BufKind::Exact(ref b) => b.remaining(),
BufKind::Limited(ref b) => b.remaining(),
BufKind::Chunked(ref b) => b.remaining(),
BufKind::ChunkedEnd(ref b) => b.remaining(),
}
}
#[inline]
fn chunk(&self) -> &[u8] {
match self.kind {
BufKind::Exact(ref b) => b.chunk(),
BufKind::Limited(ref b) => b.chunk(),
BufKind::Chunked(ref b) => b.chunk(),
BufKind::ChunkedEnd(ref b) => b.chunk(),
}
}
#[inline]
fn advance(&mut self, cnt: usize) {
match self.kind {
BufKind::Exact(ref mut b) => b.advance(cnt),
BufKind::Limited(ref mut b) => b.advance(cnt),
BufKind::Chunked(ref mut b) => b.advance(cnt),
BufKind::ChunkedEnd(ref mut b) => b.advance(cnt),
}
}
#[inline]
fn chunks_vectored<'t>(&'t self, dst: &mut [IoSlice<'t>]) -> usize {
match self.kind {
BufKind::Exact(ref b) => b.chunks_vectored(dst),
BufKind::Limited(ref b) => b.chunks_vectored(dst),
BufKind::Chunked(ref b) => b.chunks_vectored(dst),
BufKind::ChunkedEnd(ref b) => b.chunks_vectored(dst),
}
}
}
#[cfg(target_pointer_width = "32")]
const USIZE_BYTES: usize = 4;
#[cfg(target_pointer_width = "64")]
const USIZE_BYTES: usize = 8;
// each byte will become 2 hex
const CHUNK_SIZE_MAX_BYTES: usize = USIZE_BYTES * 2;
#[derive(Clone, Copy)]
struct ChunkSize {
bytes: [u8; CHUNK_SIZE_MAX_BYTES + 2],
pos: u8,
len: u8,
}
impl ChunkSize {
fn new(len: usize) -> ChunkSize {
use std::fmt::Write;
let mut size = ChunkSize {
bytes: [0; CHUNK_SIZE_MAX_BYTES + 2],
pos: 0,
len: 0,
};
write!(&mut size, "{:X}\r\n", len).expect("CHUNK_SIZE_MAX_BYTES should fit any usize");
size
}
}
impl Buf for ChunkSize {
#[inline]
fn remaining(&self) -> usize {
(self.len - self.pos).into()
}
#[inline]
fn chunk(&self) -> &[u8] {
&self.bytes[self.pos.into()..self.len.into()]
}
#[inline]
fn advance(&mut self, cnt: usize) {
assert!(cnt <= self.remaining());
self.pos += cnt as u8; // just asserted cnt fits in u8
}
}
impl fmt::Debug for ChunkSize {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ChunkSize")
.field("bytes", &&self.bytes[..self.len.into()])
.field("pos", &self.pos)
.finish()
}
}
impl fmt::Write for ChunkSize {
fn write_str(&mut self, num: &str) -> fmt::Result {
use std::io::Write;
(&mut self.bytes[self.len.into()..])
.write_all(num.as_bytes())
.expect("&mut [u8].write() cannot error");
self.len += num.len() as u8; // safe because bytes is never bigger than 256
Ok(())
}
}
impl<B: Buf> From<B> for EncodedBuf<B> {
fn from(buf: B) -> Self {
EncodedBuf {
kind: BufKind::Exact(buf),
}
}
}
impl<B: Buf> From<Take<B>> for EncodedBuf<B> {
fn from(buf: Take<B>) -> Self {
EncodedBuf {
kind: BufKind::Limited(buf),
}
}
}
impl<B: Buf> From<Chain<Chain<ChunkSize, B>, StaticBuf>> for EncodedBuf<B> {
fn from(buf: Chain<Chain<ChunkSize, B>, StaticBuf>) -> Self {
EncodedBuf {
kind: BufKind::Chunked(buf),
}
}
}
impl fmt::Display for NotEof {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "early end, expected {} more bytes", self.0)
}
}
impl std::error::Error for NotEof {}
#[cfg(test)]
mod tests {
use bytes::BufMut;
use super::super::io::Cursor;
use super::Encoder;
#[test]
fn chunked() {
let mut encoder = Encoder::chunked();
let mut dst = Vec::new();
let msg1 = b"foo bar".as_ref();
let buf1 = encoder.encode(msg1);
dst.put(buf1);
assert_eq!(dst, b"7\r\nfoo bar\r\n");
let msg2 = b"baz quux herp".as_ref();
let buf2 = encoder.encode(msg2);
dst.put(buf2);
assert_eq!(dst, b"7\r\nfoo bar\r\nD\r\nbaz quux herp\r\n");
let end = encoder.end::<Cursor<Vec<u8>>>().unwrap().unwrap();
dst.put(end);
assert_eq!(
dst,
b"7\r\nfoo bar\r\nD\r\nbaz quux herp\r\n0\r\n\r\n".as_ref()
);
}
#[test]
fn length() {
let max_len = 8;
let mut encoder = Encoder::length(max_len as u64);
let mut dst = Vec::new();
let msg1 = b"foo bar".as_ref();
let buf1 = encoder.encode(msg1);
dst.put(buf1);
assert_eq!(dst, b"foo bar");
assert!(!encoder.is_eof());
encoder.end::<()>().unwrap_err();
let msg2 = b"baz".as_ref();
let buf2 = encoder.encode(msg2);
dst.put(buf2);
assert_eq!(dst.len(), max_len);
assert_eq!(dst, b"foo barb");
assert!(encoder.is_eof());
assert!(encoder.end::<()>().unwrap().is_none());
}
#[test]
fn eof() {
let mut encoder = Encoder::close_delimited();
let mut dst = Vec::new();
let msg1 = b"foo bar".as_ref();
let buf1 = encoder.encode(msg1);
dst.put(buf1);
assert_eq!(dst, b"foo bar");
assert!(!encoder.is_eof());
encoder.end::<()>().unwrap();
let msg2 = b"baz".as_ref();
let buf2 = encoder.encode(msg2);
dst.put(buf2);
assert_eq!(dst, b"foo barbaz");
assert!(!encoder.is_eof());
encoder.end::<()>().unwrap();
}
}

1002
hyper/src/proto/h1/io.rs Normal file

File diff suppressed because it is too large Load diff

122
hyper/src/proto/h1/mod.rs Normal file
View file

@ -0,0 +1,122 @@
#[cfg(all(feature = "server", feature = "runtime"))]
use std::{pin::Pin, time::Duration};
use bytes::BytesMut;
use http::{HeaderMap, Method};
use httparse::ParserConfig;
#[cfg(all(feature = "server", feature = "runtime"))]
use tokio::time::Sleep;
use crate::body::DecodedLength;
use crate::proto::{BodyLength, MessageHead};
pub(crate) use self::conn::Conn;
pub(crate) use self::decode::Decoder;
pub(crate) use self::dispatch::Dispatcher;
pub(crate) use self::encode::{EncodedBuf, Encoder};
//TODO: move out of h1::io
pub(crate) use self::io::MINIMUM_MAX_BUFFER_SIZE;
mod conn;
mod decode;
pub(crate) mod dispatch;
mod encode;
mod io;
mod role;
cfg_client! {
pub(crate) type ClientTransaction = role::Client;
}
cfg_server! {
pub(crate) type ServerTransaction = role::Server;
}
pub(crate) trait Http1Transaction {
type Incoming;
type Outgoing: Default;
const LOG: &'static str;
fn parse(bytes: &mut BytesMut, ctx: ParseContext<'_>) -> ParseResult<Self::Incoming>;
fn encode(enc: Encode<'_, Self::Outgoing>, dst: &mut Vec<u8>) -> crate::Result<Encoder>;
fn on_error(err: &crate::Error) -> Option<MessageHead<Self::Outgoing>>;
fn is_client() -> bool {
!Self::is_server()
}
fn is_server() -> bool {
!Self::is_client()
}
fn should_error_on_parse_eof() -> bool {
Self::is_client()
}
fn should_read_first() -> bool {
Self::is_server()
}
fn update_date() {}
}
/// Result newtype for Http1Transaction::parse.
pub(crate) type ParseResult<T> = Result<Option<ParsedMessage<T>>, crate::error::Parse>;
#[derive(Debug)]
pub(crate) struct ParsedMessage<T> {
head: MessageHead<T>,
decode: DecodedLength,
expect_continue: bool,
keep_alive: bool,
wants_upgrade: bool,
}
pub(crate) struct ParseContext<'a> {
cached_headers: &'a mut Option<HeaderMap>,
req_method: &'a mut Option<Method>,
h1_parser_config: ParserConfig,
#[cfg(all(feature = "server", feature = "runtime"))]
h1_header_read_timeout: Option<Duration>,
#[cfg(all(feature = "server", feature = "runtime"))]
h1_header_read_timeout_fut: &'a mut Option<Pin<Box<Sleep>>>,
#[cfg(all(feature = "server", feature = "runtime"))]
h1_header_read_timeout_running: &'a mut bool,
preserve_header_case: bool,
#[cfg(feature = "ffi")]
preserve_header_order: bool,
h09_responses: bool,
#[cfg(feature = "ffi")]
on_informational: &'a mut Option<crate::ffi::OnInformational>,
#[cfg(feature = "ffi")]
raw_headers: bool,
}
/// Passed to Http1Transaction::encode
pub(crate) struct Encode<'a, T> {
head: &'a mut MessageHead<T>,
body: Option<BodyLength>,
#[cfg(feature = "server")]
keep_alive: bool,
req_method: &'a mut Option<Method>,
title_case_headers: bool,
}
/// Extra flags that a request "wants", like expect-continue or upgrades.
#[derive(Clone, Copy, Debug)]
struct Wants(u8);
impl Wants {
const EMPTY: Wants = Wants(0b00);
const EXPECT: Wants = Wants(0b01);
const UPGRADE: Wants = Wants(0b10);
#[must_use]
fn add(self, other: Wants) -> Wants {
Wants(self.0 | other.0)
}
fn contains(&self, other: Wants) -> bool {
(self.0 & other.0) == other.0
}
}

2847
hyper/src/proto/h1/role.rs Normal file

File diff suppressed because it is too large Load diff