commit 04607923831336d0b3542f52076aae79010af191 Author: hellerve Date: Wed May 9 00:23:21 2018 +0200 initial diff --git a/README.md b/README.md new file mode 100644 index 0000000..76efa62 --- /dev/null +++ b/README.md @@ -0,0 +1,5 @@ +# se + +A simple data compressor as presented by Gary Bernhardt in [his +screencast](https://www.destroyallsoftware.com/screencasts/catalog/data-compressor-from-scratch), +ported to Python. diff --git a/se.py b/se.py new file mode 100755 index 0000000..e18ed5d --- /dev/null +++ b/se.py @@ -0,0 +1,204 @@ +#!/usr/bin/env python + +import sys +import operator + +class Node: + def __init__(self, left, right, count): + self.left = left + self.right = right + self.count = count + + def __unicode__(self): + return 'Node(left={}, right={}, count={})'.format(self.left, + self.right, + self.count) + + def __str__(self): + return self.__unicode__() + + +class Leaf: + def __init__(self, byte, count): + self.byte = byte + self.count = count + + def __unicode__(self): + return 'Leaf(byte={}, count={})'.format(self.byte, self.count) + + def __str__(self): + return self.__unicode__() + + +class TableRow: + def __init__(self, byte, bits): + self.byte = byte + self.bits = bits + + def __unicode__(self): + return 'TableRow(byte={}, bits={})'.format(self.byte, self.bits) + + def __str__(self): + return self.__unicode__() + + +class BinPacker: + def __init__(self): + self.a = [] + + def bits(self, bits): + self.a.extend(bits) + + def int32(self, int32): + return self.int(32, int32) + + def int8(self, int8): + return self.int(8, int8) + + def int(self, n, intn): + a = [] + for i in range(n): + a.append(intn & 1) + intn = intn >> 1 + self.a.extend(a) + + def pack(self): + x = [] + e = 0 + for i, b in enumerate(self.a): + if i and not i % 8: + x.append(e) + e = 0 + e += b << (i % 8) + x.append(e) + return x + + +class BinUnpacker: + def __init__(self, data): + self.a = BinUnpacker.unpack(data) + + @staticmethod + def unpack(data): + a = [] + for e in data: + for i in range(8): + a.append(e & 1) + e = e >> 1 + return a + + def int(self, n): + a = self.a[:n] + self.a = self.a[n:] + x = 0 + for i, e in enumerate(a): + x = x + (e << i) + return x + + def int8(self): + return self.int(8) + + def int32(self): + return self.int(32) + + def bits(self, n): + a = self.a[:n] + self.a = self.a[n:] + return a + + def peek(self, n): + return self.a[:n] + + +def lookup_byte(table, byte): + for row in table: + if row.byte == byte: + return row.bits + raise RuntimeError("Internal Error") + + +def lookup_bits(table, unpacker): + for row in table: + if row.bits == unpacker.peek(len(row.bits)): + unpacker.bits(len(row.bits)) + return row.byte + raise RuntimeError("Internal Error") + + +def pack_table(table, packer): + packer.int8(len(table) - 1) + for row in table: + packer.int8(ord(row.byte)) + packer.int8(len(row.bits)) + packer.bits(row.bits) + + +def unpack_table(unpacker): + table_len = unpacker.int8() + 1 + table = [] + for _ in range(table_len): + byte = chr(unpacker.int8()) + n = unpacker.int8() + bits = unpacker.bits(n) + table.append(TableRow(byte, bits)) + return table + + +def compress(original): + tree = build_tree(original) + table = build_table(tree) + packer = BinPacker() + + pack_table(table, packer) + packer.int32(len(original)) + + for byte in bytes(original): + bits = lookup_byte(table, byte) + packer.bits(bits) + + return packer.pack() + + +def decompress(data): + unpacker = BinUnpacker(data) + + table = unpack_table(unpacker) + data_len = unpacker.int32() + + s = [] + for i in range(data_len): + s.append(lookup_bits(table, unpacker)) + + return "".join(s) + + + +def build_table(node, path=None): + if not path: path = [] + + if isinstance(node, Node): + return build_table(node.left, path + [0]) + build_table(node.right, path + [1]) + return [TableRow(node.byte, path)] + + +def build_tree(original): + bs = bytes(original) + unique_bs = set(bs) + + leaves = [] + for byte in unique_bs: + leaves.append(Leaf(byte, bs.count(byte))) + + while len(leaves) > 1: + l = min(leaves, key=operator.attrgetter("count")) + leaves.remove(l) + r = min(leaves, key=operator.attrgetter("count")) + leaves.remove(r) + leaves.append(Node(l, r, l.count + r.count)) + + return leaves[0] + +if len(sys.argv) > 1 and sys.argv[1] == "compress": + print("".join(chr(n) for n in compress(sys.stdin.read()))) +else: + print(decompress([ord(n) for n in sys.stdin.read()]))