Make syntax error call yyerror() rather than raising a Python exception
[piyacc.git] / element.py
index e8732cc..54cf4db 100644 (file)
 # this program; if not, write to the Free Software Foundation, Inc., 51
 # Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA
 
+'''
+Module which provides the base class for πtree Abstract Syntax Trees (ASTs) and
+provides serialization/deserialization. The code generated by πtree will define
+specific ASTs for your project, as derived classes of that defined here.
+
+Generally you shouldn't check this file in to your project. Instead, at around
+the same time you run πtree to generate your ``t_def.py`` or similar file, also
+install this module by using the ``pitree --install-element`` command.
+'''
+
 import xml.etree.ElementTree
 
-class Element(xml.etree.ElementTree._Element_Py):
-  def __init__(self, tag = 'Element', attrib = {}, text = '', children = []):
-    xml.etree.ElementTree._Element_Py.__init__(self, tag, attrib)
-    self.ref = -1
-    self.seen = False
-    set_text(self, 0, text)
-    self[:] = children
+class Element:
+  '''Class which holds a single node of a πtree AST.'''
+
+  tag = 'Element'
+  def __init__(self, text = None, children = None):
+    self.visited = None
+    '''
+During serialization, this indicates nodes that have been encountered before,
+so that DAGs or circular constructs can be serialized and reconstructed later.
+It contains either ``None`` or ``(element, ref, seen)``.
+
+Note that it is not allowed to have multiple use of direct children (direct
+children are nodes in the ``self.children`` list). That is, a node may have at
+most one direct parent. Any DAGs or circular constructs must be via fields.
+Fields are added by creating a derived class and overriding the serialization
+methods appropriately (or using the πtree generator, which does this for you). 
+'''
+    self.children = [] if children is None else children
+    '''
+Contains the direct child nodes, also of class ``Element`` or a derived class
+of ``Element``.
+
+* Often the number of children and their meanings will be known in advance. For
+  example, a ``BinaryExpression`` class might have left and right children,
+  accessed via ``children[0]`` and ``children[1]``.
+* Sometimes the number of children can be arbitrary. For example, a
+  ``Function`` class might contain an arbitrary number of statements as direct
+  children.
+
+It is expected that the types of the children will be known statically. In the
+``BinaryExpression`` example, the children would be of class ``Expression`` or
+a derived class of ``Expression``. In the ``Function`` example, the children
+would be of class ``Statement`` or a derived class of ``Statement``. When the
+children are implicitly a tuple the children can be typed independently of one
+another. When they are implicitly a list they should ideally have uniform type.
+
+If no ``children`` argument is passed to the constructor, it will default to
+``None`` and then be internally translated to a freshly constructed empty list
+``[]``. This ensures that it is mutable and not shared with other instances.
+'''
+    assert isinstance(self.children, list)
+    self.text = (
+      ['' for i in range(len(self.children) + 1)]
+    if text is None else
+      text
+    )
+    '''
+Contains strings of text to interpolate between the child nodes. Must have
+length ``len(self.children) + 1``. So, for example, if there are two child
+nodes the contents of the node can be conceptualized as ``text[0] children[0]
+text[1] children[1] text[2]``.
+
+* For nodes with children, the text is often not very significant and may be
+  set to all empty strings. For example, a ``BinaryExpression`` node could have
+  ``self.text == ['', '', '']`` and only ``children[0]`` and ``children[1]``
+  significant. On the other hand, it could store the operator as a string, such
+  as ``'+'`` or ``'-'``, in the ``text[1]`` field, if needed. This would print
+  in a natural way. A combination of approaches is also possible, so that text
+  isn't significant, but would be filled in before pretty-printing the tree.
+* For nodes with no children, often the ``text[0]`` value holds the content of
+  the node. For example, an ``Identifier`` node would usually store its
+  identifier string in ``text[0]``. Again, this would print in a natural way.
+
+If no ``text`` argument is passed to the constructor, it will default to
+``None`` and then be internally translated to a freshly constructed list of
+the correct number of empty strings. This ensures that it is the right length
+and also that it is mutable and not shared with other instances.
+'''
+    assert isinstance(self.text, list)
+    assert len(self.text) == len(self.children) + 1
+
   def serialize(self, ref_list):
