Honour specification of field default value, allow mandatory field (no default)
authorNick Downing <nick@ndcode.org>
Sun, 27 Jan 2019 23:10:04 +0000 (10:10 +1100)
committerNick Downing <nick@ndcode.org>
Mon, 28 Jan 2019 10:27:40 +0000 (21:27 +1100)
ast.py
pitree.l
skel/element.py [new file with mode: 0644]

diff --git a/ast.py b/ast.py
index e8a8fd9..3314281 100644 (file)
--- a/ast.py
+++ b/ast.py
@@ -1,35 +1,6 @@
 import element
 import sys
 
-default_value = {
-  'bool': 'False',
-  'int': '-1',
-  'ref': 'None',
-  'str': '\'\'',
-  'list(bool)': '[]',
-  'list(int)': '[]',
-  'list(ref)': '[]',
-  'list(str)': '[]',
-  'set(bool)': 'set()',
-  'set(int)': 'set()',
-  'set(ref)': 'set()',
-  'set(str)': 'set()'
-}
-default_value_str = {
-  'bool': 'false',
-  'int': '-1',
-  'ref': '-1',
-  'str': '',
-  'list(bool)': '',
-  'list(int)': '',
-  'list(ref)': '',
-  'list(str)': '',
-  'set(bool)': '',
-  'set(int)': '',
-  'set(ref)': '',
-  'set(str)': ''
-}
-
 class Context:
   def __init__(
     self,
@@ -99,6 +70,10 @@ class AST(element.Element):
       )
       return result
     # GENERATE END
