Rationalize constructor parameters, create mutable default in the constructor rather...
authorNick Downing <nick@ndcode.org>
Tue, 29 Jan 2019 09:49:03 +0000 (20:49 +1100)
committerNick Downing <nick@ndcode.org>
Tue, 29 Jan 2019 09:49:03 +0000 (20:49 +1100)
pitree.t
skel/element.py

index e1a06a8..6b74fd3 100644 (file)
--- a/pitree.t
+++ b/pitree.t
@@ -98,7 +98,7 @@ class Context:
     self.field_name = field_name
 
 def factory(tag, *args, **kwargs):
-  return tag_to_class.get(tag, element.Element)(tag, *args, **kwargs)
+  return tag_to_class[tag](*args, **kwargs)
 
 @method(AST.Section2.ClassOrFieldDef)
 def generate_class_or_field_def(self, context):
@@ -120,89 +120,96 @@ def generate_class_or_field_def(self, context):
     base_class = 'element.Element'
     full_base_class = 'element.Element'
     context.fields = []
-  context.fout.write(
-    '{0:s}class {1:s}({2:s}):\n'.format(
-      context.indent,
-      class_name,
-      base_class
-    )
-  )
   indent_save = context.indent
   context.indent += '  '
   context.stack.append(class_name)
   context.classes.append('.'.join(context.stack))
   context.base_classes.append({})
