#!/usr/bin/env python import sys import operator class Node: def __init__(self, left, right, count): self.left = left self.right = right self.count = count def __unicode__(self): return 'Node(left={}, right={}, count={})'.format(self.left, self.right, self.count) def __str__(self): return self.__unicode__() class Leaf: def __init__(self, byte, count): self.byte = byte self.count = count def __unicode__(self): return 'Leaf(byte={}, count={})'.format(self.byte, self.count) def __str__(self): return self.__unicode__() class TableRow: def __init__(self, byte, bits): self.byte = byte self.bits = bits def __unicode__(self): return 'TableRow(byte={}, bits={})'.format(self.byte, self.bits) def __str__(self): return self.__unicode__() class BinPacker: def __init__(self): self.a = [] def bits(self, bits): self.a.extend(bits) def int32(self, int32): return self.int(32, int32) def int8(self, int8): return self.int(8, int8) def int(self, n, intn): a = [] for i in range(n): a.append(intn & 1) intn = intn >> 1 self.a.extend(a) def pack(self): x = [] e = 0 for i, b in enumerate(self.a): if i and not i % 8: x.append(e) e = 0 e += b << (i % 8) x.append(e) return x class BinUnpacker: def __init__(self, data): self.a = BinUnpacker.unpack(data) @staticmethod def unpack(data): a = [] for e in data: for i in range(8): a.append(e & 1) e = e >> 1 return a def int(self, n): a = self.a[:n] self.a = self.a[n:] x = 0 for i, e in enumerate(a): x = x + (e << i) return x def int8(self): return self.int(8) def int32(self): return self.int(32) def bits(self, n): a = self.a[:n] self.a = self.a[n:] return a def peek(self, n): return self.a[:n] def lookup_byte(table, byte): for row in table: if row.byte == byte: return row.bits raise RuntimeError("Internal Error") def lookup_bits(table, unpacker): for row in table: if row.bits == unpacker.peek(len(row.bits)): unpacker.bits(len(row.bits)) return row.byte raise RuntimeError("Internal Error") def pack_table(table, packer): packer.int8(len(table) - 1) for row in table: packer.int8(ord(row.byte)) packer.int8(len(row.bits)) packer.bits(row.bits) def unpack_table(unpacker): table_len = unpacker.int8() + 1 table = [] for _ in range(table_len): byte = chr(unpacker.int8()) n = unpacker.int8() bits = unpacker.bits(n) table.append(TableRow(byte, bits)) return table def compress(original): tree = build_tree(original) table = build_table(tree) packer = BinPacker() pack_table(table, packer) packer.int32(len(original)) for byte in bytes(original): bits = lookup_byte(table, byte) packer.bits(bits) return packer.pack() def decompress(data): unpacker = BinUnpacker(data) table = unpack_table(unpacker) data_len = unpacker.int32() s = [] for i in range(data_len): s.append(lookup_bits(table, unpacker)) return "".join(s) def build_table(node, path=None): if not path: path = [] if isinstance(node, Node): return build_table(node.left, path + [0]) + build_table(node.right, path + [1]) return [TableRow(node.byte, path)] def build_tree(original): bs = bytes(original) unique_bs = set(bs) leaves = [] for byte in unique_bs: leaves.append(Leaf(byte, bs.count(byte))) while len(leaves) > 1: l = min(leaves, key=operator.attrgetter("count")) leaves.remove(l) r = min(leaves, key=operator.attrgetter("count")) leaves.remove(r) leaves.append(Node(l, r, l.count + r.count)) return leaves[0] if len(sys.argv) > 1 and sys.argv[1] == "compress": print("".join(chr(n) for n in compress(sys.stdin.read()))) else: print(decompress([ord(n) for n in sys.stdin.read()]))