From 3a8b622ae78d8c1a2446a50eb3d18d97f91b5087 Mon Sep 17 00:00:00 2001 From: sergeych Date: Fri, 24 Nov 2023 21:59:18 +0300 Subject: [PATCH] ser/de support including enums --- src/de.rs | 166 +++++++++++++++++++++++++++++++++++++++-------------- src/ser.rs | 6 +- 2 files changed, 129 insertions(+), 43 deletions(-) diff --git a/src/de.rs b/src/de.rs index b3b2af0..52a49bd 100644 --- a/src/de.rs +++ b/src/de.rs @@ -1,7 +1,5 @@ -use serde::de::{ - self, DeserializeSeed, MapAccess, SeqAccess, - Visitor, -}; +use serde::de::{self, DeserializeSeed, IntoDeserializer, MapAccess, SeqAccess, Visitor}; +use serde::de::value::U32Deserializer; use serde::Deserialize; use crate::bipack_source::{BipackSource, VecSource}; @@ -13,11 +11,13 @@ pub struct Deserializer { input: VecSource, } -pub fn from_bytes<'de,T: Deserialize<'de>>(source: Vec) -> Result { - let mut des = Deserializer { input: VecSource::from(source)}; +pub fn from_bytes<'de, T: Deserialize<'de>>(source: &[u8]) -> Result { + let mut des = Deserializer { input: VecSource::from(source.to_vec()) }; T::deserialize(&mut des) } +impl Deserializer {} + impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer { type Error = Error; @@ -26,7 +26,7 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer { } fn deserialize_bool(self, visitor: V) -> std::result::Result where V: Visitor<'de> { - visitor.visit_bool(if self.input.get_u8()? == 0 { false } else { true } ) + visitor.visit_bool(if self.input.get_u8()? == 0 { false } else { true }) } fn deserialize_i8(self, visitor: V) -> std::result::Result where V: Visitor<'de> { @@ -73,14 +73,13 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer { let ch = self.input.get_str()?; if ch.len() != 1 { Err(Error::BadFormat(format!("Char length is {}, should be 1 {:?}", ch.len(), ch))) - } - else { + } else { visitor.visit_char(ch.chars().next().unwrap()) } } fn deserialize_str(self, visitor: V) -> std::result::Result where V: Visitor<'de> { - visitor.visit_string( self.input.get_str()?) + visitor.visit_string(self.input.get_str()?) } fn deserialize_string(self, visitor: V) -> std::result::Result where V: Visitor<'de> { @@ -131,15 +130,24 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer { fn deserialize_map(self, visitor: V) -> std::result::Result where V: Visitor<'de> { let size = self.input.get_unsigned()?; - visitor.visit_map(SimpleMap { de: self, size: size as usize} ) + visitor.visit_map(SimpleMap { de: self, size: size as usize }) } fn deserialize_struct(self, name: &'static str, fields: &'static [&'static str], visitor: V) -> std::result::Result where V: Visitor<'de> { - visitor.visit_seq(SimpleSeq::new(self, fields.len() )) + visitor.visit_seq(SimpleSeq::new(self, fields.len())) } - fn deserialize_enum(self, name: &'static str, variants: &'static [&'static str], visitor: V) -> std::result::Result where V: Visitor<'de> { - todo!() + #[inline] + fn deserialize_enum( + self, + _name: &'static str, + _variants: &'static [&'static str], + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + visitor.visit_enum(self) } fn deserialize_identifier(self, visitor: V) -> std::result::Result where V: Visitor<'de> { @@ -155,13 +163,56 @@ impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer { } } +impl<'de, 'a> serde::de::EnumAccess<'de> for &'a mut Deserializer { + type Error = Error; + type Variant = Self; + + #[inline] + fn variant_seed>(self, seed: V) -> Result<(V::Value, Self)> { + let varint: u32 = self.input.get_unsigned()? as u32; + let v = DeserializeSeed::deserialize::>( + seed, + varint.into_deserializer())?; + Ok((v, self)) + } +} + +impl<'de, 'a> serde::de::VariantAccess<'de> for &'a mut Deserializer { + type Error = Error; + + #[inline] + fn unit_variant(self) -> Result<()> { + Ok(()) + } + + #[inline] + fn newtype_variant_seed>(self, seed: V) -> Result { + DeserializeSeed::deserialize(seed, self) + } + + #[inline] + fn tuple_variant>(self, len: usize, visitor: V) -> Result { + serde::de::Deserializer::deserialize_tuple(self, len, visitor) + } + + #[inline] + fn struct_variant>( + self, + fields: &'static [&'static str], + visitor: V, + ) -> Result { + serde::de::Deserializer::deserialize_tuple(self, fields.len(), visitor) + } +} + + struct SimpleSeq<'a> { de: &'a mut Deserializer, size: usize, } impl<'a> SimpleSeq<'a> { - fn new(de: &'a mut Deserializer,size: usize) -> Self { + fn new(de: &'a mut Deserializer, size: usize) -> Self { SimpleSeq { de, size: size, @@ -198,8 +249,7 @@ impl<'de, 'a> MapAccess<'de> for SimpleMap<'a> { fn next_key_seed(&mut self, seed: K) -> std::result::Result, Self::Error> where K: DeserializeSeed<'de> { if self.size < 1 { Ok(None) - } - else { + } else { self.size -= 1; seed.deserialize(&mut *self.de).map(Some) } @@ -210,30 +260,9 @@ impl<'de, 'a> MapAccess<'de> for SimpleMap<'a> { } } - -#[test] -fn test_ints() -> Result<()> { - // #[derive(Deserialize, PartialEq, Debug)] - // struct Test { - // int: u32, - // seq: Vec, - // } - - let b = vec![7]; - assert_eq!( 7u8, from_bytes(vec![7u8])?); - - - // let j = r#"{"int":1,"seq":["a","b"]}"#; - // let expected = Test { - // int: 1, - // seq: vec!["a".to_owned(), "b".to_owned()], - // }; - // assert_eq!(expected, from_str(j).unwrap()); - Ok(()) -} - mod tests { use std::collections::{HashMap, HashSet}; + use std::fmt::Debug; use serde::{Deserialize, Serialize}; @@ -242,6 +271,27 @@ mod tests { use crate::ser::to_bytes; use crate::tools::to_dump; + #[test] + fn test_ints() -> Result<()> { + // #[derive(Deserialize, PartialEq, Debug)] + // struct Test { + // int: u32, + // seq: Vec, + // } + + let b = vec![7]; + assert_eq!(7u8, from_bytes(&vec![7u8])?); + + + // let j = r#"{"int":1,"seq":["a","b"]}"#; + // let expected = Test { + // int: 1, + // seq: vec!["a".to_owned(), "b".to_owned()], + // }; + // assert_eq!(expected, from_str(j).unwrap()); + Ok(()) + } + #[test] fn test_struct() -> Result<()> { #[derive(Serialize, Deserialize, PartialEq, Debug)] @@ -256,7 +306,7 @@ mod tests { let packed = to_bytes(&expected)?; println!("::{}", to_dump(&packed)); - let unpacked: Test = from_bytes(packed)?; + let unpacked: Test = from_bytes(&packed)?; println!("::{:?}", unpacked); assert_eq!(&expected, &unpacked); Ok(()) @@ -275,7 +325,7 @@ mod tests { let packed = to_bytes(&src)?; println!("{}", to_dump(&packed)); - let restored: HashMap = from_bytes(packed)?; + let restored: HashMap = from_bytes(&packed)?; println!("{:?}", restored); assert_eq!(src, restored); @@ -288,10 +338,42 @@ mod tests { let packed = to_bytes(&src)?; println!("{}", to_dump(&packed)); - let restored: HashSet = from_bytes(packed)?; + let restored: HashSet = from_bytes(&packed)?; println!("{:?}", restored); assert_eq!(src, restored); Ok(()) } + + + fn testeq<'a, T: Serialize + Deserialize<'a> + PartialEq + Debug>(x: & 'a T) { + let packed = to_bytes(x).unwrap(); + println!("packed {:?}:\n{}", x, to_dump(&packed) ); + assert_eq!(*x, from_bytes(&packed).unwrap()); + } + + #[test] + fn test_enum() -> Result<()> { + #[derive(Serialize, Deserialize, Debug, PartialEq)] + enum E { + Unit, + Unit2, + Newtype(u32), + Tuple(u32, u32), + Struct { a: u32 }, + } + + let packed = to_bytes(&E::Newtype(7))?; + println!("{}", to_dump(&packed)); + let r: E = from_bytes(&packed)?; + println!("{:?}", r); + + testeq(&E::Unit); + testeq(&E::Unit2); + testeq(&E::Newtype(101)); + testeq(&E::Tuple(17, 42)); + testeq(&E::Struct {a: 19} ); + + Ok(()) + } } \ No newline at end of file diff --git a/src/ser.rs b/src/ser.rs index 3630c9e..0bf38ea 100644 --- a/src/ser.rs +++ b/src/ser.rs @@ -134,6 +134,7 @@ impl<'a> ser::Serializer for &'a mut Serializer { where T: ?Sized + Serialize, { + value.serialize(self) } @@ -200,7 +201,7 @@ impl<'a> ser::Serializer for &'a mut Serializer { variant: &'static str, _len: usize, ) -> Result { - variant.serialize(&mut *self)?; + self.output.put_unsigned(_variant_index); Ok(self) } } @@ -387,6 +388,9 @@ fn test_enum() -> std::result::Result<(), FromUtf8Error> { let t = E::Tuple(7,17); println!("u:{}",to_dump(to_bytes(&t).unwrap().as_slice())); assert_eq!("0c 1c 44",to_hex(to_bytes(&t).unwrap())?); + let t = E::Struct { a: 17 }; + println!("u:{}",to_dump(to_bytes(&t).unwrap().as_slice())); + assert_eq!("10 44",to_hex(to_bytes(&t).unwrap())?); // let expected = r#""Unit""#; // assert_eq!(to_string(&u).unwrap(), expected); //