diff options
| author | Mustafa Quraish <[email protected]> | 2022-01-30 01:10:35 -0500 |
|---|---|---|
| committer | Mustafa Quraish <[email protected]> | 2022-01-30 01:10:35 -0500 |
| commit | 33e86c66ce739913d453808d3fecd6670a0e9fe1 (patch) | |
| tree | 7bfcb4d7e9bf67dd979d47d8c89a9e56bff96eec | |
| parent | Make the compiler / scripts work on Linux too (yay!) (diff) | |
| download | cup-33e86c66ce739913d453808d3fecd6670a0e9fe1.tar.xz cup-33e86c66ce739913d453808d3fecd6670a0e9fe1.zip | |
Functions, yay!
We now support function calls! We don't have support for forward
declaring functions right now though, so no mutual recursion is
possible.
The arguments are passed via the stack instead of through registers
(unlike the x86_64 calling convention, I think). We'll probably need
some sort of primitives built into the language for syscalls
eventually because of this.
Return types are also not checked, and right now it's possible to have
a function that doesn't return anything even when the caller expects it
to, error checking and reporting definitely needs to be improved.
| -rw-r--r-- | cup/ast.c | 19 | ||||
| -rw-r--r-- | cup/ast.h | 10 | ||||
| -rw-r--r-- | cup/common.h | 2 | ||||
| -rw-r--r-- | cup/generator.c | 37 | ||||
| -rw-r--r-- | cup/parser.c | 134 | ||||
| -rw-r--r-- | examples/functions.cup | 13 | ||||
| -rwxr-xr-x | tests/functions.sh | 63 |
7 files changed, 259 insertions, 19 deletions
@@ -122,6 +122,15 @@ static void do_print_ast(Node *node, int depth) } else { printf("\n"); } + } else if (node->type == AST_FUNCCALL) { + printf("CALL %s(\n", node->call.func->func.name); + for (int i = 0; i < node->call.num_args; i++) { + do_print_ast(node->call.args[i], depth + 1); + } + for (int i = 0; i < depth; i++) { + printf(" "); + } + printf(")\n"); } else { printf("{{ %s }}\n", node_type_to_str(node->type)); } @@ -129,12 +138,20 @@ static void do_print_ast(Node *node, int depth) void dump_func(Node *node, int depth) { - printf("fn %s()", node->func.name); + printf("fn %s(", node->func.name); + for (int i = 0; i < node->func.num_args; i++) { + if (i > 0) printf(", "); + printf("%s: ", node->func.args[i].name); + print_type_to_file(stdout, node->func.args[i].type); + printf("[[%lld]]", node->func.args[i].offset); + } + printf(")"); if (node->func.return_type.type != TYPE_NONE) { // FIXME: Print return type properly printf(" -> "); print_type_to_file(stdout, node->func.return_type); } + do_print_ast(node->func.body, depth + 1); } @@ -24,6 +24,7 @@ F(OP_GEQ, ">=") \ F(OP_ASSIGN, "=") \ F(AST_LITERAL, "literal") \ + F(AST_FUNCCALL, "Function call") \ F(AST_CONDITIONAL, "conditional expression") \ F(AST_IF, "if statement") \ F(AST_WHILE, "while statement") \ @@ -89,7 +90,10 @@ typedef struct ast_node { // TODO: Should we just dynamically allocate space on the // stack for each block instead of storing this? i64 max_locals_size; + // TODO: Arguments / etc? + Variable *args; + int num_args; } func; // Block of statements @@ -134,6 +138,12 @@ typedef struct ast_node { } loop; Variable *variable; + + struct { + Node *func; + Node **args; + int num_args; + } call; }; } Node; diff --git a/cup/common.h b/cup/common.h index 522483f..20d057f 100644 --- a/cup/common.h +++ b/cup/common.h @@ -3,4 +3,4 @@ #include <stdbool.h> #include <stdint.h> -typedef int64_t i64; +typedef long long int i64; diff --git a/cup/generator.c b/cup/generator.c index 61e8c69..6a95d0c 100644 --- a/cup/generator.c +++ b/cup/generator.c @@ -7,9 +7,30 @@ #include <string.h> #include <assert.h> +#include <sys/syscall.h> + static int label_counter = 0; static Node *current_function = NULL; +void generate_expr_into_rax(Node *expr, FILE *out); + +void generate_func_call(Node *node, FILE *out) +{ + assert(node->type == AST_FUNCCALL); + // FIXME: This seems like a big hack + i64 total_size = 0; + for (int i = node->call.num_args - 1; i >= 0; i--) { + Node *arg = node->call.args[i]; + generate_expr_into_rax(arg, out); + fprintf(out, " push rax\n"); + // TODO: Compute this for different types + // TODO: Also make sure of padding and stuff? + total_size += 8; + } + fprintf(out, " call %s\n", node->call.func->func.name); + fprintf(out, " add rsp, %lld\n", total_size); +} + // The evaluated expression is stored into `rax` void generate_expr_into_rax(Node *expr, FILE *out) { @@ -19,9 +40,15 @@ void generate_expr_into_rax(Node *expr, FILE *out) assert(expr->literal.type.type == TYPE_INT); fprintf(out, " mov rax, %d\n", expr->literal.as_int); + } else if (expr->type == AST_FUNCCALL) { + generate_func_call(expr, out); + } else if (expr->type == AST_VAR) { i64 offset = expr->variable->offset; - fprintf(out, " mov rax, [rbp-%lld]\n", offset); + if (offset > 0) + fprintf(out, " mov rax, [rbp-%lld]\n", offset); + else + fprintf(out, " mov rax, [rbp+%lld]\n", -offset); } else if (expr->type == OP_ASSIGN) { i64 offset = expr->assign.var->offset; @@ -212,7 +239,7 @@ void generate_statement(Node *stmt, FILE *out) assert(stmt->conditional.cond); assert(stmt->conditional.do_then); int cur_label = label_counter++; - + generate_expr_into_rax(stmt->conditional.cond, out); // If we don't have an `else` clause, we can simplify if (!stmt->conditional.do_else) { @@ -229,7 +256,7 @@ void generate_statement(Node *stmt, FILE *out) generate_statement(stmt->conditional.do_else, out); fprintf(out, ".if_end_%d:\n", cur_label); } - } else if (stmt->type == AST_WHILE) { + } else if (stmt->type == AST_WHILE) { int cur_label = label_counter++; fprintf(out, ".loop_start_%d:\n", cur_label); fprintf(out, ".loop_continue_%d:\n", cur_label); @@ -240,7 +267,7 @@ void generate_statement(Node *stmt, FILE *out) fprintf(out, " jmp .loop_start_%d\n", cur_label); fprintf(out, ".loop_end_%d:\n", cur_label); - } else if (stmt->type == AST_FOR) { + } else if (stmt->type == AST_FOR) { int cur_label = label_counter++; if (stmt->loop.init) { generate_statement(stmt->loop.init, out); @@ -259,7 +286,7 @@ void generate_statement(Node *stmt, FILE *out) fprintf(out, " jmp .loop_start_%d\n", cur_label); fprintf(out, ".loop_end_%d:\n", cur_label); - } else if (stmt->type == AST_BLOCK) { + } else if (stmt->type == AST_BLOCK) { generate_block(stmt, out); } else { // Once again, default to an expression here... diff --git a/cup/parser.c b/cup/parser.c index 698002f..cfb5d97 100644 --- a/cup/parser.c +++ b/cup/parser.c @@ -4,6 +4,10 @@ #include <string.h> #include <assert.h> +#define MAX_FUNCTION_COUNT 1024 +static Node *all_functions[MAX_FUNCTION_COUNT]; +static i64 function_count = 0; + static Node *current_function = NULL; #define BLOCK_STACK_SIZE 64 @@ -92,6 +96,24 @@ Variable *find_local_variable(Token *token) } } } + Node *func = current_function; + for (int i = 0; i < func->func.num_args; i++) { + if (strcmp(func->func.args[i].name, token->value.as_string) == 0) { + return &func->func.args[i]; + } + } + return NULL; +} + +Node *find_function_definition(Token *token) +{ + assert_token(*token, TOKEN_IDENTIFIER); + for (i64 i = 0; i < function_count; i++) { + Node *function = all_functions[i]; + if (strcmp(function->func.name, token->value.as_string) == 0) { + return function; + } + } return NULL; } @@ -203,6 +225,38 @@ Node *parse_var_declaration(Lexer *lexer) return node; } +Node *parse_function_call_args(Lexer *lexer, Node *func) +{ + Token identifier = assert_token(Lexer_next(lexer), TOKEN_IDENTIFIER); + Node *call = Node_new(AST_FUNCCALL); + call->call.func = func; + assert_token(Lexer_next(lexer), TOKEN_OPEN_PAREN); + Token token = Lexer_peek(lexer); + + while (token.type != TOKEN_CLOSE_PAREN) { + Node *arg = parse_expression(lexer); + + int new_size = call->call.num_args + 1; + call->call.args = realloc(call->call.args, sizeof(Node *) * new_size); + call->call.args[call->call.num_args++] = arg; + + if (new_size > func->func.num_args) + die_location(identifier.loc, "Too many arguments to function `%s`", func->func.name); + + token = Lexer_peek(lexer); + if (token.type == TOKEN_COMMA) { + Lexer_next(lexer); + token = Lexer_peek(lexer); + } + } + + if (call->call.num_args != func->func.num_args) + die_location(identifier.loc, "Too few arguments to function `%s`", func->func.name); + + assert_token(Lexer_next(lexer), TOKEN_CLOSE_PAREN); + return call; +} + Node *parse_factor(Lexer *lexer) { // TODO: Parse more complicated things @@ -226,13 +280,24 @@ Node *parse_factor(Lexer *lexer) assert_token(Lexer_next(lexer), TOKEN_CLOSE_PAREN); } else if (token.type == TOKEN_INTLIT) { expr = parse_literal(lexer); - } else if (token.type == TOKEN_IDENTIFIER) { - Lexer_next(lexer); + } else if (token.type == TOKEN_IDENTIFIER) { + // TODO: Check for global variables when added + Variable *var = find_local_variable(&token); - if (var == NULL) - die_location(token.loc, "Could not find variable `%s`", token.value.as_string); - expr = Node_new(AST_VAR); - expr->variable = var; + if (var != NULL) { + Lexer_next(lexer); + expr = Node_new(AST_VAR); + expr->variable = var; + return expr; + } + + Node *func = find_function_definition(&token); + if (func != NULL) { + return parse_function_call_args(lexer, func); + } + + die_location(token.loc, "Unknown identifier `%s`", token.value.as_string); + expr = NULL; } else { die_location(token.loc, ": Expected token found in parse_factor: `%s`", token_type_to_str(token.type)); exit(1); @@ -401,20 +466,63 @@ Node *parse_block(Lexer *lexer) return block; } +void push_new_function(Node *func) +{ + assert(func->type == AST_FUNC); + assert(function_count < MAX_FUNCTION_COUNT); + all_functions[function_count++] = func; + current_function = func; +} + +void parse_func_args(Lexer *lexer, Node *func) +{ + assert_token(Lexer_next(lexer), TOKEN_OPEN_PAREN); + Token token = Lexer_peek(lexer); + while (token.type != TOKEN_CLOSE_PAREN) { + token = assert_token(Lexer_next(lexer), TOKEN_IDENTIFIER); + // TODO: Check for shadowing with globals + assert_token(Lexer_next(lexer), TOKEN_COLON); + Type type = parse_type(lexer); + + i64 new_count = func->func.num_args + 1; + func->func.args = realloc(func->func.args, sizeof(Variable) * new_count); + Variable *var = &func->func.args[func->func.num_args++]; + var->name = token.value.as_string; + var->type = type; + + token = Lexer_peek(lexer); + if (token.type == TOKEN_COMMA) { + Lexer_next(lexer); + token = Lexer_peek(lexer); + } + } + assert_token(Lexer_next(lexer), TOKEN_CLOSE_PAREN); + + // Set the offsets for the arguments + + // IMPORTANT: We want to skip the saved ret_addr+old_rbp that we + // pushed on the stack. Each of these is 8 bytes. + int offset = -16; + for (int i = 0; i < func->func.num_args; i++) { + Variable *var = &func->func.args[i]; + var->offset = offset; + // TODO: Compute this for different types + int var_size = 8; + offset -= var_size; + } +} + Node *parse_func(Lexer *lexer) { Token token; token = assert_token(Lexer_next(lexer), TOKEN_FN); Node *func = Node_new(AST_FUNC); - current_function = func; + push_new_function(func); token = assert_token(Lexer_next(lexer), TOKEN_IDENTIFIER); - func->func.name = token.value.as_string; - assert_token(Lexer_next(lexer), TOKEN_OPEN_PAREN); - // TODO: Parse parameters - assert_token(Lexer_next(lexer), TOKEN_CLOSE_PAREN); + parse_func_args(lexer, func); token = Lexer_peek(lexer); if (token.type == TOKEN_COLON) { @@ -426,9 +534,11 @@ Node *parse_func(Lexer *lexer) func->func.return_type = (Type){.type = TYPE_NONE}; } - // Make sure there's no funny business with the stack offet + // Make sure there's no funny business with the stack offset assert(cur_stack_offset == 0); + assert(block_stack_count == 0); func->func.body = parse_block(lexer); + assert(block_stack_count == 0); assert(cur_stack_offset == 0); return func; diff --git a/examples/functions.cup b/examples/functions.cup new file mode 100644 index 0000000..89858c6 --- /dev/null +++ b/examples/functions.cup @@ -0,0 +1,13 @@ +fn rec_sum(n: int, accum: int): int { + if (n == 0) + return accum; + return rec_sum(n - 1, accum + n); +} + +fn sum(n: int): int { + return rec_sum(n, 0); +} + +fn main() { + return sum(10); +}
\ No newline at end of file diff --git a/tests/functions.sh b/tests/functions.sh new file mode 100755 index 0000000..3e87b37 --- /dev/null +++ b/tests/functions.sh @@ -0,0 +1,63 @@ +#!/bin/bash + +# Test compound scopes + +. tests/common.sh + +set -e + +echo -n "- Different argument counts: " +assert_exit_status_stdin 5 <<EOF +fn test() { return 5; } +fn main() { return test(); } +EOF + +assert_exit_status_stdin 5 <<EOF +fn test(a: int) { return a; } +fn main() { return test(5); } +EOF + +assert_exit_status_stdin 5 <<EOF +fn test(a: int, b: int) { return a+b; } +fn main() { return test(2, 3); } +EOF + +assert_exit_status_stdin 5 <<EOF +fn test(a: int, b: int, c: int, d: int, e: int) { return a+b+c+d+e; } +fn main() { return test(1,1,1,1,1); } +EOF + +assert_compile_failure_stdin <<EOF +fn test() { return 5; } +fn main() { return test(5); } +EOF + +assert_compile_failure_stdin <<EOF +fn test(a: int, b: int, c: int) { return 5; } +fn main() { return test(5); } +EOF + +assert_compile_failure_stdin <<EOF +fn test(a: int, b: int, c: int) { return 5; } +fn main() { return test(5, 6, 5, 8); } +EOF + +echo " OK" + +echo -n "- Recursion: " +assert_exit_status_stdin 3 <<EOF +fn test(n: int) { return n == 0 ? 0 : 1 + test(n-1); } +fn main() { return test(3); } +EOF + +assert_exit_status_stdin 55 <<EOF +fn test(n: int) { return n == 0 ? 0 : n + test(n-1); } +fn main() { return test(10); } +EOF + +assert_exit_status_stdin 55 <<EOF +fn test(n: int, accum: int) { return n == 0 ? accum : test(n-1, n+accum); } +fn main() { return test(10,0); } +EOF + +echo " OK"
\ No newline at end of file |