Rework token numbering so that the LR1DFA is generated using the internal numbering...
authorNick Downing <downing.nick@gmail.com>
Tue, 17 Jul 2018 06:55:02 +0000 (16:55 +1000)
committerNick Downing <downing.nick@gmail.com>
Tue, 17 Jul 2018 06:55:02 +0000 (16:55 +1000)
ast.py
bison_lr1dfa.py

diff --git a/ast.py b/ast.py
index 26b7a6f..3333fa4 100644 (file)
--- a/ast.py
+++ b/ast.py
@@ -40,7 +40,7 @@ class Item(element.Element):
 class PYACC(element.Element):
   # internal classes
   class Symbol(element.Element):
-    # GENERATE ELEMENT(str name, list(int) character_set, int precedence, int associativity) BEGIN
+    # GENERATE ELEMENT(str name, int token, list(int) character_set, int precedence, int associativity) BEGIN
     def __init__(
       self,
       tag = 'PYACC_Symbol',
@@ -48,6 +48,7 @@ class PYACC(element.Element):
       text = '',
       children = [],
       name = '',
+      token = -1,
       character_set = [],
       precedence = -1,
       associativity = -1
@@ -60,6 +61,11 @@ class PYACC(element.Element):
         children
       )
       self.name = name
+      self.token = (
+        element.deserialize_int(token)
+      if isinstance(token, str) else
+        token
+      )
       self.character_set = (
         [element.deserialize_int(i) for i in character_set.split()]
       if isinstance(character_set, str) else
@@ -78,6 +84,7 @@ class PYACC(element.Element):
     def serialize(self, ref_list, indent = 0):
       element.Element.serialize(self, ref_list, indent)
       self.set('name', element.serialize_str(self.name))
+      self.set('token', element.serialize_int(self.token))
       self.set(
         'character_set',
         ' '.join([element.serialize_int(i) for i in self.character_set])
@@ -87,6 +94,7 @@ class PYACC(element.Element):
     def deserialize(self, ref_list):
       element.Element.deserialize(self, ref_list)
       self.name = element.deserialize_str(self.get('name', ''))
+      self.token = element.deserialize_int(self.get('token', '-1'))
       self.character_set = [
         element.deserialize_int(i)
         for i in self.get('character_set', '').split()
@@ -99,6 +107,7 @@ class PYACC(element.Element):
         Symbol if factory is None else factory
       )
       result.name = self.name
+      result.token = self.token
       result.character_set = self.character_set
       result.precedence = self.precedence
       result.associativity = self.associativity
@@ -109,6 +118,10 @@ class PYACC(element.Element):
         params.append(
           'name = {0:s}'.format(repr(self.name))
         )
+      if self.token != -1:
+        params.append(
+          'token = {0:s}'.format(repr(self.token))
+        )
       if len(self.character_set):
         params.append(
           'character_set = [{0:s}]'.format(
@@ -897,13 +910,14 @@ class PYACC(element.Element):
         return self
 
     class Symbol(TagOrSymbol):
-      # GENERATE ELEMENT() BEGIN
+      # GENERATE ELEMENT(int user_token) BEGIN
       def __init__(
         self,
         tag = 'PYACC_Section1Or2_Symbol',
         attrib = {},
         text = '',
-        children = []
+        children = [],
+        user_token = -1
       ):
         PYACC.Section1Or2.TagOrSymbol.__init__(
           self,
@@ -912,12 +926,30 @@ class PYACC(element.Element):
           text,
           children
         )
+        self.user_token = (
+          element.deserialize_int(user_token)
+        if isinstance(user_token, str) else
+          user_token
+        )
+      def serialize(self, ref_list, indent = 0):
+        PYACC.Section1Or2.TagOrSymbol.serialize(self, ref_list, indent)
+        self.set('user_token', element.serialize_int(self.user_token))
+      def deserialize(self, ref_list):
+        PYACC.Section1Or2.TagOrSymbol.deserialize(self, ref_list)
+        self.user_token = element.deserialize_int(self.get('user_token', '-1'))
       def copy(self, factory = None):
         result = PYACC.Section1Or2.TagOrSymbol.copy(
           self,
           Symbol if factory is None else factory
         )
+        result.user_token = self.user_token
         return result
+      def repr_serialize(self, params):
+        PYACC.Section1Or2.TagOrSymbol.repr_serialize(self, params)
+        if self.user_token != -1:
+          params.append(
+            'user_token = {0:s}'.format(repr(self.user_token))
+          )
       def __repr__(self):
         params = []
         self.repr_serialize(params)
@@ -944,9 +976,11 @@ class PYACC(element.Element):
             character_to_symbol[character] = symbol
             pyacc.terminal_symbols.append(
               PYACC.Symbol(
-                character_set = [character, character + 1]
+                token = character,
+                character_set = [symbol, symbol + 1]
               )
             )
+            print('symbol', symbol, 'token', character)
         elif isinstance(self[0], PYACC.ID):
           name = element.get_text(self[0], 0)
           if name in name_to_symbol:
@@ -955,19 +989,21 @@ class PYACC(element.Element):
           else:
             symbol = len(pyacc.terminal_symbols)
             name_to_symbol[name] = symbol
-            character = 0x100 + symbol # fix this later
             pyacc.terminal_symbols.append(
               PYACC.Symbol(
                 name = name,
-                character_set = [character, character + 1]
+                character_set = [symbol, symbol + 1]
               )
             )
         else:
           assert False
-        if precedence >= 0:
+        if self.user_token != -1:
+          assert pyacc.terminal_symbols[symbol].token == -1
+          pyacc.terminal_symbols[symbol].token = self.user_token
+        if precedence != -1:
           assert pyacc.terminal_symbols[symbol].precedence == -1
           pyacc.terminal_symbols[symbol].precedence = precedence
-        if associativity >= 0:
+        if associativity != -1:
           assert pyacc.terminal_symbols[symbol].associativity == -1
           pyacc.terminal_symbols[symbol].associativity = associativity
         return tag
@@ -1968,17 +2004,21 @@ class PYACC(element.Element):
               if isinstance(self[i][0], PYACC.Char):
                 character = ord(self[i][0].get_text())
                 assert character != 0 # would conflict with YYEOF
-                if character not in character_to_symbol:
-                  character_to_symbol[character] = len(pyacc.terminal_symbols)
+                if character in character_to_symbol:
+                  symbol = character_to_symbol[character]
+                  assert symbol >= 0
+                else:
+                  symbol = len(pyacc.terminal_symbols)
+                  character_to_symbol[character] = symbol
                   pyacc.terminal_symbols.append(
                     PYACC.Symbol(
-                      character_set = [character, character + 1]
+                      token = character,
+                      character_set = [symbol, symbol + 1]
                     )
                   )
-                pyacc.characters_used.add(character) # remove this later
                 production.append(
                   grammar.Grammar.Production.Symbol(
-                    terminal_set = [character, character + 1]
+                    terminal_set = pyacc.terminal_symbols[symbol].character_set
                   )
                 )
               elif isinstance(self[i][0], PYACC.ID):
@@ -2178,8 +2218,7 @@ class PYACC(element.Element):
       return 'ast.PYACC.ValueReference({0:s})'.format(', '.join(params))
     # GENERATE END
 
-
-  # GENERATE ELEMENT(list(ref) prologue_text, set(int) characters_used, int precedences, list(ref) terminal_symbols, list(ref) nonterminal_symbols, ref grammar, list(ref) actions_braced_code) BEGIN
+  # GENERATE ELEMENT(list(ref) prologue_text, int precedences, list(ref) terminal_symbols, list(ref) nonterminal_symbols, ref grammar, list(ref) actions_braced_code) BEGIN
   def __init__(
     self,
     tag = 'PYACC',
@@ -2187,7 +2226,6 @@ class PYACC(element.Element):
     text = '',
     children = [],
     prologue_text = [],
-    characters_used = set(),
     precedences = -1,
     terminal_symbols = [],
     nonterminal_symbols = [],
@@ -2202,11 +2240,6 @@ class PYACC(element.Element):
       children
     )
     self.prologue_text = prologue_text
-    self.characters_used = (
-      set([element.deserialize_int(i) for i in characters_used.split()])
-    if isinstance(characters_used, str) else
-      characters_used
-    )
     self.precedences = (
       element.deserialize_int(precedences)
     if isinstance(precedences, str) else
@@ -2222,10 +2255,6 @@ class PYACC(element.Element):
       'prologue_text',
       ' '.join([element.serialize_ref(i, ref_list) for i in self.prologue_text])
     )
-    self.set(
-      'characters_used',
-      ' '.join([element.serialize_int(i) for i in sorted(self.characters_used)])
-    )
     self.set('precedences', element.serialize_int(self.precedences))
     self.set(
       'terminal_symbols',
@@ -2246,12 +2275,6 @@ class PYACC(element.Element):
       element.deserialize_ref(i, ref_list)
       for i in self.get('prologue_text', '').split()
     ]
-    self.characters_used = set(
-      [
-        element.deserialize_int(i)
-        for i in self.get('characters_used', '').split()
-      ]
-    )
     self.precedences = element.deserialize_int(self.get('precedences', '-1'))
     self.terminal_symbols = [
       element.deserialize_ref(i, ref_list)
@@ -2272,7 +2295,6 @@ class PYACC(element.Element):
       PYACC if factory is None else factory
     )
     result.prologue_text = self.prologue_text
-    result.characters_used = self.characters_used
     result.precedences = self.precedences
     result.terminal_symbols = self.terminal_symbols
     result.nonterminal_symbols = self.nonterminal_symbols
@@ -2287,12 +2309,6 @@ class PYACC(element.Element):
           ', '.join([repr(i) for i in self.prologue_text])
         )
       )
-    if len(self.characters_used):
-      params.append(
-        'characters_used = set([{0:s}])'.format(
-          ', '.join([repr(i) for i in sorted(self.characters_used)])
-        )
-      )
     if self.precedences != -1:
       params.append(
         'precedences = {0:s}'.format(repr(self.precedences))
@@ -2327,11 +2343,11 @@ class PYACC(element.Element):
   def post_process(self):
     # variables that will be serialized
     self.prologue_text = []
-    self.characters_used = set()
     self.precedences = 0
     self.terminal_symbols = [
-      PYACC.Symbol(name = 'error', character_set = [0x100, 0x101]),
-      PYACC.Symbol(name = '$undefined', character_set = [0x101, 0x102])
+      PYACC.Symbol(name = '$eof', token = 0, character_set = [0, 1]),
+      PYACC.Symbol(name = 'error', character_set = [1, 2]),
+      PYACC.Symbol(name = '$undefined', character_set = [2, 3])
     ]
     self.nonterminal_symbols = []
     self.grammar = grammar.Grammar(
@@ -2361,17 +2377,25 @@ class PYACC(element.Element):
         name_to_symbol
       )
 
+    # fill in token numbers that are not characters or overridden by user
+    token = 0x100
+    for i in self.terminal_symbols:
+      if i.token == -1:
+        i.token = token
+        token += 1
+
     # if start symbol not specified, use first nonterminal defined in file
     if len(self.grammar[0][0].name) == 0:
       self.grammar[0][0].name = self.nonterminal_symbols[0].name
 
     # look up rule names and substitute appropriate character_set for each
-    self.grammar.n_terminals = 0x100 + len(self.terminal_symbols)
+    self.grammar.n_terminals = len(self.terminal_symbols)
     self.grammar.post_process(
       dict(
         [
           (i.name, (i.character_set, []))
           for i in self.terminal_symbols
+          if len(i.name) # fix this later
         ] +
         [
           (i.name, ([], i.character_set))
index 7b56710..d55f12e 100644 (file)
@@ -35,11 +35,11 @@ class BisonLR1DFA:
     # note: the goto table is transposed with respect to the action table,
     # so the row in the table corresponds to the yypact[]/yypgoto[] index,
     # and the column in the table is what gets added to yypact[]/yypgoto[]
-    orig_action_table = numpy.zeros(
+    action_table = numpy.zeros(
       (len(lr1dfa.states), lr1dfa.n_terminals),
       numpy.int16
     )
-    orig_goto_table = numpy.zeros(
+    goto_table = numpy.zeros(
       (len(lr1dfa.productions), len(lr1dfa.states)),
       numpy.int16
     )
@@ -49,28 +49,24 @@ class BisonLR1DFA:
       terminal0 = 0
       for j in range(len(terminal_breaks)):
         terminal1 = terminal_breaks[j]
-        orig_action_table[i, terminal0:terminal1] = actions[j]
+        action_table[i, terminal0:terminal1] = actions[j]
         terminal0 = terminal1
       assert terminal0 == lr1dfa.n_terminals
 
       nonterminal0 = 0
       for j in range(len(nonterminal_breaks)):
         nonterminal1 = nonterminal_breaks[j]
-        orig_goto_table[nonterminal0:nonterminal1, i] = gotos[j]
+        goto_table[nonterminal0:nonterminal1, i] = gotos[j]
         nonterminal0 = nonterminal1
       assert nonterminal0 == len(lr1dfa.productions)
 
-    # permute and combine columns/rows on the basis of the translate vectors
-    action_table = numpy.zeros(
-      (len(lr1dfa.states), n_terminals),
-      numpy.int16
-    )
-    action_table[:, translate_terminals] = orig_action_table
-    goto_table = numpy.zeros(
+    # permute and combine rows based on the nonterminal translate vector
+    new_goto_table = numpy.zeros(
       (n_nonterminals, len(lr1dfa.states)),
       numpy.int16
     )
-    goto_table[translate_nonterminals, :] = orig_goto_table[1:] # ignore start
+    new_goto_table[translate_nonterminals, :] = goto_table[1:] # ignore start
+    goto_table = new_goto_table
 
     # manipulate the table entries as follows:
     # - ensure there is no shift or goto 0 (cannot return to starting state)
@@ -255,40 +251,31 @@ class BisonLR1DFA:
     self.n_terminals = n_terminals
 
 def generate(pyacc, skel_file, out_file):
-  # generate the tables using an expanded character set, consisting of:
-  # the full 0x100 character literals (whether they are referenced or not)
-  # the terminals with defined names (for each terminal, one character)
-  # the nonterminals (for each nonterminal, one character per production)
   lr1dfa = pyacc.grammar.to_lr1().to_lalr1()
 
-  # squash this down to the set of terminals, then the set of character
-  # literals that are referenced, then only one character per nonterminal
-  # (nonterminals referenced by pyacc.nonterminal_symbols[] index, rather
-  # than the internal way as only the set of lr1dfa.productions[] indices)
-
-  # generate translate table for character literals and terminal symbols
-  n_terminals = 1 # room for '$eof'
-  translate_terminals = numpy.zeros(
-    (lr1dfa.n_terminals,),
+  # generate translate table for terminal symbols
+  # this undoes yacc/bison's rather wasteful mapping of 0x00..0xff to literal
+  # characters, and also accommodates any token value overrides given by the
+  # user, yielding a consecutive set of terminal numbers that are really used
+  # (matching pyacc.terminal_symbols[*].character_set, and hence the lr1dfa)
+  reverse_translate_terminals = numpy.array(
+    [i.token for i in pyacc.terminal_symbols],
     numpy.int16
   )
-  translate_terminals[1:0x100] = 2 # '$undefined'
-  for i in pyacc.terminal_symbols:
-    if len(i.name): # fix this later
-      for j in range(0, len(i.character_set), 2):
-        translate_terminals[
-          i.character_set[j]:
-          i.character_set[j + 1]
-        ] = n_terminals
-      n_terminals += 1
-  for i in sorted(pyacc.characters_used):
-    translate_terminals[i] = n_terminals
-    n_terminals += 1
+  translate_terminals = numpy.full(
+    (numpy.max(reverse_translate_terminals) + 1,),
+    2, # '$undefined'
+    numpy.int16
+  )
+  translate_terminals[reverse_translate_terminals] = numpy.arange(
+    len(pyacc.terminal_symbols),
+    dtype = numpy.int16
+  )
 
   # generate translate table for nonterminal symbols
   # this is effectively a map from productions back to nonterminal symbols
   # we do not generate an entry for the first production (start production)
-  n_nonterminals = 0
+  nonterminal = 0
   translate_nonterminals = numpy.zeros(
     (len(lr1dfa.productions) - 1,),
     numpy.int16
@@ -298,15 +285,15 @@ def generate(pyacc, skel_file, out_file):
       translate_nonterminals[
         i.character_set[j] - 1:
         i.character_set[j + 1] - 1
-      ] = n_nonterminals
-    n_nonterminals += 1
+      ] = nonterminal
+    nonterminal += 1
 
   # translate and compress the tables
   bison_lr1dfa = BisonLR1DFA(
     lr1dfa,
-    n_terminals,
+    len(pyacc.terminal_symbols),
     translate_terminals,
-    n_nonterminals,
+    len(pyacc.nonterminal_symbols),
     translate_nonterminals
   )
 
@@ -329,12 +316,9 @@ def generate(pyacc, skel_file, out_file):
 '''.format(
               ','.join(
                 [
-                  '\n    {0:s} = {1:d}'.format(
-                    pyacc.terminal_symbols[i].name,
-                    0x100 + i
-                  )
-                  for i in range(2, len(pyacc.terminal_symbols))
-                  if len(pyacc.terminal_symbols[i].name) # fix this later
+                  '\n    {0:s} = {1:d}'.format(i.name, i.token)
+                  for i in pyacc.terminal_symbols[3:]
+                  if len(i.name)
                 ]
               )
             )
@@ -346,12 +330,9 @@ def generate(pyacc, skel_file, out_file):
 '''.format(
               ''.join(
                 [
-                  '#define {0:s} {1:d}\n'.format(
-                    pyacc.terminal_symbols[i].name,
-                    0x100 + i
-                  )
-                  for i in range(2, len(pyacc.terminal_symbols))
-                  if len(pyacc.terminal_symbols[i].name) # fix this later
+                  '#define {0:s} {1:d}\n'.format(i.name, i.token)
+                  for i in pyacc.terminal_symbols[3:]
+                  if len(i.name)
                 ]
               )
             )
@@ -375,15 +356,17 @@ def generate(pyacc, skel_file, out_file):
           x = 70
           yytname_lines = []
           for i in (
-            ['"$end"'] +
-            ['"{0:s}"'.format(i.name) for i in pyacc.terminal_symbols] +
             [
-              '"\'{0:s}\'"'.format(
-                chr(i)
-              if i >= 0x20 else
-                '\\\\x{0:02x}'.format(i)
-              )
-              for i in sorted(pyacc.characters_used)
+              (
+                '"{0:s}"'.format(i.name)
+              if len(i.name) else
+                '"\'{0:s}\'"'.format(
+                  chr(i.token)
+                if i.token >= 0x20 else
+                  '\\\\x{0:02x}'.format(i.token)
+                )
+              ) 
+              for i in pyacc.terminal_symbols
             ] +
             ['"{0:s}"'.format(i.name) for i in pyacc.nonterminal_symbols] +
             ['YY_NULLPTR']