compression: Implement ciscorn's dictionary approach

Massive savings.  Thanks so much @ciscorn for providing the initial
code for choosing the dictionary.

This adds a bit of time to the build, both to find the dictionary
but also because (for reasons I don't fully understand), the binary
search in the compress() function no longer worked and had to be
replaced with a linear search.

I think this is because the intended invariant is that for codebook
entries that encode to the same number of bits, the entries are ordered
in ascending value.  However, I mis-placed the transition from "words"
to "byte/char values" so the codebook entries for words are in word-order
rather than their code order.

Because this price is only paid at build time, I didn't care to determine
exactly where the correct fix was.

I also commented out a line to produce the "estimated total memory size"
-- at least on the unix build with TRANSLATION=ja, this led to a build
time KeyError trying to compute the codebook size for all the strings.
I think this occurs because some single unicode code point ('ァ') is
no longer present as itself in the compressed strings, due to always
being replaced by a word.

As promised, this seems to save hundreds of bytes in the German translation
on the trinket m0.

Testing performed:
 - built trinket_m0 in several languages
 - built and ran unix port in several languages (en, de_DE, ja) and ran
   simple error-producing codes like ./micropython -c '1/0'
This commit is contained in:
Jeff Epler 2020-09-12 10:10:18 -05:00
parent 7611e71a1b
commit 40ab5c6b21
2 changed files with 152 additions and 66 deletions

View file

@ -100,77 +100,153 @@ def translate(translation_file, i18ns):
translations.append((original, translation))
return translations
def frequent_ngrams(corpus, sz, n):
return collections.Counter(corpus[i:i+sz] for i in range(len(corpus)-sz)).most_common(n)
class TextSplitter:
def __init__(self, words):
words.sort(key=lambda x: len(x), reverse=True)
self.words = set(words)
self.pat = re.compile("|".join(re.escape(w) for w in words) + "|.", flags=re.DOTALL)
def encode_ngrams(translation, ngrams):
if len(ngrams) > 32:
start = 0xe000
else:
start = 0x80
for i, g in enumerate(ngrams):
translation = translation.replace(g, chr(start + i))
return translation
def iter_words(self, text):
s = []
for m in self.pat.finditer(text):
t = m.group(0)
if t in self.words:
if s:
yield (False, "".join(s))
s = []
yield (True, t)
else:
s.append(t)
if s:
yield (False, "".join(s))
def decode_ngrams(compressed, ngrams):
if len(ngrams) > 32:
start, end = 0xe000, 0xf8ff
else:
start, end = 0x80, 0x9f
return "".join(ngrams[ord(c) - start] if (start <= ord(c) <= end) else c for c in compressed)
def iter(self, text):
s = []
for m in self.pat.finditer(text):
yield m.group(0)
def iter_substrings(s, minlen, maxlen):
maxlen = min(len(s), maxlen)
for n in range(minlen, maxlen + 1):
for begin in range(0, len(s) - n + 1):
yield s[begin : begin + n]
def compute_huffman_coding(translations, compression_filename):
texts = [t[1] for t in translations]
all_strings_concat = "".join(texts)
words = []
max_ord = 0
begin_unused = 128
end_unused = 256
for text in texts:
for c in text:
ord_c = ord(c)
max_ord = max(max_ord, ord_c)
if 128 <= ord_c < 256:
end_unused = min(ord_c, end_unused)
max_words = end_unused - begin_unused
char_size = 1 if max_ord < 256 else 2
sum_word_len = 0
while True:
extractor = TextSplitter(words)
counter = collections.Counter()
for t in texts:
for (found, word) in extractor.iter_words(t):
if not found:
for substr in iter_substrings(word, minlen=2, maxlen=9):
counter[substr] += 1
scores = sorted(
(
# I don't know why this works good. This could be better.
(s, (len(s) - 1) ** ((max(occ - 2, 1) + 0.5) ** 0.8), occ)
for (s, occ) in counter.items()
),
key=lambda x: x[1],
reverse=True,
)
w = None
for (s, score, occ) in scores:
if score < 0:
break
if len(s) > 1:
w = s
break
if not w:
break
if len(w) + sum_word_len > 256:
break
if len(words) == max_words:
break
words.append(w)
sum_word_len += len(w)
extractor = TextSplitter(words)
counter = collections.Counter()
for t in texts:
for atom in extractor.iter(t):
counter[atom] += 1
cb = huffman.codebook(counter.items())
word_start = begin_unused
word_end = word_start + len(words) - 1
print("// # words", len(words))
print("// words", words)
def compute_huffman_coding(translations, qstrs, compression_filename):
all_strings = [x[1] for x in translations]
all_strings_concat = "".join(all_strings)
ngrams = [i[0] for i in frequent_ngrams(all_strings_concat, 2, 32)]
all_strings_concat = encode_ngrams(all_strings_concat, ngrams)
counts = collections.Counter(all_strings_concat)
cb = huffman.codebook(counts.items())
values = []
length_count = {}
renumbered = 0
last_l = None
canonical = {}
for ch, code in sorted(cb.items(), key=lambda x: (len(x[1]), x[0])):
values.append(ch)
for atom, code in sorted(cb.items(), key=lambda x: (len(x[1]), x[0])):
values.append(atom)
l = len(code)
if l not in length_count:
length_count[l] = 0
length_count[l] += 1
if last_l:
renumbered <<= (l - last_l)
canonical[ch] = '{0:0{width}b}'.format(renumbered, width=l)
s = C_ESCAPES.get(ch, ch)
print("//", ord(ch), s, counts[ch], canonical[ch], renumbered)
canonical[atom] = '{0:0{width}b}'.format(renumbered, width=l)
#print(f"atom={repr(atom)} code={code}", file=sys.stderr)
if len(atom) > 1:
o = words.index(atom) + 0x80
s = "".join(C_ESCAPES.get(ch1, ch1) for ch1 in atom)
else:
s = C_ESCAPES.get(atom, atom)
o = ord(atom)
print("//", o, s, counter[atom], canonical[atom], renumbered)
renumbered += 1
last_l = l
lengths = bytearray()
print("// length count", length_count)
print("// bigrams", ngrams)
for i in range(1, max(length_count) + 2):
lengths.append(length_count.get(i, 0))
print("// values", values, "lengths", len(lengths), lengths)
ngramdata = [ord(ni) for i in ngrams for ni in i]
print("// estimated total memory size", len(lengths) + 2*len(values) + 2 * len(ngramdata) + sum((len(cb[u]) + 7)//8 for u in all_strings_concat))
maxord = max(ord(u) for u in values if len(u) == 1)
values_type = "uint16_t" if maxord > 255 else "uint8_t"
ch_size = 1 if maxord > 255 else 2
print("//", values, lengths)
values = [(atom if len(atom) == 1 else chr(0x80 + words.index(atom))) for atom in values]
print("//", values, lengths)
values_type = "uint16_t" if max(ord(u) for u in values) > 255 else "uint8_t"
max_translation_encoded_length = max(len(translation.encode("utf-8")) for original,translation in translations)
with open(compression_filename, "w") as f:
f.write("const uint8_t lengths[] = {{ {} }};\n".format(", ".join(map(str, lengths))))
f.write("const {} values[] = {{ {} }};\n".format(values_type, ", ".join(str(ord(u)) for u in values)))
f.write("#define compress_max_length_bits ({})\n".format(max_translation_encoded_length.bit_length()))
f.write("const {} bigrams[] = {{ {} }};\n".format(values_type, ", ".join(str(u) for u in ngramdata)))
if len(ngrams) > 32:
bigram_start = 0xe000
else:
bigram_start = 0x80
bigram_end = bigram_start + len(ngrams) - 1 # End is inclusive
f.write("#define bigram_start {}\n".format(bigram_start))
f.write("#define bigram_end {}\n".format(bigram_end))
return values, lengths, ngrams
f.write("const {} words[] = {{ {} }};\n".format(values_type, ", ".join(str(ord(c)) for w in words for c in w)))
f.write("const uint8_t wlen[] = {{ {} }};\n".format(", ".join(str(len(w)) for w in words)))
f.write("#define word_start {}\n".format(word_start))
f.write("#define word_end {}\n".format(word_end))
extractor = TextSplitter(words)
return values, lengths, words, extractor
def decompress(encoding_table, encoded, encoded_length_bits):
values, lengths, ngrams = encoding_table
values, lengths, words, extractor = encoding_table
dec = []
this_byte = 0
this_bit = 7
@ -218,7 +294,8 @@ def decompress(encoding_table, encoded, encoded_length_bits):
searched_length += lengths[bit_length]
v = values[searched_length + bits - max_code]
v = decode_ngrams(v, ngrams)
if v >= chr(0x80) and v < chr(0x80 + len(words)):
v = words[ord(v) - 0x80]
i += len(v.encode('utf-8'))
dec.append(v)
return ''.join(dec)
@ -226,8 +303,8 @@ def decompress(encoding_table, encoded, encoded_length_bits):
def compress(encoding_table, decompressed, encoded_length_bits, len_translation_encoded):
if not isinstance(decompressed, str):
raise TypeError()
values, lengths, ngrams = encoding_table
decompressed = encode_ngrams(decompressed, ngrams)
values, lengths, words, extractor = encoding_table
enc = bytearray(len(decompressed) * 3)
#print(decompressed)
#print(lengths)
@ -246,9 +323,15 @@ def compress(encoding_table, decompressed, encoded_length_bits, len_translation_
else:
current_bit -= 1
for c in decompressed:
#print()
#print("char", c, values.index(c))
#print("values = ", values, file=sys.stderr)
for atom in extractor.iter(decompressed):
#print("", file=sys.stderr)
if len(atom) > 1:
c = chr(0x80 + words.index(atom))
else:
c = atom
assert c in values
start = 0
end = lengths[0]
bits = 1
@ -258,18 +341,12 @@ def compress(encoding_table, decompressed, encoded_length_bits, len_translation_
s = start
e = end
#print("{0:0{width}b}".format(code, width=bits))
# Binary search!
while e > s:
midpoint = (s + e) // 2
#print(s, e, midpoint)
if values[midpoint] == c:
compressed = code + (midpoint - start)
#print("found {0:0{width}b}".format(compressed, width=bits))
# Linear search!
for i in range(s, e):
if values[i] == c:
compressed = code + (i - start)
#print("found {0:0{width}b}".format(compressed, width=bits), file=sys.stderr)
break
elif c < values[midpoint]:
e = midpoint
else:
s = midpoint + 1
code += end - start
code <<= 1
start = end
@ -452,7 +529,7 @@ if __name__ == "__main__":
if args.translation:
i18ns = sorted(i18ns)
translations = translate(args.translation, i18ns)
encoding_table = compute_huffman_coding(translations, qstrs, args.compression_filename)
encoding_table = compute_huffman_coding(translations, args.compression_filename)
print_qstr_data(encoding_table, qcfgs, qstrs, translations)
else:
print_qstr_enums(qstrs)

View file

@ -47,13 +47,22 @@ STATIC int put_utf8(char *buf, int u) {
if(u <= 0x7f) {
*buf = u;
return 1;
} else if(bigram_start <= u && u <= bigram_end) {
int n = (u - 0x80) * 2;
// (note that at present, entries in the bigrams table are
// guaranteed not to represent bigrams themselves, so this adds
} else if(word_start <= u && u <= word_end) {
int n = (u - 0x80);
size_t off = 0;
for(int i=0; i<n; i++) {
off += wlen[i];
}
int ret = 0;
// note that at present, entries in the words table are
// guaranteed not to represent words themselves, so this adds
// at most 1 level of recursive call
int ret = put_utf8(buf, bigrams[n]);
return ret + put_utf8(buf + ret, bigrams[n+1]);
for(int i=0; i<wlen[n]; i++) {
int len = put_utf8(buf, words[off+i]);
buf += len;
ret += len;
}
return ret;
} else if(u <= 0x07ff) {
*buf++ = 0b11000000 | (u >> 6);
*buf = 0b10000000 | (u & 0b00111111);