Improve how first_nonterminal and last_terminal are handled (may be overridden)
authorNick Downing <downing.nick@gmail.com>
Wed, 18 Jul 2018 13:18:59 +0000 (23:18 +1000)
committerNick Downing <downing.nick@gmail.com>
Wed, 18 Jul 2018 13:19:14 +0000 (23:19 +1000)
ast.py

diff --git a/ast.py b/ast.py
index 2571e13..596c722 100644 (file)
--- a/ast.py
+++ b/ast.py
@@ -378,8 +378,7 @@ class PYACC(element.Element):
         section,
         production,
         character_to_symbol,
-        name_to_symbol,
-        last_terminal
+        name_to_symbol
       ):
         raise NotImplementedException
       def add_to_symbols(self, pyacc, last_action, symbols):
@@ -418,10 +417,9 @@ class PYACC(element.Element):
         section,
         production,
         character_to_symbol,
-        name_to_symbol,
-        last_terminal
+        name_to_symbol
       ):
-        return last_terminal
+        pass
       def add_to_symbols(self, pyacc, last_action, symbols):
         assert last_action is None
         return self[0]
@@ -562,8 +560,7 @@ class PYACC(element.Element):
         section,
         production,
         character_to_symbol,
-        name_to_symbol,
-        last_terminal
+        name_to_symbol
       ):
         self[0].post_process(
           pyacc,
@@ -576,7 +573,6 @@ class PYACC(element.Element):
         )
         assert production.precedence_terminal == -1
         production.precedence_terminal = self[0].terminal
-        return last_terminal
 
     class Symbol(Item):
       # GENERATE ELEMENT(int symbol) BEGIN
@@ -630,8 +626,7 @@ class PYACC(element.Element):
         section,
         production,
         character_to_symbol,
-        name_to_symbol,
-        last_terminal
+        name_to_symbol
       ):
         if isinstance(self[0], PYACC.Char):
           character = ord(self[0].get_text())
@@ -657,7 +652,8 @@ class PYACC(element.Element):
             )
         else:
           assert False
-        return self.symbol if self.symbol >= 0 else last_terminal
+        if self.symbol >= 0:
+          self.last_terminal = self.symbol
       def add_to_symbols(self, pyacc, last_action, symbols):
         assert last_action is None
         symbols.append(
@@ -667,7 +663,7 @@ class PYACC(element.Element):
         )
         return None
 
-    # GENERATE ELEMENT(int lhs_nonterminal, int precedence_terminal) BEGIN
+    # GENERATE ELEMENT(int lhs_nonterminal, int last_terminal, int precedence_terminal) BEGIN
     def __init__(
       self,
       tag = 'PYACC_Production',
@@ -675,6 +671,7 @@ class PYACC(element.Element):
       text = '',
       children = [],
       lhs_nonterminal = -1,
+      last_terminal = -1,
       precedence_terminal = -1
     ):
       element.Element.__init__(
@@ -689,6 +686,11 @@ class PYACC(element.Element):
       if isinstance(lhs_nonterminal, str) else
         lhs_nonterminal
       )
