--- /dev/null
+#!/usr/bin/env python3
+
+import sys
+from intelhex import IntelHex
+
+EXIT_SUCCESS = 0
+EXIT_FAILURE = 1
+
+# short (8-bit) pointer
+DIST_BITS0 = 7
+LEN_BITS0 = 1
+
+# long (16-bit) pointer
+DIST_BITS1 = 10
+LEN_BITS1 = 6
+
+MAX_DIST = (1 << DIST_BITS1) # distance codes are 1..MAX_DIST
+MAX_LEN = (1 << LEN_BITS1) + 1 # length codes are 2..MAX_LEN
+
+if len(sys.argv) < 5:
+ print(
+ f'usage: {sys.argv[0]:s} load_addr lzss_unpack_rev.bin in.ihx out.a2bin'
+ )
+ sys.exit(EXIT_FAILURE)
+load_addr = int(sys.argv[1], 0)
+lzss_unpack_rev_bin = sys.argv[2]
+in_ihx = sys.argv[3]
+out_a2bin = sys.argv[4]
+
+with open(lzss_unpack_rev_bin, 'rb') as fin:
+ lzss_unpack_rev = list(fin.read())
+
+def lzss_pack(dest, bin):
+ bin = bin[::-1] # !!! for rev, makes it easier to construct LZSS items
+
+ heads = {}
+ links = [-1] * len(bin)
+ lzss = []
+ i = 0
+ while i < len(bin):
+ _len = 1
+ dist = bin[i]
+
+ if i + 1 < len(bin):
+ pair = bin[i], bin[i + 1]
+ j = heads.get(pair, -1)
+ while j != -1 and i - j <= MAX_DIST:
+ #assert bin[i:i + 2] == bin[j:j + 2]
+ if (
+ _len < MAX_LEN and
+ i + _len < len(bin) and
+ bin[i + 2:i + _len + 1] == bin[j + 2:j + _len + 1]
+ ):
+ _len += 1
+ while (
+ _len < MAX_LEN and
+ i + _len < len(bin) and
+ bin[i + _len] == bin[j + _len]
+ ):
+ _len += 1
+ dist = i - j
+ j = links[j]
+ lzss.append((_len, dist))
+
+ for j in range(_len):
+ if i + 1 < len(bin):
+ pair = bin[i], bin[i + 1]
+ links[i] = heads.get(pair, -1)
+ heads[pair] = i
+ i += 1
+
+ # checking
+ bin1 = []
+ lzss1 = lzss[::-1]
+ while len(lzss1):
+ _len, dist = lzss1.pop()
+ if _len == 1:
+ bin1.append(dist)
+ else:
+ for i in range(_len):
+ bin1.append(bin1[-dist])
+ assert bin == bin1
+
+ # construct the real output in reverse to how it will be decoded,
+ # this means we flush the bits at the right time for the decoder,
+ # and any partial bit buffer is decoded at start rather than end
+ lzss1 = []
+ count = 0
+ bits = 1
+ while len(lzss):
+ _len, dist = lzss.pop()
+ if _len == 1:
+ #print('a', dist)
+ lzss1.append(dist)
+ cf = 0
+ else:
+ _len -= 2
+ dist -= 1
+ if _len < (1 << LEN_BITS0) and dist < (1 << DIST_BITS0):
+ item = dist | (_len << DIST_BITS0)
+ #print('b', item)
+ lzss1.append(item)
+ cf = 0
+ elif _len < (1 << LEN_BITS1) and dist < (1 << DIST_BITS1):
+ item = dist | (_len << DIST_BITS1)
+ #print('c', item)
+ lzss1.extend([item & 0xff, item >> 8]) # !!! swapped for rev
+ cf = 1
+ else:
+ assert False
+
+ bits = (bits << 1) | cf
+ if bits & 0x100:
+ #print('d', bits)
+ lzss1.append(bits & 0xff)
+ bits = 1
+ # in this case we leave count alone (at decoding side we get
+ # another bit buffer for free without any increment or test)
+
+ cf = 1
+
+ bits = (bits << 1) | cf
+ if bits & 0x100:
+ #print('e', bits)
+ lzss1.append(bits & 0xff)
+ bits = 1
+ count += 1
+ lzss = lzss1 # !!! not reversed for rev
+
+ # checking
+ bin1 = []
+ lzss1 = list(lzss) # !!! not reversed for rev
+ count1 = count
+ bits1 = bits
+ while True:
+ if bits1 == 1:
+ if count1 == 0:
+ break
+ count1 -= 1
+ bits1 = lzss1.pop() | 0x100
+ #print('e', bits1)
+ cf = bits1 & 1
+ bits1 >>= 1
+
+ if cf:
+ if bits1 == 1:
+ bits1 = lzss1.pop() | 0x100
+ #print('d', bits1)
+ cf = bits1 & 1
+ bits1 >>= 1
+
+ if cf:
+ item = lzss1[-2] | (lzss1[-1] << 8) # !!! swapped for rev
+ del lzss1[-2:]
+ #print('c', item)
+ dist = item & ((1 << DIST_BITS1) - 1)
+ _len = item >> DIST_BITS1
+ else:
+ item = lzss1.pop()
+ #print('b', item)
+ dist = item & ((1 << DIST_BITS0) - 1)
+ _len = item >> DIST_BITS0
+ _len += 2
+ dist += 1
+
+ for i in range(_len):
+ bin1.append(bin1[-dist])
+ else:
+ #print('a', lzss1[-1])
+ bin1.append(lzss1.pop())
+ assert len(lzss1) == 0
+ assert bin1 == bin
+
+ # optimization: provided the input is not null, the first byte
+ # has to be literal, so the loader can fall straight into the
+ # literal decoding routine (saves a jump to the official loop)
+ if bits == 1:
+ assert count
+ count -= 1
+ bits = lzss.pop() | 0x100 # !!! from end for rev
+ assert (bits & 1) == 0
+ bits >>= 1
+
+ # append data block
+ count ^= 0xffff # inc/test is easier than test/dec
+ return lzss + [dest & 0xff, dest >> 8, count & 0xff, count >> 8, bits]
+
+intelhex = IntelHex(in_ihx)
+entry_point = intelhex.start_addr['EIP']
+segments = [j for i in intelhex.segments() for j in i]
+
+# zero page and stack are done last, after we finish with them,
+# and in 0x100-byte pieces so we can do them without zero page
+def intersect(segments, segment):
+ [addr0, addr1] = segment
+ segments1 = []
+ for i in range(0, len(segments), 2):
+ [addr2, addr3] = segments[i:i + 2]
+ if addr2 < addr0:
+ addr2 = addr0
+ if addr3 > addr1:
+ addr3 = addr1
+ if addr3 > addr2:
+ segments1.extend([addr2, addr3])
+ return segments1
+segments = (
+ intersect(segments, [0, 0x100]) +
+ intersect(segments, [0x100, 0x200]) +
+ intersect(segments, [0x200, 0x10000])
+)
+
+# sections are output to the a2bin file from bottom to top as follows:
+SECTION_LOADER = 0
+SECTION_UNPACKER = 1
+SECTION_PAYLOAD = 2
+N_SECTIONS = 3
+
+# fixup is a 4-tuple:
+# (fixup type, fixup address, target section, target address)
+# both addresses are relative to the load address of the section,
+# except if the fixup address is negative it is relative to end
+# (needed for the loader section, which has to be constructed in
+# reverse, because the 6502 cannot execute in reverse, thus the
+# order of execution has to be opposite to the payload section)
+FIXUP_TYPE_LO_BYTE = 0
+FIXUP_TYPE_HI_BYTE = 1
+FIXUP_TYPE_WORD = 2
+
+# each section has a data area, a load address and a list of fixups
+# relocation is done after section lengths and load addresses known
+class Section:
+ def __init__(self, data, load_addr, fixups):
+ self.data = data
+ self.load_addr = load_addr
+ self.fixups = fixups
+sections = [Section([], 0, []) for i in range(N_SECTIONS)]
+
+# report is a 5-tuple:
+# (report type, ihx start, ihx end, a2bin start, a2bin end)
+# for compressed the compression ratio will be printed
+# for direct poke the a2bin values are not used, otherwise they
+# are relative to the load address of the payload section
+# report is used to visually check for source/destination overlap
+REPORT_TYPE_DIRECT_POKE = 0
+REPORT_TYPE_UNCOMPRESSED = 1
+REPORT_TYPE_COMPRESSED = 2
+report = []
+
+# prologue
+sections[SECTION_LOADER].data.extend(
+ [
+ 0x4c, entry_point & 0xff, entry_point >> 8, # jmp entry_point
+ ][::-1]
+)
+
+# segments
+# constructed bottom to top (in reverse order of unpacking: top to bottom)
+for i in range(0, len(segments), 2):
+ addr0 = segments[i]
+ addr1 = segments[i + 1]
+ data = list(intelhex.tobinstr(addr0, addr1 - 1))
+
+ if len(data) <= 4:
+ report.append(
+ (REPORT_TYPE_DIRECT_POKE, addr0, addr1, 0, 0)
+ )
+
+ # use of zpage version is determined byte by byte
+ for i in data:
+ sections[SECTION_LOADER].data.extend(
+ [
+ 0xa9, i, # lda #data
+ 0x85, addr0, # sta *addr0
+ ][::-1]
+ if addr0 < 0x100 else
+ [
+ 0xa9, i, # lda #data
+ 0x8d, addr0 & 0xff, addr0 >> 8, # sta addr0
+ ][::-1]
+ )
+ addr0 += 1
+ elif len(data) <= 0x100:
+ addr2 = len(sections[SECTION_PAYLOAD].data)
+ sections[SECTION_PAYLOAD].data.extend(
+ data
+ )
+ addr3 = len(sections[SECTION_PAYLOAD].data)
+ report.append(
+ (REPORT_TYPE_UNCOMPRESSED, addr0, addr1, addr2, addr3)
+ )
+
+ # use of zpage version is determined in advance (if completely fits)
+ zpage = addr1 < 0x100
+
+ if len(data) == 0x100:
+ # for the full count we will copy forward (an exception)
+ sections[SECTION_LOADER].data.extend(
+ [
+ 0xa2, 0x00, # ldx #0
+ 0xbd, 0x00, 0x00, # lda addr2,x
+ 0x95, addr0 & 0xff, # sta *addr0,x
+ 0xe8, # inx
+ 0xd0, 0xf8 # bne .-6
+ ][::-1]
+ if zpage else
+ [
+ 0xa2, 0x00, # ldx #0
+ 0xbd, 0x00, 0x00, # lda addr2,x
+ 0x9d, addr0 & 0xff, (addr0 >> 8) & 0xff, # sta addr0,x
+ 0xe8, # inx
+ 0xd0, 0xf7 # bne .-7
+ ][::-1]
+ )
+ else:
+ addr0 -= 1
+ addr2 -= 1
+ sections[SECTION_LOADER].data.extend(
+ [
+ 0xa2, len(data), # ldx #count
+ 0xbd, 0x00, 0x00, # lda addr2,x
+ 0x95, addr0 & 0xff, # sta *addr0,x
+ 0xca, # dex
+ 0xd0, 0xf8 # bne .-6
+ ][::-1]
+ if zpage else
+ [
+ 0xa2, len(data), # ldx #count
+ 0xbd, 0x00, 0x00, # lda addr2,x
+ 0x9d, addr0 & 0xff, (addr0 >> 8) & 0xff, # sta addr0,x
+ 0xca, # dex
+ 0xd0, 0xf7 # bne .-7
+ ][::-1]
+ )
+ sections[SECTION_LOADER].fixups.extend(
+ [
+ (
+ FIXUP_TYPE_WORD,
+ 3 - len(sections[SECTION_LOADER].data),
+ SECTION_PAYLOAD,
+ addr2
+ ),
+ ]
+ )
+ else:
+ addr2 = len(sections[SECTION_PAYLOAD].data)
+ sections[SECTION_PAYLOAD].data.extend(
+ lzss_pack(addr1 - 1, data)
+ )
+ addr3 = len(sections[SECTION_PAYLOAD].data)
+ report.append(
+ (REPORT_TYPE_COMPRESSED, addr0, addr1, addr2, addr3)
+ )
+
+ if len(sections[SECTION_UNPACKER].data) == 0:
+ sections[SECTION_UNPACKER].data.extend(
+ lzss_unpack_rev
+ )
+
+ addr3 -= 5 + 1
+ sections[SECTION_LOADER].data.extend(
+ [
+ 0xa9, 0x00, # lda #<addr3
+ 0xa0, 0x00, # ldy #>addr3
+ 0x20, 0x00, 0x00, # jsr lzss_unpack_rev
+ ][::-1]
+ )
+ sections[SECTION_LOADER].fixups.extend(
+ [
+ (
+ FIXUP_TYPE_LO_BYTE,
+ 1 - len(sections[SECTION_LOADER].data),
+ SECTION_PAYLOAD,
+ addr3
+ ),
+ (
+ FIXUP_TYPE_HI_BYTE,
+ 3 - len(sections[SECTION_LOADER].data),
+ SECTION_PAYLOAD,
+ addr3
+ ),
+ (
+ FIXUP_TYPE_WORD,
+ 5 - len(sections[SECTION_LOADER].data),
+ SECTION_UNPACKER,
+ 0
+ ),
+ ]
+ )
+
+
+# prologue
+sections[SECTION_LOADER].data.extend(
+ [
+ 0xd8, # cld
+ 0xa2, 0xff, # ldx #0xff
+ 0x9a, # txs
+ ][::-1]
+)
+sections[SECTION_LOADER].data = sections[SECTION_LOADER].data[::-1]
+
+# relocate
+end_addr = load_addr
+for i in range(N_SECTIONS):
+ sections[i].load_addr = end_addr
+ end_addr += len(sections[i].data)
+load_size = end_addr - load_addr
+
+for report_type, addr0, addr1, addr2, addr3 in report:
+ if report_type == REPORT_TYPE_DIRECT_POKE:
+ print(f'[0x{addr0:04x}, 0x{addr1:04x})')
+ else:
+ addr2 += sections[SECTION_PAYLOAD].load_addr
+ addr3 += sections[SECTION_PAYLOAD].load_addr
+ print(
+ f'[0x{addr0:04x}, 0x{addr1:04x}) -> [0x{addr2:04x}, 0x{addr3:04x})' + (
+ f'{100. * (addr3 - addr2) / (addr1 - addr0):6.1f}%'
+ if report_type == REPORT_TYPE_COMPRESSED else
+ ''
+ )
+ )
+
+bin = []
+for i in range(N_SECTIONS):
+ for fixup_type, fixup_addr, section, addr in sections[i].fixups:
+ addr += sections[section].load_addr
+ if fixup_type == FIXUP_TYPE_LO_BYTE:
+ assert sections[i].data[fixup_addr] == 0
+ sections[i].data[fixup_addr] = addr & 0xff
+ elif fixup_type == FIXUP_TYPE_HI_BYTE:
+ assert sections[i].data[fixup_addr] == 0
+ sections[i].data[fixup_addr] = addr >> 8
+ elif fixup_type == FIXUP_TYPE_WORD:
+ assert sections[i].data[fixup_addr] == 0
+ sections[i].data[fixup_addr] = addr & 0xff
+ assert sections[i].data[fixup_addr + 1] == 0
+ sections[i].data[fixup_addr + 1] = addr >> 8
+ else:
+ assert False
+ bin.extend(sections[i].data)
+
+hdr = [load_addr & 0xff, load_addr >> 8, load_size & 0xff, load_size >> 8]
+with open(out_a2bin, 'wb') as fout:
+ fout.write(bytes(hdr + bin))