+    def generate_expression(self, context):
+      raiseNotImplementedError
+    def generate_expression_serialized(self, context):
+      raiseNotImplementedError
   class Type(element.Element):
     # GENERATE ELEMENT() BEGIN
     def __init__(
@@ -286,6 +261,10 @@ class AST(element.Element):
       result.value = self.value
       return result
     # GENERATE END
+    def generate_expression(self, context):
+      return "True" if self.value else "False"
+    def generate_expression_serialized(self, context):
+      return "true" if self.value else "false"
   class LiteralInt(Expression):
     # GENERATE ELEMENT(str sign, int base, str digits) BEGIN
     def __init__(
@@ -337,6 +316,13 @@ class AST(element.Element):
       result.digits = self.digits
       return result
     # GENERATE END
+    def get_value(self):
+      value = int(self.digits, self.base)
+      return -value if len(self.sign) else value
+    def generate_expression(self, context):
+      return str(self.get_value())
+    def generate_expression_serialized(self, context):
+      return str(self.get_value())
   class LiteralList(Expression):
     # GENERATE ELEMENT() BEGIN
     def __init__(
@@ -360,6 +346,22 @@ class AST(element.Element):
       )
       return result
     # GENERATE END
+    def generate_expression(self, context):
+      return '[{0:s}]'.format(
+        ', '.join(
+          [
+            i.generate_expression(context)
+            for i in self
+          ]
+        )
+      )
+    def generate_expression_serialized(self, context):
+      return ' '.join(
+        [
+          i.generate_expression_serialized(context)
+          for i in self
+        ]
+      )
   class LiteralRef(Expression):
     # GENERATE ELEMENT() BEGIN
     def __init__(
@@ -383,6 +385,10 @@ class AST(element.Element):
       )
       return result
     # GENERATE END
+    def generate_expression(self, context):
+      return 'None'
+    def generate_expression_serialized(self, context):
+      return '-1'
   class LiteralSet(Expression):
     # GENERATE ELEMENT() BEGIN
     def __init__(
@@ -406,6 +412,26 @@ class AST(element.Element):
       )
       return result
     # GENERATE END
+    def generate_expression(self, context):
+      return (
+        'set([{0:s}])'.format(
+          ', '.join(
+            [
+              i.generate_expression(context)
+              for i in self
+            ]
+          )
+        )
+      if len(self) else
+        'set()'
+      )
+    def generate_expression_serialized(self, context):
+      return ' '.join(
+        [
+          i.generate_expression_serialized(context)
+          for i in self
+        ]
+      )
   class LiteralStr(Expression):
     # GENERATE ELEMENT() BEGIN
     def __init__(
@@ -429,6 +455,12 @@ class AST(element.Element):
       )
       return result
     # GENERATE END
+    def generate_expression(self, context):
+      return "'{0:s}'".format(
+        self[0].get_text().replace('\\', '\\\\').replace('\'', '\\\'')
+      )
+    def generate_expression_serialized(self, context):
+      return self[0].get_text()
   class Text(element.Element):
     class Escape(element.Element):
       # GENERATE ELEMENT(int value) BEGIN
@@ -487,6 +519,14 @@ class AST(element.Element):
       )
       return result
     # GENERATE END
+    def get_text(self):
+      result = []
+      for i in range(len(self)):
+        result.extend(
+          [element.get_text(self, i), chr(self[i].value)]
+        )
+      result.append(element.get_text(self, len(self)))
+      return ''.join(result)
   class TypeBool(Type):
     # GENERATE ELEMENT() BEGIN
     def __init__(
@@ -828,10 +868,14 @@ class AST(element.Element):
             context.indent,
             ''.join(
               [
-                ',\n{0:s}  {1:s} = {2:s}'.format(
+                ',\n{0:s}  {1:s}{2:s}'.format(
                   context.indent,
                   i[1].get_text(),
-                  default_value[element.to_text(i[0])]
+                  (
+                    ' = {0:s}'.format(i[2][0].generate_expression(context))
+                  if len(i[2]) else
+                    ''
+                  )
                 )
                 for i in context.fields
               ]
@@ -918,9 +962,18 @@ class AST(element.Element):
                 context.field_name,
                 i[0].generate_deserialize(
                   context,
-                  'self.get(\'{0:s}\', \'{1:s}\')'.format(
+                  'self.get(\'{0:s}\'{1:s})'.format(
                     context.field_name,
-                    default_value_str[element.to_text(i[0])]
+                    (
+                      ', \'{0:s}\''.format(
+                        i[2][0].
+                        generate_expression_serialized(context).
+                        replace('\\', '\\\\').
+                        replace('\'', '\\\'')
+                      )
+                    if len(i[2]) else
+                      ''
+                    )
                   )
                 )
               )
index 2cb01fc..2880930 100644 (file)
--- a/pitree.l
+++ b/pitree.l
@@ -53,8 +53,8 @@
   }"False")                    return y_tab.LITERAL_BOOL
   (?E{
     ast.AST.LiteralBool,
-    value = False
-  }"False")                    return y_tab.LITERAL_BOOL
+    value = True
+  }"True")                     return y_tab.LITERAL_BOOL
   (?E{
     ast.AST.LiteralRef
   }"None")                     return y_tab.LITERAL_REF
