--- /dev/null
+#!/usr/bin/env python3
+
+import sys
+
+EXIT_SUCCESS = 0
+EXIT_FAILURE = 1
+
+# short (8-bit) pointer
+DIST_BITS0 = 7
+LEN_BITS0 = 1
+
+# long (16-bit) pointer
+DIST_BITS1 = 11
+LEN_BITS1 = 5
+
+MAX_DIST = (1 << DIST_BITS1) # distance codes are 1..MAX_DIST
+MAX_LEN = (1 << LEN_BITS1) + 1 # length codes are 2..MAX_LEN
+
+if len(sys.argv) < 3:
+ print(f'usage: {sys.argv[0]:s} in.bin out.bin')
+ sys.exit(EXIT_FAILURE)
+in_bin = sys.argv[1]
+out_bin = sys.argv[2]
+
+with open(in_bin, 'rb') as fin:
+ data = list(fin.read())
+ hdr = data[:4]
+ lzss = data[4:]
+load_addr0 = hdr[0] | (hdr[1] << 8)
+load_size0 = hdr[2] | (hdr[3] << 8)
+assert load_size0 == len(lzss)
+
+assert lzss[1] == 0xa9 # lda #NN
+assert lzss[5] == 0xa9 # lda #NN
+src = lzss[2] | (lzss[6] << 8)
+
+assert lzss[9] == 0xa9 # lda #NN
+assert lzss[0xd] == 0xa9 # lda #NN
+dest = lzss[0xa] + (lzss[0xe] << 8)
+
+assert lzss[0x11] == 0xa9 # lda #NN
+assert lzss[0x15] == 0xa9 # lda #NN
+count = lzss[0x12] + (lzss[0x16] << 8)
+
+assert lzss[0x19] == 0xa9 # lda #NN
+#assert lzss[0x1d] & ~0x20 == 0x18 # clc
+bits = lzss[0x1a] #| ((lzss[0x1d] & 0x20) << 3)
+
+assert lzss[0x42] == 0x4c # jmp NNNN
+entry_point = lzss[0x43] | (lzss[0x44] << 8)
+
+assert src == load_addr0 + load_size0
+count = ~count & 0xffff
+
+bin = []
+lzss = lzss
+count = count
+bits = bits
+while True:
+ if bits == 1:
+ if count == 0:
+ break
+ count -= 1
+ bits = lzss.pop() | 0x100
+ #print('e', bits)
+ cf = bits & 1
+ bits >>= 1
+
+ if cf:
+ if bits == 1:
+ bits = lzss.pop() | 0x100
+ #print('d', bits)
+ cf = bits & 1
+ bits >>= 1
+
+ if cf:
+ item = lzss[-2] | (lzss[-1] << 8)
+ del lzss[-2:]
+ #print('c', item)
+ dist = item & ((1 << DIST_BITS1) - 1)
+ _len = item >> DIST_BITS1
+ else:
+ item = lzss.pop()
+ #print('b', item)
+ dist = item & ((1 << DIST_BITS0) - 1)
+ _len = item >> DIST_BITS0
+ _len += 2
+ dist += 1
+
+ for i in range(_len):
+ bin.append(bin[-dist])
+ else:
+ #print('a', lzss[-1])
+ bin.append(lzss.pop())
+assert len(lzss) == 0xb6
+bin = bin[::-1]
+
+load_size = len(bin)
+print(f'lzss {load_size0:04x} orig {load_size:04x}')
+
+load_addr = dest + 1 - load_size
+if entry_point != load_addr:
+ bin = [0x4c, entry_point & 0xff, entry_point >> 8] + bin # jmp NNNN
+ load_addr -= 3
+ load_size += 3
+
+hdr = [load_addr & 0xff, load_addr >> 8, load_size & 0xff, load_size >> 8]
+with open(out_bin, 'wb') as fout:
+ fout.write(bytes(hdr + bin))
--- /dev/null
+#!/usr/bin/env python3
+
+import bisect
+import numpy
+import sys
+from heapdict import heapdict
+
+EXIT_SUCCESS = 0
+EXIT_FAILURE = 1
+
+# short (8-bit) pointer
+DIST_BITS0 = 7
+LEN_BITS0 = 1
+
+# long (16-bit) pointer
+DIST_BITS1 = 11
+LEN_BITS1 = 5
+
+MAX_DIST = (1 << DIST_BITS1) # distance codes are 1..MAX_DIST
+MAX_LEN = (1 << LEN_BITS1) + 1 # length codes are 2..MAX_LEN
+
+if len(sys.argv) < 5:
+ print(f'usage: {sys.argv[0]:s} load_addr lzss_loader.bin in.bin out.bin')
+ sys.exit(EXIT_FAILURE)
+load_addr = int(sys.argv[1], 0)
+lzss_loader_bin = sys.argv[2]
+in_bin = sys.argv[3]
+out_bin = sys.argv[4]
+
+with open(lzss_loader_bin, 'rb') as fin:
+ lzss_loader = list(fin.read())
+assert len(lzss_loader) == 0xb6
+
+with open(in_bin, 'rb') as fin:
+ data = list(fin.read())
+ hdr = data[:4]
+ bin = data[4:]
+load_addr0 = hdr[0] | (hdr[1] << 8)
+load_size0 = hdr[2] | (hdr[3] << 8)
+assert load_size0 == len(bin)
+
+# absorb a jump to the real entry point, ensuring that we
+# can losslessly reconstruct the jump when decoding later
+entry_point = load_addr0
+if bin[0] == 0x4c: # jmp NNNN
+ entry_point = bin[1] | (bin[2] << 8)
+ assert entry_point != load_addr0 + 3
+
+ bin = bin[3:]
+ load_addr0 += 3
+ load_size0 -= 3
+
+# construct the LZSS items in order they'll be decoded, but store
+# them in reverse (stack-wise) as we don't know the final length
+heads = {}
+links = [-1] * len(bin)
+lzss = []
+i = len(bin) - 1
+while i >= 0:
+ _len = 1
+ dist = bin[i]
+
+ if i >= 1:
+ pair = bin[i - 1], bin[i]
+ j = heads.get(pair, -1)
+ while j != -1 and j - i <= MAX_DIST:
+ assert bin[i] == bin[j]
+ assert bin[i - 1] == bin[j - 1]
+ k = 2
+ while k < MAX_LEN and i - k >= 1 and bin[i - k] == bin[j - k]:
+ k += 1
+ if k > _len:
+ _len = k
+ dist = j - i
+ j = links[j]
+ lzss.append((_len, dist))
+
+ for j in range(_len):
+ if i >= 1:
+ pair = bin[i - 1], bin[i]
+ links[i] = heads.get(pair, -1)
+ heads[pair] = i
+ i -= 1
+lzss = lzss[::-1]
+
+# checking
+#bin1 = []
+#lzss1 = list(lzss)
+#while len(lzss1):
+# _len, dist = lzss1.pop()
+# if _len == 1:
+# bin1.append(dist)
+# else:
+# for i in range(_len):
+# bin1.append(bin1[-dist])
+#bin1 = bin1[::-1]
+#assert bin == bin1
+
+# construct the real output in reverse to how it will be decoded,
+# this means we flush the bits at the right time for the decoder,
+# and any partial bit buffer is decoded at start rather than end
+lzss1 = list(lzss_loader)
+count = 0
+bits = 1
+for _len, dist in lzss:
+ if _len == 1:
+ #print('a', dist)
+ lzss1.append(dist)
+ cf = 0
+ else:
+ _len -= 2
+ dist -= 1
+ if _len < (1 << LEN_BITS0) and dist < (1 << DIST_BITS0):
+ item = dist | (_len << DIST_BITS0)
+ #print('b', item)
+ lzss1.append(item)
+ cf = 0
+ elif _len < (1 << LEN_BITS1) and dist < (1 << DIST_BITS1):
+ item = dist | (_len << DIST_BITS1)
+ #print('c', item)
+ lzss1.extend([item & 0xff, item >> 8])
+ cf = 1
+ else:
+ assert False
+
+ bits = (bits << 1) | cf
+ if bits & 0x100:
+ #print('d', bits)
+ lzss1.append(bits & 0xff)
+ bits = 1
+ # in this case we leave count alone (at decoding side we get
+ # another bit buffer for free without any increment or test)
+
+ cf = 1
+
+ bits = (bits << 1) | cf
+ if bits & 0x100:
+ #print('e', bits)
+ lzss1.append(bits & 0xff)
+ bits = 1
+ count += 1
+lzss = lzss1
+load_size = len(lzss)
+print('orig', f'{load_size0:04x}', 'lzss', f'{load_size:04x}')
+
+# checking
+#bin1 = []
+#lzss1 = lzss
+#count1 = count
+#bits1 = bits
+#while True:
+# if bits1 == 1:
+# if count1 == 0:
+# break
+# count1 -= 1
+# bits1 = lzss1.pop() | 0x100
+# #print('e', bits1)
+# cf = bits1 & 1
+# bits1 >>= 1
+#
+# if cf:
+# if bits1 == 1:
+# bits1 = lzss1.pop() | 0x100
+# #print('d', bits1)
+# cf = bits1 & 1
+# bits1 >>= 1
+#
+# if cf:
+# item = lzss1[-2] | (lzss1[-1] << 8)
+# del lzss1[-2:]
+# #print('c', item)
+# dist = item & ((1 << DIST_BITS1) - 1)
+# _len = item >> DIST_BITS1
+# else:
+# item = lzss1.pop()
+# #print('b', item)
+# dist = item & ((1 << DIST_BITS0) - 1)
+# _len = item >> DIST_BITS0
+# _len += 2
+# dist += 1
+#
+# for i in range(_len):
+# bin1.append(bin1[-dist])
+# else:
+# #print('a', lzss1[-1])
+# bin1.append(lzss1.pop())
+#assert lzss1 == lzss_loader
+#bin1 = bin1[::-1]
+#assert bin == bin1
+
+src = load_addr + load_size
+dest = load_addr0 + load_size0 - 1
+count = ~count & 0xffff # inc/test is easier than test/dec
+
+assert lzss[0x1] == 0xa9 # lda #NN
+lzss[0x2] = src & 0xff
+assert lzss[5] == 0xa9 # lda #NN
+lzss[6] = src >> 8
+
+assert lzss[9] == 0xa9 # lda #NN
+lzss[0xa] = dest & 0xff
+assert lzss[0xd] == 0xa9 # lda #NN
+lzss[0xe] = dest >> 8
+
+assert lzss[0x11] == 0xa9 # lda #NN
+lzss[0x12] = count & 0xff
+assert lzss[0x15] == 0xa9 # lda #NN
+lzss[0x16] = count >> 8
+
+assert lzss[0x19] == 0xa9 # lda #NN
+lzss[0x1a] = bits #& 0xff
+#assert lzss[0x1b] == 0x18 # clc
+#lzss[0x1b] |= (bits >> 3) & 0x20 # clc or sec
+
+assert lzss[0x42] == 0x4c # jmp NNNN
+lzss[0x43] = entry_point & 0xff
+lzss[0x44] = entry_point >> 8
+
+hdr = [load_addr & 0xff, load_addr >> 8, load_size & 0xff, load_size >> 8]
+with open(out_bin, 'wb') as fout:
+ fout.write(bytes(hdr + lzss))