205 lines
4.4 KiB
Python
Executable File
205 lines
4.4 KiB
Python
Executable File
#!/usr/bin/env python
|
|
|
|
import sys
|
|
import operator
|
|
|
|
class Node:
|
|
def __init__(self, left, right, count):
|
|
self.left = left
|
|
self.right = right
|
|
self.count = count
|
|
|
|
def __unicode__(self):
|
|
return 'Node(left={}, right={}, count={})'.format(self.left,
|
|
self.right,
|
|
self.count)
|
|
|
|
def __str__(self):
|
|
return self.__unicode__()
|
|
|
|
|
|
class Leaf:
|
|
def __init__(self, byte, count):
|
|
self.byte = byte
|
|
self.count = count
|
|
|
|
def __unicode__(self):
|
|
return 'Leaf(byte={}, count={})'.format(self.byte, self.count)
|
|
|
|
def __str__(self):
|
|
return self.__unicode__()
|
|
|
|
|
|
class TableRow:
|
|
def __init__(self, byte, bits):
|
|
self.byte = byte
|
|
self.bits = bits
|
|
|
|
def __unicode__(self):
|
|
return 'TableRow(byte={}, bits={})'.format(self.byte, self.bits)
|
|
|
|
def __str__(self):
|
|
return self.__unicode__()
|
|
|
|
|
|
class BinPacker:
|
|
def __init__(self):
|
|
self.a = []
|
|
|
|
def bits(self, bits):
|
|
self.a.extend(bits)
|
|
|
|
def int32(self, int32):
|
|
return self.int(32, int32)
|
|
|
|
def int8(self, int8):
|
|
return self.int(8, int8)
|
|
|
|
def int(self, n, intn):
|
|
a = []
|
|
for i in range(n):
|
|
a.append(intn & 1)
|
|
intn = intn >> 1
|
|
self.a.extend(a)
|
|
|
|
def pack(self):
|
|
x = []
|
|
e = 0
|
|
for i, b in enumerate(self.a):
|
|
if i and not i % 8:
|
|
x.append(e)
|
|
e = 0
|
|
e += b << (i % 8)
|
|
x.append(e)
|
|
return x
|
|
|
|
|
|
class BinUnpacker:
|
|
def __init__(self, data):
|
|
self.a = BinUnpacker.unpack(data)
|
|
|
|
@staticmethod
|
|
def unpack(data):
|
|
a = []
|
|
for e in data:
|
|
for i in range(8):
|
|
a.append(e & 1)
|
|
e = e >> 1
|
|
return a
|
|
|
|
def int(self, n):
|
|
a = self.a[:n]
|
|
self.a = self.a[n:]
|
|
x = 0
|
|
for i, e in enumerate(a):
|
|
x = x + (e << i)
|
|
return x
|
|
|
|
def int8(self):
|
|
return self.int(8)
|
|
|
|
def int32(self):
|
|
return self.int(32)
|
|
|
|
def bits(self, n):
|
|
a = self.a[:n]
|
|
self.a = self.a[n:]
|
|
return a
|
|
|
|
def peek(self, n):
|
|
return self.a[:n]
|
|
|
|
|
|
def lookup_byte(table, byte):
|
|
for row in table:
|
|
if row.byte == byte:
|
|
return row.bits
|
|
raise RuntimeError("Internal Error")
|
|
|
|
|
|
def lookup_bits(table, unpacker):
|
|
for row in table:
|
|
if row.bits == unpacker.peek(len(row.bits)):
|
|
unpacker.bits(len(row.bits))
|
|
return row.byte
|
|
raise RuntimeError("Internal Error")
|
|
|
|
|
|
def pack_table(table, packer):
|
|
packer.int8(len(table) - 1)
|
|
for row in table:
|
|
packer.int8(ord(row.byte))
|
|
packer.int8(len(row.bits))
|
|
packer.bits(row.bits)
|
|
|
|
|
|
def unpack_table(unpacker):
|
|
table_len = unpacker.int8() + 1
|
|
table = []
|
|
for _ in range(table_len):
|
|
byte = chr(unpacker.int8())
|
|
n = unpacker.int8()
|
|
bits = unpacker.bits(n)
|
|
table.append(TableRow(byte, bits))
|
|
return table
|
|
|
|
|
|
def compress(original):
|
|
tree = build_tree(original)
|
|
table = build_table(tree)
|
|
packer = BinPacker()
|
|
|
|
pack_table(table, packer)
|
|
packer.int32(len(original))
|
|
|
|
for byte in bytes(original):
|
|
bits = lookup_byte(table, byte)
|
|
packer.bits(bits)
|
|
|
|
return packer.pack()
|
|
|
|
|
|
def decompress(data):
|
|
unpacker = BinUnpacker(data)
|
|
|
|
table = unpack_table(unpacker)
|
|
data_len = unpacker.int32()
|
|
|
|
s = []
|
|
for i in range(data_len):
|
|
s.append(lookup_bits(table, unpacker))
|
|
|
|
return "".join(s)
|
|
|
|
|
|
|
|
def build_table(node, path=None):
|
|
if not path: path = []
|
|
|
|
if isinstance(node, Node):
|
|
return build_table(node.left, path + [0]) + build_table(node.right, path + [1])
|
|
return [TableRow(node.byte, path)]
|
|
|
|
|
|
def build_tree(original):
|
|
bs = bytes(original)
|
|
unique_bs = set(bs)
|
|
|
|
leaves = []
|
|
for byte in unique_bs:
|
|
leaves.append(Leaf(byte, bs.count(byte)))
|
|
|
|
while len(leaves) > 1:
|
|
l = min(leaves, key=operator.attrgetter("count"))
|
|
leaves.remove(l)
|
|
r = min(leaves, key=operator.attrgetter("count"))
|
|
leaves.remove(r)
|
|
leaves.append(Node(l, r, l.count + r.count))
|
|
|
|
return leaves[0]
|
|
|
|
if len(sys.argv) > 1 and sys.argv[1] == "compress":
|
|
print("".join(chr(n) for n in compress(sys.stdin.read())))
|
|
else:
|
|
print(decompress([ord(n) for n in sys.stdin.read()]))
|