-
Notifications
You must be signed in to change notification settings - Fork 171
/
Copy pathlpython_parser.py
156 lines (133 loc) · 4.75 KB
/
lpython_parser.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import sys
import python_ast
import ast
filename = sys.argv[1]
filename_out = sys.argv[2]
input = open(filename).read()
a = ast.parse(input, type_comments=True)
#print(ast.unparse(a))
#print()
#print(ast.dump(a))
# Transform ast.AST to python_ast.AST:
def get_newlines(s):
newlines = []
for pos in range(len(s)):
if s[pos] == "\n":
newlines.append(pos)
return newlines
newlines = get_newlines(input)
# line and col starts from 1
# It returns a linear position, which starts from 0
def linecol_to_pos(line, col, newlines):
if line <= 0:
return 0
elif line == 1:
return col - 1
elif line-1 >= len(newlines):
return newlines[-1] + 1 + col - 1
else:
return newlines[line-2] + 1 + col - 1
class Transform(ast.NodeVisitor):
# Transform Constant to specific Constant* types
def visit_Constant(self, node):
if isinstance(node.value, str):
new_node = python_ast.ConstantStr(node.value, node.kind)
elif isinstance(node.value, bool):
new_node = python_ast.ConstantBool(node.value, node.kind)
elif isinstance(node.value, int):
new_node = python_ast.ConstantInt(node.value, node.kind)
elif isinstance(node.value, float):
new_node = python_ast.ConstantFloat(node.value, node.kind)
elif isinstance(node.value, complex):
new_node = python_ast.ConstantComplex(node.value.real,
node.value.imag, node.kind)
elif isinstance(node.value, Ellipsis.__class__):
new_node = python_ast.ConstantEllipsis(node.kind)
elif isinstance(node.value, None.__class__):
new_node = python_ast.ConstantNone(node.kind)
elif isinstance(node.value, bytes):
new_node = python_ast.ConstantBytes(str(node.value), node.kind)
else:
print(type(node.value))
raise Exception("Unsupported Constant type")
new_node.first = linecol_to_pos(node.lineno, node.col_offset+1, newlines)
new_node.last = linecol_to_pos(node.end_lineno, node.end_col_offset, newlines)
return new_node
def generic_visit(self, node):
d = {}
class_name = node.__class__.__name__
for field, value in ast.iter_fields(node):
if field == "ops": # For Compare()
# We only represent one comparison operator
assert len(value) == 1
d[field] = self.visit(value[0])
elif isinstance(value, list):
new_list = []
for item in value:
if isinstance(item, ast.AST):
new_list.append(self.visit(item))
else:
if type(item) == str:
new_list.append(item)
elif item is None:
new_list.append(self.visit(python_ast.ConstantNone()))
d[field] = new_list
elif field in ["vararg", "kwarg"]:
if value is None:
d[field] = []
else:
d[field] = [self.visit(value)]
elif isinstance(value, ast.AST):
d[field] = self.visit(value)
elif isinstance(value, (str, int)):
d[field] = value
elif value is None:
d[field] = value
else:
print("Node type:", class_name)
print("Value type:", type(value))
raise Exception("Unsupported value type")
new_ast = getattr(python_ast, class_name)
new_node = new_ast(**d)
if hasattr(node, "col_offset"):
new_node.first = linecol_to_pos(node.lineno, node.col_offset+1, newlines)
new_node.last = linecol_to_pos(node.end_lineno, node.end_col_offset, newlines)
else:
new_node.first = 1
new_node.last = 1
return new_node
#print()
v = Transform()
a2 = v.visit(a)
#print(a2)
# Test the visitor python_ast.AST
v = python_ast.GenericASTVisitor()
v.visit(a2)
# Serialize
class Serialization(python_ast.SerializationBaseVisitor):
def __init__(self):
# Start with a "mod" class
self.s = "0 "
def write_int8(self, i):
assert i >= 0
self.s += str(i) + " "
def write_int64(self, i):
if i < 0:
i += 2**64
assert i >= 0
self.s += str(i) + " "
def write_float64(self, f):
self.s += str(f) + " "
def write_string(self, s):
self.write_int64(len(s))
self.s += str(s) + " "
def write_bool(self, b):
if b:
self.write_int8(1)
else:
self.write_int8(0)
v = Serialization()
v.visit(a2)
#print()
#print(v.s)
open(filename_out, "w").write(v.s)