Add a2_pack_rev.py and lzss_unpack_rev.py (loads game without clobbering DOS)
authorNick Downing <nick@ndcode.org>
Tue, 21 Jun 2022 05:04:32 +0000 (15:04 +1000)
committerNick Downing <nick@ndcode.org>
Tue, 21 Jun 2022 05:05:25 +0000 (15:05 +1000)
loader/Makefile
loader/a2_pack_fwd.py
loader/a2_pack_rev.py [new file with mode: 0755]
loader/lzss_unpack_rev.asm [new file with mode: 0644]

index ab27272..f5d2e95 100755 (executable)
@@ -9,12 +9,14 @@ ASLINK=../asxv5pxx/asxmak/linux/exe/aslink
 #   pip3 install --user intelhex
 HEX2BIN=hex2bin.py
 
+LOAD_ADDR=0x800
 END_ADDR=0xa000
 LZSS_LOADER=0x800
 RECRACK_LOADER=0x9ded
 
 .PHONY: all
 all: \
+star_blazer_pack_rev.a2bin \
 star_blazer_pack_fwd.a2bin \
 star_blazer.ihx \
 star_blazer.a2bin \
@@ -22,6 +24,18 @@ star_blazer_dejunked0.a2bin \
 star_blazer_dejunked1.a2bin \
 star_blazer_recrack_lzss.a2bin
 
+star_blazer_pack_rev.a2bin: lzss_unpack_rev.bin star_blazer.ihx
+       ./a2_pack_rev.py ${LOAD_ADDR} $^ $@
+
+lzss_unpack_rev.bin: lzss_unpack_rev.ihx
+       ${HEX2BIN} $< $@
+
+lzss_unpack_rev.ihx: lzss_unpack_rev.rel
+       ${ASLINK} -n -m -u -i -b text=0 $@ $^
+
+lzss_unpack_rev.rel: lzss_unpack_rev.asm
+       ${AS6500} -l -o $<
+
 star_blazer_pack_fwd.a2bin: lzss_unpack_fwd.bin star_blazer.ihx
        ./a2_pack_fwd.py ${END_ADDR} $^ $@
 
index bc21239..1211917 100755 (executable)
@@ -18,7 +18,9 @@ MAX_DIST = (1 << DIST_BITS1) # distance codes are 1..MAX_DIST
 MAX_LEN = (1 << LEN_BITS1) + 1 # length codes are 2..MAX_LEN
 
 if len(sys.argv) < 5:
-  print(f'usage: {sys.argv[0]:s} end_addr lzss_unpack_fwd.bin in.ihx out.a2bin')
+  print(
+    f'usage: {sys.argv[0]:s} end_addr lzss_unpack_fwd.bin in.ihx out.a2bin'
+  )
   sys.exit(EXIT_FAILURE)
 end_addr = int(sys.argv[1], 0)
 lzss_unpack_fwd_bin = sys.argv[2]
@@ -213,7 +215,7 @@ N_SECTIONS = 3
 
 # fixup is a 4-tuple:
 #   (fixup type, fixup address, target section, target address)
-# both addresses are negative and relative to the end addr of the section
+# both addresses are negative and relative to the end address of the section
 FIXUP_TYPE_LO_BYTE = 0
 FIXUP_TYPE_HI_BYTE = 1
 FIXUP_TYPE_WORD = 2
@@ -232,7 +234,7 @@ sections = [Section([], 0, []) for i in range(N_SECTIONS)]
 #   (report type, ihx start, ihx end, a2bin start, a2bin end)
 # for compressed the compression ratio will be printed
 # for direct poke the a2bin values are not used, otherwise they
-# are negative and relative to the end addr of the payload section
+# are negative and relative to the end address of the payload section
 # report is used to visually check for source/destination overlap
 REPORT_TYPE_DIRECT_POKE = 0
 REPORT_TYPE_UNCOMPRESSED = 1
@@ -247,6 +249,7 @@ sections[SECTION_LOADER].data.extend(
 )
 
 # segments
+# constructed top to bottom (in reverse order of unpacking: bottom to top)
 for i in range(len(segments) - 2, -2, -2):
   addr0 = segments[i]
   addr1 = segments[i + 1]
