From 705f6692dde81aab1a99e846eabd60e6ad1f54a2 Mon Sep 17 00:00:00 2001 From: Nick Downing Date: Sun, 27 Jan 2019 00:03:10 +1100 Subject: [PATCH] First cut at pitree.py converter, hacky conversion of generate_ast.py to ast.py --- ast.py | 461 ++++++++++++++++++++++++++++++++++++++++++++++++ generate_ast.py | 29 ++- generate_c.py | 2 + generate_py.py | 90 ++++++++++ n.sh | 8 + pitree.py | 63 +++++++ 6 files changed, 638 insertions(+), 15 deletions(-) create mode 100644 generate_c.py create mode 100644 generate_py.py create mode 100755 pitree.py diff --git a/ast.py b/ast.py index 69c0b7e..fc27870 100644 --- a/ast.py +++ b/ast.py @@ -1,7 +1,77 @@ import element +import sys + +default_value = { + 'bool': 'False', + 'int': '-1', + 'ref': 'None', + 'str': '\'\'', + 'list(bool)': '[]', + 'list(int)': '[]', + 'list(ref)': '[]', + 'list(str)': '[]', + 'set(bool)': 'set()', + 'set(int)': 'set()', + 'set(ref)': 'set()', + 'set(str)': 'set()' +} +default_value_str = { + 'bool': 'false', + 'int': '-1', + 'ref': '-1', + 'str': '' +} + +class Context: + def __init__( + self, + fout = sys.stdout, + package_name = 't_ast.', + indent = '', + stack = [], + classes = [], + base_classes = None, + fields = [] + ): + self.fout = fout + self.package_name = package_name + self.indent = indent + self.stack = stack + self.classes = classes + self.base_classes = [{'element.Element': []}] if base_classes is None else base_classes # params + self.fields = fields class AST(element.Element): # internal classes: + class ClassOrFieldDef(element.Element): + # GENERATE ELEMENT() BEGIN + def __init__( + self, + tag = 'AST_ClassOrFieldDef', + attrib = {}, + text = '', + children = [] + ): + element.Element.__init__( + self, + tag, + attrib, + text, + children + ) + def copy(self, factory = None): + result = element.Element.copy( + self, + ClassOrFieldDef if factory is None else factory + ) + return result + def __repr__(self): + params = [] + self.repr_serialize(params) + return 'ast.AST.ClassOrFieldDef({0:s})'.format(', '.join(params)) + # GENERATE END + def generate_class_or_field_def(self, context): + raise NotImplementedError class Expression(element.Element): # GENERATE ELEMENT() BEGIN def __init__( @@ -193,6 +263,8 @@ class AST(element.Element): self.repr_serialize(params) return 'ast.AST.Identifier({0:s})'.format(', '.join(params)) # GENERATE END + def get_text(self): + return element.get_text(self, 0) class LiteralBool(Expression): # GENERATE ELEMENT(bool value) BEGIN def __init__( @@ -728,6 +800,392 @@ class AST(element.Element): self.repr_serialize(params) return 'ast.AST.Section2.ClassDef({0:s})'.format(', '.join(params)) # GENERATE END + def generate_class_or_field_def(self, context): + class_name = self[0].get_text() + if len(self[1]): + base_class = '.'.join([i.get_text() for i in self[1][0]]) + else: + base_class = 'element.Element' + for i in range(len(context.base_classes) - 1, -1, -1): + if base_class in context.base_classes[i]: + full_base_class = '.'.join(context.stack[:i] + [base_class]) + context.base_classes[-1][class_name] = \ + context.base_classes[i][base_class].copy() + break + else: + assert False + context.fout.write( + '{0:s}class {1:s}({2:s}):\n'.format( + context.indent, + class_name, + base_class + ) + ) + indent_save = context.indent + context.indent += ' ' + context.stack.append(class_name) + context.classes.append('.'.join(context.stack)) + context.base_classes.append({}) + fields_save = context.fields + context.fields = [] + for i in self[2]: + i.generate_class_or_field_def(context) + i = len(context.base_classes[-2][class_name]) + context.base_classes[-2][class_name].extend(context.fields) + + context.fout.write( + '''{0:s}def __init__( +{1:s} self, +{2:s} tag = '{3:s}', +{4:s} attrib = {{}}, +{5:s} text = '', +{6:s} children = []{7:s} +{8:s}): +{9:s} {10:s}.__init__( +{11:s} self, +{12:s} tag, +{13:s} attrib, +{14:s} text, +{15:s} children{16:s} +{17:s} ) +'''.format( + context.indent, + context.indent, + context.indent, + '_'.join(context.stack), + context.indent, + context.indent, + context.indent, + ''.join( + [ + ',\n{0:s} {1:s} = {2:s}'.format( + context.indent, + name, + default_value[type] + ) + for type, name in context.base_classes[-2][class_name] + ] + ), + context.indent, + context.indent, + full_base_class, + context.indent, + context.indent, + context.indent, + context.indent, + context.indent, + ''.join( + [ + ',\n{0:s} {1:s}'.format( + context.indent, + name + ) + for type, name in context.base_classes[-2][class_name][:i] + ] + ), + context.indent + ) + ) + for type, name in context.fields: + if type == 'ref' or type == 'list(ref)' or type == 'set(ref)' or type == 'str': + context.fout.write( + '''{0:s} self.{1:s} = {2:s} +'''.format(context.indent, name, name) + ) + elif type[:5] == 'list(' and type[-1:] == ')': + subtype = type[5:-1] + context.fout.write( + '''{0:s} self.{1:s} = ( +{2:s} [element.deserialize_{3:s}(i) for i in {4:s}.split()] +{5:s} if isinstance({6:s}, str) else +{7:s} {8:s} +{9:s} ) +'''.format( + context.indent, + name, + context.indent, + subtype, + name, + context.indent, + name, + context.indent, + name, + context.indent + ) + ) + elif type[:4] == 'set(' and type[-1:] == ')': + subtype = type[4:-1] + context.fout.write( + '''{0:s} self.{1:s} = ( +{2:s} set([element.deserialize_{3:s}(i) for i in {4:s}.split()]) +{5:s} if isinstance({6:s}, str) else +{7:s} {8:s} +{9:s} ) +'''.format( + context.indent, + name, + context.indent, + subtype, + name, + context.indent, + name, + context.indent, + name, + context.indent + ) + ) + else: + context.fout.write( + '''{0:s} self.{1:s} = ( +{2:s} element.deserialize_{3:s}({4:s}) +{5:s} if isinstance({6:s}, str) else +{7:s} {8:s} +{9:s} ) +'''.format( + context.indent, + name, + context.indent, + type, + name, + context.indent, + name, + context.indent, + name, + context.indent + ) + ) + if len(context.fields): + context.fout.write( + '''{0:s}def serialize(self, ref_list): +{1:s} {2:s}.serialize(self, ref_list) +'''.format(context.indent, context.indent, full_base_class) + ) + for type, name in context.fields: + if type[:5] == 'list(' and type[-1:] == ')': + subtype = type[5:-1] + context.fout.write( + '''{0:s} self.set( +{1:s} '{2:s}', +{3:s} ' '.join([element.serialize_{4:s}(i{5:s}) for i in self.{6:s}]) +{7:s} ) +'''.format( + context.indent, + context.indent, + name, + context.indent, + subtype, + ', ref_list' if subtype == 'ref' else '', + name, + context.indent + ) + ) + elif type[:4] == 'set(' and type[-1:] == ')': + subtype = type[4:-1] + context.fout.write( + '''{0:s} self.set( +{1:s} '{2:s}', +{3:s} ' '.join([element.serialize_{4:s}(i{5:s}) for i in sorted(self.{6:s})]) +{7:s} ) +'''.format( + context.indent, + context.indent, + name, + context.indent, + subtype, + ', ref_list' if subtype == 'ref' else '', + name, + context.indent + ) + ) + else: + context.fout.write( + '''{0:s} self.set('{1:s}', element.serialize_{2:s}(self.{3:s}{4:s})) +'''.format( + context.indent, + name, + type, + name, + ', ref_list' if type == 'ref' else '' + ) + ) + context.fout.write( + '''{0:s}def deserialize(self, ref_list): +{1:s} {2:s}.deserialize(self, ref_list) +'''.format(context.indent, context.indent, full_base_class) + ) + for type, name in context.fields: + if type[:5] == 'list(' and type[-1:] == ')': + subtype = type[5:-1] + context.fout.write( + '''{0:s} self.{1:s} = [ +{2:s} element.deserialize_{3:s}(i{4:s}) +{5:s} for i in self.get('{6:s}', '').split() +{7:s} ] +'''.format( + context.indent, + name, + context.indent, + subtype, + ', ref_list' if subtype == 'ref' else '', + context.indent, + name, + context.indent + ) + ) + elif type[:4] == 'set(' and type[-1:] == ')': + subtype = type[4:-1] + context.fout.write( + '''{0:s} self.{1:s} = set( +{2:s} [ +{3:s} element.deserialize_{4:s}(i{5:s}) +{6:s} for i in self.get('{7:s}', '').split() +{8:s} ] +{9:s} ) +'''.format( + context.indent, + name, + context.indent, + context.indent, + subtype, + ', ref_list' if subtype == 'ref' else '', + context.indent, + name, + context.indent, + context.indent + ) + ) + else: + context.fout.write( + '''{0:s} self.{1:s} = element.deserialize_{2:s}(self.get('{3:s}', '{4:s}'){5:s}) +'''.format( + context.indent, + name, + type, + name, + default_value_str[type], + ', ref_list' if type == 'ref' else '' + ) + ) + context.fout.write( + '''{0:s}def copy(self, factory = None): +{1:s} result = {2:s}.copy( +{3:s} self, +{4:s} {5:s} if factory is None else factory +{6:s} ){7:s} +{8:s} return result +'''.format( + context.indent, + context.indent, + full_base_class, + context.indent, + context.indent, + class_name, + context.indent, + ''.join( + [ + '\n{0:s} result.{1:s} = self.{2:s}'.format( + context.indent, + name, + name + ) + for _, name in context.fields + ] + ), + context.indent + ) + ) + if len(context.fields): + context.fout.write( + '''{0:s}def repr_serialize(self, params): +{1:s} {2:s}.repr_serialize(self, params) +'''.format( + context.indent, + context.indent, + full_base_class + ) + ) + for type, name in context.fields: + if type[:5] == 'list(' and type[-1:] == ')': + subtype = type[5:-1] + context.fout.write( + '''{0:s} if len(self.{1:s}): +{2:s} params.append( +{3:s} '{4:s} = [{{0:s}}]'.format( +{5:s} ', '.join([repr(i) for i in self.{6:s}]) +{7:s} ) +{8:s} ) +'''.format( + context.indent, + name, + context.indent, + context.indent, + name, + context.indent, + name, + context.indent, + context.indent + ) + ) + elif type[:4] == 'set(' and type[-1:] == ')': + subtype = type[4:-1] + context.fout.write( + '''{0:s} if len(self.{1:s}): +{2:s} params.append( +{3:s} '{4:s} = set([{{0:s}}])'.format( +{5:s} ', '.join([repr(i) for i in sorted(self.{6:s})]) +{7:s} ) +{8:s} ) +'''.format( + context.indent, + name, + context.indent, + context.indent, + name, + context.indent, + name, + context.indent, + context.indent + ) + ) + else: + context.fout.write( + '''{0:s} if self.{1:s} != {2:s}: +{3:s} params.append( +{4:s} '{5:s} = {{0:s}}'.format(repr(self.{6:s})) +{7:s} ) +'''.format( + context.indent, + name, + default_value[type], + context.indent, + context.indent, + name, + name, + context.indent + ) + ) + context.fout.write( + '''{0:s}def __repr__(self): +{1:s} params = [] +{2:s} self.repr_serialize(params) +{3:s} return '{4:s}{5:s}({{0:s}})'.format(', '.join(params)) +'''.format( + context.indent, + context.indent, + context.indent, + context.indent, + context.package_name, + '.'.join(context.stack), + ) + ) + + context.indent = indent_save + del context.stack[-1] + for temp_base_class, temp_fields in context.base_classes.pop().items(): + context.base_classes[-1][ + '{0:s}.{1:s}'.format(class_name, temp_base_class) + ] = temp_fields + context.fields = fields_save class FieldDef(element.Element): # GENERATE ELEMENT() BEGIN def __init__( @@ -755,6 +1213,8 @@ class AST(element.Element): self.repr_serialize(params) return 'ast.AST.Section2.FieldDef({0:s})'.format(', '.join(params)) # GENERATE END + def generate_class_or_field_def(self, context): + context.fields.append((element.to_text(self[0]), self[1].get_text())) # GENERATE ELEMENT() BEGIN def __init__( self, @@ -811,6 +1271,7 @@ class AST(element.Element): # GENERATE FACTORY(element.Element) BEGIN tag_to_class = { 'AST': AST, + 'AST_ClassOrFieldDef': AST.ClassOrFieldDef, 'AST_Expression': AST.Expression, 'AST_Type': AST.Type, 'AST_BaseClass': AST.BaseClass, diff --git a/generate_ast.py b/generate_ast.py index f9a633b..9c5f0e8 100755 --- a/generate_ast.py +++ b/generate_ast.py @@ -94,21 +94,20 @@ while len(line): base_classes[-2][class_name].extend(fields) sys.stdout.write( - '''{0:s}# GENERATE ELEMENT({1:s}) BEGIN -{2:s}def __init__( -{3:s} self, -{4:s} tag = '{5:s}', -{6:s} attrib = {{}}, -{7:s} text = '', -{8:s} children = []{9:s} -{10:s}): -{11:s} {12:s}.__init__( -{13:s} self, -{14:s} tag, -{15:s} attrib, -{16:s} text, -{17:s} children{18:s} -{19:s} ) + '''{0:s}def __init__( +{1:s} self, +{2:s} tag = '{3:s}', +{4:s} attrib = {{}}, +{5:s} text = '', +{6:s} children = []{7:s} +{8:s}): +{9:s} {10:s}.__init__( +{11:s} self, +{12:s} tag, +{13:s} attrib, +{14:s} text, +{15:s} children{16:s} +{17:s} ) '''.format( indent, params, diff --git a/generate_c.py b/generate_c.py new file mode 100644 index 0000000..a7596d8 --- /dev/null +++ b/generate_c.py @@ -0,0 +1,2 @@ +def generate_c(_ast, home_dir, skel_file, out_file): + assert False diff --git a/generate_py.py b/generate_py.py new file mode 100644 index 0000000..c8bfab1 --- /dev/null +++ b/generate_py.py @@ -0,0 +1,90 @@ +import ast +import element +import os + +def text_to_python(text, indent): + #text_strip = text.strip() + #if text_strip[:1] == '{' and text_strip[-1:] == '}': + # text = text_strip[1:-1] + lines = text.rstrip().split('\n') + while len(lines) and len(lines[0].lstrip()) == 0: + lines = lines[1:] + while len(lines) and len(lines[-1].lstrip()) == 0: + lines = lines[:-1] + if len(lines) == 0: + return '' #{0:s}pass\n'.format(indent) + for j in range(len(lines[0])): + if lines[0][j] != '\t' and lines[0][j] != ' ': + break + else: + print(text) + assert False + #print('---') + #print(text) + prefix = lines[0][:j] + for j in range(len(lines)): + if len(lines[j]) == 0: + lines[j] = '\n' + else: + assert lines[j][:len(prefix)] == prefix + lines[j] = '{0:s}{1:s}\n'.format(indent, lines[j][len(prefix):]) + return ''.join(lines) + +def generate_py(_ast, home_dir, skel_file, out_file): + if skel_file is None: + skel_file = os.path.join(home_dir, 'skel/skel_py.py') + if out_file is None: + out_file = 't_ast.py' + with open(skel_file, 'r') as fin: + with open(out_file, 'w+') as fout: + line = fin.readline() + while len(line): + if line == '# GENERATE SECTION1\n': + fout.write( + '''# GENERATE SECTION1 BEGIN +{0:s}# GENERATE END +'''.format( + ''.join( + [ + text_to_python(element.get_text(i[0], 0), '') + for i in _ast[0] + ] + ) + ) + ) + elif line == '# GENERATE SECTION2\n': + fout.write('# GENERATE SECTION2 BEGIN\n') + #package_name = os.path.basename(out_file) + #if package_name[-3:] == '.py': + # package_name = package_name[:-3] + context = ast.Context(fout, 'ast.') #, package_name + '.') + for i in _ast[1]: + assert isinstance(i, ast.AST.Section2.ClassDef) + i.generate_class_or_field_def(context) + fout.write( + '''tag_to_class = {{{0:s} +}} +# GENERATE END +'''.format( + ','.join( + [ + '\n \'{0:s}\': {1:s}'.format( + i.replace('.', '_'), + i + ) + for i in context.classes + ] + ) + ) + ) + elif line == '# GENERATE SECTION3\n': + fout.write( + '''# GENERATE SECTION3 BEGIN +{0:s}# GENERATE END +'''.format( + '' if len(_ast) < 3 else element.get_text(_ast[2], 0) + ) + ) + else: + fout.write(line) + line = fin.readline() diff --git a/n.sh b/n.sh index 3576b91..2676045 100755 --- a/n.sh +++ b/n.sh @@ -15,3 +15,11 @@ then ./expected.sh ../piyacc.git/ast.py tests/piyacc.t >out/ast_piyacc.py.ok ./expected.sh ../piyacc.git/tests_ast/ast.py tests/cal.t >out/ast_cal.py.ok fi +./pitree.py --python -o out/ast_ansi_c.py tests/ansi_c.t +diff -q out/ast_ansi_c.py.ok out/ast_ansi_c.py +./pitree.py --python -o out/ast_pilex.py tests/pilex.t +diff -q out/ast_pilex.py.ok out/ast_pilex.py +./pitree.py --python -o out/ast_piyacc.py tests/piyacc.t +diff -q out/ast_piyacc.py.ok out/ast_piyacc.py +./pitree.py --python -o out/ast_cal.py tests/cal.t +diff -q out/ast_cal.py.ok out/ast_cal.py diff --git a/pitree.py b/pitree.py new file mode 100755 index 0000000..b7fe14f --- /dev/null +++ b/pitree.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python3 + +import ast +import element +import generate_c +import generate_py +import getopt +import os +import sys + +home_dir = os.path.dirname(sys.argv[0]) +try: + opts, args = getopt.getopt( + sys.argv[1:], + 'o:pS:', + ['outfile=', 'python', 'skel='] + ) +except getopt.GetoptError as err: + sys.stderr.write('{0:s}\n'.format(str(err))) + sys.exit(1) + +out_file = None +python = False +skel_file = None +for opt, arg in opts: + if opt == '-e' or opt == '--element': + _element = True + elif opt == '-o' or opt == '--outfile': + out_file = arg + elif opt == '-p' or opt == '--python': + python = True + elif opt == '-S' or opt == '--skel': + skel_file = arg + else: + assert False +if len(args) < 1: + sys.stdout.write( + 'usage: {0:s} [options] defs.t\n'.format( + sys.argv[0] + ) + ) + sys.exit(1) +in_file = args[0] + +with open(in_file) as fin: + if in_file[-4:] == '.xml': + _ast = element.deserialize(fin, ast.factory) + else: + import lex_yy + import y_tab + lex_yy.yyin = fin + _ast = y_tab.yyparse(ast.AST) +#element.serialize(_ast, 'a.xml', 'utf-8') +#_ast = element.deserialize('a.xml', ast.factory, 'utf-8') +#_ast.post_process() +#element.serialize(_ast, 'b.xml', 'utf-8') +#_ast = element.deserialize('b.xml', ast.factory, 'utf-8') +(generate_py.generate_py if python else generate_c.generate_c)( + _ast, + home_dir, + skel_file, + out_file +) -- 2.34.1