Add LZSS loader
authorNick Downing <nick@ndcode.org>
Fri, 17 Jun 2022 11:49:45 +0000 (21:49 +1000)
committerNick Downing <nick@ndcode.org>
Fri, 17 Jun 2022 13:16:19 +0000 (23:16 +1000)
loader/Makefile
loader/lzss_decode.py [new file with mode: 0755]
loader/lzss_encode.py [new file with mode: 0755]
loader/lzss_loader.asm [new file with mode: 0644]

index cb7904c..ce61334 100755 (executable)
@@ -12,6 +12,7 @@ HEX2BIN=hex2bin.py
 RLE_LOADER=0x800
 RLE2_LOADER=0x800
 TREE_LOADER=0x800
+LZSS_LOADER=0x800
 HIRES_LOADER=0x2000
 
 .PHONY: all
@@ -19,6 +20,7 @@ all: \
 star_blazer.bin \
 star_blazer_dejunked0.bin \
 star_blazer_dejunked1.bin \
+star_blazer_lzss_loader.bin \
 star_blazer_tree_loader.bin \
 star_blazer_rle2_loader.bin \
 star_blazer_rle_loader.bin
@@ -32,6 +34,18 @@ star_blazer_dejunked0.bin: star_blazer.bin
 star_blazer_dejunked1.bin: star_blazer.bin
        ./dejunk.py $< $@ 0xff
 
+star_blazer_lzss_loader.bin: lzss_loader.bin star_blazer_hires_loader.bin
+       ./lzss_encode.py ${LZSS_LOADER} $^ $@
+
+lzss_loader.bin: lzss_loader.ihx
+       ${HEX2BIN} $< $@
+
+lzss_loader.ihx: lzss_loader.rel
+       ${ASLINK} -n -m -u -i -b text=${LZSS_LOADER} $@ $^
+
+lzss_loader.rel: lzss_loader.asm
+       ${AS6500} -l -o $<
+
 star_blazer_tree_loader.bin: tree_loader.bin star_blazer_hires_loader.bin
        ./tree_encode.py ${TREE_LOADER} $^ $@
 
