Use a heap to speed up the minimum spanning tree generation
authorNick Downing <downing.nick@gmail.com>
Mon, 23 Jul 2018 08:38:21 +0000 (18:38 +1000)
committerNick Downing <downing.nick@gmail.com>
Mon, 23 Jul 2018 08:38:21 +0000 (18:38 +1000)
flex_dfa.py
numpy_heap.py [new file with mode: 0644]

index 807cd2b..0bc311e 100644 (file)
@@ -1,6 +1,7 @@
 import dfa
 import element
 import numpy
+import numpy_heap
 import regex
 
 class FlexDFA:
@@ -163,68 +164,61 @@ class FlexDFA:
     transitions[:, 0x100] = transitions[:, 0]
     transitions[:, 0] = eob_state
 
-    # state 0 is the jam state, the EOB state will be added later on
-    self.states = numpy.zeros((n_states, 2), numpy.int16) # base, def
-
     # calculate default states by constructing minimum spanning tree
-    dist = numpy.sum(
-      transitions[:, numpy.newaxis, :] != transitions[numpy.newaxis, :, :],
-      2
-    )
-    dist_key = numpy.stack(
-      [
-        numpy.zeros((n_states,), dtype = numpy.int16), # hop count
-        numpy.arange(0, n_states, dtype = numpy.int16) # permutation
-      ],
-      1
-    )
-    order = []
-    for i in range(1, n_states):
-      print('mst', i, 'of', n_states)
+    # heap contains n states todo followed by n_states - n states done
+    # each heap entry is [distance, hop count, state done, state todo]
+    heap = numpy.zeros((n_states, 4), numpy.int16)
+    heap[:-1, 0] = numpy.sum(transitions[1:, :] != transitions[:1, :], 1)
+    heap[:-1, 3] = numpy.arange(1, n_states, dtype = numpy.int16)
+    numpy_heap.heapify(heap, n_states - 1)
+    for n in range(n_states - 2, 0, -1):
+      if n % 100 == 0:
+        print('mst', n)
 
-      # Prim's algorithm: find most similar (done, todo) state pair
-      temp = dist[:i, i:]
-      done, todo = numpy.unravel_index(numpy.argmin(temp), temp.shape)
-      todo += i
-      order.append((dist_key[done, 1], dist_key[todo, 1])) # state done, todo
-
-      # permute to make done states consecutive, in order of hop count
-      dist_key[todo, 0] = dist_key[done, 0] + 1
-      j = done + 1
-      while j < i:
-        if tuple(dist_key[todo, :]) < tuple(dist_key[j, :]):
-          break
-        j += 1
-      temp = numpy.copy(dist[todo, :])
-      dist[j + 1:todo + 1, :] = dist[j:todo, :]
-      dist[j, :] = temp
-      temp = numpy.copy(dist[:, todo])
-      dist[:, j + 1:todo + 1] = dist[:, j:todo]
-      dist[:, j] = temp
-      temp = numpy.copy(dist_key[todo, :])
-      dist_key[j + 1:todo + 1, :] = dist_key[j:todo, :]
-      dist_key[j, :] = temp
+      key = tuple(heap[n, :])
+      heap[n, :] = heap[0, :]
+      numpy_heap.bubble_down(heap, 0, key, n)
+      hop_count = heap[n, 1] + 1 # proposed hop_count is current hop_count + 1
+      state_done = heap[n, 3] # proposed state_done is current state_todo
+      dist = numpy.sum(
+        transitions[heap[:n, 3], :] !=
+        transitions[state_done:state_done + 1, :],
+        1
+      )
+      # although numpy cannot do lexicographic comparisons, check the
+      # first field via numpy to quickly generate a list of candidates
+      for i in numpy.nonzero(dist <= heap[:n, 0])[0]:
+        key = (dist[i], hop_count, state_done, heap[i, 3])
+        if key < tuple(heap[i, :]):
+          numpy_heap.bubble_up(heap, i, key)
 
-    # encode states in reverse order (larger distances first)
+    # state 0 is the jam state, the EOB state will be added later on
+    self.states = numpy.zeros((n_states, 2), numpy.int16) # base, def
     self.entries = numpy.full((0x200, 2), -1, numpy.int16) # nxt, chk
     self.entries[:0x101, :] = 0 # jam state just returns to jam state
     self.entries[0, 0] = eob_state # except for the EOB transition
     entries_used = numpy.zeros(0x200, numpy.bool)
     entries_used[:0x101] = True # account for the jam (don't care) state
     n_entries = 0x101
+
+    # pack states in reverse order (larger distances first)
     dupes = []
-    while len(order):
-      print('order', len(order))
-      state_done, state_todo = order.pop()
-      indices = numpy.nonzero(
-        transitions[state_todo, :] != transitions[state_done, :]
-      )[0]
-      if indices.shape[0] == 0:
+    for i in range(n_states - 1):
+      if (n_states - i) % 100 == 0:
+        print('pack', n_states - i)
+      if heap[i, 0] == 0:
         # when copying another state, need to have the same base, though
         # the base will not matter since there will be no entries, it is
         # is because of the awkward way the compressed lookup is written
-        dupes.append((state_done, state_todo))
+        dupes.append(i)
       else:
+        state_done = heap[i, 2]
+        state_todo = heap[i, 3]
+        indices = numpy.nonzero(
+          transitions[state_todo, :] != transitions[state_done, :]
+        )[0]
+
         # make sure entries array is at least large enough to find a spot
         while self.entries.shape[0] < n_entries + 0x101:
           # extend entries, copying only n_entries entries
@@ -261,8 +255,12 @@ class FlexDFA:
         self.states[state_todo, 0] = start_index
         self.states[state_todo, 1] = state_done
     while len(dupes):
-      print('dupes', len(dupes))
-      state_done, state_todo = dupes.pop()
+      if len(dupes) % 100 == 0:
+        print('dupe', len(dupes))
+
+      i = dupes.pop()
+      state_done = heap[i, 2]
+      state_todo = heap[i, 3]
       self.states[state_todo, 0] = self.states[state_done, 0]
       self.states[state_todo, 1] = state_done
 
diff --git a/numpy_heap.py b/numpy_heap.py
new file mode 100644 (file)
index 0000000..e8d3add
--- /dev/null
@@ -0,0 +1,33 @@
+import numpy
+
+def bubble_up(heap, i, key):
+  # call with key == tuple(heap[i, :])
+  while i:
+    parent = (i - 1) >> 1
+    if key >= tuple(heap[parent, :]):
+      break
+    heap[i, :] = heap[parent, :]
+    i = parent
+  heap[i, :] = numpy.array(key, heap.dtype)
+
+def bubble_down(heap, i, key, n):
+  # call with key == tuple(heap[i, :])
+  child = (i << 1) + 1
+  while child < n:
+    child_key = tuple(heap[child, :])
+    child1 = child + 1
+    if child1 < n:
+      child1_key = tuple(heap[child1, :])
+      if child1_key < child_key:
+        child = child1
+        child_key = child1_key
+    if key <= child_key:
+      break
+    heap[i, :] = heap[child, :]
+    i = child
+    child = (i << 1) + 1
+  heap[i, :] = numpy.array(key, heap.dtype)
+
+def heapify(heap, n):
+  for i in range((n - 1) >> 1, -1, -1):
+    bubble_down(heap, i, tuple(heap[i, :]), n)