diff --git a/loader/a2_pack_rev.py b/loader/a2_pack_rev.py
new file mode 100755 (executable)
index 0000000..2145cda
--- /dev/null
@@ -0,0 +1,443 @@
+#!/usr/bin/env python3
+
+import sys
+from intelhex import IntelHex
+
+EXIT_SUCCESS = 0
+EXIT_FAILURE = 1
+
+# short (8-bit) pointer
+DIST_BITS0 = 7
+LEN_BITS0 = 1
+
+# long (16-bit) pointer
+DIST_BITS1 = 10
+LEN_BITS1 = 6
+
+MAX_DIST = (1 << DIST_BITS1) # distance codes are 1..MAX_DIST
+MAX_LEN = (1 << LEN_BITS1) + 1 # length codes are 2..MAX_LEN
+
+if len(sys.argv) < 5:
+  print(
+    f'usage: {sys.argv[0]:s} load_addr lzss_unpack_rev.bin in.ihx out.a2bin'
+  )
+  sys.exit(EXIT_FAILURE)
+load_addr = int(sys.argv[1], 0)
+lzss_unpack_rev_bin = sys.argv[2]
+in_ihx = sys.argv[3]
+out_a2bin = sys.argv[4]
+
+with open(lzss_unpack_rev_bin, 'rb') as fin:
+  lzss_unpack_rev = list(fin.read())
+
+def lzss_pack(dest, bin):
+  bin = bin[::-1] # !!! for rev, makes it easier to construct LZSS items
+
+  heads = {}
+  links = [-1] * len(bin)
+  lzss = []
+  i = 0
+  while i < len(bin):
+    _len = 1
+    dist = bin[i]
+
+    if i + 1 < len(bin):
+      pair = bin[i], bin[i + 1]
+      j = heads.get(pair, -1)
+      while j != -1 and i - j <= MAX_DIST:
+        #assert bin[i:i + 2] == bin[j:j + 2]
+        if (
+          _len < MAX_LEN and
+            i + _len < len(bin) and
+            bin[i + 2:i + _len + 1] == bin[j + 2:j + _len + 1]
+        ):
+          _len += 1
+          while (
+            _len < MAX_LEN and
+              i + _len < len(bin) and
+              bin[i + _len] == bin[j + _len]
+          ):
+            _len += 1
+          dist = i - j
+        j = links[j]
+    lzss.append((_len, dist))
+
+    for j in range(_len):
+      if i + 1 < len(bin):
+        pair = bin[i], bin[i + 1]
+        links[i] = heads.get(pair, -1)
+        heads[pair] = i
+      i += 1
+
+  # checking
+  bin1 = []
+  lzss1 = lzss[::-1]
+  while len(lzss1):
+    _len, dist = lzss1.pop()
+    if _len == 1:
+      bin1.append(dist)
+    else:
+      for i in range(_len):
+        bin1.append(bin1[-dist])
+  assert bin == bin1
+
+  # construct the real output in reverse to how it will be decoded,
+  # this means we flush the bits at the right time for the decoder,
+  # and any partial bit buffer is decoded at start rather than end
+  lzss1 = []
+  count = 0
+  bits = 1
+  while len(lzss):
+    _len, dist = lzss.pop()
+    if _len == 1:
+      #print('a', dist)
+      lzss1.append(dist)
+      cf = 0
+    else:
+      _len -= 2
+      dist -= 1
+      if _len < (1 << LEN_BITS0) and dist < (1 << DIST_BITS0):
+        item = dist | (_len << DIST_BITS0)
+        #print('b', item)
+        lzss1.append(item)
+        cf = 0
+      elif _len < (1 << LEN_BITS1) and dist < (1 << DIST_BITS1):
+        item = dist | (_len << DIST_BITS1)
+        #print('c', item)
+        lzss1.extend([item & 0xff, item >> 8]) # !!! swapped for rev
+        cf = 1
+      else:
+        assert False
+
+      bits = (bits << 1) | cf
+      if bits & 0x100:
+        #print('d', bits)
+        lzss1.append(bits & 0xff)
+        bits = 1
+        # in this case we leave count alone (at decoding side we get
+        # another bit buffer for free without any increment or test)
+
+      cf = 1
+
+    bits = (bits << 1) | cf
+    if bits & 0x100:
+      #print('e', bits)
+      lzss1.append(bits & 0xff)
+      bits = 1
+      count += 1
+  lzss = lzss1 # !!! not reversed for rev
+
+  # checking
+  bin1 = []
+  lzss1 = list(lzss) # !!! not reversed for rev
+  count1 = count
+  bits1 = bits
+  while True:
+    if bits1 == 1:
+      if count1 == 0:
+        break
+      count1 -= 1
+      bits1 = lzss1.pop() | 0x100
+      #print('e', bits1)
+    cf = bits1 & 1
+    bits1 >>= 1
+  
+    if cf:
+      if bits1 == 1:
+        bits1 = lzss1.pop() | 0x100
+        #print('d', bits1)
+      cf = bits1 & 1
+      bits1 >>= 1
+  
+      if cf:
+        item = lzss1[-2] | (lzss1[-1] << 8) # !!! swapped for rev
+        del lzss1[-2:]
+        #print('c', item)
+        dist = item & ((1 << DIST_BITS1) - 1)
+        _len = item >> DIST_BITS1
+      else: 
+        item = lzss1.pop()
+        #print('b', item)
+        dist = item & ((1 << DIST_BITS0) - 1)
+        _len = item >> DIST_BITS0
+      _len += 2
+      dist += 1
+  
+      for i in range(_len):
+        bin1.append(bin1[-dist])
+    else:
+      #print('a', lzss1[-1])
+      bin1.append(lzss1.pop())
+  assert len(lzss1) == 0
+  assert bin1 == bin
+
+  # optimization: provided the input is not null, the first byte
+  # has to be literal, so the loader can fall straight into the
+  # literal decoding routine (saves a jump to the official loop)
+  if bits == 1:
+    assert count
+    count -= 1
+    bits = lzss.pop() | 0x100 # !!! from end for rev
+  assert (bits & 1) == 0
+  bits >>= 1
+
+  # append data block
+  count ^= 0xffff # inc/test is easier than test/dec
+  return lzss + [dest & 0xff, dest >> 8, count & 0xff, count >> 8, bits]
+
+intelhex = IntelHex(in_ihx)
+entry_point = intelhex.start_addr['EIP']
+segments = [j for i in intelhex.segments() for j in i]
+
+# zero page and stack are done last, after we finish with them,
+# and in 0x100-byte pieces so we can do them without zero page
+def intersect(segments, segment):
+  [addr0, addr1] = segment
+  segments1 = []
+  for i in range(0, len(segments), 2):
+    [addr2, addr3] = segments[i:i + 2]
+    if addr2 < addr0:
+      addr2 = addr0
+    if addr3 > addr1:
+      addr3 = addr1
+    if addr3 > addr2:
+      segments1.extend([addr2, addr3])
+  return segments1
+segments = (
+  intersect(segments, [0, 0x100]) +
+  intersect(segments, [0x100, 0x200]) +
+  intersect(segments, [0x200, 0x10000])
+)
+
+# sections are output to the a2bin file from bottom to top as follows:
+SECTION_LOADER = 0
+SECTION_UNPACKER = 1
+SECTION_PAYLOAD = 2
+N_SECTIONS = 3
+
+# fixup is a 4-tuple:
+#   (fixup type, fixup address, target section, target address)
+# both addresses are relative to the load address of the section,
+# except if the fixup address is negative it is relative to end
+# (needed for the loader section, which has to be constructed in
+# reverse, because the 6502 cannot execute in reverse, thus the
+# order of execution has to be opposite to the payload section)
+FIXUP_TYPE_LO_BYTE = 0
+FIXUP_TYPE_HI_BYTE = 1
+FIXUP_TYPE_WORD = 2
+
+# each section has a data area, a load address and a list of fixups
+# relocation is done after section lengths and load addresses known
+class Section:
+  def __init__(self, data, load_addr, fixups):
+    self.data = data
+    self.load_addr = load_addr
+    self.fixups = fixups
+sections = [Section([], 0, []) for i in range(N_SECTIONS)]
+
+# report is a 5-tuple:
+#   (report type, ihx start, ihx end, a2bin start, a2bin end)
+# for compressed the compression ratio will be printed
+# for direct poke the a2bin values are not used, otherwise they
+# are relative to the load address of the payload section
+# report is used to visually check for source/destination overlap
+REPORT_TYPE_DIRECT_POKE = 0
+REPORT_TYPE_UNCOMPRESSED = 1
+REPORT_TYPE_COMPRESSED = 2
+report = []
+
+# prologue
+sections[SECTION_LOADER].data.extend(
+  [
+    0x4c, entry_point & 0xff, entry_point >> 8,        # jmp entry_point
+  ][::-1]
+)
+
+# segments
+# constructed bottom to top (in reverse order of unpacking: top to bottom)
+for i in range(0, len(segments), 2):
+  addr0 = segments[i]
+  addr1 = segments[i + 1]
+  data = list(intelhex.tobinstr(addr0, addr1 - 1))
+
+  if len(data) <= 4:
+    report.append(
+      (REPORT_TYPE_DIRECT_POKE, addr0, addr1, 0, 0)
+    )
+
+    # use of zpage version is determined byte by byte
+    for i in data:
+      sections[SECTION_LOADER].data.extend(
+        [
+          0xa9, i,                             # lda #data
+          0x85, addr0,                         # sta *addr0
+        ][::-1]
+      if addr0 < 0x100 else
+        [
+          0xa9, i,                             # lda #data
+          0x8d, addr0 & 0xff, addr0 >> 8,      # sta addr0
+        ][::-1]
+      )
+      addr0 += 1
+  elif len(data) <= 0x100:
+    addr2 = len(sections[SECTION_PAYLOAD].data)
+    sections[SECTION_PAYLOAD].data.extend(
+      data
+    )
+    addr3 = len(sections[SECTION_PAYLOAD].data)
+    report.append(
+      (REPORT_TYPE_UNCOMPRESSED, addr0, addr1, addr2, addr3)
+    )
+
+    # use of zpage version is determined in advance (if completely fits)
+    zpage = addr1 < 0x100
+
+    if len(data) == 0x100:
+      # for the full count we will copy forward (an exception)
+      sections[SECTION_LOADER].data.extend(
+        [
+          0xa2, 0x00,                                  # ldx #0
+          0xbd, 0x00, 0x00,                            # lda addr2,x
+          0x95, addr0 & 0xff,                          # sta *addr0,x
+          0xe8,                                                # inx
+          0xd0, 0xf8                                   # bne .-6
+        ][::-1]
+      if zpage else
+        [
+          0xa2, 0x00,                                  # ldx #0
+          0xbd, 0x00, 0x00,                            # lda addr2,x
+          0x9d, addr0 & 0xff, (addr0 >> 8) & 0xff,     # sta addr0,x
+          0xe8,                                                # inx
+          0xd0, 0xf7                                   # bne .-7
+        ][::-1]
+      )
+    else:
+      addr0 -= 1
+      addr2 -= 1
+      sections[SECTION_LOADER].data.extend(
+        [
+          0xa2, len(data),                             # ldx #count
+          0xbd, 0x00, 0x00,                            # lda addr2,x
+          0x95, addr0 & 0xff,                          # sta *addr0,x
+          0xca,                                                # dex
+          0xd0, 0xf8                                   # bne .-6
+        ][::-1]
+      if zpage else
+        [
+          0xa2, len(data),                             # ldx #count
+          0xbd, 0x00, 0x00,                            # lda addr2,x
+          0x9d, addr0 & 0xff, (addr0 >> 8) & 0xff,     # sta addr0,x
+          0xca,                                                # dex
+          0xd0, 0xf7                                   # bne .-7
+        ][::-1]
+      )
+    sections[SECTION_LOADER].fixups.extend(
+      [
+        (
+          FIXUP_TYPE_WORD,
+          3 - len(sections[SECTION_LOADER].data),
+          SECTION_PAYLOAD,
+          addr2
+        ),
+      ]
+    )
+  else:
+    addr2 = len(sections[SECTION_PAYLOAD].data)
+    sections[SECTION_PAYLOAD].data.extend(
+      lzss_pack(addr1 - 1, data)
+    )
+    addr3 = len(sections[SECTION_PAYLOAD].data)
+    report.append(
+      (REPORT_TYPE_COMPRESSED, addr0, addr1, addr2, addr3)
+    )
+
+    if len(sections[SECTION_UNPACKER].data) == 0:
+      sections[SECTION_UNPACKER].data.extend(
+        lzss_unpack_rev
+      )
+
+    addr3 -= 5 + 1
+    sections[SECTION_LOADER].data.extend(
+      [
+        0xa9, 0x00,            # lda #<addr3
+        0xa0, 0x00,            # ldy #>addr3
+        0x20, 0x00, 0x00,      # jsr lzss_unpack_rev
+      ][::-1]
+    )
+    sections[SECTION_LOADER].fixups.extend(
+      [
+        (
+          FIXUP_TYPE_LO_BYTE,
+          1 - len(sections[SECTION_LOADER].data),
+          SECTION_PAYLOAD,
+          addr3
+        ),
+        (
+          FIXUP_TYPE_HI_BYTE,
+          3 - len(sections[SECTION_LOADER].data),
+          SECTION_PAYLOAD,
+          addr3
+        ),
+        (
+          FIXUP_TYPE_WORD,
+          5 - len(sections[SECTION_LOADER].data),
+          SECTION_UNPACKER,
+          0
+        ),
+      ]
+    )
+
+
+# prologue
+sections[SECTION_LOADER].data.extend(
+  [
+    0xd8,      # cld
+    0xa2, 0xff,        # ldx #0xff
+    0x9a,      # txs
+  ][::-1]
+)
+sections[SECTION_LOADER].data = sections[SECTION_LOADER].data[::-1]
+
+# relocate
+end_addr = load_addr
+for i in range(N_SECTIONS):
+  sections[i].load_addr = end_addr
+  end_addr += len(sections[i].data)
+load_size = end_addr - load_addr
+
+for report_type, addr0, addr1, addr2, addr3 in report:
+  if report_type == REPORT_TYPE_DIRECT_POKE:
+    print(f'[0x{addr0:04x}, 0x{addr1:04x})')
+  else:
+    addr2 += sections[SECTION_PAYLOAD].load_addr
+    addr3 += sections[SECTION_PAYLOAD].load_addr
+    print(
+      f'[0x{addr0:04x}, 0x{addr1:04x}) -> [0x{addr2:04x}, 0x{addr3:04x})' + (
+        f'{100. * (addr3 - addr2) / (addr1 - addr0):6.1f}%'
+      if report_type == REPORT_TYPE_COMPRESSED else
+        ''
+      )
+    )
+
+bin = []
+for i in range(N_SECTIONS):
+  for fixup_type, fixup_addr, section, addr in sections[i].fixups:
+    addr += sections[section].load_addr
+    if fixup_type == FIXUP_TYPE_LO_BYTE:
+      assert sections[i].data[fixup_addr] == 0
+      sections[i].data[fixup_addr] = addr & 0xff
+    elif fixup_type == FIXUP_TYPE_HI_BYTE:
+      assert sections[i].data[fixup_addr] == 0
+      sections[i].data[fixup_addr] = addr >> 8
+    elif fixup_type == FIXUP_TYPE_WORD:
+      assert sections[i].data[fixup_addr] == 0
+      sections[i].data[fixup_addr] = addr & 0xff
+      assert sections[i].data[fixup_addr + 1] == 0
+      sections[i].data[fixup_addr + 1] = addr >> 8
+    else:
+      assert False
+  bin.extend(sections[i].data)
+
+hdr = [load_addr & 0xff, load_addr >> 8, load_size & 0xff, load_size >> 8]
+with open(out_a2bin, 'wb') as fout:
+  fout.write(bytes(hdr + bin))
diff --git a/loader/lzss_unpack_rev.asm b/loader/lzss_unpack_rev.asm
new file mode 100644 (file)
index 0000000..752d7a4
--- /dev/null
@@ -0,0 +1,171 @@
+       .r65c02
+
+       .area   zpage
+       .setdp
+
+       .ds     0xf0
+src:   .ds     2                       ; address of last byte read
+dest:  .ds     2                       ; address of last byte written
+count: .ds     2                       ; count of bit buffer refills to do
+bits:  .ds     1                       ; bit buffer (highest 1 = sentinel)
+dist:  .ds     2                       ; distance, or address of repeated data
+
+       .area   text
+
+       ; enter with y:a = source address - 1 (last byte of LZSS data)
+       ; LZSS data is followed by data block
+       ;   0 (word): destination address - 1 (last byte of decoded data)
+       ;   2 (word): count of bit buffer refills to do
+       ;   4 (byte): bit buffer (highest 1 = sentinel)
+       ; type of LZSS item depends on a bit from bit buffer:
+       ;   0: literal
+       ;   1: pointer
+       ; type of pointer depends on a bit from bit buffer:
+       ;   0: short pointer, lddddddd
+       ;   1: long pointer, lllllldd:dddddddd (LS byte first)
+
+       ; src = y:a
+       sta     src
+       sty     src + 1
+
+       ; copy data block
+       ldy     #5
+0$:    lda     [src],y
+       sta     dest - 1,y
+       dey
+       bne     0$
+
+       clc
+       ; optimization: the first byte has to be literal
+       ;bcc    loop1
+
+literal: ; copy one byte
+       lda     [src],y
+       sta     [dest],y
+
+       lda     dest
+       bne     0$
+       dec     dest + 1
+0$:    dec     dest
+
+loop0: ; decrement src for literal or last byte of pointer
+       lda     src
+       bne     0$
+       dec     src + 1
+0$:    dec     src
+
+loop1: ; process LZSS item
+       ;clc
+       ror     bits
+       bne     literal_or_pointer
+
+       ; bit buffer exhausted
+       ; count refills of bit buffer
+       inc     count
+       bne     0$
+       inc     count + 1
+       beq     done
+
+0$:    ; load one byte to bit buffer
+       lda     [src],y
+       ;sec
+       ror     a
+       sta     bits
+
+       lda     src
+       bne     1$
+       dec     src + 1
+1$:    dec     src
+
+literal_or_pointer:
+       ; cf=0 literal, cf=1 pointer
+       bcc     literal
+
+       ; pointer
+       clc
+       ror     bits
+       bne     short_or_long_pointer
+
+       ; bit buffer exhausted
+       ; load one byte to bit buffer
+       lda     [src],y
+       ;sec
+       ror     a
+       sta     bits
+
+       lda     src
+       bne     0$
+       dec     src + 1
+0$:    dec     src
+
+short_or_long_pointer:
+       ; cf=0 short pointer, cf=1 long pointer
+       bcs     long_pointer
+
+       ; short pointer, lddddddd
+       ; take source byte, but don't decrement yet
+       lda     [src],y
+       tax
+       and     #0x7f
+       sta     dist
+       sty     dist + 1
+
+       txa
+       asl     a ; cf = len - 2
+       tya
+       beq     pointer
+
+long_pointer:
+       ; high of long pointer, lllllldd
+       ; take source byte
+       lda     [src],y
+       tax
+       and     #3
+       sta     dist + 1
+
+       lda     src
+       bne     0$
+       dec     src + 1
+0$:    dec     src
+
+       ; low of long pointer, dddddddd
+       ; take source byte, but don't decrement yet
+       lda     [src],y
+       sta     dist
+
+       txa
+       lsr     a
+       lsr     a
+
+       clc
+pointer: ; dist 0 based, a + cf = len 0 based, source needs decrement
+       adc     #2
+       tay
+
+       ; dest -= len
+       sec
+       eor     #0xff
+       adc     dest
+       sta     dest
+       bcs     0$
+       dec     dest + 1
+
+       ; dist = dest + dist + 1
+       sec
+0$:    lda     dest
+       adc     dist
+       sta     dist
+       lda     dest + 1
+       adc     dist + 1
+       sta     dist + 1
+
+       ; copy previous data
+1$:    lda     [dist],y
+       sta     [dest],y
+       dey
+       bne     1$
+
+       ; src += 1, process LZSS item
+       beq     loop0
+
+done:  rts