Simplify handling of field list
authorNick Downing <nick@ndcode.org>
Sat, 26 Jan 2019 23:00:48 +0000 (10:00 +1100)
committerNick Downing <nick@ndcode.org>
Sat, 26 Jan 2019 23:00:48 +0000 (10:00 +1100)
ast.py

diff --git a/ast.py b/ast.py
index fc27870..1bdf4e5 100644 (file)
--- a/ast.py
+++ b/ast.py
@@ -802,18 +802,20 @@ class AST(element.Element):
       # GENERATE END
       def generate_class_or_field_def(self, context):
         class_name = self[0].get_text()
+        fields_save = context.fields
         if len(self[1]):
           base_class = '.'.join([i.get_text() for i in self[1][0]])
+          for i in range(len(context.base_classes) - 1, -1, -1):
+            if base_class in context.base_classes[i]:
+              full_base_class = '.'.join(context.stack[:i] + [base_class])
+              context.fields = context.base_classes[i][base_class].copy()
+              break
+          else:
+            assert False
         else:
           base_class = 'element.Element'
-        for i in range(len(context.base_classes) - 1, -1, -1):
-          if base_class in context.base_classes[i]:
-            full_base_class = '.'.join(context.stack[:i] + [base_class])
-            context.base_classes[-1][class_name] = \
-              context.base_classes[i][base_class].copy()
-            break
-        else:
-          assert False
+          full_base_class = 'element.Element'
+          context.fields = []
         context.fout.write(
           '{0:s}class {1:s}({2:s}):\n'.format(
             context.indent,
@@ -826,12 +828,9 @@ class AST(element.Element):
         context.stack.append(class_name)
         context.classes.append('.'.join(context.stack))
         context.base_classes.append({})
-        fields_save = context.fields
-        context.fields = []
+        n_base_fields = len(context.fields)
         for i in self[2]:
           i.generate_class_or_field_def(context)
-        i = len(context.base_classes[-2][class_name])
-        context.base_classes[-2][class_name].extend(context.fields)
 
         context.fout.write(
           '''{0:s}def __init__(
@@ -863,7 +862,7 @@ class AST(element.Element):
                   name,
                   default_value[type]
                 )
-                for type, name in context.base_classes[-2][class_name]
+                for type, name in context.fields
               ]
             ),
             context.indent,
@@ -880,13 +879,13 @@ class AST(element.Element):
                   context.indent,
                   name
                 )
-                for type, name in context.base_classes[-2][class_name][:i]
+                for type, name in context.fields[:n_base_fields]
               ]
             ),
             context.indent
           )
         )
-        for type, name in context.fields:
+        for type, name in context.fields[n_base_fields:]:
           if type == 'ref' or type == 'list(ref)' or type == 'set(ref)' or type == 'str':
             context.fout.write(
               '''{0:s}  self.{1:s} = {2:s}
@@ -954,13 +953,13 @@ class AST(element.Element):
                 context.indent
               )
             )
-        if len(context.fields):
+        if len(context.fields) > n_base_fields:
           context.fout.write(
             '''{0:s}def serialize(self, ref_list):
 {1:s}  {2:s}.serialize(self, ref_list)
 '''.format(context.indent, context.indent, full_base_class)
           )
-          for type, name in context.fields:
+          for type, name in context.fields[n_base_fields:]:
             if type[:5] == 'list(' and type[-1:] == ')':
               subtype = type[5:-1]
               context.fout.write(
@@ -1013,7 +1012,7 @@ class AST(element.Element):
 {1:s}  {2:s}.deserialize(self, ref_list)
 '''.format(context.indent, context.indent, full_base_class)
           )
-          for type, name in context.fields:
+          for type, name in context.fields[n_base_fields:]:
             if type[:5] == 'list(' and type[-1:] == ')':
               subtype = type[5:-1]
               context.fout.write(
@@ -1088,13 +1087,13 @@ class AST(element.Element):
                   name,
                   name
                 )
-                for _, name in context.fields
+                for _, name in context.fields[n_base_fields:]
               ]
             ),
             context.indent
           )
         )
-        if len(context.fields):
+        if len(context.fields) > n_base_fields:
           context.fout.write(
             '''{0:s}def repr_serialize(self, params):
 {1:s}  {2:s}.repr_serialize(self, params)
@@ -1104,7 +1103,7 @@ class AST(element.Element):
               full_base_class
             )
           )
-          for type, name in context.fields:
+          for type, name in context.fields[n_base_fields:]:
             if type[:5] == 'list(' and type[-1:] == ')':
               subtype = type[5:-1]
               context.fout.write(
@@ -1185,6 +1184,7 @@ class AST(element.Element):
           context.base_classes[-1][
             '{0:s}.{1:s}'.format(class_name, temp_base_class)
           ] = temp_fields
+        context.base_classes[-1][class_name] = context.fields
         context.fields = fields_save
     class FieldDef(element.Element):
       # GENERATE ELEMENT() BEGIN