-    for i in self:
-      # parented, enforce that child can only be parented at most once
-      # (although there can be unlimited numbers of numeric refs to it)
-      assert not i.seen
-      i.seen = True
-      if i.ref == -1:
-        i.serialize(ref_list)
-  def deserialize(self, ref_list):
-    for i in self:
-      i.deserialize(ref_list)
-  def copy(self, factory = None):
-    result = (Element if factory is None else factory)(self.tag, self.attrib)
-    result.text = self.text
-    result.tail = self.tail
-    result[:] = [i.copy() for i in self]
-    return result
-
-bool_to_str = ['false', 'true']
-def serialize_bool(value):
-  return bool_to_str[int(value)]
-
-str_to_bool = {'false': False, 'true': True}
-def deserialize_bool(text):
-  assert text is not None
-  return str_to_bool[text]
-
-def serialize_int(value):
-  return str(value)
-
-def deserialize_int(text):
-  assert text is not None
-  return int(text)
+    '''
+Internal routine that supports serialization. It should not be called directly
+-- use ``element.serialize()`` instead. This method converts ``self`` into an
+``xml.etree.ElementTree`` node, populated with the recursive conversion of its
+children. It will be overridden in a derived class if the class has fields,
+to populate the attributes of the returned ``xml.etree.ElementTree`` node.
+'''
+    element = self.visited[0]
+    if len(self.text[0]):
+      element.text = self.text[0]
+    for i in range(len(self.children)):
+      child = self.children[i]
+      if child.visited is None:
+        child_element = xml.etree.ElementTree.Element(child.tag)
+        child.visited = (child_element, -1, True)
+        child.serialize(ref_list)
+      else:
+        child_element, child_ref, child_seen = child.visited
+        assert not child_seen
+        child.visited = (child_element, child_ref, True)
+      if len(self.text[i + 1]):
+        child_element.tail = self.text[i + 1]
+      element.append(child_element)
+    return element
+  def deserialize(self, element, ref_list):
+    '''
+Internal routine that supports deserialization. It should not be called
+directly -- use ``element.deserialize()`` instead. It will be overridden in a
+derived class if the class has fields, to set the fields from the attributes
+of the ``xml.etree.ElementTree`` node passed in as the ``element`` argument.
+'''
+    pass
 
 def serialize_ref(value, ref_list):
+  '''
+Internal routine to serialize a reference and return a value that can be placed
+in the attribute dictionary of an ``xml.etree.ElementTree`` node. It is meant
+to be called from the ``serialize()`` method of a derived class of ``Element``,
+for serializing fields that are references to an AST node or AST subtree.
+
+This is a special case, since other kinds of values (``int``, ``str`` etc) can
+be serialized by the ``json`` module, whereas references must be recursively
+converted. If the reference has already been serialized its value is returned
+directly, otherwise it will be added to the list ``ref_list`` and serialized.
+The value returned is its position in ``ref_list`` (it behaves like serializing
+an integer from that point on). The ``None`` value is serialized as ``-1``.
+'''
   if value is None:
     ref = -1
   else:
-    ref = value.ref
-    if ref == -1:
+    if value.visited is None:
+      element = xml.etree.ElementTree.Element(value.tag)
       ref = len(ref_list)
       ref_list.append(value)
