Skip to content

Commit 463ed17

Browse files
committed
Fixed nested procedure problem in PL/0.
1 parent 925a611 commit 463ed17

File tree

1 file changed

+102
-55
lines changed

1 file changed

+102
-55
lines changed

pl0/pl0.cc

Lines changed: 102 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -79,39 +79,56 @@ struct SymbolScope;
7979

8080
struct Annotation {
8181
shared_ptr<SymbolScope> scope;
82+
shared_ptr<vector<string>> freeVariables;
8283
};
8384

8485
typedef AstBase<Annotation> AstPL0;
86+
shared_ptr<SymbolScope> get_closest_scope(shared_ptr<AstPL0> ast) {
87+
ast = ast->parent;
88+
while (ast->tag != "block"_) {
89+
ast = ast->parent;
90+
}
91+
return ast->scope;
92+
}
8593

8694
/*
8795
* Symbol Table
8896
*/
8997
struct SymbolScope {
9098
SymbolScope(shared_ptr<SymbolScope> outer) : outer(outer) {}
9199

92-
bool has_symbol(const string& ident) const {
100+
bool has_symbol(const string& ident, bool extend = true) const {
93101
auto ret = constants.count(ident) || variables.count(ident);
94-
return ret ? true : (outer ? outer->has_symbol(ident) : false);
102+
return ret ? true : (extend && outer ? outer->has_symbol(ident) : false);
103+
}
104+
105+
bool has_constant(const string& ident, bool extend = true) const {
106+
return constants.count(ident)
107+
? true
108+
: (extend && outer ? outer->has_constant(ident) : false);
95109
}
96110

97-
bool has_constant(const string& ident) const {
98-
return constants.count(ident) ? true : (outer ? outer->has_constant(ident)
99-
: false);
111+
bool has_variable(const string& ident, bool extend = true) const {
112+
return variables.count(ident)
113+
? true
114+
: (extend && outer ? outer->has_variable(ident) : false);
100115
}
101116

102-
bool has_variable(const string& ident) const {
103-
return variables.count(ident) ? true : (outer ? outer->has_variable(ident)
104-
: false);
117+
bool has_procedure(const string& ident, bool extend = true) const {
118+
return procedures.count(ident)
119+
? true
120+
: (extend && outer ? outer->has_procedure(ident) : false);
105121
}
106122

107-
bool has_procedure(const string& ident) const {
108-
return procedures.count(ident) ? true : (outer ? outer->has_procedure(ident)
109-
: false);
123+
shared_ptr<AstPL0> get_procedure(const string& ident) const {
124+
auto it = procedures.find(ident);
125+
return it != procedures.end() ? it->second : outer->get_procedure(ident);
110126
}
111127

112128
map<string, int> constants;
113129
set<string> variables;
114130
map<string, shared_ptr<AstPL0>> procedures;
131+
set<string> free_variables;
115132

116133
private:
117134
shared_ptr<SymbolScope> outer;
@@ -161,8 +178,7 @@ struct SymbolTable {
161178

162179
static void constants(const shared_ptr<AstPL0> ast,
163180
shared_ptr<SymbolScope> scope) {
164-
// const <- ('CONST' __ ident '=' _ number(',' _ ident '=' _ number)* ';'
165-
// _)?
181+
// const <- ('CONST' __ ident '=' _ number(',' _ ident '=' _ number)* ';' _)?
166182
const auto& nodes = ast->nodes;
167183
for (auto i = 0u; i < nodes.size(); i += 2) {
168184
const auto& ident = nodes[i + 0]->token;
@@ -210,6 +226,12 @@ struct SymbolTable {
210226
throw_runtime_error(ast->nodes[0],
211227
"undefined variable '" + ident + "'...");
212228
}
229+
230+
build_on_ast(ast->nodes[1], scope);
231+
232+
if (!scope->has_symbol(ident, false)) {
233+
scope->free_variables.emplace(ident);
234+
}
213235
}
214236

215237
static void call(const shared_ptr<AstPL0> ast,
@@ -220,6 +242,15 @@ struct SymbolTable {
220242
throw_runtime_error(ast->nodes[0],
221243
"undefined procedure '" + ident + "'...");
222244
}
245+
246+
auto block = scope->get_procedure(ident);
247+
if (block->scope) {
248+
for (const auto& free : block->scope->free_variables) {
249+
if (!scope->has_symbol(free, false)) {
250+
scope->free_variables.emplace(free);
251+
}
252+
}
253+
}
223254
}
224255

225256
static void ident(const shared_ptr<AstPL0> ast,
@@ -228,6 +259,10 @@ struct SymbolTable {
228259
if (!scope->has_symbol(ident)) {
229260
throw_runtime_error(ast, "undefined variable '" + ident + "'...");
230261
}
262+
263+
if (!scope->has_symbol(ident, false)) {
264+
scope->free_variables.emplace(ident);
265+
}
231266
}
232267
};
233268

@@ -260,9 +295,7 @@ struct Environment {
260295
}
261296

262297
shared_ptr<AstPL0> get_procedure(const string& ident) const {
263-
auto it = scope->procedures.find(ident);
264-
return it != scope->procedures.end() ? it->second
265-
: outer->get_procedure(ident);
298+
return scope->get_procedure(ident);
266299
}
267300

268301
private:
@@ -602,46 +635,33 @@ struct LLVM {
602635
{
603636
auto BB = BasicBlock::Create(context_, "entry", fn);
604637
builder_.SetInsertPoint(BB);
605-
compile_block(ast->nodes[0], true);
638+
compile_block(ast->nodes[0]);
606639
builder_.CreateRetVoid();
607640
}
608641
}
609642

610-
void compile_block(const shared_ptr<AstPL0> ast, bool top) {
611-
compile_const(ast->nodes[0], top);
612-
compile_var(ast->nodes[1], top);
643+
void compile_block(const shared_ptr<AstPL0> ast) {
644+
compile_const(ast->nodes[0]);
645+
compile_var(ast->nodes[1]);
613646
compile_procedure(ast->nodes[2]);
614647
compile_statement(ast->nodes[3]);
615648
}
616649

617-
void compile_const(const shared_ptr<AstPL0> ast, bool top) {
650+
void compile_const(const shared_ptr<AstPL0> ast) {
618651
for (auto i = 0u; i < ast->nodes.size(); i += 2) {
619652
auto ident = ast->nodes[i]->token;
620653
auto number = stoi(ast->nodes[i + 1]->token);
621654

622-
if (top) {
623-
auto gv = cast<GlobalVariable>(
624-
module_->getOrInsertGlobal(ident, builder_.getInt32Ty()));
625-
gv->setAlignment(4);
626-
gv->setInitializer(builder_.getInt32(number));
627-
} else {
628-
auto alloca =
629-
builder_.CreateAlloca(builder_.getInt32Ty(), nullptr, ident);
630-
builder_.CreateStore(builder_.getInt32(number), alloca);
631-
}
655+
auto alloca =
656+
builder_.CreateAlloca(builder_.getInt32Ty(), nullptr, ident);
657+
builder_.CreateStore(builder_.getInt32(number), alloca);
632658
}
633659
}
634660

635-
void compile_var(const shared_ptr<AstPL0> ast, bool top) {
661+
void compile_var(const shared_ptr<AstPL0> ast) {
636662
for (const auto node : ast->nodes) {
637-
if (top) {
638-
auto gv = cast<GlobalVariable>(
639-
module_->getOrInsertGlobal(node->token, builder_.getInt32Ty()));
640-
gv->setAlignment(4);
641-
gv->setInitializer(builder_.getInt32(0));
642-
} else {
643-
builder_.CreateAlloca(builder_.getInt32Ty(), nullptr, node->token);
644-
}
663+
auto ident = node->token;
664+
builder_.CreateAlloca(builder_.getInt32Ty(), nullptr, ident);
645665
}
646666
}
647667

@@ -650,13 +670,24 @@ struct LLVM {
650670
auto ident = ast->nodes[i]->token;
651671
auto block = ast->nodes[i + 1];
652672

653-
auto fn = cast<Function>(
654-
module_->getOrInsertFunction(ident, builder_.getVoidTy(), nullptr));
673+
std::vector<Type*> pt(block->scope->free_variables.size(),
674+
Type::getInt32PtrTy(context_));
675+
auto ft = FunctionType::get(builder_.getVoidTy(), pt, false);
676+
auto fn = cast<Function>(module_->getOrInsertFunction(ident, ft));
677+
678+
{
679+
auto it = block->scope->free_variables.begin();
680+
for (auto& arg : fn->args()) {
681+
arg.setName(*it);
682+
++it;
683+
}
684+
}
685+
655686
{
656687
auto prevBB = builder_.GetInsertBlock();
657688
auto BB = BasicBlock::Create(context_, "entry", fn);
658689
builder_.SetInsertPoint(BB);
659-
compile_block(block, false);
690+
compile_block(block);
660691
builder_.CreateRetVoid();
661692
builder_.SetInsertPoint(prevBB);
662693
}
@@ -670,22 +701,38 @@ struct LLVM {
670701
}
671702

672703
void compile_assignment(const shared_ptr<AstPL0> ast) {
673-
auto name = ast->nodes[0]->token;
704+
auto ident = ast->nodes[0]->token;
674705

675706
auto fn = builder_.GetInsertBlock()->getParent();
676707
auto tbl = fn->getValueSymbolTable();
677-
auto var = tbl->lookup(name);
708+
auto var = tbl->lookup(ident);
678709
if (!var) {
679-
var = module_->getGlobalVariable(name);
710+
throw_runtime_error(ast, "'" + ident + "' is not defined...");
680711
}
681712

682713
auto val = compile_expression(ast->nodes[1]);
683714
builder_.CreateStore(val, var);
684715
}
685716

686717
void compile_call(const shared_ptr<AstPL0> ast) {
687-
auto fn = module_->getFunction(ast->nodes[0]->token);
688-
builder_.CreateCall(fn);
718+
auto ident = ast->nodes[0]->token;
719+
720+
auto scope = get_closest_scope(ast);
721+
auto block = scope->get_procedure(ident);
722+
723+
std::vector<Value*> args;
724+
for (auto& free : block->scope->free_variables) {
725+
auto fn = builder_.GetInsertBlock()->getParent();
726+
auto tbl = fn->getValueSymbolTable();
727+
auto var = tbl->lookup(free);
728+
if (!var) {
729+
throw_runtime_error(ast, "'" + free + "' is not defined...");
730+
}
731+
args.push_back(var);
732+
}
733+
734+
auto fn = module_->getFunction(ident);
735+
builder_.CreateCall(fn, args);
689736
}
690737

691738
void compile_statements(const shared_ptr<AstPL0> ast) {
@@ -831,13 +878,13 @@ struct LLVM {
831878
}
832879

833880
Value* compile_ident(const shared_ptr<AstPL0> ast) {
834-
auto name = ast->token;
881+
auto ident = ast->token;
835882

836883
auto fn = builder_.GetInsertBlock()->getParent();
837884
auto tbl = fn->getValueSymbolTable();
838-
auto var = tbl->lookup(name);
885+
auto var = tbl->lookup(ident);
839886
if (!var) {
840-
var = module_->getGlobalVariable(name);
887+
throw_runtime_error(ast, "'" + ident + "' is not defined...");
841888
}
842889

843890
return builder_.CreateLoad(var);
@@ -900,13 +947,13 @@ int main(int argc, const char** argv) {
900947
}
901948
}
902949

903-
if (opt_ast) {
904-
cout << ast_to_s(ast);
905-
}
906-
907950
try {
908951
SymbolTable::build_on_ast(ast);
909952

953+
if (opt_ast) {
954+
cout << ast_to_s<AstPL0>(ast);
955+
}
956+
910957
if (opt_llvm || opt_jit) {
911958
if (opt_llvm) {
912959
LLVM::dump(ast);

0 commit comments

Comments
 (0)