diff --git a/loader/lzss_decode.py b/loader/lzss_decode.py
new file mode 100755 (executable)
index 0000000..8a0c0b6
--- /dev/null
@@ -0,0 +1,109 @@
+#!/usr/bin/env python3
+
+import sys
+
+EXIT_SUCCESS = 0
+EXIT_FAILURE = 1
+
+# short (8-bit) pointer
+DIST_BITS0 = 7
+LEN_BITS0 = 1
+
+# long (16-bit) pointer
+DIST_BITS1 = 11
+LEN_BITS1 = 5
+
+MAX_DIST = (1 << DIST_BITS1) # distance codes are 1..MAX_DIST
+MAX_LEN = (1 << LEN_BITS1) + 1 # length codes are 2..MAX_LEN
+
+if len(sys.argv) < 3:
+  print(f'usage: {sys.argv[0]:s} in.bin out.bin')
+  sys.exit(EXIT_FAILURE)
+in_bin = sys.argv[1]
+out_bin = sys.argv[2]
+
+with open(in_bin, 'rb') as fin:
+  data = list(fin.read())
+  hdr = data[:4]
+  lzss = data[4:]
+load_addr0 = hdr[0] | (hdr[1] << 8)
+load_size0 = hdr[2] | (hdr[3] << 8)
+assert load_size0 == len(lzss)
+
+assert lzss[1] == 0xa9 # lda #NN
+assert lzss[5] == 0xa9 # lda #NN
+src = lzss[2] | (lzss[6] << 8)
+
+assert lzss[9] == 0xa9 # lda #NN
+assert lzss[0xd] == 0xa9 # lda #NN
+dest = lzss[0xa] + (lzss[0xe] << 8)
+
+assert lzss[0x11] == 0xa9 # lda #NN
+assert lzss[0x15] == 0xa9 # lda #NN
+count = lzss[0x12] + (lzss[0x16] << 8)
+
+assert lzss[0x19] == 0xa9 # lda #NN
+#assert lzss[0x1d] & ~0x20 == 0x18 # clc
+bits = lzss[0x1a] #| ((lzss[0x1d] & 0x20) << 3)
+
+assert lzss[0x42] == 0x4c # jmp NNNN
+entry_point = lzss[0x43] | (lzss[0x44] << 8)
+
+assert src == load_addr0 + load_size0
+count = ~count & 0xffff
+
+bin = []
+lzss = lzss
+count = count
+bits = bits
+while True:
+  if bits == 1:
+    if count == 0:
+      break
+    count -= 1
+    bits = lzss.pop() | 0x100
+    #print('e', bits)
+  cf = bits & 1
+  bits >>= 1
+
+  if cf:
+    if bits == 1:
+      bits = lzss.pop() | 0x100
+      #print('d', bits)
+    cf = bits & 1
+    bits >>= 1
+
+    if cf:
+      item = lzss[-2] | (lzss[-1] << 8)
+      del lzss[-2:]
+      #print('c', item)
+      dist = item & ((1 << DIST_BITS1) - 1)
+      _len = item >> DIST_BITS1
+    else: 
+      item = lzss.pop()
+      #print('b', item)
+      dist = item & ((1 << DIST_BITS0) - 1)
+      _len = item >> DIST_BITS0
+    _len += 2
+    dist += 1
+
+    for i in range(_len):
+      bin.append(bin[-dist])
+  else:
+    #print('a', lzss[-1])
+    bin.append(lzss.pop())
+assert len(lzss) == 0xb6
+bin = bin[::-1]
+
+load_size = len(bin)
+print(f'lzss {load_size0:04x} orig {load_size:04x}')
+
+load_addr = dest + 1 - load_size
+if entry_point != load_addr:
+  bin = [0x4c, entry_point & 0xff, entry_point >> 8] + bin # jmp NNNN
+  load_addr -= 3
+  load_size += 3
+
+hdr = [load_addr & 0xff, load_addr >> 8, load_size & 0xff, load_size >> 8]
+with open(out_bin, 'wb') as fout:
+  fout.write(bytes(hdr + bin))
diff --git a/loader/lzss_encode.py b/loader/lzss_encode.py
new file mode 100755 (executable)
index 0000000..2c0a552
--- /dev/null
@@ -0,0 +1,221 @@
+#!/usr/bin/env python3
+
+import bisect
+import numpy
+import sys
+from heapdict import heapdict
+
+EXIT_SUCCESS = 0
+EXIT_FAILURE = 1
+
+# short (8-bit) pointer
+DIST_BITS0 = 7
+LEN_BITS0 = 1
+
+# long (16-bit) pointer
+DIST_BITS1 = 11
+LEN_BITS1 = 5
+
+MAX_DIST = (1 << DIST_BITS1) # distance codes are 1..MAX_DIST
+MAX_LEN = (1 << LEN_BITS1) + 1 # length codes are 2..MAX_LEN
+
+if len(sys.argv) < 5:
+  print(f'usage: {sys.argv[0]:s} load_addr lzss_loader.bin in.bin out.bin')
+  sys.exit(EXIT_FAILURE)
+load_addr = int(sys.argv[1], 0)
+lzss_loader_bin = sys.argv[2]
+in_bin = sys.argv[3]
+out_bin = sys.argv[4]
+
+with open(lzss_loader_bin, 'rb') as fin:
+  lzss_loader = list(fin.read())
+assert len(lzss_loader) == 0xb6
+
+with open(in_bin, 'rb') as fin:
+  data = list(fin.read())
+  hdr = data[:4]
+  bin = data[4:]
+load_addr0 = hdr[0] | (hdr[1] << 8)
+load_size0 = hdr[2] | (hdr[3] << 8)
+assert load_size0 == len(bin)
+
+# absorb a jump to the real entry point, ensuring that we
+# can losslessly reconstruct the jump when decoding later
+entry_point = load_addr0
+if bin[0] == 0x4c: # jmp NNNN
+  entry_point = bin[1] | (bin[2] << 8)
+  assert entry_point != load_addr0 + 3
+
+  bin = bin[3:]
+  load_addr0 += 3
+  load_size0 -= 3
+
+# construct the LZSS items in order they'll be decoded, but store
+# them in reverse (stack-wise) as we don't know the final length
+heads = {}
+links = [-1] * len(bin)
+lzss = []
+i = len(bin) - 1
+while i >= 0:
+  _len = 1
+  dist = bin[i]
+
+  if i >= 1:
+    pair = bin[i - 1], bin[i]
+    j = heads.get(pair, -1)
+    while j != -1 and j - i <= MAX_DIST:
+      assert bin[i] == bin[j]
+      assert bin[i - 1] == bin[j - 1]
+      k = 2
+      while k < MAX_LEN and i - k >= 1 and bin[i - k] == bin[j - k]:
+        k += 1
+      if k > _len:
+        _len = k
+        dist = j - i
+      j = links[j]
+  lzss.append((_len, dist))
+
+  for j in range(_len):
+    if i >= 1:
+      pair = bin[i - 1], bin[i]
+      links[i] = heads.get(pair, -1)
+      heads[pair] = i
+    i -= 1
+lzss = lzss[::-1]
+
+# checking
+#bin1 = []
+#lzss1 = list(lzss)
+#while len(lzss1):
+#  _len, dist = lzss1.pop()
+#  if _len == 1:
+#    bin1.append(dist)
+#  else:
+#    for i in range(_len):
+#      bin1.append(bin1[-dist])
+#bin1 = bin1[::-1]
+#assert bin == bin1
+
+# construct the real output in reverse to how it will be decoded,
+# this means we flush the bits at the right time for the decoder,
+# and any partial bit buffer is decoded at start rather than end
+lzss1 = list(lzss_loader)
+count = 0
+bits = 1
+for _len, dist in lzss:
+  if _len == 1:
+    #print('a', dist)
+    lzss1.append(dist)
+    cf = 0
+  else:
+    _len -= 2
+    dist -= 1
+    if _len < (1 << LEN_BITS0) and dist < (1 << DIST_BITS0):
+      item = dist | (_len << DIST_BITS0)
+      #print('b', item)
+      lzss1.append(item)
+      cf = 0
+    elif _len < (1 << LEN_BITS1) and dist < (1 << DIST_BITS1):
+      item = dist | (_len << DIST_BITS1)
+      #print('c', item)
+      lzss1.extend([item & 0xff, item >> 8])
+      cf = 1
+    else:
+      assert False
+
+    bits = (bits << 1) | cf
+    if bits & 0x100:
+      #print('d', bits)
+      lzss1.append(bits & 0xff)
+      bits = 1
+      # in this case we leave count alone (at decoding side we get
+      # another bit buffer for free without any increment or test)
+
+    cf = 1
+
+  bits = (bits << 1) | cf
+  if bits & 0x100:
+    #print('e', bits)
+    lzss1.append(bits & 0xff)
+    bits = 1
+    count += 1
+lzss = lzss1
+load_size = len(lzss)
+print('orig', f'{load_size0:04x}', 'lzss', f'{load_size:04x}')
+
+# checking
+#bin1 = []
+#lzss1 = lzss
+#count1 = count
+#bits1 = bits
+#while True:
+#  if bits1 == 1:
+#    if count1 == 0:
+#      break
+#    count1 -= 1
+#    bits1 = lzss1.pop() | 0x100
+#    #print('e', bits1)
+#  cf = bits1 & 1
+#  bits1 >>= 1
+#
+#  if cf:
+#    if bits1 == 1:
+#      bits1 = lzss1.pop() | 0x100
+#      #print('d', bits1)
+#    cf = bits1 & 1
+#    bits1 >>= 1
+#
+#    if cf:
+#      item = lzss1[-2] | (lzss1[-1] << 8)
+#      del lzss1[-2:]
+#      #print('c', item)
+#      dist = item & ((1 << DIST_BITS1) - 1)
+#      _len = item >> DIST_BITS1
+#    else: 
+#      item = lzss1.pop()
+#      #print('b', item)
+#      dist = item & ((1 << DIST_BITS0) - 1)
+#      _len = item >> DIST_BITS0
+#    _len += 2
+#    dist += 1
+#
+#    for i in range(_len):
+#      bin1.append(bin1[-dist])
+#  else:
+#    #print('a', lzss1[-1])
+#    bin1.append(lzss1.pop())
+#assert lzss1 == lzss_loader
+#bin1 = bin1[::-1]
+#assert bin == bin1
+
+src = load_addr + load_size
+dest = load_addr0 + load_size0 - 1
+count = ~count & 0xffff # inc/test is easier than test/dec
+
+assert lzss[0x1] == 0xa9 # lda #NN
+lzss[0x2] = src & 0xff
+assert lzss[5] == 0xa9 # lda #NN
+lzss[6] = src >> 8
+
+assert lzss[9] == 0xa9 # lda #NN
+lzss[0xa] = dest & 0xff
+assert lzss[0xd] == 0xa9 # lda #NN
+lzss[0xe] = dest >> 8
+
+assert lzss[0x11] == 0xa9 # lda #NN
+lzss[0x12] = count & 0xff
+assert lzss[0x15] == 0xa9 # lda #NN
+lzss[0x16] = count >> 8
+
+assert lzss[0x19] == 0xa9 # lda #NN
+lzss[0x1a] = bits #& 0xff
+#assert lzss[0x1b] == 0x18 # clc
+#lzss[0x1b] |= (bits >> 3) & 0x20 # clc or sec
+
+assert lzss[0x42] == 0x4c # jmp NNNN
+lzss[0x43] = entry_point & 0xff
+lzss[0x44] = entry_point >> 8
+
+hdr = [load_addr & 0xff, load_addr >> 8, load_size & 0xff, load_size >> 8]
+with open(out_bin, 'wb') as fout:
+  fout.write(bytes(hdr + lzss))
diff --git a/loader/lzss_loader.asm b/loader/lzss_loader.asm
new file mode 100644 (file)
index 0000000..4586359
--- /dev/null
@@ -0,0 +1,159 @@
+       .r65c02
+
+       .area   zpage
+       .setdp
+
+       .ds     80
+src:   .ds     2                       ; address of last byte read
+dest:  .ds     2                       ; address of next byte to write
+count: .ds     2                       ; count of bit buffer refills to do
+bits:  .ds     1                       ; bit buffer (highest 1 = sentinel)
+len:   .ds     1                       ; length
+dist:  .ds     2                       ; distance, or address of repeated data
+
+       .area   text
+
+       cld
+
+       lda     #0
+       sta     src
+       lda     #0
+       sta     src + 1
+
+       lda     #0
+       sta     dest
+       lda     #0
+       sta     dest + 1
+
+       lda     #0
+       sta     count
+       lda     #0
+       sta     count + 1
+
+       lda     #0
+       sta     bits
+       clc
+
+       ldy     #0
+       beq     loop
+
+literal:
+       lda     src
+       bne     0$
+       dec     src + 1
+0$:    dec     src
+
+       lda     [src],y
+       sta     [dest],y
+
+       lda     dest
+       bne     1$
+       dec     dest + 1
+1$:    dec     dest
+
+loop:  ;clc
+       ror     bits
+       bne     literal_or_pointer
+
+       inc     count
+       bne     0$
+       inc     count + 1
+       bne     0$
+       jmp     0
+
+0$:    lda     src
+       bne     1$
+       dec     src + 1
+1$:    dec     src
+
+       lda     [src],y
+       ;sec
+       ror     a
+       sta     bits
+
+literal_or_pointer:
+       ; cf=0 literal, cf=1 pointer
+       bcc     literal
+
+       ; pointer
+       clc
+       ror     bits
+       bne     short_or_long_pointer
+
+       lda     src
+       bne     0$
+       dec     src + 1
+0$:    dec     src
+
+       lda     [src],y
+       ;sec
+       ror     a
+       sta     bits
+
+short_or_long_pointer:
+       ; cf=0 short pointer, cf=1 long pointer
+       ; take source byte (short pointer or high of long pointer)
+       lda     src
+       bne     0$
+       dec     src + 1
+0$:    dec     src
+
+       lda     [src],y
+       bcs     long_pointer
+
+       ; short pointer lddddddd
+       sty     dist + 1 ; 0
+       bpl     1$
+       and     #0x7f
+       iny
+1$:    sta     dist
+
+       bcc     pointer
+
+long_pointer:
+       ; high of long pointer, lllllddd
+       tax
+       and     #7
+       sta     dist + 1
+
+       lda     src
+       bne     0$
+       dec     src + 1
+0$:    dec     src
+
+       lda     [src],y
+       sta     dist
+
+       txa
+       lsr     a
+       lsr     a
+       lsr     a
+       tay
+
+pointer: ; dist set up, y = len - 2
+       iny
+       iny
+       sty     len
+
+       ; dest -= len
+       sec
+       lda     dest
+       sbc     len
+       sta     dest
+       bcs     0$
+       dec     dest + 1
+
+       ; dist += dest + 1
+       sec
+0$:    ;lda    dest
+       adc     dist
+       sta     dist
+       lda     dest + 1
+       adc     dist + 1
+       sta     dist + 1
+
+1$:    lda     [dist],y
+       sta     [dest],y
+       dey
+       bne     1$
+       beq     loop ; rely on cf = 0 here