aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMustafa Quraish <[email protected]>2022-01-30 01:10:35 -0500
committerMustafa Quraish <[email protected]>2022-01-30 01:10:35 -0500
commit33e86c66ce739913d453808d3fecd6670a0e9fe1 (patch)
tree7bfcb4d7e9bf67dd979d47d8c89a9e56bff96eec
parentMake the compiler / scripts work on Linux too (yay!) (diff)
downloadcup-33e86c66ce739913d453808d3fecd6670a0e9fe1.tar.xz
cup-33e86c66ce739913d453808d3fecd6670a0e9fe1.zip
Functions, yay!
We now support function calls! We don't have support for forward declaring functions right now though, so no mutual recursion is possible. The arguments are passed via the stack instead of through registers (unlike the x86_64 calling convention, I think). We'll probably need some sort of primitives built into the language for syscalls eventually because of this. Return types are also not checked, and right now it's possible to have a function that doesn't return anything even when the caller expects it to, error checking and reporting definitely needs to be improved.
-rw-r--r--cup/ast.c19
-rw-r--r--cup/ast.h10
-rw-r--r--cup/common.h2
-rw-r--r--cup/generator.c37
-rw-r--r--cup/parser.c134
-rw-r--r--examples/functions.cup13
-rwxr-xr-xtests/functions.sh63
7 files changed, 259 insertions, 19 deletions
diff --git a/cup/ast.c b/cup/ast.c
index 76ace8a..6e33de1 100644
--- a/cup/ast.c
+++ b/cup/ast.c
@@ -122,6 +122,15 @@ static void do_print_ast(Node *node, int depth)
} else {
printf("\n");
}
+ } else if (node->type == AST_FUNCCALL) {
+ printf("CALL %s(\n", node->call.func->func.name);
+ for (int i = 0; i < node->call.num_args; i++) {
+ do_print_ast(node->call.args[i], depth + 1);
+ }
+ for (int i = 0; i < depth; i++) {
+ printf(" ");
+ }
+ printf(")\n");
} else {
printf("{{ %s }}\n", node_type_to_str(node->type));
}
@@ -129,12 +138,20 @@ static void do_print_ast(Node *node, int depth)
void dump_func(Node *node, int depth)
{
- printf("fn %s()", node->func.name);
+ printf("fn %s(", node->func.name);
+ for (int i = 0; i < node->func.num_args; i++) {
+ if (i > 0) printf(", ");
+ printf("%s: ", node->func.args[i].name);
+ print_type_to_file(stdout, node->func.args[i].type);
+ printf("[[%lld]]", node->func.args[i].offset);
+ }
+ printf(")");
if (node->func.return_type.type != TYPE_NONE) {
// FIXME: Print return type properly
printf(" -> ");
print_type_to_file(stdout, node->func.return_type);
}
+
do_print_ast(node->func.body, depth + 1);
}
diff --git a/cup/ast.h b/cup/ast.h
index 22bedcc..22a7283 100644
--- a/cup/ast.h
+++ b/cup/ast.h
@@ -24,6 +24,7 @@
F(OP_GEQ, ">=") \
F(OP_ASSIGN, "=") \
F(AST_LITERAL, "literal") \
+ F(AST_FUNCCALL, "Function call") \
F(AST_CONDITIONAL, "conditional expression") \
F(AST_IF, "if statement") \
F(AST_WHILE, "while statement") \
@@ -89,7 +90,10 @@ typedef struct ast_node {
// TODO: Should we just dynamically allocate space on the
// stack for each block instead of storing this?
i64 max_locals_size;
+
// TODO: Arguments / etc?
+ Variable *args;
+ int num_args;
} func;
// Block of statements
@@ -134,6 +138,12 @@ typedef struct ast_node {
} loop;
Variable *variable;
+
+ struct {
+ Node *func;
+ Node **args;
+ int num_args;
+ } call;
};
} Node;
diff --git a/cup/common.h b/cup/common.h
index 522483f..20d057f 100644
--- a/cup/common.h
+++ b/cup/common.h
@@ -3,4 +3,4 @@
#include <stdbool.h>
#include <stdint.h>
-typedef int64_t i64;
+typedef long long int i64;
diff --git a/cup/generator.c b/cup/generator.c
index 61e8c69..6a95d0c 100644
--- a/cup/generator.c
+++ b/cup/generator.c
@@ -7,9 +7,30 @@
#include <string.h>
#include <assert.h>
+#include <sys/syscall.h>
+
static int label_counter = 0;
static Node *current_function = NULL;
+void generate_expr_into_rax(Node *expr, FILE *out);
+
+void generate_func_call(Node *node, FILE *out)
+{
+ assert(node->type == AST_FUNCCALL);
+ // FIXME: This seems like a big hack
+ i64 total_size = 0;
+ for (int i = node->call.num_args - 1; i >= 0; i--) {
+ Node *arg = node->call.args[i];
+ generate_expr_into_rax(arg, out);
+ fprintf(out, " push rax\n");
+ // TODO: Compute this for different types
+ // TODO: Also make sure of padding and stuff?
+ total_size += 8;
+ }
+ fprintf(out, " call %s\n", node->call.func->func.name);
+ fprintf(out, " add rsp, %lld\n", total_size);
+}
+
// The evaluated expression is stored into `rax`
void generate_expr_into_rax(Node *expr, FILE *out)
{
@@ -19,9 +40,15 @@ void generate_expr_into_rax(Node *expr, FILE *out)
assert(expr->literal.type.type == TYPE_INT);
fprintf(out, " mov rax, %d\n", expr->literal.as_int);
+ } else if (expr->type == AST_FUNCCALL) {
+ generate_func_call(expr, out);
+
} else if (expr->type == AST_VAR) {
i64 offset = expr->variable->offset;
- fprintf(out, " mov rax, [rbp-%lld]\n", offset);
+ if (offset > 0)
+ fprintf(out, " mov rax, [rbp-%lld]\n", offset);
+ else
+ fprintf(out, " mov rax, [rbp+%lld]\n", -offset);
} else if (expr->type == OP_ASSIGN) {
i64 offset = expr->assign.var->offset;
@@ -212,7 +239,7 @@ void generate_statement(Node *stmt, FILE *out)
assert(stmt->conditional.cond);
assert(stmt->conditional.do_then);
int cur_label = label_counter++;
-
+
generate_expr_into_rax(stmt->conditional.cond, out);
// If we don't have an `else` clause, we can simplify
if (!stmt->conditional.do_else) {
@@ -229,7 +256,7 @@ void generate_statement(Node *stmt, FILE *out)
generate_statement(stmt->conditional.do_else, out);
fprintf(out, ".if_end_%d:\n", cur_label);
}
- } else if (stmt->type == AST_WHILE) {
+ } else if (stmt->type == AST_WHILE) {
int cur_label = label_counter++;
fprintf(out, ".loop_start_%d:\n", cur_label);
fprintf(out, ".loop_continue_%d:\n", cur_label);
@@ -240,7 +267,7 @@ void generate_statement(Node *stmt, FILE *out)
fprintf(out, " jmp .loop_start_%d\n", cur_label);
fprintf(out, ".loop_end_%d:\n", cur_label);
- } else if (stmt->type == AST_FOR) {
+ } else if (stmt->type == AST_FOR) {
int cur_label = label_counter++;
if (stmt->loop.init) {
generate_statement(stmt->loop.init, out);
@@ -259,7 +286,7 @@ void generate_statement(Node *stmt, FILE *out)
fprintf(out, " jmp .loop_start_%d\n", cur_label);
fprintf(out, ".loop_end_%d:\n", cur_label);
- } else if (stmt->type == AST_BLOCK) {
+ } else if (stmt->type == AST_BLOCK) {
generate_block(stmt, out);
} else {
// Once again, default to an expression here...
diff --git a/cup/parser.c b/cup/parser.c
index 698002f..cfb5d97 100644
--- a/cup/parser.c
+++ b/cup/parser.c
@@ -4,6 +4,10 @@
#include <string.h>
#include <assert.h>
+#define MAX_FUNCTION_COUNT 1024
+static Node *all_functions[MAX_FUNCTION_COUNT];
+static i64 function_count = 0;
+
static Node *current_function = NULL;
#define BLOCK_STACK_SIZE 64
@@ -92,6 +96,24 @@ Variable *find_local_variable(Token *token)
}
}
}
+ Node *func = current_function;
+ for (int i = 0; i < func->func.num_args; i++) {
+ if (strcmp(func->func.args[i].name, token->value.as_string) == 0) {
+ return &func->func.args[i];
+ }
+ }
+ return NULL;
+}
+
+Node *find_function_definition(Token *token)
+{
+ assert_token(*token, TOKEN_IDENTIFIER);
+ for (i64 i = 0; i < function_count; i++) {
+ Node *function = all_functions[i];
+ if (strcmp(function->func.name, token->value.as_string) == 0) {
+ return function;
+ }
+ }
return NULL;
}
@@ -203,6 +225,38 @@ Node *parse_var_declaration(Lexer *lexer)
return node;
}
+Node *parse_function_call_args(Lexer *lexer, Node *func)
+{
+ Token identifier = assert_token(Lexer_next(lexer), TOKEN_IDENTIFIER);
+ Node *call = Node_new(AST_FUNCCALL);
+ call->call.func = func;
+ assert_token(Lexer_next(lexer), TOKEN_OPEN_PAREN);
+ Token token = Lexer_peek(lexer);
+
+ while (token.type != TOKEN_CLOSE_PAREN) {
+ Node *arg = parse_expression(lexer);
+
+ int new_size = call->call.num_args + 1;
+ call->call.args = realloc(call->call.args, sizeof(Node *) * new_size);
+ call->call.args[call->call.num_args++] = arg;
+
+ if (new_size > func->func.num_args)
+ die_location(identifier.loc, "Too many arguments to function `%s`", func->func.name);
+
+ token = Lexer_peek(lexer);
+ if (token.type == TOKEN_COMMA) {
+ Lexer_next(lexer);
+ token = Lexer_peek(lexer);
+ }
+ }
+
+ if (call->call.num_args != func->func.num_args)
+ die_location(identifier.loc, "Too few arguments to function `%s`", func->func.name);
+
+ assert_token(Lexer_next(lexer), TOKEN_CLOSE_PAREN);
+ return call;
+}
+
Node *parse_factor(Lexer *lexer)
{
// TODO: Parse more complicated things
@@ -226,13 +280,24 @@ Node *parse_factor(Lexer *lexer)
assert_token(Lexer_next(lexer), TOKEN_CLOSE_PAREN);
} else if (token.type == TOKEN_INTLIT) {
expr = parse_literal(lexer);
- } else if (token.type == TOKEN_IDENTIFIER) {
- Lexer_next(lexer);
+ } else if (token.type == TOKEN_IDENTIFIER) {
+ // TODO: Check for global variables when added
+
Variable *var = find_local_variable(&token);
- if (var == NULL)
- die_location(token.loc, "Could not find variable `%s`", token.value.as_string);
- expr = Node_new(AST_VAR);
- expr->variable = var;
+ if (var != NULL) {
+ Lexer_next(lexer);
+ expr = Node_new(AST_VAR);
+ expr->variable = var;
+ return expr;
+ }
+
+ Node *func = find_function_definition(&token);
+ if (func != NULL) {
+ return parse_function_call_args(lexer, func);
+ }
+
+ die_location(token.loc, "Unknown identifier `%s`", token.value.as_string);
+ expr = NULL;
} else {
die_location(token.loc, ": Expected token found in parse_factor: `%s`", token_type_to_str(token.type));
exit(1);
@@ -401,20 +466,63 @@ Node *parse_block(Lexer *lexer)
return block;
}
+void push_new_function(Node *func)
+{
+ assert(func->type == AST_FUNC);
+ assert(function_count < MAX_FUNCTION_COUNT);
+ all_functions[function_count++] = func;
+ current_function = func;
+}
+
+void parse_func_args(Lexer *lexer, Node *func)
+{
+ assert_token(Lexer_next(lexer), TOKEN_OPEN_PAREN);
+ Token token = Lexer_peek(lexer);
+ while (token.type != TOKEN_CLOSE_PAREN) {
+ token = assert_token(Lexer_next(lexer), TOKEN_IDENTIFIER);
+ // TODO: Check for shadowing with globals
+ assert_token(Lexer_next(lexer), TOKEN_COLON);
+ Type type = parse_type(lexer);
+
+ i64 new_count = func->func.num_args + 1;
+ func->func.args = realloc(func->func.args, sizeof(Variable) * new_count);
+ Variable *var = &func->func.args[func->func.num_args++];
+ var->name = token.value.as_string;
+ var->type = type;
+
+ token = Lexer_peek(lexer);
+ if (token.type == TOKEN_COMMA) {
+ Lexer_next(lexer);
+ token = Lexer_peek(lexer);
+ }
+ }
+ assert_token(Lexer_next(lexer), TOKEN_CLOSE_PAREN);
+
+ // Set the offsets for the arguments
+
+ // IMPORTANT: We want to skip the saved ret_addr+old_rbp that we
+ // pushed on the stack. Each of these is 8 bytes.
+ int offset = -16;
+ for (int i = 0; i < func->func.num_args; i++) {
+ Variable *var = &func->func.args[i];
+ var->offset = offset;
+ // TODO: Compute this for different types
+ int var_size = 8;
+ offset -= var_size;
+ }
+}
+
Node *parse_func(Lexer *lexer)
{
Token token;
token = assert_token(Lexer_next(lexer), TOKEN_FN);
Node *func = Node_new(AST_FUNC);
- current_function = func;
+ push_new_function(func);
token = assert_token(Lexer_next(lexer), TOKEN_IDENTIFIER);
-
func->func.name = token.value.as_string;
- assert_token(Lexer_next(lexer), TOKEN_OPEN_PAREN);
- // TODO: Parse parameters
- assert_token(Lexer_next(lexer), TOKEN_CLOSE_PAREN);
+ parse_func_args(lexer, func);
token = Lexer_peek(lexer);
if (token.type == TOKEN_COLON) {
@@ -426,9 +534,11 @@ Node *parse_func(Lexer *lexer)
func->func.return_type = (Type){.type = TYPE_NONE};
}
- // Make sure there's no funny business with the stack offet
+ // Make sure there's no funny business with the stack offset
assert(cur_stack_offset == 0);
+ assert(block_stack_count == 0);
func->func.body = parse_block(lexer);
+ assert(block_stack_count == 0);
assert(cur_stack_offset == 0);
return func;
diff --git a/examples/functions.cup b/examples/functions.cup
new file mode 100644
index 0000000..89858c6
--- /dev/null
+++ b/examples/functions.cup
@@ -0,0 +1,13 @@
+fn rec_sum(n: int, accum: int): int {
+ if (n == 0)
+ return accum;
+ return rec_sum(n - 1, accum + n);
+}
+
+fn sum(n: int): int {
+ return rec_sum(n, 0);
+}
+
+fn main() {
+ return sum(10);
+} \ No newline at end of file
diff --git a/tests/functions.sh b/tests/functions.sh
new file mode 100755
index 0000000..3e87b37
--- /dev/null
+++ b/tests/functions.sh
@@ -0,0 +1,63 @@
+#!/bin/bash
+
+# Test compound scopes
+
+. tests/common.sh
+
+set -e
+
+echo -n "- Different argument counts: "
+assert_exit_status_stdin 5 <<EOF
+fn test() { return 5; }
+fn main() { return test(); }
+EOF
+
+assert_exit_status_stdin 5 <<EOF
+fn test(a: int) { return a; }
+fn main() { return test(5); }
+EOF
+
+assert_exit_status_stdin 5 <<EOF
+fn test(a: int, b: int) { return a+b; }
+fn main() { return test(2, 3); }
+EOF
+
+assert_exit_status_stdin 5 <<EOF
+fn test(a: int, b: int, c: int, d: int, e: int) { return a+b+c+d+e; }
+fn main() { return test(1,1,1,1,1); }
+EOF
+
+assert_compile_failure_stdin <<EOF
+fn test() { return 5; }
+fn main() { return test(5); }
+EOF
+
+assert_compile_failure_stdin <<EOF
+fn test(a: int, b: int, c: int) { return 5; }
+fn main() { return test(5); }
+EOF
+
+assert_compile_failure_stdin <<EOF
+fn test(a: int, b: int, c: int) { return 5; }
+fn main() { return test(5, 6, 5, 8); }
+EOF
+
+echo " OK"
+
+echo -n "- Recursion: "
+assert_exit_status_stdin 3 <<EOF
+fn test(n: int) { return n == 0 ? 0 : 1 + test(n-1); }
+fn main() { return test(3); }
+EOF
+
+assert_exit_status_stdin 55 <<EOF
+fn test(n: int) { return n == 0 ? 0 : n + test(n-1); }
+fn main() { return test(10); }
+EOF
+
+assert_exit_status_stdin 55 <<EOF
+fn test(n: int, accum: int) { return n == 0 ? accum : test(n-1, n+accum); }
+fn main() { return test(10,0); }
+EOF
+
+echo " OK" \ No newline at end of file