+      self.last_terminal = (
+        element.deserialize_int(last_terminal)
+      if isinstance(last_terminal, str) else
+        last_terminal
+      )
       self.precedence_terminal = (
         element.deserialize_int(precedence_terminal)
       if isinstance(precedence_terminal, str) else
@@ -697,10 +699,12 @@ class PYACC(element.Element):
     def serialize(self, ref_list, indent = 0):
       element.Element.serialize(self, ref_list, indent)
       self.set('lhs_nonterminal', element.serialize_int(self.lhs_nonterminal))
+      self.set('last_terminal', element.serialize_int(self.last_terminal))
       self.set('precedence_terminal', element.serialize_int(self.precedence_terminal))
     def deserialize(self, ref_list):
       element.Element.deserialize(self, ref_list)
       self.lhs_nonterminal = element.deserialize_int(self.get('lhs_nonterminal', '-1'))
+      self.last_terminal = element.deserialize_int(self.get('last_terminal', '-1'))
       self.precedence_terminal = element.deserialize_int(self.get('precedence_terminal', '-1'))
     def copy(self, factory = None):
       result = element.Element.copy(
@@ -708,6 +712,7 @@ class PYACC(element.Element):
         Production if factory is None else factory
       )
       result.lhs_nonterminal = self.lhs_nonterminal
+      result.last_terminal = self.last_terminal
       result.precedence_terminal = self.precedence_terminal
       return result
     def repr_serialize(self, params):
@@ -716,6 +721,10 @@ class PYACC(element.Element):
         params.append(
           'lhs_nonterminal = {0:s}'.format(repr(self.lhs_nonterminal))
         )
+      if self.last_terminal != -1:
+        params.append(
+          'last_terminal = {0:s}'.format(repr(self.last_terminal))
+        )
       if self.precedence_terminal != -1:
         params.append(
           'precedence_terminal = {0:s}'.format(repr(self.precedence_terminal))
@@ -736,19 +745,18 @@ class PYACC(element.Element):
     ):
       self.lhs_nonterminal = lhs_nonterminal
 
+      self.last_terminal = -1
       self.precedence_terminal = -1
-      last_terminal = -1
       for i in self:
-        last_terminal = i.post_process(
+        i.post_process(
           pyacc,
           section,
           self,
           character_to_symbol,
-          name_to_symbol,
-          last_terminal
+          name_to_symbol
         )
       if self.precedence_terminal == -1:
-        self.precedence_terminal = last_terminal
+        self.precedence_terminal = self.last_terminal
 
       character_set = pyacc.nonterminal_symbols[
         self.lhs_nonterminal
@@ -2317,6 +2325,8 @@ class PYACC(element.Element):
           -1, # associativity
           None # tag
         )
+        if pyacc.first_nonterminal == -1:
+          pyacc.first_nonterminal = self[0].nonterminal
         for i in self[1:]:
           i.post_process(
             pyacc,
@@ -2444,7 +2454,7 @@ class PYACC(element.Element):
       return 'ast.PYACC.ValueReference({0:s})'.format(', '.join(params))
     # GENERATE END
 
-  # GENERATE ELEMENT(list(ref) prologue_text, int precedences, list(ref) terminal_symbols, list(ref) nonterminal_symbols, int start_nonterminal, list(ref) productions) BEGIN
+  # GENERATE ELEMENT(list(ref) prologue_text, int precedences, list(ref) terminal_symbols, list(ref) nonterminal_symbols, list(ref) productions, int first_nonterminal, int start_nonterminal) BEGIN
   def __init__(
     self,
     tag = 'PYACC',
@@ -2455,8 +2465,9 @@ class PYACC(element.Element):
     precedences = -1,
     terminal_symbols = [],
     nonterminal_symbols = [],
-    start_nonterminal = -1,
-    productions = []
+    productions = [],
+    first_nonterminal = -1,
+    start_nonterminal = -1
   ):
     element.Element.__init__(
       self,
@@ -2473,12 +2484,17 @@ class PYACC(element.Element):
     )
     self.terminal_symbols = terminal_symbols
     self.nonterminal_symbols = nonterminal_symbols
+    self.productions = productions
+    self.first_nonterminal = (
+      element.deserialize_int(first_nonterminal)
+    if isinstance(first_nonterminal, str) else
+      first_nonterminal
+    )
     self.start_nonterminal = (
       element.deserialize_int(start_nonterminal)
     if isinstance(start_nonterminal, str) else
       start_nonterminal
     )
-    self.productions = productions
   def serialize(self, ref_list, indent = 0):
     element.Element.serialize(self, ref_list, indent)
     self.set(
@@ -2494,11 +2510,12 @@ class PYACC(element.Element):
       'nonterminal_symbols',
       ' '.join([element.serialize_ref(i, ref_list) for i in self.nonterminal_symbols])
     )
-    self.set('start_nonterminal', element.serialize_int(self.start_nonterminal))
     self.set(
       'productions',
       ' '.join([element.serialize_ref(i, ref_list) for i in self.productions])
     )
+    self.set('first_nonterminal', element.serialize_int(self.first_nonterminal))
+    self.set('start_nonterminal', element.serialize_int(self.start_nonterminal))
   def deserialize(self, ref_list):
     element.Element.deserialize(self, ref_list)
     self.prologue_text = [
@@ -2514,11 +2531,12 @@ class PYACC(element.Element):
       element.deserialize_ref(i, ref_list)
       for i in self.get('nonterminal_symbols', '').split()
     ]
-    self.start_nonterminal = element.deserialize_int(self.get('start_nonterminal', '-1'))
     self.productions = [
       element.deserialize_ref(i, ref_list)
       for i in self.get('productions', '').split()
     ]
+    self.first_nonterminal = element.deserialize_int(self.get('first_nonterminal', '-1'))
+    self.start_nonterminal = element.deserialize_int(self.get('start_nonterminal', '-1'))
   def copy(self, factory = None):
     result = element.Element.copy(
       self,
@@ -2528,8 +2546,9 @@ class PYACC(element.Element):
     result.precedences = self.precedences
     result.terminal_symbols = self.terminal_symbols
     result.nonterminal_symbols = self.nonterminal_symbols
-    result.start_nonterminal = self.start_nonterminal
     result.productions = self.productions
+    result.first_nonterminal = self.first_nonterminal
+    result.start_nonterminal = self.start_nonterminal
     return result
   def repr_serialize(self, params):
     element.Element.repr_serialize(self, params)
@@ -2555,16 +2574,20 @@ class PYACC(element.Element):
           ', '.join([repr(i) for i in self.nonterminal_symbols])
         )
       )
-    if self.start_nonterminal != -1:
-      params.append(
-        'start_nonterminal = {0:s}'.format(repr(self.start_nonterminal))
-      )
     if len(self.productions):
       params.append(
         'productions = [{0:s}]'.format(
           ', '.join([repr(i) for i in self.productions])
         )
       )
+    if self.first_nonterminal != -1:
+      params.append(
+        'first_nonterminal = {0:s}'.format(repr(self.first_nonterminal))
+      )
+    if self.start_nonterminal != -1:
+      params.append(
+        'start_nonterminal = {0:s}'.format(repr(self.start_nonterminal))
+      )
   def __repr__(self):
     params = []
     self.repr_serialize(params)
@@ -2581,7 +2604,6 @@ class PYACC(element.Element):
       PYACC.Symbol(name = '$undefined')
     ]
     self.nonterminal_symbols = []
-    self.start_nonterminal = -1
     self.productions = []
 
     # variables that won't be serialized
@@ -2591,12 +2613,16 @@ class PYACC(element.Element):
     name_to_symbol = {'error': 1}
 
     # perform the semantic analysis pass
+    self.first_nonterminal = -1
+    self.start_nonterminal = -1
     for i in self:
       i.post_process(
         self,
         character_to_symbol,
         name_to_symbol
       )
+    if self.start_nonterminal == -1:
+      self.start_nonterminal = self.first_nonterminal
 
     # fill in token numbers that are not characters or overridden by user
     token = 0x100
@@ -2605,11 +2631,6 @@ class PYACC(element.Element):
         i.character_set = [token, token + 1]
         token += 1
 
-    # fill in start nonterminal if not overridden by user
-    if self.start_nonterminal == -1:
-      assert len(self.nonterminal_symbols) != 0
-      self.start_nonterminal = 0
-
   def to_lr1(self):
     _lr1 = lr1.LR1(
       [