2
\$\begingroup\$

Follow up of this question.

I implemented most suggestions from the answers of the previous question.

Code:

#[macro_use]
extern crate quick_error;
extern crate byteorder;
#[cfg(test)]
#[macro_use]
extern crate quickcheck;
use std::collections::hash_map::HashMap;
use std::collections::BinaryHeap;
use byteorder::{ReadBytesExt, WriteBytesExt, LittleEndian};
use std::{result, io, path, cmp, fs};
use std::io::prelude::*;
type Result<T> = result::Result<T, HuffmanError>;
const BITS: usize = 8;
const DEFAULT_BUFFER_SIZE: usize = 8192;
macro_rules! try_opt {
 ($e:expr) =>(
 match $e {
 Ok(v) => v,
 Err(err) => return Some(Err(err)),
 }
 )
}
quick_error! {
 #[derive(Debug)]
 pub enum HuffmanError {
 Io(err: io::Error) {
 from()
 cause(err)
 description(err.description())
 display("IO error: {}", err)
 }
 ParseTree {
 description("Parse tree error: invalid encoded huffman tree")
 }
 AlphabetMismatch {
 description("Alphabet mismatch error: alphabet doesn't match parsed tree")
 }
 Empty {
 description("Empty stream")
 }
 }
}
struct DecodeBitIterator<'a> {
 root: &'a HuffmanTree,
 current: &'a HuffmanTree,
}
impl<'a> DecodeBitIterator<'a> {
 fn next(&mut self, bit: bool) -> Option<u8> {
 use HuffmanTree::*;
 let (next, result) = match (self.root, self.current.next(bit)) {
 (_, Some(&Leaf { byte, .. })) |
 (&Leaf { byte, .. }, None) => (self.root, Some(byte)),
 (_, Some(inner)) => (inner, None),
 _ => panic!("assertion error"),
 };
 self.current = next;
 result
 }
}
fn create_mode_map<R: std::io::Read>(reader: &mut R,
 buffer_size: usize)
 -> Result<HashMap<u8, u32>> {
 let mut mode_map = HashMap::new();
 let mut buf = vec![0; buffer_size];
 loop {
 let size = reader.read(&mut buf)?;
 if size == 0 {
 break;
 }
 for &b in buf.iter().take(size) {
 *(mode_map.entry(b).or_insert(0)) += 1;
 }
 }
 Ok(mode_map)
}
fn iterate_bits<'a>(bits: &'a [u8], size: usize) -> BitVectorIter<'a> {
 let itr = bits.iter()
 .flat_map(|byte| (0..BITS).map(move |i| ((byte >> i) & 1) != 0))
 .take(size);
 BitVectorIter { iter: Box::new(itr) }
}
#[derive(Eq, Debug)]
pub enum HuffmanTree {
 Inner {
 frequency: u32,
 left: Box<HuffmanTree>,
 right: Box<HuffmanTree>,
 },
 Leaf { frequency: u32, byte: u8 },
}
impl HuffmanTree {
 pub fn new<R: std::io::Read>(reader: &mut R, buffer_size: usize) -> Result<HuffmanTree> {
 let mut pq: BinaryHeap<_> = create_mode_map(reader, buffer_size)
 ?
 .into_iter()
 .map(|(c, f)| {
 HuffmanTree::Leaf {
 byte: c,
 frequency: f,
 }
 })
 .collect();
 for _ in 1..pq.len() {
 let min1 = pq.pop().expect("assertion error");
 let min2 = pq.pop().expect("assertion error");
 pq.push(min1.join(min2));
 }
 pq.pop().ok_or(HuffmanError::Empty)
 }
 fn freq(&self) -> u32 {
 use HuffmanTree::*;
 match *self {
 Inner { frequency, .. } |
 Leaf { frequency, .. } => frequency,
 }
 }
 fn join(self, other: HuffmanTree) -> HuffmanTree {
 HuffmanTree::Inner {
 frequency: self.freq() + other.freq(),
 left: Box::new(self),
 right: Box::new(other),
 }
 }
 fn create_mapper_recur(&self, bit_vec: &mut BitVector, map: &mut HashMap<u8, BitVector>) {
 use HuffmanTree::*;
 match *self {
 Inner { ref left, ref right, .. } => {
 bit_vec.push(true);
 left.create_mapper_recur(bit_vec, map);
 bit_vec.push(false);
 right.create_mapper_recur(bit_vec, map);
 }
 Leaf { byte, .. } => {
 map.insert(byte, bit_vec.clone());
 }
 }
 bit_vec.pop();
 }
 fn create_mapper(&self) -> HashMap<u8, BitVector> {
 let mut bit_vec = BitVector::new();
 let mut map = HashMap::new();
 if let HuffmanTree::Leaf { .. } = *self {
 bit_vec.push(true);
 }
 self.create_mapper_recur(
 &mut bit_vec, 
 &mut map);
 map
 }
 fn decode<I, T>(encoded_walk: &mut I, mut bytes: &mut T) -> Result<HuffmanTree>
 where I: Iterator<Item = bool>,
 T: Iterator<Item = u8>
 {
 match encoded_walk.next() {
 Some(true) => {
 let left = Self::decode(encoded_walk, bytes)?;
 let right = Self::decode(encoded_walk, bytes)?;
 Ok(left.join(right))
 }
 Some(false) => {
 let c = bytes.next()
 .ok_or(HuffmanError::AlphabetMismatch)?;
 Ok(HuffmanTree::Leaf {
 frequency: 0,
 byte: c,
 })
 }
 None => Err(HuffmanError::ParseTree),
 }
 }
 pub fn serialize<W: io::Write>(&self, writer: &mut W) -> Result<()> {
 let (encoded_walk, alphabet_bytes) = self.encode();
 let walk_bit_len = encoded_walk.len() as u64;
 writer.write_u64::<LittleEndian>(walk_bit_len)?;
 writer.write_all(&encoded_walk.bits)?;
 let alphabet_byte_len = alphabet_bytes.len() as u64;
 writer.write_u64::<LittleEndian>(alphabet_byte_len)?;
 writer.write_all(&alphabet_bytes)?;
 Ok(())
 }
 pub fn deserialize<R: io::Read>(reader: &mut R) -> Result<HuffmanTree> {
 let walk_bit_len = reader.read_u64::<LittleEndian>()?;
 let walk_byte_len = (walk_bit_len + BITS as u64 - 1) / BITS as u64;
 let mut walk_bytes = Vec::new();
 reader.take(walk_byte_len)
 .read_to_end(&mut walk_bytes)?;
 let alphabet_bytes_len = reader.read_u64::<LittleEndian>()?;
 let mut alphabet_bytes = Vec::new();
 reader.take(alphabet_bytes_len)
 .read_to_end(&mut alphabet_bytes)?;
 let bit_vec = BitVector {
 bits: walk_bytes,
 size: walk_bit_len as usize,
 };
 let bit_vec_iter = &mut bit_vec.iter();
 Self::decode(
 bit_vec_iter, 
 &mut alphabet_bytes.iter().cloned())
 }
 fn encode_recur(&self, bit_vec: &mut BitVector, alphabet_bytes: &mut Vec<u8>) {
 use HuffmanTree::*;
 match *self {
 Inner { ref left, ref right, .. } => {
 bit_vec.push(true);
 left.encode_recur(bit_vec, alphabet_bytes);
 right.encode_recur(bit_vec, alphabet_bytes);
 }
 Leaf { byte, .. } => {
 bit_vec.push(false);
 alphabet_bytes.push(byte);
 }
 }
 }
 fn encode(&self) -> (BitVector, Vec<u8>) {
 let mut bit_vector = BitVector::new();
 let mut alphabet_bytes = Vec::new();
 self.encode_recur(&mut bit_vector, &mut alphabet_bytes);
 (bit_vector, alphabet_bytes)
 }
 fn decode_bit_iter<'a>(&'a self) -> DecodeBitIterator {
 DecodeBitIterator {
 root: self,
 current: self,
 }
 }
 pub fn encode_iter<'a, R: io::Read>(&self,
 reader: &'a mut R,
 buffer_size: usize)
 -> EncodeIterator<'a, R> {
 EncodeIterator {
 reader: reader,
 mapper: self.create_mapper(),
 buffer: vec![0u8; buffer_size],
 }
 }
 pub fn decode_iter<'a, 'b, R: io::Read>(&'b self,
 reader: &'a mut R,
 buffer_size: u64,
 bit_junk: usize)
 -> io::Result<DecodeIterator<'a, 'b, R>> {
 let mut last_buffer = Vec::new();
 reader.take(buffer_size)
 .read_to_end(&mut last_buffer)?;
 let itr = DecodeIterator {
 reader: reader,
 bit_mapper: self.decode_bit_iter(),
 buffer: Vec::new(),
 buffer_size: buffer_size,
 last_buffer: last_buffer,
 bit_junk: bit_junk,
 };
 Ok(itr)
 }
 fn next(&self, bit: bool) -> Option<&HuffmanTree> {
 use HuffmanTree::*;
 match *self {
 Inner { ref left, .. } if bit => Some(left),
 Inner { ref right, .. } => Some(right),
 Leaf { .. } => None,
 }
 }
}
pub struct EncodeIterator<'a, R: io::Read + 'a> {
 reader: &'a mut R,
 mapper: HashMap<u8, BitVector>,
 buffer: Vec<u8>,
}
impl<'a, R: io::Read> Iterator for EncodeIterator<'a, R> {
 type Item = io::Result<BitVector>;
 fn next(&mut self) -> Option<Self::Item> {
 let size = try_opt!(self.reader.read(&mut self.buffer));
 if size == 0 {
 return None;
 }
 let mut bit_vec = BitVector::new();
 let itr = self.buffer
 .iter()
 .take(size)
 .map(|c| {
 self.mapper
 .get(&c)
 .expect("assertion error")
 });
 for code in itr {
 bit_vec.append(&code);
 }
 Some(Ok(bit_vec))
 }
}
pub struct DecodeIterator<'a, 'b, R: io::Read + 'a> {
 reader: &'a mut R,
 bit_mapper: DecodeBitIterator<'b>,
 buffer: Vec<u8>,
 buffer_size: u64,
 last_buffer: Vec<u8>,
 bit_junk: usize,
}
impl<'a, 'b, R: io::Read> Iterator for DecodeIterator<'a, 'b, R> {
 type Item = io::Result<Vec<u8>>;
 fn next(&mut self) -> Option<Self::Item> {
 fn bool_mask(b: bool) -> usize {
 (((b as i32) << 31) >> 31) as usize
 }
 let size = try_opt!(self.reader
 .take(self.buffer_size)
 .read_to_end(&mut self.buffer));
 if size == 0 
 && self.last_buffer.is_empty() {
 return None;
 }
 let bit_len = BITS * self.last_buffer.len() 
 - (bool_mask(size == 0) 
 & self.bit_junk); 
 let mut bytes = Vec::new();
 for bit in iterate_bits(&self.last_buffer, bit_len) {
 let byte = match self.bit_mapper.next(bit) {
 Some(byte) => byte,
 None => continue,
 };
 bytes.push(byte);
 }
 std::mem::swap(
 &mut self.buffer, 
 &mut self.last_buffer);
 self.buffer.clear();
 Some(Ok(bytes))
 }
}
impl Ord for HuffmanTree {
 fn cmp(&self, other: &HuffmanTree) -> cmp::Ordering {
 self.freq().cmp(&other.freq()).reverse()
 }
}
impl PartialOrd for HuffmanTree {
 fn partial_cmp(&self, other: &HuffmanTree) -> Option<cmp::Ordering> {
 Some(self.cmp(other))
 }
}
impl PartialEq for HuffmanTree {
 fn eq(&self, other: &HuffmanTree) -> bool {
 self.freq() == other.freq()
 }
}
pub trait HuffmanCompress {
 fn compress<T>(&mut self, writer: &mut T) -> Result<()>
 where T: io::Write + io::Seek
 {
 self.compress_with_buffer_size(writer, DEFAULT_BUFFER_SIZE)
 }
 fn compress_with_buffer_size<T>(&mut self, &mut T, buffer: usize) -> Result<()>
 where T: io::Write + io::Seek;
}
pub trait HuffmanDeCompress {
 fn decompress<T>(&mut self, writer: &mut T) -> Result<()>
 where T: io::Write
 {
 self.decompress_with_buffer_size(writer, DEFAULT_BUFFER_SIZE as u64)
 }
 fn decompress_with_buffer_size<T>(&mut self, writer: &mut T, buffer_size: u64) -> Result<()>
 where T: io::Write;
}
pub trait HuffmanCodes: HuffmanDeCompress + HuffmanCompress {}
impl<T: io::Read + io::Seek> HuffmanCodes for T {}
impl<T: io::Read + io::Seek> HuffmanCompress for T {
 fn compress_with_buffer_size<W>(&mut self, writer: &mut W, buffer_size: usize) -> Result<()>
 where W: io::Write + io::Seek
 {
 use std::io::SeekFrom;
 let tree = HuffmanTree::new(self, buffer_size)?;
 tree.serialize(writer)?;
 self.seek(SeekFrom::Start(0))?;
 let pos = writer.seek(SeekFrom::Current(0))?;
 writer.write_i8(0)?;
 let mut bit_vec = BitVector::new();
 for code in tree.encode_iter(self, buffer_size) {
 bit_vec.append(&code?);
 let bit_leftover = bit_vec.len() % BITS;
 let leftover = (bit_leftover != 0) as usize;
 let byte_len = bit_vec.byte_len();
 writer.write_all(&bit_vec.bits[0..byte_len - leftover])?;
 bit_vec.bits.swap(0, byte_len - 1);
 bit_vec.bits.truncate(leftover);
 bit_vec.size = bit_leftover;
 }
 writer.write_all(&bit_vec.bits)?;
 let junk = BITS - bit_vec.len() % BITS;
 let mask = (((junk != BITS) as i8) << (BITS - 1)) >> (BITS - 1);
 let current = writer.seek(SeekFrom::Current(0))?;
 writer.seek(SeekFrom::Start(pos))?;
 writer.write_i8(junk as i8 & mask)?;
 writer.seek(SeekFrom::Start(current))?;
 Ok(())
 }
}
impl<T: std::io::Read> HuffmanDeCompress for T {
 fn decompress_with_buffer_size<W>(&mut self, writer: &mut W, buffer_size: u64) -> Result<()>
 where W: io::Write
 {
 let tree = HuffmanTree::deserialize(self)?;
 let junk = self.read_i8()? as usize;
 for bytes in tree.decode_iter(self, buffer_size, junk)? {
 writer.write_all(&bytes?)?
 }
 Ok(())
 }
}
pub fn io_map<P, T, F>(path: P, target: T, mapping_function: F) -> Result<()>
 where P: AsRef<path::Path>,
 T: AsRef<path::Path>,
 F: Fn(&mut io::BufReader<fs::File>, &mut io::BufWriter<fs::File>) -> Result<()>
{
 let f1 = fs::File::open(path)?;
 let mut reader = io::BufReader::new(f1);
 let f2 = fs::File::create(target)?;
 let mut writer = io::BufWriter::new(f2);
 mapping_function(&mut reader, &mut writer)?;
 Ok(())
}
pub fn decode<P, T>(path: P, target: T) -> Result<()>
 where P: AsRef<path::Path>,
 T: AsRef<path::Path>
{
 io_map(path, target, |r, w| r.decompress(w))
}
pub fn encode<P, T>(path: P, target: T) -> Result<()>
 where P: AsRef<path::Path>,
 T: AsRef<path::Path>
{
 io_map(path, target, |r, w| r.compress(w))
}
#[derive(Clone, Debug)]
pub struct BitVector {
 bits: Vec<u8>,
 size: usize,
}
impl BitVector {
 pub fn new() -> BitVector {
 BitVector {
 bits: Vec::new(),
 size: 0,
 }
 }
 pub fn push(&mut self, bit: bool) {
 let leftover = self.size % BITS;
 if leftover == 0 {
 self.bits.push(0);
 }
 let last_byte = self.bits
 .last_mut()
 .expect("Assertion error");
 *last_byte |= (bit as u8) << leftover;
 self.size += 1;
 }
 #[allow(dead_code)]
 pub fn clear(&mut self) {
 self.bits.clear();
 self.size = 0;
 }
 pub fn pop(&mut self) {
 if self.len() == 0 {
 return;
 }
 let len = self.size - 1;
 self.put(len, false);
 self.size = len;
 if self.len() % BITS == 0 {
 self.bits.pop();
 }
 }
 #[allow(dead_code)]
 pub fn push_all(&mut self, bits: &[bool]) {
 for bit in bits {
 self.push(*bit);
 }
 }
 pub fn append(&mut self, other: &BitVector) {
 let leftover = self.size % BITS;
 let empty_bits = BITS - leftover;
 let len = self.bits.len();
 self.bits.extend(other.bits.iter().cloned());
 self.size += other.size;
 if leftover == 0 {
 return;
 }
 for i in len..self.bits.len() {
 self.move_bits(empty_bits, i);
 }
 if (self.size - 1) / BITS != self.bits.len() - 1 {
 self.bits.pop();
 }
 }
 pub fn check(&self, i: usize) {
 assert!(i < self.size,
 format!("Index out of bounds. Index: {} >= len: {}", i, self.size));
 }
 #[allow(dead_code)]
 pub fn get(&self, i: usize) -> bool {
 self.check(i);
 (1 & (self.bits[i / BITS] >> (i % BITS))) != 0
 }
 pub fn put(&mut self, i: usize, bit: bool) {
 self.check(i);
 let (byte, bit_pos) = (i / BITS, i % BITS);
 self.bits[byte] &= !(1 << bit_pos);
 self.bits[byte] |= (bit as u8) << bit_pos;
 }
 fn move_bits(&mut self, bits: usize, i: usize) {
 let leftover = BITS - bits;
 let (from, to) = (self.bits[i], self.bits[i - 1]);
 self.bits[i - 1] = to ^ (from << leftover);
 self.bits[i] = from >> bits;
 }
 pub fn iter<'a>(&'a self) -> BitVectorIter<'a> {
 iterate_bits(&self.bits, self.size)
 }
 pub fn len(&self) -> usize {
 self.size
 }
 #[allow(dead_code)]
 pub fn empty(&self) -> bool {
 self.size == 0
 }
 #[allow(dead_code)]
 pub fn byte_len(&self) -> usize {
 self.bits.len()
 }
}
pub struct BitVectorIter<'a> {
 iter: Box<std::iter::Iterator<Item = bool> + 'a>,
}
impl<'a> IntoIterator for &'a BitVector {
 type Item = bool;
 type IntoIter = BitVectorIter<'a>;
 fn into_iter(self) -> Self::IntoIter {
 self.iter()
 }
}
impl<'a> Iterator for BitVectorIter<'a> {
 type Item = bool;
 fn next(&mut self) -> Option<Self::Item> {
 self.iter.next()
 }
}