diff --git a/skel/element.py b/skel/element.py
new file mode 100644 (file)
index 0000000..e8732cc
--- /dev/null
@@ -0,0 +1,172 @@
+# Copyright (C) 2019 Nick Downing <nick@ndcode.org>
+# SPDX-License-Identifier: GPL-2.0-only
+#
+# This program is free software; you can redistribute it and/or modify it under
+# the terms of the GNU General Public License as published by the Free Software
+# Foundation; version 2.
+#
+# This program is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+# FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
+# details.
+#
+# You should have received a copy of the GNU General Public License along with
+# this program; if not, write to the Free Software Foundation, Inc., 51
+# Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA
+
+import xml.etree.ElementTree
+
+class Element(xml.etree.ElementTree._Element_Py):
+  def __init__(self, tag = 'Element', attrib = {}, text = '', children = []):
+    xml.etree.ElementTree._Element_Py.__init__(self, tag, attrib)
+    self.ref = -1
+    self.seen = False
+    set_text(self, 0, text)
+    self[:] = children
+  def serialize(self, ref_list):
+    for i in self:
+      # parented, enforce that child can only be parented at most once
+      # (although there can be unlimited numbers of numeric refs to it)
+      assert not i.seen
+      i.seen = True
+      if i.ref == -1:
+        i.serialize(ref_list)
+  def deserialize(self, ref_list):
+    for i in self:
+      i.deserialize(ref_list)
+  def copy(self, factory = None):
+    result = (Element if factory is None else factory)(self.tag, self.attrib)
+    result.text = self.text
+    result.tail = self.tail
+    result[:] = [i.copy() for i in self]
+    return result
+
+bool_to_str = ['false', 'true']
+def serialize_bool(value):
+  return bool_to_str[int(value)]
+
+str_to_bool = {'false': False, 'true': True}
+def deserialize_bool(text):
+  assert text is not None
+  return str_to_bool[text]
+
+def serialize_int(value):
+  return str(value)
+
+def deserialize_int(text):
+  assert text is not None
+  return int(text)
+
+def serialize_ref(value, ref_list):
+  if value is None:
+    ref = -1
+  else:
+    ref = value.ref
+    if ref == -1:
+      ref = len(ref_list)
+      ref_list.append(value)
+      value.ref = ref
+      value.set('ref', str(ref))
+      # this doesn't set the seen flag, so it will be parented by the
+      # root, unless it is already parented or gets parented later on
+      if not value.seen:
+        value.serialize(ref_list)
+  return str(ref)
+
+def deserialize_ref(text, ref_list):
+  assert text is not None
+  ref = int(text)
+  return None if ref < 0 else ref_list[ref]
+
+def serialize_str(value):
+  return value
+
+def deserialize_str(text):
+  assert text is not None
+  return text
+
+def serialize(value, fout, encoding = 'unicode'):
+  ref_list = []
+  serialize_ref(value, ref_list)
+  parents = [i for i in ref_list if not i.seen]
+  root = Element('root', children = parents)
+  for i in range(len(root)):
+    set_text(root, i, '\n  ')
+  set_text(root, len(root), '\n')
+  root.tail = '\n'
+  xml.etree.ElementTree.ElementTree(root).write(fout, encoding)
+  for i in root:
+    i.tail = None
+  for i in ref_list:
+    i.ref = -1
+    del i.attrib['ref']
+  i = 0
+  while i < len(parents):
+    for j in parents[i]:
+      j.seen = False
+      parents.append(j)
+    i += 1
+
+def deserialize(fin, factory = Element, encoding = 'unicode'):
+  root = xml.etree.ElementTree.parse(
+    fin,
+    xml.etree.ElementTree.XMLParser(
+      target = xml.etree.ElementTree.TreeBuilder(factory),
+      encoding = encoding
+    )
+  ).getroot()
+  assert root.tag == 'root'
+  for i in root:
+    i.tail = None
+  i = 0
+  parents = root[:]
+  ref_list = []
+  while i < len(parents):
+    j = parents[i]
+    if 'ref' in j.attrib:
+      ref = int(j.attrib['ref'])
+      del j.attrib['ref']
+      if len(ref_list) < ref + 1:
+        ref_list.extend([None] * (ref + 1 - len(ref_list)))
+      ref_list[ref] = j
+    parents.extend(j[:])
+    i += 1
+  for i in root:
+    i.deserialize(ref_list)
+  return ref_list[0]
+
+# compatibility scheme to access arbitrary xml.etree.ElementTree.Element-like
+# objects (not just Element defined above) using a more consistent interface:
+def get_text(root, i):
+  if i < 0:
+    i += len(root) + 1
+  text = root.text if i == 0 else root[i - 1].tail
+  return '' if text is None else text
+
+def set_text(root, i, text):
+  if i < 0:
+    i += len(root) + 1
+  if len(text) == 0:
+    text = None
+  if i == 0:
+    root.text = text
+  else:
+    root[i - 1].tail = text
+
+def to_text(root):
+  return ''.join(
+    [
+      j
+      for i in range(len(root))
+      for j in [get_text(root, i), to_text(root[i])]
+    ] +
+    [get_text(root, len(root))]
+  )
+
+def concatenate(children, factory = Element, *args, **kwargs):
+  root = factory(*args, **kwargs)
+  for child in children:
+    i = len(root)
+    set_text(root, i, get_text(root, i) + get_text(child, 0))
+    root[i:] = child[:]
+  return root