@@ -79,39 +79,56 @@ struct SymbolScope;
79
79
80
80
struct Annotation {
81
81
shared_ptr<SymbolScope> scope;
82
+ shared_ptr<vector<string>> freeVariables;
82
83
};
83
84
84
85
typedef AstBase<Annotation> AstPL0;
86
+ shared_ptr<SymbolScope> get_closest_scope (shared_ptr<AstPL0> ast) {
87
+ ast = ast->parent ;
88
+ while (ast->tag != " block" _) {
89
+ ast = ast->parent ;
90
+ }
91
+ return ast->scope ;
92
+ }
85
93
86
94
/*
87
95
* Symbol Table
88
96
*/
89
97
struct SymbolScope {
90
98
SymbolScope (shared_ptr<SymbolScope> outer) : outer(outer) {}
91
99
92
- bool has_symbol (const string& ident) const {
100
+ bool has_symbol (const string& ident, bool extend = true ) const {
93
101
auto ret = constants.count (ident) || variables.count (ident);
94
- return ret ? true : (outer ? outer->has_symbol (ident) : false );
102
+ return ret ? true : (extend && outer ? outer->has_symbol (ident) : false );
103
+ }
104
+
105
+ bool has_constant (const string& ident, bool extend = true ) const {
106
+ return constants.count (ident)
107
+ ? true
108
+ : (extend && outer ? outer->has_constant (ident) : false );
95
109
}
96
110
97
- bool has_constant (const string& ident) const {
98
- return constants.count (ident) ? true : (outer ? outer->has_constant (ident)
99
- : false );
111
+ bool has_variable (const string& ident, bool extend = true ) const {
112
+ return variables.count (ident)
113
+ ? true
114
+ : (extend && outer ? outer->has_variable (ident) : false );
100
115
}
101
116
102
- bool has_variable (const string& ident) const {
103
- return variables.count (ident) ? true : (outer ? outer->has_variable (ident)
104
- : false );
117
+ bool has_procedure (const string& ident, bool extend = true ) const {
118
+ return procedures.count (ident)
119
+ ? true
120
+ : (extend && outer ? outer->has_procedure (ident) : false );
105
121
}
106
122
107
- bool has_procedure (const string& ident) const {
108
- return procedures.count (ident) ? true : (outer ? outer-> has_procedure (ident)
109
- : false );
123
+ shared_ptr<AstPL0> get_procedure (const string& ident) const {
124
+ auto it = procedures.find (ident);
125
+ return it != procedures. end () ? it-> second : outer-> get_procedure (ident );
110
126
}
111
127
112
128
map<string, int > constants;
113
129
set<string> variables;
114
130
map<string, shared_ptr<AstPL0>> procedures;
131
+ set<string> free_variables;
115
132
116
133
private:
117
134
shared_ptr<SymbolScope> outer;
@@ -161,8 +178,7 @@ struct SymbolTable {
161
178
162
179
static void constants (const shared_ptr<AstPL0> ast,
163
180
shared_ptr<SymbolScope> scope) {
164
- // const <- ('CONST' __ ident '=' _ number(',' _ ident '=' _ number)* ';'
165
- // _)?
181
+ // const <- ('CONST' __ ident '=' _ number(',' _ ident '=' _ number)* ';' _)?
166
182
const auto & nodes = ast->nodes ;
167
183
for (auto i = 0u ; i < nodes.size (); i += 2 ) {
168
184
const auto & ident = nodes[i + 0 ]->token ;
@@ -210,6 +226,12 @@ struct SymbolTable {
210
226
throw_runtime_error (ast->nodes [0 ],
211
227
" undefined variable '" + ident + " '..." );
212
228
}
229
+
230
+ build_on_ast (ast->nodes [1 ], scope);
231
+
232
+ if (!scope->has_symbol (ident, false )) {
233
+ scope->free_variables .emplace (ident);
234
+ }
213
235
}
214
236
215
237
static void call (const shared_ptr<AstPL0> ast,
@@ -220,6 +242,15 @@ struct SymbolTable {
220
242
throw_runtime_error (ast->nodes [0 ],
221
243
" undefined procedure '" + ident + " '..." );
222
244
}
245
+
246
+ auto block = scope->get_procedure (ident);
247
+ if (block->scope ) {
248
+ for (const auto & free : block->scope ->free_variables ) {
249
+ if (!scope->has_symbol (free, false )) {
250
+ scope->free_variables .emplace (free);
251
+ }
252
+ }
253
+ }
223
254
}
224
255
225
256
static void ident (const shared_ptr<AstPL0> ast,
@@ -228,6 +259,10 @@ struct SymbolTable {
228
259
if (!scope->has_symbol (ident)) {
229
260
throw_runtime_error (ast, " undefined variable '" + ident + " '..." );
230
261
}
262
+
263
+ if (!scope->has_symbol (ident, false )) {
264
+ scope->free_variables .emplace (ident);
265
+ }
231
266
}
232
267
};
233
268
@@ -260,9 +295,7 @@ struct Environment {
260
295
}
261
296
262
297
shared_ptr<AstPL0> get_procedure (const string& ident) const {
263
- auto it = scope->procedures .find (ident);
264
- return it != scope->procedures .end () ? it->second
265
- : outer->get_procedure (ident);
298
+ return scope->get_procedure (ident);
266
299
}
267
300
268
301
private:
@@ -602,46 +635,33 @@ struct LLVM {
602
635
{
603
636
auto BB = BasicBlock::Create (context_, " entry" , fn);
604
637
builder_.SetInsertPoint (BB);
605
- compile_block (ast->nodes [0 ], true );
638
+ compile_block (ast->nodes [0 ]);
606
639
builder_.CreateRetVoid ();
607
640
}
608
641
}
609
642
610
- void compile_block (const shared_ptr<AstPL0> ast, bool top ) {
611
- compile_const (ast->nodes [0 ], top );
612
- compile_var (ast->nodes [1 ], top );
643
+ void compile_block (const shared_ptr<AstPL0> ast) {
644
+ compile_const (ast->nodes [0 ]);
645
+ compile_var (ast->nodes [1 ]);
613
646
compile_procedure (ast->nodes [2 ]);
614
647
compile_statement (ast->nodes [3 ]);
615
648
}
616
649
617
- void compile_const (const shared_ptr<AstPL0> ast, bool top ) {
650
+ void compile_const (const shared_ptr<AstPL0> ast) {
618
651
for (auto i = 0u ; i < ast->nodes .size (); i += 2 ) {
619
652
auto ident = ast->nodes [i]->token ;
620
653
auto number = stoi (ast->nodes [i + 1 ]->token );
621
654
622
- if (top) {
623
- auto gv = cast<GlobalVariable>(
624
- module_->getOrInsertGlobal (ident, builder_.getInt32Ty ()));
625
- gv->setAlignment (4 );
626
- gv->setInitializer (builder_.getInt32 (number));
627
- } else {
628
- auto alloca =
629
- builder_.CreateAlloca (builder_.getInt32Ty (), nullptr , ident);
630
- builder_.CreateStore (builder_.getInt32 (number), alloca);
631
- }
655
+ auto alloca =
656
+ builder_.CreateAlloca (builder_.getInt32Ty (), nullptr , ident);
657
+ builder_.CreateStore (builder_.getInt32 (number), alloca);
632
658
}
633
659
}
634
660
635
- void compile_var (const shared_ptr<AstPL0> ast, bool top ) {
661
+ void compile_var (const shared_ptr<AstPL0> ast) {
636
662
for (const auto node : ast->nodes ) {
637
- if (top) {
638
- auto gv = cast<GlobalVariable>(
639
- module_->getOrInsertGlobal (node->token , builder_.getInt32Ty ()));
640
- gv->setAlignment (4 );
641
- gv->setInitializer (builder_.getInt32 (0 ));
642
- } else {
643
- builder_.CreateAlloca (builder_.getInt32Ty (), nullptr , node->token );
644
- }
663
+ auto ident = node->token ;
664
+ builder_.CreateAlloca (builder_.getInt32Ty (), nullptr , ident);
645
665
}
646
666
}
647
667
@@ -650,13 +670,24 @@ struct LLVM {
650
670
auto ident = ast->nodes [i]->token ;
651
671
auto block = ast->nodes [i + 1 ];
652
672
653
- auto fn = cast<Function>(
654
- module_->getOrInsertFunction (ident, builder_.getVoidTy (), nullptr ));
673
+ std::vector<Type*> pt (block->scope ->free_variables .size (),
674
+ Type::getInt32PtrTy (context_));
675
+ auto ft = FunctionType::get (builder_.getVoidTy (), pt, false );
676
+ auto fn = cast<Function>(module_->getOrInsertFunction (ident, ft));
677
+
678
+ {
679
+ auto it = block->scope ->free_variables .begin ();
680
+ for (auto & arg : fn->args ()) {
681
+ arg.setName (*it);
682
+ ++it;
683
+ }
684
+ }
685
+
655
686
{
656
687
auto prevBB = builder_.GetInsertBlock ();
657
688
auto BB = BasicBlock::Create (context_, " entry" , fn);
658
689
builder_.SetInsertPoint (BB);
659
- compile_block (block, false );
690
+ compile_block (block);
660
691
builder_.CreateRetVoid ();
661
692
builder_.SetInsertPoint (prevBB);
662
693
}
@@ -670,22 +701,38 @@ struct LLVM {
670
701
}
671
702
672
703
void compile_assignment (const shared_ptr<AstPL0> ast) {
673
- auto name = ast->nodes [0 ]->token ;
704
+ auto ident = ast->nodes [0 ]->token ;
674
705
675
706
auto fn = builder_.GetInsertBlock ()->getParent ();
676
707
auto tbl = fn->getValueSymbolTable ();
677
- auto var = tbl->lookup (name );
708
+ auto var = tbl->lookup (ident );
678
709
if (!var) {
679
- var = module_-> getGlobalVariable (name );
710
+ throw_runtime_error (ast, " ' " + ident + " ' is not defined... " );
680
711
}
681
712
682
713
auto val = compile_expression (ast->nodes [1 ]);
683
714
builder_.CreateStore (val, var);
684
715
}
685
716
686
717
void compile_call (const shared_ptr<AstPL0> ast) {
687
- auto fn = module_->getFunction (ast->nodes [0 ]->token );
688
- builder_.CreateCall (fn);
718
+ auto ident = ast->nodes [0 ]->token ;
719
+
720
+ auto scope = get_closest_scope (ast);
721
+ auto block = scope->get_procedure (ident);
722
+
723
+ std::vector<Value*> args;
724
+ for (auto & free : block->scope ->free_variables ) {
725
+ auto fn = builder_.GetInsertBlock ()->getParent ();
726
+ auto tbl = fn->getValueSymbolTable ();
727
+ auto var = tbl->lookup (free);
728
+ if (!var) {
729
+ throw_runtime_error (ast, " '" + free + " ' is not defined..." );
730
+ }
731
+ args.push_back (var);
732
+ }
733
+
734
+ auto fn = module_->getFunction (ident);
735
+ builder_.CreateCall (fn, args);
689
736
}
690
737
691
738
void compile_statements (const shared_ptr<AstPL0> ast) {
@@ -831,13 +878,13 @@ struct LLVM {
831
878
}
832
879
833
880
Value* compile_ident (const shared_ptr<AstPL0> ast) {
834
- auto name = ast->token ;
881
+ auto ident = ast->token ;
835
882
836
883
auto fn = builder_.GetInsertBlock ()->getParent ();
837
884
auto tbl = fn->getValueSymbolTable ();
838
- auto var = tbl->lookup (name );
885
+ auto var = tbl->lookup (ident );
839
886
if (!var) {
840
- var = module_-> getGlobalVariable (name );
887
+ throw_runtime_error (ast, " ' " + ident + " ' is not defined... " );
841
888
}
842
889
843
890
return builder_.CreateLoad (var);
@@ -900,13 +947,13 @@ int main(int argc, const char** argv) {
900
947
}
901
948
}
902
949
903
- if (opt_ast) {
904
- cout << ast_to_s (ast);
905
- }
906
-
907
950
try {
908
951
SymbolTable::build_on_ast (ast);
909
952
953
+ if (opt_ast) {
954
+ cout << ast_to_s<AstPL0>(ast);
955
+ }
956
+
910
957
if (opt_llvm || opt_jit) {
911
958
if (opt_llvm) {
912
959
LLVM::dump (ast);
0 commit comments