more things

This commit is contained in:
nora 2022-03-04 23:05:50 +01:00
parent 5d127eceee
commit b6355f5e35
11 changed files with 129 additions and 122 deletions

View file

@ -53,17 +53,18 @@ pub struct ConnectionInner {
pub global_data: GlobalData, pub global_data: GlobalData,
pub channels: Mutex<HashMap<ChannelNum, Channel>>, pub channels: Mutex<HashMap<ChannelNum, Channel>>,
pub exclusive_queues: Vec<Queue>, pub exclusive_queues: Vec<Queue>,
_events: ConEventSender, pub event_sender: ConEventSender,
} }
#[derive(Debug)] #[derive(Debug)]
pub enum QueuedMethod { pub enum ConnectionEvent {
Normal(Method), Shutdown,
WithContent(Method, ContentHeader, SmallVec<[Bytes; 1]>), Method(ChannelNum, Box<Method>),
MethodContent(ChannelNum, Box<Method>, ContentHeader, SmallVec<[Bytes; 1]>),
} }
pub type ConEventSender = mpsc::Sender<(ChannelNum, QueuedMethod)>; pub type ConEventSender = mpsc::Sender<ConnectionEvent>;
pub type ConEventReceiver = mpsc::Receiver<(ChannelNum, QueuedMethod)>; pub type ConEventReceiver = mpsc::Receiver<ConnectionEvent>;
impl ConnectionInner { impl ConnectionInner {
#[must_use] #[must_use]
@ -71,7 +72,7 @@ impl ConnectionInner {
id: ConnectionId, id: ConnectionId,
peer_addr: SocketAddr, peer_addr: SocketAddr,
global_data: GlobalData, global_data: GlobalData,
method_queue: ConEventSender, event_sender: ConEventSender,
) -> Connection { ) -> Connection {
Arc::new(Self { Arc::new(Self {
id, id,
@ -79,7 +80,7 @@ impl ConnectionInner {
global_data, global_data,
channels: Mutex::new(HashMap::new()), channels: Mutex::new(HashMap::new()),
exclusive_queues: vec![], exclusive_queues: vec![],
_events: method_queue, event_sender,
}) })
} }
@ -97,7 +98,7 @@ pub struct ChannelInner {
pub num: ChannelNum, pub num: ChannelNum,
pub connection: Connection, pub connection: Connection,
pub global_data: GlobalData, pub global_data: GlobalData,
method_queue: ConEventSender, pub event_sender: ConEventSender,
} }
impl ChannelInner { impl ChannelInner {
@ -114,7 +115,7 @@ impl ChannelInner {
num, num,
connection, connection,
global_data, global_data,
method_queue, event_sender: method_queue,
}) })
} }
@ -122,13 +123,6 @@ impl ChannelInner {
let mut global_data = self.global_data.lock(); let mut global_data = self.global_data.lock();
global_data.channels.remove(&self.id); global_data.channels.remove(&self.id);
} }
pub fn queue_method(&self, method: QueuedMethod) {
// todo: this is a horrible hack around the lock chaos
self.method_queue
.try_send((self.num, method))
.expect("could not send method to channel, RIP");
}
} }
/// A content frame header. /// A content frame header.

View file