-      value.ref = ref
-      value.set('ref', str(ref))
-      # this doesn't set the seen flag, so it will be parented by the
-      # root, unless it is already parented or gets parented later on
-      if not value.seen:
-        value.serialize(ref_list)
-  return str(ref)
-
-def deserialize_ref(text, ref_list):
-  assert text is not None
-  ref = int(text)
-  return None if ref < 0 else ref_list[ref]
-
-def serialize_str(value):
-  return value
-
-def deserialize_str(text):
-  assert text is not None
-  return text
+      element.set('ref', str(ref))
+      value.visited = (element, ref, False)
+      value.serialize(ref_list)
+    else:
+      element, ref, seen = value.visited
+      if ref == -1:
+        ref = len(ref_list)
+        ref_list.append(value)
+        element.set('ref', str(ref))
+        value.visited = (element, ref, seen)
+  return ref
+
+def deserialize_ref(value, ref_list):
+  '''
+Internal routine to deserialize a reference and return an object of type
+``Element`` or a derived class of ``Element``. It is meant to be called from
+the ``deserialize()`` method of a derived class of ``Element``, for
+deserializing fields that are references to an AST node or AST subtree. 
+
+The reference has already been processed as an integer and hence the ``value``
+is the position in ``ref_list`` where the referenced object is to be found,
+or ``-1``. Returns that object from ``ref_list``, or ``None`` for ``-1``.
+'''
+  assert value is not None
+  return None if value == -1 else ref_list[value]
 
 def serialize(value, fout, encoding = 'unicode'):
+  '''
+Front end to the serializer. Pass in an AST that is to be serialized and an
+output stream to place the serialized output on. The encoding should be passed
+as either ``'unicode'`` for encoding to standard output or ``'utf-8'`` for
+encoding to a file descriptor, this is a bit hacky and one should refer to the
+code of ``xml.etree.ElementTree`` to see what it really does with this value.
+
+The output stream will look something like::
+
+  <root>
+    <AnObject ref="0"><ANestedObject /></AnObject>
+    <AnotherObject ref="1" />
+  </root>
+
+The object with the attribute ``ref="0"`` corresponds to the ``value``
+parameter passed in here. Any direct children, grandchildren etc will be
+serialized inside it. Objects which are accessed through fields from those
+objects will be serialized separately and be given higher reference numbers.
+Note that those secondary objects can also have direct children, which are
+serialized inside those secondary objects, and so on. The ``<root>...</root>``
+element then ends up being a collection of all objects with no direct parent.
+'''
   ref_list = []
   serialize_ref(value, ref_list)
-  parents = [i for i in ref_list if not i.seen]
-  root = Element('root', children = parents)
-  for i in range(len(root)):
-    set_text(root, i, '\n  ')
-  set_text(root, len(root), '\n')
+  todo = [i for i in ref_list if not i.visited[2]]
+  root = xml.etree.ElementTree.Element('root')
+  root[:] = [i.visited[0] for i in todo]
+  root.text = '\n  '
+  for i in range(len(root) - 1):
+    root[i].tail = '\n  '
+  root[-1].tail = '\n'
   root.tail = '\n'
   xml.etree.ElementTree.ElementTree(root).write(fout, encoding)
-  for i in root:
-    i.tail = None
-  for i in ref_list:
-    i.ref = -1
-    del i.attrib['ref']
   i = 0
-  while i < len(parents):
-    for j in parents[i]:
-      j.seen = False
-      parents.append(j)
+  while i < len(todo):
+    node = todo[i]
+    node.visited = None
+    todo.extend(node.children)
     i += 1
 
 def deserialize(fin, factory = Element, encoding = 'unicode'):
-  root = xml.etree.ElementTree.parse(
-    fin,
-    xml.etree.ElementTree.XMLParser(
-      target = xml.etree.ElementTree.TreeBuilder(factory),
-      encoding = encoding
-    )
-  ).getroot()
+  '''
+Front end to the deserializer. Essentially, reverses the process of
+``element.serialize()``. All the same comments apply to this function also.
+
+The tricky part with deserializing is knowing what kind of object to construct.
+For instance if the XML looks like this, ::
+
+  <root>
+    <AnObject ref="0">some text</AnObject>
+  </root>
+
+we want to find a constructor for a derived class of ``Element`` called
+``AnObject``. This is the role of the ``factory`` function that you pass in to
+``deserialize()``. It takes a tag name such as ``'AnObject'``, followed by the
+arguments to be passed into ``AnObject``'s constructor. Typically the
+``factory`` function will be written like this::
+
+  tag_to_class = {
+    'AnObject': AnObject
+  }
+  def factory(tag, *args, **kwargs):
+    return tag_to_class[tag](*args, **kwargs)
+
+It is also possible to have a more complex factory function, for instance if
+you have defined an AST to use as a mini-language inside another AST, and you
+want to defer object creation to the mini-language's factory for those objects.
+'''
+  root = xml.etree.ElementTree.parse(fin).getroot()
   assert root.tag == 'root'
-  for i in root:
-    i.tail = None
+  children = [factory(i.tag) for i in root]
+  todo = list(zip(root, children))
   i = 0
-  parents = root[:]
   ref_list = []
-  while i < len(parents):
-    j = parents[i]
-    if 'ref' in j.attrib:
-      ref = int(j.attrib['ref'])
-      del j.attrib['ref']
+  while i < len(todo):
+    element, node = todo[i]
+    ref = element.get('ref')
+    if ref is not None:
+      ref = int(ref)
       if len(ref_list) < ref + 1:
         ref_list.extend([None] * (ref + 1 - len(ref_list)))
-      ref_list[ref] = j
-    parents.extend(j[:])
+      ref_list[ref] = node
+    children = [factory(i.tag) for i in element]
+    node.children = children
+    node.text = (
+      ['' if element.text is None else element.text] +
+      ['' if j.tail is None else j.tail for j in element]
+    )
+    todo.extend(zip(element, children))
     i += 1
-  for i in root:
-    i.deserialize(ref_list)
+  for element, node in todo:
+    node.deserialize(element, ref_list)
   return ref_list[0]
 
-# compatibility scheme to access arbitrary xml.etree.ElementTree.Element-like
-# objects (not just Element defined above) using a more consistent interface:
-def get_text(root, i):
-  if i < 0:
-    i += len(root) + 1
-  text = root.text if i == 0 else root[i - 1].tail
-  return '' if text is None else text
-
-def set_text(root, i, text):
-  if i < 0:
-    i += len(root) + 1
-  if len(text) == 0:
-    text = None
-  if i == 0:
-    root.text = text
-  else:
-    root[i - 1].tail = text
-
 def to_text(root):
+  '''
+Convenience function to recursively extract all the text from a subtree.
+
+The result is similar to serializing the subtree (itself, its direct children,
+grandchildren etc) to XML, and then throwing away all XML tags and attributes.
+
+For example, if there are two child nodes, the value returned will be::
+
+  text[0] + to_text(children[0]) + text[1] + to_text(children[1] + text[2]
+'''
+  assert len(root.text) == len(root.children) + 1
   return ''.join(
     [
       j
-      for i in range(len(root))
-      for j in [get_text(root, i), to_text(root[i])]
+      for i in range(len(root.children))
+      for j in [root.text[i], to_text(root[i])]
     ] +
-    [get_text(root, len(root))]
+    [root.text[-1]]
   )
 
 def concatenate(children, factory = Element, *args, **kwargs):
+  '''
+Convenience function to concatenate an arbitrary number of nodes into one.
+
+The nodes are concatenated into a new empty node constructed by the ``factory``
+function that you specify. Only the text and children are taken from the nodes
+being concatenated, the types of the nodes and any data in fields are ignored.
+
+The ``factory`` argument is usually a constructor for an ``Element``-derived
+object, but it can also be any arbitrary function, and any further arguments
+sent after the ``factory`` argument will be sent into the ``factory`` call.
+
+For example, suppose node ``a`` has two children and node ``b`` has one. Then
+the call ``concatenate([a, b])`` is equivalent to::
+
+  Element(
+    children = [a.children[0], a.children[1], b.children[0]],
+    text = [a.text[0], a.text[1], a.text[2] + b.text[0], b.text[1]
+  )
+'''
   root = factory(*args, **kwargs)
   for child in children:
-    i = len(root)
-    set_text(root, i, get_text(root, i) + get_text(child, 0))
-    root[i:] = child[:]
+    assert len(root.text) == len(root.children) + 1
+    assert len(child.text) == len(child.children) + 1
+    root.text[-1] += child.text[0]
+    root.children.extend(child.children)
+    root.text.extend(child.text[1:])
+  assert len(root.text) == len(root.children) + 1
   return root