Make serializer use a temporary tree of xml.etree.ElementTree.Element
authorNick Downing <nick@ndcode.org>
Mon, 28 Jan 2019 23:49:45 +0000 (10:49 +1100)
committerNick Downing <nick@ndcode.org>
Tue, 29 Jan 2019 02:01:54 +0000 (13:01 +1100)
pitree.t
skel/element.py

index 242eaef..2dd216c 100644 (file)
--- a/pitree.t
+++ b/pitree.t
@@ -205,7 +205,7 @@ def generate_class_or_field_def(self, context):
   if len(context.fields) > n_base_fields:
     context.fout.write(
       '''{0:s}def serialize(self, ref_list):
-{1:s}  {2:s}.serialize(self, ref_list)
+{1:s}  _element = {2:s}.serialize(self, ref_list)
 '''.format(
         context.indent,
         context.indent,
@@ -217,23 +217,18 @@ def generate_class_or_field_def(self, context):
     for i in context.fields[n_base_fields:]:
       context.field_name = i[1].get_text()
       context.fout.write(
-        '''{0:s}  if self.{1:s} == {2:s}:
-{3:s}    self.attrib.pop('{4:s}', None)
-{5:s}  else:
-{6:s}    self.set(
-{7:s}      '{8:s}',
-{9:s}      json.dumps(
-{10:s}        {11:s}
-{12:s}      )
-{13:s}    )
+        '''{0:s}  if self.{1:s} != {2:s}:
+{3:s}    _element.set(
+{4:s}      '{5:s}',
+{6:s}      json.dumps(
+{7:s}        {8:s}
+{9:s}      )
+{10:s}    )
 '''.format(
           indent_save2,
           context.field_name,
           i[2][0].generate_expression(i[0]),
           indent_save2,
-          context.field_name,
-          indent_save2,
-          indent_save2,
           indent_save2,
           context.field_name,
           indent_save2,
@@ -246,7 +241,7 @@ def generate_class_or_field_def(self, context):
           indent_save2
         )
       if len(i[2]) else
-        '''{0:s}  self.set(
+        '''{0:s}  _element.set(
 {1:s}    '{2:s}',
 {3:s}    json.dumps(
 {4:s}      {5:s}
@@ -268,9 +263,11 @@ def generate_class_or_field_def(self, context):
       )
     context.indent = indent_save2
     context.fout.write(
-      '''{0:s}def deserialize(self, ref_list):
-{1:s}  {2:s}.deserialize(self, ref_list)
+      '''{0:s}  return _element
+{1:s}def deserialize(self, ref_list):
+{2:s}  {3:s}.deserialize(self, ref_list)
 '''.format(
+        context.indent,
         context.indent,
         context.indent,
         full_base_class
index 7b92345..d3b534e 100644 (file)
@@ -160,18 +160,29 @@ class Element:
     self.tag = tag
     self.attrib = attrib.copy()
     self.tail = None
-    self.ref = -1
-    self.seen = False
+    self.visited = None # (element, ref, seen)
     set_text(self, 0, text)
     self.children = children.copy()
   def serialize(self, ref_list):
-    for i in self:
-      # parented, enforce that child can only be parented at most once
-      # (although there can be unlimited numbers of numeric refs to it)
-      assert not i.seen
-      i.seen = True
-      if i.ref == -1:
-        i.serialize(ref_list)
+    element = self.visited[0]
+    text = get_text(self, 0)
+    if len(text):
+      element.text = text
+    for i in range(len(self)):
+      child = self[i]
+      if child.visited is None:
+        child_element = xml.etree.ElementTree.Element(child.tag)
+        child.visited = (child_element, -1, True)
+        child.serialize(ref_list)
+      else:
+        child_element, child_ref, child_seen = child.visited
+        assert not child_seen
+        child.visited = (child_element, child_ref, True)
+      tail = get_text(self, i + 1)
+      if len(tail):
+        child_element.tail = tail
+      element.append(child_element)
+    return element
   def deserialize(self, ref_list):
     for i in self:
       i.deserialize(ref_list)
@@ -186,16 +197,20 @@ def serialize_ref(value, ref_list):
   if value is None:
     ref = -1
   else:
-    ref = value.ref
-    if ref == -1:
+    if value.visited is None:
+      element = xml.etree.ElementTree.Element(value.tag)
       ref = len(ref_list)
       ref_list.append(value)
-      value.ref = ref
-      value.set('ref', str(ref))
-      # this doesn't set the seen flag, so it will be parented by the
-      # root, unless it is already parented or gets parented later on
-      if not value.seen:
-        value.serialize(ref_list)
+      element.set('ref', str(ref))
+      value.visited = (element, ref, False)
+      value.serialize(ref_list)
+    else:
+      element, ref, seen = value.visited
+      if ref == -1:
+        ref = len(ref_list)
+        ref_list.append(value)
+        element.set('ref', str(ref))
+        value.visited = (element, ref, seen)
   return ref
 
 def deserialize_ref(value, ref_list):
@@ -205,23 +220,20 @@ def deserialize_ref(value, ref_list):
 def serialize(value, fout, encoding = 'unicode'):
   ref_list = []
   serialize_ref(value, ref_list)
-  parents = [i for i in ref_list if not i.seen]
-  root = Element('root', children = parents)
-  for i in range(len(root)):
-    set_text(root, i, '\n  ')
-  set_text(root, len(root), '\n')
+  parents = [i for i in ref_list if not i.visited[2]]
+  root = xml.etree.ElementTree.Element('root')
+  root[:] = [i.visited[0] for i in parents]
+  root.text = '\n  '
+  for i in range(len(root) - 1):
+    root[i].tail = '\n  '
+  root[-1].tail = '\n'
   root.tail = '\n'
   xml.etree.ElementTree.ElementTree(root).write(fout, encoding)
-  for i in root:
-    i.tail = None
-  for i in ref_list:
-    i.ref = -1
-    del i.attrib['ref']
   i = 0
   while i < len(parents):
-    for j in parents[i]:
-      j.seen = False
-      parents.append(j)
+    child = parents[i]
+    child.visited = None
+    parents.extend(child[:])
     i += 1
 
 def deserialize(fin, factory = Element, encoding = 'unicode'):