@ -1,7 +1,7 @@
use crate::Result; use crate::Result;
use amqp_core::{ use amqp_core::{
amqp_todo, amqp_todo,
connection::{Channel, QueuedMethod}, connection::{Channel, ConnectionEvent},
error::ChannelException, error::ChannelException,
message::Message, message::Message,
methods::{BasicPublish, Method}, methods::{BasicPublish, Method},
@ -31,19 +31,24 @@ pub async fn publish(channel_handle: Channel, message: Message) -> Result<()> {
// consuming is hard, but this should work *for now* // consuming is hard, but this should work *for now*
let consumers = queue.consumers.lock(); let consumers = queue.consumers.lock();
if let Some(consumer) = consumers.first() { if let Some(consumer) = consumers.first() {
let method = Method::BasicPublish(BasicPublish { let method = Box::new(Method::BasicPublish(BasicPublish {
reserved_1: 0, reserved_1: 0,
exchange: routing.exchange.clone(), exchange: routing.exchange.clone(),
routing_key: routing.routing_key.clone(), routing_key: routing.routing_key.clone(),
mandatory: false, mandatory: false,
immediate: false, immediate: false,
}); }));
consumer.channel.queue_method(QueuedMethod::WithContent( consumer
method, .channel
message.header.clone(), .event_sender
message.content.clone(), .try_send(ConnectionEvent::MethodContent(
)); consumer.channel.num,
method,
message.header.clone(),
message.content.clone(),
))
.unwrap();
} }
} }

View file

@ -8,17 +8,17 @@ use amqp_core::{
amqp_todo, amqp_todo,
connection::{ connection::{
Channel, ChannelInner, ChannelNum, ConEventReceiver, ConEventSender, Connection, Channel, ChannelInner, ChannelNum, ConEventReceiver, ConEventSender, Connection,
ConnectionId, ContentHeader, QueuedMethod, ConnectionEvent, ConnectionId, ContentHeader,
}, },
message::{MessageId, RawMessage, RoutingInformation}, message::{MessageId, RawMessage, RoutingInformation},
methods::{ methods::{
BasicPublish, ChannelClose, ChannelCloseOk, ChannelOpenOk, ConnectionClose, BasicPublish, ChannelClose, ChannelCloseOk, ChannelOpenOk, ConnectionClose,
ConnectionCloseOk, ConnectionOpen, ConnectionOpenOk, ConnectionStart, ConnectionStartOk, ConnectionCloseOk, ConnectionOpen, ConnectionOpenOk, ConnectionStart, ConnectionStartOk,
ConnectionTune, ConnectionTuneOk, FieldValue, Method, Table, ConnectionTune, ConnectionTuneOk, FieldValue, Method, ReplyCode, ReplyText, Table,
}, },
GlobalData, GlobalData,
}; };
use anyhow::Context; use anyhow::{anyhow, Context};
use bytes::Bytes; use bytes::Bytes;
use smallvec::SmallVec; use smallvec::SmallVec;
use std::{ use std::{
@ -64,9 +64,10 @@ pub struct TransportConnection {
channels: HashMap<ChannelNum, TransportChannel>, channels: HashMap<ChannelNum, TransportChannel>,
global_con: Connection, global_con: Connection,
global_data: GlobalData, global_data: GlobalData,
/// Only here to forward to other futures so they can send events
method_queue_send: ConEventSender, event_sender: ConEventSender,
method_queue_recv: ConEventReceiver, /// To receive events from other futures
event_receiver: ConEventReceiver,
} }
const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30); const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
@ -102,8 +103,8 @@ impl TransportConnection {
global_con, global_con,
channels: HashMap::with_capacity(4), channels: HashMap::with_capacity(4),
global_data, global_data,
method_queue_send, event_sender: method_queue_send,
method_queue_recv, event_receiver: method_queue_recv,
} }
} }
@ -117,27 +118,12 @@ impl TransportConnection {
} }
Err(TransError::Protocol(ProtocolError::ConException(ex))) => { Err(TransError::Protocol(ProtocolError::ConException(ex))) => {
warn!(%ex, "Connection exception occurred. This indicates a faulty client."); warn!(%ex, "Connection exception occurred. This indicates a faulty client.");
if let Err(err) = self let close_result = self.close(ex.reply_code(), ex.reply_text()).await;
.send_method(
ChannelNum::zero(), match close_result {
Method::ConnectionClose(ConnectionClose { Ok(()) => {}
reply_code: ex.reply_code(),
reply_text: ex.reply_text(),
class_id: 0, // todo: do this
method_id: 0,
}),
)
.await
{
error!(%ex, %err, "Failed to close connection after ConnectionException");
}
match self.recv_method().await {
Ok(Method::ConnectionCloseOk(_)) => {}
Ok(method) => {
error!(%ex, ?method, "Received wrong method after ConnectionException")
}
Err(err) => { Err(err) => {
error!(%ex, %err, "Failed to receive Connection.CloseOk method after ConnectionException") error!(%ex, %err, "Failed to close connection after ConnectionException");
} }
} }
} }
@ -161,7 +147,7 @@ impl TransportConnection {
async fn send_method_content( async fn send_method_content(
&mut self, &mut self,
channel: ChannelNum, channel: ChannelNum,
method: Method, method: &Method,
_header: ContentHeader, _header: ContentHeader,
_body: SmallVec<[Bytes; 1]>, _body: SmallVec<[Bytes; 1]>,
) -> Result<()> { ) -> Result<()> {
@ -169,7 +155,7 @@ impl TransportConnection {
amqp_todo!() amqp_todo!()
} }
async fn send_method(&mut self, channel: ChannelNum, method: Method) -> Result<()> { async fn send_method(&mut self, channel: ChannelNum, method: &Method) -> Result<()> {
trace!(%channel, ?method, "Sending method"); trace!(%channel, ?method, "Sending method");
let mut payload = Vec::with_capacity(64); let mut payload = Vec::with_capacity(64);
@ -208,7 +194,7 @@ impl TransportConnection {
}); });
debug!(?start_method, "Sending Start method"); debug!(?start_method, "Sending Start method");
self.send_method(ChannelNum::zero(), start_method).await?; self.send_method(ChannelNum::zero(), &start_method).await?;
let start_ok = self.recv_method().await?; let start_ok = self.recv_method().await?;
debug!(?start_ok, "Received Start-Ok"); debug!(?start_ok, "Received Start-Ok");
@ -239,7 +225,7 @@ impl TransportConnection {
}); });
debug!("Sending Tune method"); debug!("Sending Tune method");
self.send_method(ChannelNum::zero(), tune_method).await?; self.send_method(ChannelNum::zero(), &tune_method).await?;
let tune_ok = self.recv_method().await?; let tune_ok = self.recv_method().await?;
debug!(?tune_ok, "Received Tune-Ok method"); debug!(?tune_ok, "Received Tune-Ok method");
@ -269,8 +255,8 @@ impl TransportConnection {
self.send_method( self.send_method(
ChannelNum::zero(), ChannelNum::zero(),
Method::ConnectionOpenOk(ConnectionOpenOk { &Method::ConnectionOpenOk(ConnectionOpenOk {
reserved_1: "".to_string(), reserved_1: "".to_owned(),
}), }),
) )
.await?; .await?;
@ -285,10 +271,11 @@ impl TransportConnection {
let frame = frame?; let frame = frame?;
self.handle_frame(frame).await?; self.handle_frame(frame).await?;
} }
queued_method = self.method_queue_recv.recv() => { queued_method = self.event_receiver.recv() => {
match queued_method { match queued_method {
Some((channel, QueuedMethod::Normal(method))) => self.send_method(channel, method).await?, Some(ConnectionEvent::Method(channel, method)) => self.send_method(channel, &method).await?,
Some((channel, QueuedMethod::WithContent(method, header, body))) => self.send_method_content(channel, method, header, body).await?, Some(ConnectionEvent::MethodContent(channel, method, header, body)) => self.send_method_content(channel, &method, header, body).await?,
Some(ConnectionEvent::Shutdown) => return self.close(0, "".to_owned()).await,
None => {} None => {}
} }
} }
@ -315,7 +302,7 @@ impl TransportConnection {
warn!(%ex, "Channel exception occurred"); warn!(%ex, "Channel exception occurred");
self.send_method( self.send_method(
channel, channel,
Method::ChannelClose(ChannelClose { &Method::ChannelClose(ChannelClose {
reply_code: ex.reply_code(), reply_code: ex.reply_code(),
reply_text: ex.reply_text(), reply_text: ex.reply_text(),
class_id: 0, // todo: do this class_id: 0, // todo: do this
@ -349,7 +336,7 @@ impl TransportConnection {
info!(%reply_code, %reply_text, %class_id, %method_id, "Closing connection"); info!(%reply_code, %reply_text, %class_id, %method_id, "Closing connection");
self.send_method( self.send_method(
ChannelNum::zero(), ChannelNum::zero(),
Method::ConnectionCloseOk(ConnectionCloseOk), &Method::ConnectionCloseOk(ConnectionCloseOk),
) )
.await?; .await?;
return Err(ProtocolError::GracefullyClosed.into()); return Err(ProtocolError::GracefullyClosed.into());
@ -375,7 +362,7 @@ impl TransportConnection {
// maybe this might become an `Option` in the future // maybe this might become an `Option` in the future
let return_method = let return_method =
amqp_messaging::methods::handle_method(channel_handle, method).await?; amqp_messaging::methods::handle_method(channel_handle, method).await?;
self.send_method(frame.channel, return_method).await?; self.send_method(frame.channel, &return_method).await?;
} }
} }
Ok(()) Ok(())
@ -489,7 +476,7 @@ impl TransportConnection {
channel_num, channel_num,
self.global_con.clone(), self.global_con.clone(),
self.global_data.clone(), self.global_data.clone(),
self.method_queue_send.clone(), self.event_sender.clone(),
); );
let channel = TransportChannel { let channel = TransportChannel {
@ -519,7 +506,7 @@ impl TransportConnection {
self.send_method( self.send_method(
channel_num, channel_num,
Method::ChannelOpenOk(ChannelOpenOk { &Method::ChannelOpenOk(ChannelOpenOk {
reserved_1: Vec::new(), reserved_1: Vec::new(),
}), }),
) )
@ -539,7 +526,7 @@ impl TransportConnection {
if let Some(channel) = self.channels.remove(&channel_id) { if let Some(channel) = self.channels.remove(&channel_id) {
drop(channel); drop(channel);
self.send_method(channel_id, Method::ChannelCloseOk(ChannelCloseOk)) self.send_method(channel_id, &Method::ChannelCloseOk(ChannelCloseOk))
.await?; .await?;
} else { } else {
return Err(ConException::Todo.into()); return Err(ConException::Todo.into());
@ -598,6 +585,33 @@ impl TransportConnection {
Err(ProtocolError::ProtocolNegotiationFailed.into()) Err(ProtocolError::ProtocolNegotiationFailed.into())
} }
} }
async fn close(&mut self, reply_code: ReplyCode, reply_text: ReplyText) -> Result<()> {
self.send_method(
ChannelNum::zero(),
&Method::ConnectionClose(ConnectionClose {
reply_code,
reply_text,
class_id: 0, // todo: do this
method_id: 0,
}),
)
.await?;
match self.recv_method().await {
Ok(Method::ConnectionCloseOk(_)) => Ok(()),
Ok(method) => {
return Err(TransError::Other(anyhow!(
"Received wrong method after closing, method: {method:?}"
)));
}
Err(err) => {
return Err(TransError::Other(anyhow!(
"Failed to receive Connection.CloseOk method after closing, err: {err}"
)));
}
}
}
} }
impl Drop for TransportConnection { impl Drop for TransportConnection {
@ -619,12 +633,12 @@ fn server_properties(host: SocketAddr) -> Table {
let host_str = host.ip().to_string(); let host_str = host.ip().to_string();
HashMap::from([ HashMap::from([
("host".to_string(), ls(&host_str)), ("host".to_owned(), ls(&host_str)),
("product".to_string(), ls("no name yet")), ("product".to_owned(), ls("no name yet")),
("version".to_string(), ls("0.1.0")), ("version".to_owned(), ls("0.1.0")),
("platform".to_string(), ls("microsoft linux")), ("platform".to_owned(), ls("microsoft linux")),
("copyright".to_string(), ls("MIT")), ("copyright".to_owned(), ls("MIT")),
("information".to_string(), ls("hello reader")), ("information".to_owned(), ls("hello reader")),
("uwu".to_string(), ls("owo")), ("uwu".to_owned(), ls("owo")),
]) ])
} }

View file

@ -119,10 +119,10 @@ pub fn parse_content_header(input: &[u8]) -> Result<ContentHeader> {
match content_header_parse::header(input) { match content_header_parse::header(input) {
Ok(([], header)) => Ok(header), Ok(([], header)) => Ok(header),
Ok((_, _)) => { Ok((_, _)) => {
Err(ConException::SyntaxError(vec!["could not consume all input".to_string()]).into()) Err(ConException::SyntaxError(vec!["could not consume all input".to_owned()]).into())
} }
Err(nom::Err::Incomplete(_)) => { Err(nom::Err::Incomplete(_)) => {
Err(ConException::SyntaxError(vec!["there was not enough data".to_string()]).into()) Err(ConException::SyntaxError(vec!["there was not enough data".to_owned()]).into())
} }
Err(nom::Err::Failure(err) | nom::Err::Error(err)) => Err(err), Err(nom::Err::Failure(err) | nom::Err::Error(err)) => Err(err),
} }

View file

@ -890,7 +890,7 @@ pub mod write {
use amqp_core::methods::*; use amqp_core::methods::*;
use std::io::Write; use std::io::Write;
pub fn write_method<W: Write>(method: Method, mut writer: W) -> Result<(), TransError> { pub fn write_method<W: Write>(method: &Method, mut writer: W) -> Result<(), TransError> {
match method { match method {
Method::ConnectionStart(ConnectionStart { Method::ConnectionStart(ConnectionStart {
version_major, version_major,

View file

@ -20,10 +20,10 @@ pub fn parse_method(payload: &[u8]) -> Result<Method, TransError> {
match nom_result { match nom_result {
Ok(([], method)) => Ok(method), Ok(([], method)) => Ok(method),
Ok((_, _)) => { Ok((_, _)) => {
Err(ConException::SyntaxError(vec!["could not consume all input".to_string()]).into()) Err(ConException::SyntaxError(vec!["could not consume all input".to_owned()]).into())
} }
Err(nom::Err::Incomplete(_)) => { Err(nom::Err::Incomplete(_)) => {
Err(ConException::SyntaxError(vec!["there was not enough data".to_string()]).into()) Err(ConException::SyntaxError(vec!["there was not enough data".to_owned()]).into())
} }
Err(nom::Err::Failure(err) | nom::Err::Error(err)) => Err(err), Err(nom::Err::Failure(err) | nom::Err::Error(err)) => Err(err),
} }

View file

@ -67,11 +67,8 @@ fn random_ser_de() {
#[test] #[test]
fn nested_table() { fn nested_table() {
let table = HashMap::from([( let table = HashMap::from([(
"A".to_string(), "A".to_owned(),
FieldValue::FieldTable(HashMap::from([( FieldValue::FieldTable(HashMap::from([("B".to_owned(), FieldValue::Boolean(true))])),
"B".to_string(),
FieldValue::Boolean(true),
)])),
)]); )]);
eprintln!("{table:?}"); eprintln!("{table:?}");

View file

@ -3,27 +3,27 @@ use amqp_core::methods::{Bit, Long, Longlong, Longstr, Octet, Short, Shortstr, T
use anyhow::Context; use anyhow::Context;
use std::io::Write; use std::io::Write;
pub fn octet<W: Write>(value: Octet, writer: &mut W) -> Result<(), TransError> { pub fn octet<W: Write>(value: &Octet, writer: &mut W) -> Result<(), TransError> {
writer.write_all(&[value])?; writer.write_all(&[*value])?;
Ok(()) Ok(())
} }
pub fn short<W: Write>(value: Short, writer: &mut W) -> Result<(), TransError> { pub fn short<W: Write>(value: &Short, writer: &mut W) -> Result<(), TransError> {
writer.write_all(&value.to_be_bytes())?; writer.write_all(&value.to_be_bytes())?;
Ok(()) Ok(())
} }
pub fn long<W: Write>(value: Long, writer: &mut W) -> Result<(), TransError> { pub fn long<W: Write>(value: &Long, writer: &mut W) -> Result<(), TransError> {
writer.write_all(&value.to_be_bytes())?; writer.write_all(&value.to_be_bytes())?;
Ok(()) Ok(())
} }
pub fn longlong<W: Write>(value: Longlong, writer: &mut W) -> Result<(), TransError> { pub fn longlong<W: Write>(value: &Longlong, writer: &mut W) -> Result<(), TransError> {
writer.write_all(&value.to_be_bytes())?; writer.write_all(&value.to_be_bytes())?;
Ok(()) Ok(())
} }
pub fn bit<W: Write>(value: &[Bit], writer: &mut W) -> Result<(), TransError> { pub fn bit<W: Write>(value: &[&Bit], writer: &mut W) -> Result<(), TransError> {
// accumulate bits into bytes, starting from the least significant bit in each byte // accumulate bits into bytes, starting from the least significant bit in each byte
// how many bits have already been packed into `current_buf` // how many bits have already been packed into `current_buf`
@ -37,7 +37,7 @@ pub fn bit<W: Write>(value: &[Bit], writer: &mut W) -> Result<(), TransError> {
already_filled = 0; already_filled = 0;
} }
let new_bit = (u8::from(bit)) << already_filled; let new_bit = (u8::from(*bit)) << already_filled;
current_buf |= new_bit; current_buf |= new_bit;
already_filled += 1; already_filled += 1;
} }
@ -49,7 +49,7 @@ pub fn bit<W: Write>(value: &[Bit], writer: &mut W) -> Result<(), TransError> {
Ok(()) Ok(())
} }
pub fn shortstr<W: Write>(value: Shortstr, writer: &mut W) -> Result<(), TransError> { pub fn shortstr<W: Write>(value: &Shortstr, writer: &mut W) -> Result<(), TransError> {
let len = u8::try_from(value.len()).context("shortstr too long")?; let len = u8::try_from(value.len()).context("shortstr too long")?;
writer.write_all(&[len])?; writer.write_all(&[len])?;
writer.write_all(value.as_bytes())?; writer.write_all(value.as_bytes())?;
@ -57,7 +57,7 @@ pub fn shortstr<W: Write>(value: Shortstr, writer: &mut W) -> Result<(), TransEr
Ok(()) Ok(())
} }
pub fn longstr<W: Write>(value: Longstr, writer: &mut W) -> Result<(), TransError> { pub fn longstr<W: Write>(value: &Longstr, writer: &mut W) -> Result<(), TransError> {
let len = u32::try_from(value.len()).context("longstr too long")?; let len = u32::try_from(value.len()).context("longstr too long")?;
writer.write_all(&len.to_be_bytes())?; writer.write_all(&len.to_be_bytes())?;
writer.write_all(value.as_slice())?; writer.write_all(value.as_slice())?;
@ -67,12 +67,12 @@ pub fn longstr<W: Write>(value: Longstr, writer: &mut W) -> Result<(), TransErro
// this appears to be unused right now, but it could be used in `Basic` things? // this appears to be unused right now, but it could be used in `Basic` things?
#[allow(dead_code)] #[allow(dead_code)]
pub fn timestamp<W: Write>(value: Timestamp, writer: &mut W) -> Result<(), TransError> { pub fn timestamp<W: Write>(value: &Timestamp, writer: &mut W) -> Result<(), TransError> {
writer.write_all(&value.to_be_bytes())?; writer.write_all(&value.to_be_bytes())?;
Ok(()) Ok(())
} }
pub fn table<W: Write>(table: Table, writer: &mut W) -> Result<(), TransError> { pub fn table<W: Write>(table: &Table, writer: &mut W) -> Result<(), TransError> {
let mut table_buf = Vec::new(); let mut table_buf = Vec::new();
for (field_name, value) in table { for (field_name, value) in table {
@ -87,17 +87,17 @@ pub fn table<W: Write>(table: Table, writer: &mut W) -> Result<(), TransError> {
Ok(()) Ok(())
} }
fn field_value<W: Write>(value: FieldValue, writer: &mut W) -> Result<(), TransError> { fn field_value<W: Write>(value: &FieldValue, writer: &mut W) -> Result<(), TransError> {
match value { match value {
FieldValue::Boolean(bool) => { FieldValue::Boolean(bool) => {
writer.write_all(&[b't', u8::from(bool)])?; writer.write_all(&[b't', u8::from(*bool)])?;
} }
FieldValue::ShortShortInt(int) => { FieldValue::ShortShortInt(int) => {
writer.write_all(b"b")?; writer.write_all(b"b")?;
writer.write_all(&int.to_be_bytes())?; writer.write_all(&int.to_be_bytes())?;
} }
FieldValue::ShortShortUInt(int) => { FieldValue::ShortShortUInt(int) => {
writer.write_all(&[b'B', int])?; writer.write_all(&[b'B', *int])?;
} }
FieldValue::ShortInt(int) => { FieldValue::ShortInt(int) => {
writer.write_all(b"U")?; writer.write_all(b"U")?;
@ -132,7 +132,7 @@ fn field_value<W: Write>(value: FieldValue, writer: &mut W) -> Result<(), TransE
writer.write_all(&float.to_be_bytes())?; writer.write_all(&float.to_be_bytes())?;
} }
FieldValue::DecimalValue(scale, long) => { FieldValue::DecimalValue(scale, long) => {
writer.write_all(&[b'D', scale])?; writer.write_all(&[b'D', *scale])?;
writer.write_all(&long.to_be_bytes())?; writer.write_all(&long.to_be_bytes())?;
} }
FieldValue::ShortString(str) => { FieldValue::ShortString(str) => {
@ -174,7 +174,7 @@ mod tests {
let bits = [true, false, true]; let bits = [true, false, true];
let mut buffer = [0u8; 1]; let mut buffer = [0u8; 1];
super::bit(&bits, &mut buffer.as_mut_slice()).unwrap(); super::bit(&bits.map(|b| &b), &mut buffer.as_mut_slice()).unwrap();
assert_eq!(buffer, [0b00000101]) assert_eq!(buffer, [0b00000101])
} }
@ -188,7 +188,7 @@ mod tests {
]; ];
let mut buffer = [0u8; 2]; let mut buffer = [0u8; 2];
super::bit(&bits, &mut buffer.as_mut_slice()).unwrap(); super::bit(&bits.map(|b| &b), &mut buffer.as_mut_slice()).unwrap();
assert_eq!(buffer, [0b00001111, 0b00001101]); assert_eq!(buffer, [0b00001111, 0b00001101]);
} }

View file

@ -12,7 +12,7 @@ async fn write_start_ok_frame() {
version_major: 0, version_major: 0,
version_minor: 9, version_minor: 9,
server_properties: HashMap::from([( server_properties: HashMap::from([(
"product".to_string(), "product".to_owned(),
FieldValue::LongString("no name yet".into()), FieldValue::LongString("no name yet".into()),
)]), )]),
mechanisms: "PLAIN".into(), mechanisms: "PLAIN".into(),
@ -145,41 +145,38 @@ fn read_start_ok_payload() {
Method::ConnectionStartOk(ConnectionStartOk { Method::ConnectionStartOk(ConnectionStartOk {
client_properties: HashMap::from([ client_properties: HashMap::from([
( (
"product".to_string(), "product".to_owned(),
FieldValue::LongString("Pika Python Client Library".into()) FieldValue::LongString("Pika Python Client Library".into())
), ),
( (
"platform".to_string(), "platform".to_owned(),
FieldValue::LongString("Python 3.8.10".into()) FieldValue::LongString("Python 3.8.10".into())
), ),
( (
"capabilities".to_string(), "capabilities".to_owned(),
FieldValue::FieldTable(HashMap::from([ FieldValue::FieldTable(HashMap::from([
( (
"authentication_failure_close".to_string(), "authentication_failure_close".to_owned(),
FieldValue::Boolean(true) FieldValue::Boolean(true)
), ),
("basic.nack".to_string(), FieldValue::Boolean(true)), ("basic.nack".to_owned(), FieldValue::Boolean(true)),
("connection.blocked".to_string(), FieldValue::Boolean(true)), ("connection.blocked".to_owned(), FieldValue::Boolean(true)),
( (
"consumer_cancel_notify".to_string(), "consumer_cancel_notify".to_owned(),
FieldValue::Boolean(true) FieldValue::Boolean(true)
), ),
("publisher_confirms".to_string(), FieldValue::Boolean(true)), ("publisher_confirms".to_owned(), FieldValue::Boolean(true)),
])) ]))
), ),
( (
"information".to_string(), "information".to_owned(),
FieldValue::LongString("See http://pika.rtfd.org".into()) FieldValue::LongString("See http://pika.rtfd.org".into())
), ),
( ("version".to_owned(), FieldValue::LongString("1.1.0".into()))
"version".to_string(),
FieldValue::LongString("1.1.0".into())
)
]), ]),
mechanism: "PLAIN".to_string(), mechanism: "PLAIN".to_owned(),
response: "\x00admin\x00".into(), response: "\x00admin\x00".into(),
locale: "en_US".to_string() locale: "en_US".to_owned()
}) })
); );
} }

View file

@ -305,7 +305,7 @@ pub struct {class_name}{method_name}"
use heck::ToSnakeCase; use heck::ToSnakeCase;
if ident == "type" { if ident == "type" {
"r#type".to_string() "r#type".to_owned()
} else { } else {
ident.to_snake_case() ident.to_snake_case()
} }
@ -336,7 +336,7 @@ pub struct {class_name}{method_name}"
fn invariants<'a>(&self, asserts: impl Iterator<Item = &'a Assert>) -> String { fn invariants<'a>(&self, asserts: impl Iterator<Item = &'a Assert>) -> String {
asserts asserts
.map(|assert| match &*assert.check { .map(|assert| match &*assert.check {
"notnull" => "must not be null".to_string(), "notnull" => "must not be null".to_owned(),
"length" => format!("must be shorter than {}", assert.value.as_ref().unwrap()), "length" => format!("must be shorter than {}", assert.value.as_ref().unwrap()),
"regexp" => format!("must match `{}`", assert.value.as_ref().unwrap()), "regexp" => format!("must match `{}`", assert.value.as_ref().unwrap()),
"le" => { "le" => {
@ -354,7 +354,7 @@ pub struct {class_name}{method_name}"
fn doc_comment(&mut self, docs: &[Doc], indent: usize) { fn doc_comment(&mut self, docs: &[Doc], indent: usize) {
for doc in docs { for doc in docs {
if doc.kind == Some("grammar".to_string()) { if doc.kind == Some("grammar".to_owned()) {
continue; continue;
} }
for line in doc.text.lines() { for line in doc.text.lines() {

View file

@ -11,7 +11,7 @@ use crate::error::TransError;
use crate::methods::write_helper::*; use crate::methods::write_helper::*;
use std::io::Write; use std::io::Write;
pub fn write_method<W: Write>(method: Method, mut writer: W) -> Result<(), TransError> {{ pub fn write_method<W: Write>(method: &Method, mut writer: W) -> Result<(), TransError> {{
match method {{" match method {{"
) )
.ok(); .ok();