diff options
| -rw-r--r-- | cup/ast.c | 19 | ||||
| -rw-r--r-- | cup/ast.h | 10 | ||||
| -rw-r--r-- | cup/common.h | 2 | ||||
| -rw-r--r-- | cup/generator.c | 37 | ||||
| -rw-r--r-- | cup/parser.c | 134 | ||||
| -rw-r--r-- | examples/functions.cup | 13 | ||||
| -rwxr-xr-x | tests/functions.sh | 63 |
7 files changed, 259 insertions, 19 deletions
@@ -122,6 +122,15 @@ static void do_print_ast(Node *node, int depth) } else { printf("\n"); } + } else if (node->type == AST_FUNCCALL) { + printf("CALL %s(\n", node->call.func->func.name); + for (int i = 0; i < node->call.num_args; i++) { + do_print_ast(node->call.args[i], depth + 1); + } + for (int i = 0; i < depth; i++) { + printf(" "); + } + printf(")\n"); } else { printf("{{ %s }}\n", node_type_to_str(node->type)); } @@ -129,12 +138,20 @@ static void do_print_ast(Node *node, int depth) void dump_func(Node *node, int depth) { - printf("fn %s()", node->func.name); + printf("fn %s(", node->func.name); + for (int i = 0; i < node->func.num_args; i++) { + if (i > 0) printf(", "); + printf("%s: ", node->func.args[i].name); + print_type_to_file(stdout, node->func.args[i].type); + printf("[[%lld]]", node->func.args[i].offset); + } + printf(")"); if (node->func.return_type.type != TYPE_NONE) { // FIXME: Print return type properly printf(" -> "); print_type_to_file(stdout, node->func.return_type); } + do_print_ast(node->func.body, depth + 1); } @@ -24,6 +24,7 @@ F(OP_GEQ, ">=") \ F(OP_ASSIGN, "=") \ F(AST_LITERAL, "literal") \ + F(AST_FUNCCALL, "Function call") \ F(AST_CONDITIONAL, "conditional expression") \ F(AST_IF, "if statement") \ F(AST_WHILE, "while statement") \ @@ -89,7 +90,10 @@ typedef struct ast_node { // TODO: Should we just dynamically allocate space on the // stack for each block instead of storing this? i64 max_locals_size; + // TODO: Arguments / etc? + Variable *args; + int num_args; } func; // Block of statements @@ -134,6 +138,12 @@ typedef struct ast_node { } loop; Variable *variable; + + struct { + Node *func; + Node **args; + int num_args; + } call; }; } Node; diff --git a/cup/common.h b/cup/common.h index 522483f..20d057f 100644 --- a/cup/common.h +++ b/cup/common.h @@ -3,4 +3,4 @@ #include <stdbool.h> #include <stdint.h> -typedef int64_t i64; +typedef long long int i64; diff --git a/cup/generator.c b/cup/generator.c index 61e8c69..6a95d0c 100644 --- a/cup/generator.c +++ b/cup/generator.c @@ -7,9 +7,30 @@ #include <string.h> #include <assert.h> +#include <sys/syscall.h> + static int label_counter = 0; static Node *current_function = NULL; +void generate_expr_into_rax(Node *expr, FILE *out); + +void generate_func_call(Node *node, FILE *out) +{ + assert(node->type == AST_FUNCCALL); + // FIXME: This seems like a big hack + i64 total_size = 0; + for (int i = node->call.num_args - 1; i >= 0; i--) { + Node *arg = node->call.args[i]; + generate_expr_into_rax(arg, out); + fprintf(out, " push rax\n"); + // TODO: Compute this for different types + // TODO: Also make sure of padding and stuff? + total_size += 8; + } + fprintf(out, " call %s\n", node->call.func->func.name); + fprintf(out, " add rsp, %lld\n", total_size); +} + // The evaluated expression is stored into `rax` void generate_expr_into_rax(Node *expr, FILE *out) { @@ -19,9 +40,15 @@ void generate_expr_into_rax(Node *expr, FILE *out) assert(expr->literal.type.type == TYPE_INT); fprintf(out, " mov rax, %d\n", expr->literal.as_int); + } else if (expr->type == AST_FUNCCALL) { + generate_func_call(expr, out); + } else if (expr->type == AST_VAR) { i64 offset = expr->variable->offset; - fprintf(out, " mov rax, [rbp-%lld]\n", offset); + if (offset > 0) + fprintf(out, " mov rax, [rbp-%lld]\n", offset); + else + fprintf(out, " mov rax, [rbp+%lld]\n", -offset); } else if (expr->type == OP_ASSIGN) { i64 offset = expr->assign.var->offset; @@ -212,7 +239,7 @@ void generate_statement(Node *stmt, FILE *out) assert(stmt->conditional.cond); assert(stmt->conditional.do_then); int cur_label = label_counter++; - + generate_expr_into_rax(stmt->conditional.cond, out); // If we don't have an `else` clause, we can simplify if (!stmt->conditional.do_else) { @@ -229,7 +256,7 @@ void generate_statement(Node *stmt, FILE *out) generate_statement(stmt->conditional.do_else, out); fprintf(out, ".if_end_%d:\n", cur_label); } - } else if (stmt->type == AST_WHILE) { + } else if (stmt->type == AST_WHILE) { int cur_label = label_counter++; fprintf(out, ".loop_start_%d:\n", cur_label); fprintf(out, ".loop_continue_%d:\n", cur_label); @@ -240,7 +267,7 @@ void generate_statement(Node *stmt, FILE *out) fprintf(out, " jmp .loop_start_%d\n", cur_label); fprintf(out, ".loop_end_%d:\n", cur_label); - } else if (stmt->type == AST_FOR) { + } else if (stmt->type == AST_FOR) { int cur_label = label_counter++; if (stmt->loop.init) { generate_statement(stmt->loop.init, out); @@ -259,7 +286,7 @@ void generate_statement(Node *stmt, FILE *out) fprintf(out, " jmp .loop_start_%d\n", cur_label); fprintf(out, ".loop_end_%d:\n", cur_label); - } else if (stmt->type == AST_BLOCK) { + } else if (stmt->type == AST_BLOCK) { generate_block(stmt, out); } else { // Once again, default to an expression here... diff --git a/cup/parser.c b/cup/parser.c index 698002f..cfb5d97 100644 --- a/cup/parser.c +++ b/cup/parser.c @@ -4,6 +4,10 @@ #include <string.h> #include <assert.h> +#define MAX_FUNCTION_COUNT 1024 +static Node *all_functions[MAX_FUNCTION_COUNT]; +static i64 function_count = 0; + static Node *current_function = NULL; #define BLOCK_STACK_SIZE 64 @@ -92,6 +96,24 @@ Variable *find_local_variable(Token *token) } } } + Node *func = current_function; + for (int i = 0; i < func->func.num_args; i++) { + if (strcmp(func->func.args[i].name, token->value.as_string) == 0) { + return &func->func.args[i]; + } + } + return NULL; +} + +Node *find_function_definition(Token *token) +{ + assert_token(*token, TOKEN_IDENTIFIER); + for (i64 i = 0; i < function_count; i++) { + Node *function = all_functions[i]; + if (strcmp(function->func.name, token->value.as_string) == 0) { + return function; + } + } return NULL; } @@ -203,6 +225,38 @@ Node *parse_var_declaration(Lexer *lexer) return node; } +Node *parse_function_call_args(Lexer *lexer, Node *func) +{ + Token identifier = assert_token(Lexer_next(lexer), TOKEN_IDENTIFIER); + Node *call = Node_new(AST_FUNCCALL); + call->call.func = func; + assert_token(Lexer_next(lexer), TOKEN_OPEN_PAREN); + Token token = Lexer_peek(lexer); + + while (token.type != TOKEN_CLOSE_PAREN) { + Node *arg = parse_expression(lexer); + + int new_size = call->call.num_args + 1; + call->call.args = realloc(call->call.args, sizeof(Node *) * new_size); + call->call.args[call->call.num_args++] = arg; + + if (new_size > func->func.num_args) + die_location(identifier.loc, "Too many arguments to function `%s`", func->func.name); + + token = Lexer_peek(lexer); + if (token.type == TOKEN_COMMA) { + Lexer_next(lexer); + token = Lexer_peek(lexer); + } + } + + if (call->call.num_args != func->func.num_args) + die_location(identifier.loc, "Too few arguments to function `%s`", func->func.name); + + assert_token(Lexer_next(lexer), TOKEN_CLOSE_PAREN); + return call; +} + Node *parse_factor(Lexer *lexer) { // TODO: Parse more complicated things @@ -226,13 +280,24 @@ Node *parse_factor(Lexer *lexer) assert_token(Lexer_next(lexer), TOKEN_CLOSE_PAREN); } else if (token.type == TOKEN_INTLIT) { expr = parse_literal(lexer); - } else if (token.type == TOKEN_IDENTIFIER) { - Lexer_next(lexer); + } else if (token.type == TOKEN_IDENTIFIER) { + // TODO: Check for global variables when added + Variable *var = find_local_variable(&token); - if (var == NULL) - die_location(token.loc, "Could not find variable `%s`", token.value.as_string); - expr = Node_new(AST_VAR); - expr->variable = var; + if (var != NULL) { + Lexer_next(lexer); + expr = Node_new(AST_VAR); + expr->variable = var; + return expr; + } + + Node *func = find_function_definition(&token); + if (func != NULL) { + return parse_function_call_args(lexer, func); + } + + die_location(token.loc, "Unknown identifier `%s`", token.value.as_string); + expr = NULL; } else { die_location(token.loc, ": Expected token found in parse_factor: `%s`", token_type_to_str(token.type)); exit(1); @@ -401,20 +466,63 @@ Node *parse_block(Lexer *lexer) return block; } +void push_new_function(Node *func) +{ + assert(func->type == AST_FUNC); + assert(function_count < MAX_FUNCTION_COUNT); + all_functions[function_count++] = func; + current_function = func; +} + +void parse_func_args(Lexer *lexer, Node *func) +{ + assert_token(Lexer_next(lexer), TOKEN_OPEN_PAREN); + Token token = Lexer_peek(lexer); + while (token.type != TOKEN_CLOSE_PAREN) { + token = assert_token(Lexer_next(lexer), TOKEN_IDENTIFIER); + // TODO: Check for shadowing with globals + assert_token(Lexer_next(lexer), TOKEN_COLON); + Type type = parse_type(lexer); + + i64 new_count = func->func.num_args + 1; + func->func.args = realloc(func->func.args, sizeof(Variable) * new_count); + Variable *var = &func->func.args[func->func.num_args++]; + var->name = token.value.as_string; + var->type = type; + + token = Lexer_peek(lexer); + if (token.type == TOKEN_COMMA) { + Lexer_next(lexer); + token = Lexer_peek(lexer); + } + } + assert_token(Lexer_next(lexer), TOKEN_CLOSE_PAREN); + + // Set the offsets for the arguments + + // IMPORTANT: We want to skip the saved ret_addr+old_rbp that we + // pushed on the stack. Each of these is 8 bytes. + int offset = -16; + for (int i = 0; i < func->func.num_args; i++) { + Variable *var = &func->func.args[i]; + var->offset = offset; + // TODO: Compute this for different types + int var_size = 8; + offset -= var_size; + } +} + Node *parse_func(Lexer *lexer) { Token token; token = assert_token(Lexer_next(lexer), TOKEN_FN); Node *func = Node_new(AST_FUNC); - current_function = func; + push_new_function(func); token = assert_token(Lexer_next(lexer), TOKEN_IDENTIFIER); - func->func.name = token.value.as_string; - assert_token(Lexer_next(lexer), TOKEN_OPEN_PAREN); - // TODO: Parse parameters - assert_token(Lexer_next(lexer), TOKEN_CLOSE_PAREN); + parse_func_args(lexer, func); token = Lexer_peek(lexer); if (token.type == TOKEN_COLON) { @@ -426,9 +534,11 @@ Node *parse_func(Lexer *lexer) func->func.return_type = (Type){.type = TYPE_NONE}; } - // Make sure there's no funny business with the stack offet + // Make sure there's no funny business with the stack offset assert(cur_stack_offset == 0); + assert(block_stack_count == 0); func->func.body = parse_block(lexer); + assert(block_stack_count == 0); assert(cur_stack_offset == 0); return func; diff --git a/examples/functions.cup b/examples/functions.cup new file mode 100644 index 0000000..89858c6 --- /dev/null +++ b/examples/functions.cup @@ -0,0 +1,13 @@ +fn rec_sum(n: int, accum: int): int { + if (n == 0) + return accum; + return rec_sum(n - 1, accum + n); +} + +fn sum(n: int): int { + return rec_sum(n, 0); +} + +fn main() { + return sum(10); +}
\ No newline at end of file diff --git a/tests/functions.sh b/tests/functions.sh new file mode 100755 index 0000000..3e87b37 --- /dev/null +++ b/tests/functions.sh @@ -0,0 +1,63 @@ +#!/bin/bash + +# Test compound scopes + +. tests/common.sh + +set -e + +echo -n "- Different argument counts: " +assert_exit_status_stdin 5 <<EOF +fn test() { return 5; } +fn main() { return test(); } +EOF + +assert_exit_status_stdin 5 <<EOF +fn test(a: int) { return a; } +fn main() { return test(5); } +EOF + +assert_exit_status_stdin 5 <<EOF +fn test(a: int, b: int) { return a+b; } +fn main() { return test(2, 3); } +EOF + +assert_exit_status_stdin 5 <<EOF +fn test(a: int, b: int, c: int, d: int, e: int) { return a+b+c+d+e; } +fn main() { return test(1,1,1,1,1); } +EOF + +assert_compile_failure_stdin <<EOF +fn test() { return 5; } +fn main() { return test(5); } +EOF + +assert_compile_failure_stdin <<EOF +fn test(a: int, b: int, c: int) { return 5; } +fn main() { return test(5); } +EOF + +assert_compile_failure_stdin <<EOF +fn test(a: int, b: int, c: int) { return 5; } +fn main() { return test(5, 6, 5, 8); } +EOF + +echo " OK" + +echo -n "- Recursion: " +assert_exit_status_stdin 3 <<EOF +fn test(n: int) { return n == 0 ? 0 : 1 + test(n-1); } +fn main() { return test(3); } +EOF + +assert_exit_status_stdin 55 <<EOF +fn test(n: int) { return n == 0 ? 0 : n + test(n-1); } +fn main() { return test(10); } +EOF + +assert_exit_status_stdin 55 <<EOF +fn test(n: int, accum: int) { return n == 0 ? accum : test(n-1, n+accum); } +fn main() { return test(10,0); } +EOF + +echo " OK"
\ No newline at end of file |