Test:

#[cfg(test)]
mod test {
 use quickcheck::TestResult;
 use std::io;
 use super::*;
 use std::io::prelude::*;
 fn read_stream<W: io::Read>(reader: &mut W) -> io::Result<String> {
 let mut result = String::new();
 reader.read_to_string(&mut result)?;
 Ok(result)
 }
 quickcheck! {
 fn prop(text: String) -> super::Result<TestResult> {
 if text.is_empty() {
 return Ok(TestResult::discard());
 }
 let mut mock_input_stream = io::Cursor::new(text.into_bytes());
 let mut mock_compress_stream = io::Cursor::new(Vec::new());
 let mut mock_decompress_stream = io::Cursor::new(Vec::new());
 mock_input_stream.compress(&mut mock_compress_stream)?;
 mock_compress_stream.seek(io::SeekFrom::Start(0))?;
 mock_compress_stream.decompress(&mut mock_decompress_stream)?;
 mock_input_stream.seek(io::SeekFrom::Start(0))?;
 mock_decompress_stream.seek(io::SeekFrom::Start(0))?;
 let expected = read_stream(&mut mock_input_stream)?;
 let test = read_stream(&mut mock_decompress_stream)?;
 Ok(TestResult::from_bool(expected == test))
 }
 }
}
asked Jun 6, 2017 at 0:30
\$\endgroup\$

0

Know someone who can answer? Share a link to this question via email, Twitter, or Facebook.

Your Answer

Draft saved
Draft discarded

Sign up or log in

Sign up using Google
Sign up using Email and Password

Post as a guest

Required, but never shown

Post as a guest

Required, but never shown

By clicking "Post Your Answer", you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.