-  n_base_fields = len(context.fields)
-  for i in self[2]:
-    i.generate_class_or_field_def(context)
   context.fout.write(
-    '''{0:s}def __init__(
-{1:s}  self,
-{2:s}  tag = '{3:s}',
-{4:s}  text = '',
-{5:s}  children = []{6:s}
-{7:s}):
-{8:s}  {9:s}.__init__(
-{10:s}    self,
-{11:s}    tag,
-{12:s}    text,
-{13:s}    children{14:s}
-{15:s}  )
-{16:s}'''.format(
-      context.indent,
-      context.indent,
-      context.indent,
-      '_'.join(context.stack),
-      context.indent,
-      context.indent,
-      ''.join(
-        [
-          ',\n{0:s}  {1:s}{2:s}'.format(
-            context.indent,
-            i[1].get_text(),
-            (
-              ' = {0:s}'.format(i[2][0].generate_expression(i[0]))
-            if len(i[2]) else
-              ''
-            )
-          )
-          for i in context.fields
-        ]
-      ),
-      context.indent,
-      context.indent,
-      full_base_class,
-      context.indent,
-      context.indent,
-      context.indent,
-      context.indent,
-      ''.join(
-        [
-          ',\n{0:s}    {1:s}'.format(
-            context.indent,
-            i[1].get_text()
-          )
-          for i in context.fields[:n_base_fields]
-        ]
-      ),
-      context.indent,
-      ''.join(
-        [
-          '{0:s}  self.{1:s} = {2:s}\n'.format(
-            context.indent,
-            i[1].get_text(),
-            i[1].get_text()
-          )
-          for i in context.fields[n_base_fields:]
-        ]
-      )
+    '''{0:s}class {1:s}({2:s}):
+{3:s}  tag = '{4:s}'
+'''.format(
+      indent_save,
+      class_name,
+      base_class,
+      indent_save,
+      '_'.join(context.stack)
     )
   )
+  n_base_fields = len(context.fields)
+  for i in self[2]:
+    i.generate_class_or_field_def(context)
   if len(context.fields) > n_base_fields:
     context.fout.write(
-      '''{0:s}def serialize(self, ref_list):
-{1:s}  _element = {2:s}.serialize(self, ref_list)
+      '''{0:s}def __init__(
+{1:s}  self,
+{2:s}  text = None,
+{3:s}  children = None{4:s}
+{5:s}):
+{6:s}  {7:s}.__init__(
+{8:s}    self,
+{9:s}    text,
+{10:s}    children{11:s}
+{12:s}  )
+{13:s}{14:s}def serialize(self, ref_list):
+{15:s}  _element = {16:s}.serialize(self, ref_list)
 '''.format(
+        context.indent,
+        context.indent,
+        context.indent,
+        context.indent,
+        ''.join(
+          [
+            ',\n{0:s}  {1:s}{2:s}'.format(
+              context.indent,
+              i[1].get_text(),
+              (
+                (
+                  ' = None'
+                if i[0].is_mutable() else
+                  ' = {0:s}'.format(i[2][0].generate_expression(i[0]))
+                )
+              if len(i[2]) else
+                ''
+              )
+            )
+            for i in context.fields
+          ]
+        ),
+        context.indent,
+        context.indent,
+        full_base_class,
+        context.indent,
+        context.indent,
+        context.indent,
+        ''.join(
+          [
+            ',\n{0:s}    {1:s}'.format(
+              context.indent,
+              i[1].get_text()
+            )
+            for i in context.fields[:n_base_fields]
+          ]
+        ),
+        context.indent,
+        ''.join(
+          [
+            '{0:s}  self.{1:s} = {2:s}\n'.format(
+              context.indent,
+              i[1].get_text(),
+              (
+                '{0:s} if {1:s} is None else {2:s}'.format(
+                  i[2][0].generate_expression(i[0]),
+                  i[1].get_text(),
+                  i[1].get_text()
+                )
+              if i[0].is_mutable() else
+                i[1].get_text()
+              )
+            )
+            for i in context.fields[n_base_fields:]
+          ]
+        ),
         context.indent,
         context.indent,
         full_base_class
@@ -296,34 +303,6 @@ def generate_class_or_field_def(self, context):
         )
       )
     context.indent = indent_save2
-  context.fout.write(
-    '''{0:s}def copy(self, factory = None):
-{1:s}  result = {2:s}.copy(
-{3:s}    self,
-{4:s}    {5:s} if factory is None else factory
-{6:s}  ){7:s}
-{8:s}  return result
-'''.format(
-      context.indent,
-      context.indent,
-      full_base_class,
-      context.indent,
-      context.indent,
-      class_name,
-      context.indent,
-      ''.join(
-        [
-          '\n{0:s}  result.{1:s} = self.{2:s}'.format(
-            context.indent,
-            i[1].get_text(),
-            i[1].get_text()
-          )
-          for i in context.fields[n_base_fields:]
-        ]
-      ),
-      context.indent
-    )
-  )
   context.indent = indent_save
   del context.stack[-1]
   for temp_base_class, temp_fields in context.base_classes.pop().items():
@@ -629,6 +608,19 @@ def generate_deserialize(self, context, expr):
   return result
 del generate_deserialize
 
+@method(AST.Type)
+def is_mutable(self):
+  return True
+@method(AST.TypeInt)
+def is_mutable(self):
+  return False
+@method(AST.TypeBool)
+def is_mutable(self):
+  return False
+@method(AST.TypeStr)
+def is_mutable(self):
+  return False
+
 @method(AST.Identifier)
 def get_text(self):
   return element.get_text(self, 0)
index a8d0c7f..b1eb364 100644 (file)
 import xml.etree.ElementTree
 
 class Element:
-  def __init__(self, tag = 'Element', text = '', children = []):
-    self.tag = tag
-    self.tail = None
+  tag = 'Element'
+  def __init__(self, text = None, children = None):
     self.visited = None # (element, ref, seen)
-    self.text = [text]
-    self.children = children.copy()
-    self.text.extend(['' for i in range(len(children))]) # bogus
+    self.children = [] if children is None else children
+    assert isinstance(self.children, list)
+    self.text = (
+      ['' for i in range(len(self.children) + 1)]
+    if text is None else
+      text
+    )
+    assert isinstance(self.text, list)
+    assert len(self.text) == len(self.children) + 1
   def serialize(self, ref_list):
     element = self.visited[0]
     if len(self.text[0]):
@@ -44,11 +49,6 @@ class Element:
     return element
   def deserialize(self, element, ref_list):
     pass
-  def copy(self, factory = None):
-    result = (Element if factory is None else factory)(self.tag, self.attrib)
-    result.text = self.text
-    result[:] = [i.copy() for i in self]
-    return result
 
 def serialize_ref(value, ref_list):
   if value is None: