First cut at pitree.py converter, hacky conversion of generate_ast.py to ast.py
authorNick Downing <nick@ndcode.org>
Sat, 26 Jan 2019 13:03:10 +0000 (00:03 +1100)
committerNick Downing <nick@ndcode.org>
Sat, 26 Jan 2019 13:03:42 +0000 (00:03 +1100)
ast.py
generate_ast.py
generate_c.py [new file with mode: 0644]
generate_py.py [new file with mode: 0644]
n.sh
pitree.py [new file with mode: 0755]

diff --git a/ast.py b/ast.py
index 69c0b7e..fc27870 100644 (file)
--- a/ast.py
+++ b/ast.py
@@ -1,7 +1,77 @@
 import element
+import sys
+
+default_value = {
+  'bool': 'False',
+  'int': '-1',
+  'ref': 'None',
+  'str': '\'\'',
+  'list(bool)': '[]',
+  'list(int)': '[]',
+  'list(ref)': '[]',
+  'list(str)': '[]',
+  'set(bool)': 'set()',
+  'set(int)': 'set()',
+  'set(ref)': 'set()',
+  'set(str)': 'set()'
+}
+default_value_str = {
+  'bool': 'false',
+  'int': '-1',
+  'ref': '-1',
+  'str': ''
+}
+
+class Context:
+  def __init__(
+    self,
+    fout = sys.stdout,
+    package_name = 't_ast.',
+    indent = '',
+    stack = [],
+    classes = [],
+    base_classes = None,
+    fields = []
+  ):
+    self.fout = fout
+    self.package_name = package_name
+    self.indent = indent
+    self.stack = stack
+    self.classes = classes
+    self.base_classes = [{'element.Element': []}] if base_classes is None else base_classes # params
+    self.fields = fields
 
 class AST(element.Element):
   # internal classes:
+  class ClassOrFieldDef(element.Element):
+    # GENERATE ELEMENT() BEGIN
+    def __init__(
+      self,
+      tag = 'AST_ClassOrFieldDef',
+      attrib = {},
+      text = '',
+      children = []
+    ):
+      element.Element.__init__(
+        self,
+        tag,
+        attrib,
+        text,
+        children
+      )
+    def copy(self, factory = None):
+      result = element.Element.copy(
+        self,
+        ClassOrFieldDef if factory is None else factory
+      )
+      return result
+    def __repr__(self):
+      params = []
+      self.repr_serialize(params)
+      return 'ast.AST.ClassOrFieldDef({0:s})'.format(', '.join(params))
+    # GENERATE END
+    def generate_class_or_field_def(self, context):
+      raise NotImplementedError
   class Expression(element.Element):
     # GENERATE ELEMENT() BEGIN
     def __init__(
@@ -193,6 +263,8 @@ class AST(element.Element):
       self.repr_serialize(params)
       return 'ast.AST.Identifier({0:s})'.format(', '.join(params))
     # GENERATE END
+    def get_text(self):
+      return element.get_text(self, 0)
   class LiteralBool(Expression):
     # GENERATE ELEMENT(bool value) BEGIN
     def __init__(
@@ -728,6 +800,392 @@ class AST(element.Element):
         self.repr_serialize(params)
         return 'ast.AST.Section2.ClassDef({0:s})'.format(', '.join(params))
       # GENERATE END
+      def generate_class_or_field_def(self, context):
+        class_name = self[0].get_text()
+        if len(self[1]):
+          base_class = '.'.join([i.get_text() for i in self[1][0]])
+        else:
+          base_class = 'element.Element'
+        for i in range(len(context.base_classes) - 1, -1, -1):
+          if base_class in context.base_classes[i]:
+            full_base_class = '.'.join(context.stack[:i] + [base_class])
+            context.base_classes[-1][class_name] = \
+              context.base_classes[i][base_class].copy()
+            break
+        else:
+          assert False
+        context.fout.write(
+          '{0:s}class {1:s}({2:s}):\n'.format(
+            context.indent,
+            class_name,
+            base_class
+          )
+        )
+        indent_save = context.indent
+        context.indent += '  '
+        context.stack.append(class_name)
+        context.classes.append('.'.join(context.stack))
+        context.base_classes.append({})
+        fields_save = context.fields
+        context.fields = []
+        for i in self[2]:
+          i.generate_class_or_field_def(context)
+        i = len(context.base_classes[-2][class_name])
+        context.base_classes[-2][class_name].extend(context.fields)
+
+        context.fout.write(
+          '''{0:s}def __init__(
+{1:s}  self,
+{2:s}  tag = '{3:s}',
+{4:s}  attrib = {{}},
+{5:s}  text = '',
+{6:s}  children = []{7:s}
+{8:s}):
+{9:s}  {10:s}.__init__(
+{11:s}    self,
+{12:s}    tag,
+{13:s}    attrib,
+{14:s}    text,
+{15:s}    children{16:s}
+{17:s}  )
+'''.format(
+            context.indent,
+            context.indent,
+            context.indent,
+            '_'.join(context.stack),
+            context.indent,
+            context.indent,
+            context.indent,
+            ''.join(
+              [
+                ',\n{0:s}  {1:s} = {2:s}'.format(
+                  context.indent,
+                  name,
+                  default_value[type]
+                )
+                for type, name in context.base_classes[-2][class_name]
+              ]
+            ),
+            context.indent,
+            context.indent,
+            full_base_class,
+            context.indent,
+            context.indent,
+            context.indent,
+            context.indent,
+            context.indent,
+            ''.join(
+              [
+                ',\n{0:s}    {1:s}'.format(
+                  context.indent,
+                  name
+                )
+                for type, name in context.base_classes[-2][class_name][:i]
+              ]
+            ),
+            context.indent
+          )
+        )
+        for type, name in context.fields:
+          if type == 'ref' or type == 'list(ref)' or type == 'set(ref)' or type == 'str':
+            context.fout.write(
+              '''{0:s}  self.{1:s} = {2:s}
+'''.format(context.indent, name, name)
+            )
+          elif type[:5] == 'list(' and type[-1:] == ')':
+            subtype = type[5:-1]
+            context.fout.write(
+              '''{0:s}  self.{1:s} = (
+{2:s}    [element.deserialize_{3:s}(i) for i in {4:s}.split()]
+{5:s}  if isinstance({6:s}, str) else
+{7:s}    {8:s}
+{9:s}  )
+'''.format(
+                context.indent,
+                name,
+                context.indent,
+                subtype,
+                name,
+                context.indent,
+                name,
+                context.indent,
+                name,
+                context.indent
+              )
+            )
+          elif type[:4] == 'set(' and type[-1:] == ')':
+            subtype = type[4:-1]
+            context.fout.write(
+              '''{0:s}  self.{1:s} = (
+{2:s}    set([element.deserialize_{3:s}(i) for i in {4:s}.split()])
+{5:s}  if isinstance({6:s}, str) else
+{7:s}    {8:s}
+{9:s}  )
+'''.format(
+                context.indent,
+                name,
+                context.indent,
+                subtype,
+                name,
+                context.indent,
+                name,
+                context.indent,
+                name,
+                context.indent
+              )
+            )
+          else:
+            context.fout.write(
+              '''{0:s}  self.{1:s} = (
+{2:s}    element.deserialize_{3:s}({4:s})
+{5:s}  if isinstance({6:s}, str) else
+{7:s}    {8:s}
+{9:s}  )
+'''.format(
+                context.indent,
+                name,
+                context.indent,
+                type,
+                name,
+                context.indent,
+                name,
+                context.indent,
+                name,
+                context.indent
+              )
+            )
+        if len(context.fields):
+          context.fout.write(
+            '''{0:s}def serialize(self, ref_list):
+{1:s}  {2:s}.serialize(self, ref_list)
+'''.format(context.indent, context.indent, full_base_class)
+          )
+          for type, name in context.fields:
+            if type[:5] == 'list(' and type[-1:] == ')':
+              subtype = type[5:-1]
+              context.fout.write(
+                '''{0:s}  self.set(
+{1:s}    '{2:s}',
+{3:s}    ' '.join([element.serialize_{4:s}(i{5:s}) for i in self.{6:s}])
+{7:s}  )
+'''.format(
+                  context.indent,
+                  context.indent,
+                  name,
+                  context.indent,
+                  subtype,
+                  ', ref_list' if subtype == 'ref' else '',
+                  name,
+                  context.indent
+                )
+              )
+            elif type[:4] == 'set(' and type[-1:] == ')':
+              subtype = type[4:-1]
+              context.fout.write(
+                '''{0:s}  self.set(
+{1:s}    '{2:s}',
+{3:s}    ' '.join([element.serialize_{4:s}(i{5:s}) for i in sorted(self.{6:s})])
+{7:s}  )
+'''.format(
+                  context.indent,
+                  context.indent,
+                  name,
+                  context.indent,
+                  subtype,
+                  ', ref_list' if subtype == 'ref' else '',
+                  name,
+                  context.indent
+                )
+              )
+            else:
+              context.fout.write(
+                '''{0:s}  self.set('{1:s}', element.serialize_{2:s}(self.{3:s}{4:s}))
+'''.format(
+                  context.indent,
+                  name,
+                  type,
+                  name,
+                  ', ref_list' if type == 'ref' else ''
+                )
+              )
+          context.fout.write(
+            '''{0:s}def deserialize(self, ref_list):
+{1:s}  {2:s}.deserialize(self, ref_list)
+'''.format(context.indent, context.indent, full_base_class)
+          )
+          for type, name in context.fields:
+            if type[:5] == 'list(' and type[-1:] == ')':
+              subtype = type[5:-1]
+              context.fout.write(
+                '''{0:s}  self.{1:s} = [
+{2:s}    element.deserialize_{3:s}(i{4:s})
+{5:s}    for i in self.get('{6:s}', '').split()
+{7:s}  ]
+'''.format(
+                  context.indent,
+                  name,
+                  context.indent,
+                  subtype,
+                  ', ref_list' if subtype == 'ref' else '',
+                  context.indent,
+                  name,
+                  context.indent
+                )
+              )
+            elif type[:4] == 'set(' and type[-1:] == ')':
+              subtype = type[4:-1]
+              context.fout.write(
+                '''{0:s}  self.{1:s} = set(
+{2:s}    [
+{3:s}      element.deserialize_{4:s}(i{5:s})
+{6:s}      for i in self.get('{7:s}', '').split()
+{8:s}    ]
+{9:s}  )
+'''.format(
+                  context.indent,
+                  name,
+                  context.indent,
+                  context.indent,
+                  subtype,
+                  ', ref_list' if subtype == 'ref' else '',
+                  context.indent,
+                  name,
+                  context.indent,
+                  context.indent
+                )
+              )
+            else: 
+              context.fout.write(
+                '''{0:s}  self.{1:s} = element.deserialize_{2:s}(self.get('{3:s}', '{4:s}'){5:s})
+'''.format(
+                  context.indent,
+                  name,
+                  type,
+                  name,
+                  default_value_str[type],
+                  ', ref_list' if type == 'ref' else ''
+                )
+              )
+        context.fout.write(
+          '''{0:s}def copy(self, factory = None):
+{1:s}  result = {2:s}.copy(
+{3:s}    self,
+{4:s}    {5:s} if factory is None else factory
+{6:s}  ){7:s}
+{8:s}  return result
+'''.format(
+            context.indent,
+            context.indent,
+            full_base_class,
+            context.indent,
+            context.indent,
+            class_name,
+            context.indent,
+            ''.join(
+              [
+                '\n{0:s}  result.{1:s} = self.{2:s}'.format(
+                  context.indent,
+                  name,
+                  name
+                )
+                for _, name in context.fields
+              ]
+            ),
+            context.indent
+          )
+        )
+        if len(context.fields):
+          context.fout.write(
+            '''{0:s}def repr_serialize(self, params):
+{1:s}  {2:s}.repr_serialize(self, params)
+'''.format(
+              context.indent,
+              context.indent,
+              full_base_class
+            )
+          )
+          for type, name in context.fields:
+            if type[:5] == 'list(' and type[-1:] == ')':
+              subtype = type[5:-1]
+              context.fout.write(
+                '''{0:s}  if len(self.{1:s}):
+{2:s}    params.append(
+{3:s}      '{4:s} = [{{0:s}}]'.format(
+{5:s}        ', '.join([repr(i) for i in self.{6:s}])
+{7:s}      )
+{8:s}    )
+'''.format(
+                  context.indent,
+                  name,
+                  context.indent,
+                  context.indent,
+                  name,
+                  context.indent,
+                  name,
+                  context.indent,
+                  context.indent
+                )
+              )
+            elif type[:4] == 'set(' and type[-1:] == ')':
+              subtype = type[4:-1]
+              context.fout.write(
+                '''{0:s}  if len(self.{1:s}):
+{2:s}    params.append(
+{3:s}      '{4:s} = set([{{0:s}}])'.format(
+{5:s}        ', '.join([repr(i) for i in sorted(self.{6:s})])
+{7:s}      )
+{8:s}    )
+'''.format(
+                  context.indent,
+                  name,
+                  context.indent,
+                  context.indent,
+                  name,
+                  context.indent,
+                  name,
+                  context.indent,
+                  context.indent
+                )
+              )
+            else:
+              context.fout.write(
+                '''{0:s}  if self.{1:s} != {2:s}:
+{3:s}    params.append(
+{4:s}      '{5:s} = {{0:s}}'.format(repr(self.{6:s}))
+{7:s}    )
+'''.format(
+                  context.indent,
+                  name,
+                  default_value[type],
+                  context.indent,
+                  context.indent,
+                  name,
+                  name,
+                  context.indent
+                )
+              )
+        context.fout.write(
+          '''{0:s}def __repr__(self):
+{1:s}  params = []
+{2:s}  self.repr_serialize(params)
+{3:s}  return '{4:s}{5:s}({{0:s}})'.format(', '.join(params))
+'''.format(
+            context.indent,
+            context.indent,
+            context.indent,
+            context.indent,
+            context.package_name,
+            '.'.join(context.stack),
+          )
+        )
+
+        context.indent = indent_save
+        del context.stack[-1]
+        for temp_base_class, temp_fields in context.base_classes.pop().items():
+          context.base_classes[-1][
+            '{0:s}.{1:s}'.format(class_name, temp_base_class)
+          ] = temp_fields
+        context.fields = fields_save
     class FieldDef(element.Element):
       # GENERATE ELEMENT() BEGIN
       def __init__(
@@ -755,6 +1213,8 @@ class AST(element.Element):
         self.repr_serialize(params)
         return 'ast.AST.Section2.FieldDef({0:s})'.format(', '.join(params))
       # GENERATE END
+      def generate_class_or_field_def(self, context):
+        context.fields.append((element.to_text(self[0]), self[1].get_text()))
     # GENERATE ELEMENT() BEGIN
     def __init__(
       self,
@@ -811,6 +1271,7 @@ class AST(element.Element):
 # GENERATE FACTORY(element.Element) BEGIN
 tag_to_class = {
   'AST': AST,
+  'AST_ClassOrFieldDef': AST.ClassOrFieldDef,
   'AST_Expression': AST.Expression,
   'AST_Type': AST.Type,
   'AST_BaseClass': AST.BaseClass,
index f9a633b..9c5f0e8 100755 (executable)
@@ -94,21 +94,20 @@ while len(line):
       base_classes[-2][class_name].extend(fields)
 
       sys.stdout.write(
-        '''{0:s}# GENERATE ELEMENT({1:s}) BEGIN
-{2:s}def __init__(
-{3:s}  self,
-{4:s}  tag = '{5:s}',
-{6:s}  attrib = {{}},
-{7:s}  text = '',
-{8:s}  children = []{9:s}
-{10:s}):
-{11:s}  {12:s}.__init__(
-{13:s}    self,
-{14:s}    tag,
-{15:s}    attrib,
-{16:s}    text,
-{17:s}    children{18:s}
-{19:s}  )
+        '''{0:s}def __init__(
+{1:s}  self,
+{2:s}  tag = '{3:s}',
+{4:s}  attrib = {{}},
+{5:s}  text = '',
+{6:s}  children = []{7:s}
+{8:s}):
+{9:s}  {10:s}.__init__(
+{11:s}    self,
+{12:s}    tag,
+{13:s}    attrib,
+{14:s}    text,
+{15:s}    children{16:s}
+{17:s}  )
 '''.format(
           indent,
           params,
diff --git a/generate_c.py b/generate_c.py
new file mode 100644 (file)
index 0000000..a7596d8
--- /dev/null
@@ -0,0 +1,2 @@
+def generate_c(_ast, home_dir, skel_file, out_file):
+  assert False
diff --git a/generate_py.py b/generate_py.py
new file mode 100644 (file)
index 0000000..c8bfab1
--- /dev/null
@@ -0,0 +1,90 @@
+import ast
+import element
+import os
+
+def text_to_python(text, indent):
+  #text_strip = text.strip()
+  #if text_strip[:1] == '{' and text_strip[-1:] == '}':
+  #  text = text_strip[1:-1]
+  lines = text.rstrip().split('\n')
+  while len(lines) and len(lines[0].lstrip()) == 0:
+    lines = lines[1:]
+  while len(lines) and len(lines[-1].lstrip()) == 0:
+    lines = lines[:-1]
+  if len(lines) == 0:
+    return '' #{0:s}pass\n'.format(indent)
+  for j in range(len(lines[0])):
+    if lines[0][j] != '\t' and lines[0][j] != ' ':
+      break
+  else:
+    print(text)
+    assert False
+  #print('---')
+  #print(text)
+  prefix = lines[0][:j]
+  for j in range(len(lines)):
+    if len(lines[j]) == 0:
+      lines[j] = '\n'
+    else:
+      assert lines[j][:len(prefix)] == prefix
+      lines[j] = '{0:s}{1:s}\n'.format(indent, lines[j][len(prefix):])
+  return ''.join(lines)
+
+def generate_py(_ast, home_dir, skel_file, out_file):
+  if skel_file is None:
+    skel_file = os.path.join(home_dir, 'skel/skel_py.py')
+  if out_file is None:
+    out_file = 't_ast.py'
+  with open(skel_file, 'r') as fin:
+    with open(out_file, 'w+') as fout:
+      line = fin.readline()
+      while len(line):
+        if line == '# GENERATE SECTION1\n':
+          fout.write(
+            '''# GENERATE SECTION1 BEGIN
+{0:s}# GENERATE END
+'''.format(
+              ''.join(
+                [
+                  text_to_python(element.get_text(i[0], 0), '')
+                  for i in _ast[0]
+                ]
+              )
+            )
+          )
+        elif line == '# GENERATE SECTION2\n':
+          fout.write('# GENERATE SECTION2 BEGIN\n')
+          #package_name = os.path.basename(out_file)
+          #if package_name[-3:] == '.py':
+          #  package_name = package_name[:-3]
+          context = ast.Context(fout, 'ast.') #, package_name + '.')
+          for i in _ast[1]:
+            assert isinstance(i, ast.AST.Section2.ClassDef)
+            i.generate_class_or_field_def(context)
+          fout.write(
+            '''tag_to_class = {{{0:s}
+}}
+# GENERATE END
+'''.format(
+              ','.join(
+                [
+                  '\n  \'{0:s}\': {1:s}'.format(
+                    i.replace('.', '_'),
+                    i
+                  )
+                  for i in context.classes
+                ]
+              )
+            )
+          )
+        elif line == '# GENERATE SECTION3\n':
+          fout.write(
+            '''# GENERATE SECTION3 BEGIN
+{0:s}# GENERATE END
+'''.format(
+              '' if len(_ast) < 3 else element.get_text(_ast[2], 0)
+            )
+          )
+        else:
+          fout.write(line)
+        line = fin.readline()
diff --git a/n.sh b/n.sh
index 3576b91..2676045 100755 (executable)
--- a/n.sh
+++ b/n.sh
@@ -15,3 +15,11 @@ then
   ./expected.sh ../piyacc.git/ast.py tests/piyacc.t >out/ast_piyacc.py.ok
   ./expected.sh ../piyacc.git/tests_ast/ast.py tests/cal.t >out/ast_cal.py.ok
 fi
+./pitree.py --python -o out/ast_ansi_c.py tests/ansi_c.t
+diff -q out/ast_ansi_c.py.ok out/ast_ansi_c.py
+./pitree.py --python -o out/ast_pilex.py tests/pilex.t
+diff -q out/ast_pilex.py.ok out/ast_pilex.py
+./pitree.py --python -o out/ast_piyacc.py tests/piyacc.t
+diff -q out/ast_piyacc.py.ok out/ast_piyacc.py
+./pitree.py --python -o out/ast_cal.py tests/cal.t
+diff -q out/ast_cal.py.ok out/ast_cal.py
diff --git a/pitree.py b/pitree.py
new file mode 100755 (executable)
index 0000000..b7fe14f
--- /dev/null
+++ b/pitree.py
@@ -0,0 +1,63 @@
+#!/usr/bin/env python3
+
+import ast
+import element
+import generate_c
+import generate_py
+import getopt
+import os
+import sys
+
+home_dir = os.path.dirname(sys.argv[0])
+try:
+  opts, args = getopt.getopt(
+    sys.argv[1:],
+    'o:pS:',
+    ['outfile=', 'python', 'skel=']
+  )
+except getopt.GetoptError as err:
+  sys.stderr.write('{0:s}\n'.format(str(err)))
+  sys.exit(1)
+
+out_file = None
+python = False
+skel_file = None
+for opt, arg in opts:
+  if opt == '-e' or opt == '--element':
+    _element = True
+  elif opt == '-o' or opt == '--outfile':
+    out_file = arg
+  elif opt == '-p' or opt == '--python':
+    python = True
+  elif opt == '-S' or opt == '--skel':
+    skel_file = arg
+  else:
+    assert False
+if len(args) < 1:
+  sys.stdout.write(
+    'usage: {0:s} [options] defs.t\n'.format(
+      sys.argv[0]
+    )
+  )
+  sys.exit(1)
+in_file = args[0]
+
+with open(in_file) as fin:
+  if in_file[-4:] == '.xml':
+    _ast = element.deserialize(fin, ast.factory)
+  else:
+    import lex_yy
+    import y_tab
+    lex_yy.yyin = fin
+    _ast = y_tab.yyparse(ast.AST)
+#element.serialize(_ast, 'a.xml', 'utf-8')
+#_ast = element.deserialize('a.xml', ast.factory, 'utf-8')
+#_ast.post_process()
+#element.serialize(_ast, 'b.xml', 'utf-8')
+#_ast = element.deserialize('b.xml', ast.factory, 'utf-8')
+(generate_py.generate_py if python else generate_c.generate_c)(
+  _ast,
+  home_dir,
+  skel_file,
+  out_file
+)