aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--cup/ast.c19
-rw-r--r--cup/ast.h10
-rw-r--r--cup/common.h2
-rw-r--r--cup/generator.c37
-rw-r--r--cup/parser.c134
-rw-r--r--examples/functions.cup13
-rwxr-xr-xtests/functions.sh63
7 files changed, 259 insertions, 19 deletions
diff --git a/cup/ast.c b/cup/ast.c
index 76ace8a..6e33de1 100644
--- a/cup/ast.c
+++ b/cup/ast.c
@@ -122,6 +122,15 @@ static void do_print_ast(Node *node, int depth)
} else {
printf("\n");
}
+ } else if (node->type == AST_FUNCCALL) {
+ printf("CALL %s(\n", node->call.func->func.name);
+ for (int i = 0; i < node->call.num_args; i++) {
+ do_print_ast(node->call.args[i], depth + 1);
+ }
+ for (int i = 0; i < depth; i++) {
+ printf(" ");
+ }
+ printf(")\n");
} else {
printf("{{ %s }}\n", node_type_to_str(node->type));
}
@@ -129,12 +138,20 @@ static void do_print_ast(Node *node, int depth)
void dump_func(Node *node, int depth)
{
- printf("fn %s()", node->func.name);
+ printf("fn %s(", node->func.name);
+ for (int i = 0; i < node->func.num_args; i++) {
+ if (i > 0) printf(", ");
+ printf("%s: ", node->func.args[i].name);
+ print_type_to_file(stdout, node->func.args[i].type);
+ printf("[[%lld]]", node->func.args[i].offset);
+ }
+ printf(")");
if (node->func.return_type.type != TYPE_NONE) {
// FIXME: Print return type properly
printf(" -> ");
print_type_to_file(stdout, node->func.return_type);
}
+
do_print_ast(node->func.body, depth + 1);
}
diff --git a/cup/ast.h b/cup/ast.h
index 22bedcc..22a7283 100644
--- a/cup/ast.h
+++ b/cup/ast.h
@@ -24,6 +24,7 @@
F(OP_GEQ, ">=") \
F(OP_ASSIGN, "=") \
F(AST_LITERAL, "literal") \
+ F(AST_FUNCCALL, "Function call") \
F(AST_CONDITIONAL, "conditional expression") \
F(AST_IF, "if statement") \
F(AST_WHILE, "while statement") \
@@ -89,7 +90,10 @@ typedef struct ast_node {
// TODO: Should we just dynamically allocate space on the
// stack for each block instead of storing this?
i64 max_locals_size;
+
// TODO: Arguments / etc?
+ Variable *args;
+ int num_args;
} func;
// Block of statements
@@ -134,6 +138,12 @@ typedef struct ast_node {
} loop;
Variable *variable;
+
+ struct {
+ Node *func;
+ Node **args;
+ int num_args;
+ } call;
};
} Node;
diff --git a/cup/common.h b/cup/common.h
index 522483f..20d057f 100644
--- a/cup/common.h
+++ b/cup/common.h
@@ -3,4 +3,4 @@
#include <stdbool.h>
#include <stdint.h>
-typedef int64_t i64;
+typedef long long int i64;
diff --git a/cup/generator.c b/cup/generator.c
index 61e8c69..6a95d0c 100644
--- a/cup/generator.c
+++ b/cup/generator.c
@@ -7,9 +7,30 @@
#include <string.h>
#include <assert.h>
+#include <sys/syscall.h>
+
static int label_counter = 0;
static Node *current_function = NULL;
+void generate_expr_into_rax(Node *expr, FILE *out);
+
+void generate_func_call(Node *node, FILE *out)
+{
+ assert(node->type == AST_FUNCCALL);
+ // FIXME: This seems like a big hack
+ i64 total_size = 0;
+ for (int i = node->call.num_args - 1; i >= 0; i--) {
+ Node *arg = node->call.args[i];
+ generate_expr_into_rax(arg, out);
+ fprintf(out, " push rax\n");
+ // TODO: Compute this for different types
+ // TODO: Also make sure of padding and stuff?
+ total_size += 8;
+ }
+ fprintf(out, " call %s\n", node->call.func->func.name);
+ fprintf(out, " add rsp, %lld\n", total_size);
+}
+
// The evaluated expression is stored into `rax`
void generate_expr_into_rax(Node *expr, FILE *out)
{
@@ -19,9 +40,15 @@ void generate_expr_into_rax(Node *expr, FILE *out)
assert(expr->literal.type.type == TYPE_INT);
fprintf(out, " mov rax, %d\n", expr->literal.as_int);
+ } else if (expr->type == AST_FUNCCALL) {
+ generate_func_call(expr, out);
+
} else if (expr->type == AST_VAR) {
i64 offset = expr->variable->offset;
- fprintf(out, " mov rax, [rbp-%lld]\n", offset);
+ if (offset > 0)
+ fprintf(out, " mov rax, [rbp-%lld]\n", offset);
+ else
+ fprintf(out, " mov rax, [rbp+%lld]\n", -offset);
} else if (expr->type == OP_ASSIGN) {
i64 offset = expr->assign.var->offset;
@@ -212,7 +239,7 @@ void generate_statement(Node *stmt, FILE *out)
assert(stmt->conditional.cond);
assert(stmt->conditional.do_then);
int cur_label = label_counter++;
-
+
generate_expr_into_rax(stmt->conditional.cond, out);
// If we don't have an `else` clause, we can simplify
if (!stmt->conditional.do_else) {
@@ -229,7 +256,7 @@ void generate_statement(Node *stmt, FILE *out)
generate_statement(stmt->conditional.do_else, out);
fprintf(out, ".if_end_%d:\n", cur_label);
}
- } else if (stmt->type == AST_WHILE) {
+ } else if (stmt->type == AST_WHILE) {
int cur_label = label_counter++;
fprintf(out, ".loop_start_%d:\n", cur_label);
fprintf(out, ".loop_continue_%d:\n", cur_label);
@@ -240,7 +267,7 @@ void generate_statement(Node *stmt, FILE *out)
fprintf(out, " jmp .loop_start_%d\n", cur_label);
fprintf(out, ".loop_end_%d:\n", cur_label);
- } else if (stmt->type == AST_FOR) {
+ } else if (stmt->type == AST_FOR) {
int cur_label = label_counter++;
if (stmt->loop.init) {
generate_statement(stmt->loop.init, out);
@@ -259,7 +286,7 @@ void generate_statement(Node *stmt, FILE *out)
fprintf(out, " jmp .loop_start_%d\n", cur_label);
fprintf(out, ".loop_end_%d:\n", cur_label);
- } else if (stmt->type == AST_BLOCK) {
+ } else if (stmt->type == AST_BLOCK) {
generate_block(stmt, out);
} else {
// Once again, default to an expression here...
diff --git a/cup/parser.c b/cup/parser.c
index 698002f..cfb5d97 100644
--- a/cup/parser.c
+++ b/cup/parser.c
@@ -4,6 +4,10 @@
#include <string.h>
#include <assert.h>
+#define MAX_FUNCTION_COUNT 1024
+static Node *all_functions[MAX_FUNCTION_COUNT];
+static i64 function_count = 0;
+
static Node *current_function = NULL;
#define BLOCK_STACK_SIZE 64
@@ -92,6 +96,24 @@ Variable *find_local_variable(Token *token)
}
}
}
+ Node *func = current_function;
+ for (int i = 0; i < func->func.num_args; i++) {
+ if (strcmp(func->func.args[i].name, token->value.as_string) == 0) {
+ return &func->func.args[i];
+ }
+ }
+ return NULL;
+}
+
+Node *find_function_definition(Token *token)
+{
+ assert_token(*token, TOKEN_IDENTIFIER);
+ for (i64 i = 0; i < function_count; i++) {
+ Node *function = all_functions[i];
+ if (strcmp(function->func.name, token->value.as_string) == 0) {
+ return function;
+ }
+ }
return NULL;
}
@@ -203,6 +225,38 @@ Node *parse_var_declaration(Lexer *lexer)
return node;
}
+Node *parse_function_call_args(Lexer *lexer, Node *func)
+{
+ Token identifier = assert_token(Lexer_next(lexer), TOKEN_IDENTIFIER);
+ Node *call = Node_new(AST_FUNCCALL);
+ call->call.func = func;
+ assert_token(Lexer_next(lexer), TOKEN_OPEN_PAREN);
+ Token token = Lexer_peek(lexer);
+
+ while (token.type != TOKEN_CLOSE_PAREN) {
+ Node *arg = parse_expression(lexer);
+
+ int new_size = call->call.num_args + 1;
+ call->call.args = realloc(call->call.args, sizeof(Node *) * new_size);
+ call->call.args[call->call.num_args++] = arg;
+
+ if (new_size > func->func.num_args)
+ die_location(identifier.loc, "Too many arguments to function `%s`", func->func.name);
+
+ token = Lexer_peek(lexer);
+ if (token.type == TOKEN_COMMA) {
+ Lexer_next(lexer);
+ token = Lexer_peek(lexer);
+ }
+ }
+
+ if (call->call.num_args != func->func.num_args)
+ die_location(identifier.loc, "Too few arguments to function `%s`", func->func.name);
+
+ assert_token(Lexer_next(lexer), TOKEN_CLOSE_PAREN);
+ return call;
+}
+
Node *parse_factor(Lexer *lexer)
{
// TODO: Parse more complicated things
@@ -226,13 +280,24 @@ Node *parse_factor(Lexer *lexer)
assert_token(Lexer_next(lexer), TOKEN_CLOSE_PAREN);
} else if (token.type == TOKEN_INTLIT) {
expr = parse_literal(lexer);
- } else if (token.type == TOKEN_IDENTIFIER) {
- Lexer_next(lexer);
+ } else if (token.type == TOKEN_IDENTIFIER) {
+ // TODO: Check for global variables when added
+
Variable *var = find_local_variable(&token);
- if (var == NULL)
- die_location(token.loc, "Could not find variable `%s`", token.value.as_string);
- expr = Node_new(AST_VAR);
- expr->variable = var;
+ if (var != NULL) {
+ Lexer_next(lexer);
+ expr = Node_new(AST_VAR);
+ expr->variable = var;
+ return expr;
+ }
+
+ Node *func = find_function_definition(&token);
+ if (func != NULL) {
+ return parse_function_call_args(lexer, func);
+ }
+
+ die_location(token.loc, "Unknown identifier `%s`", token.value.as_string);
+ expr = NULL;
} else {
die_location(token.loc, ": Expected token found in parse_factor: `%s`", token_type_to_str(token.type));
exit(1);
@@ -401,20 +466,63 @@ Node *parse_block(Lexer *lexer)
return block;
}
+void push_new_function(Node *func)
+{
+ assert(func->type == AST_FUNC);
+ assert(function_count < MAX_FUNCTION_COUNT);
+ all_functions[function_count++] = func;
+ current_function = func;
+}
+
+void parse_func_args(Lexer *lexer, Node *func)
+{
+ assert_token(Lexer_next(lexer), TOKEN_OPEN_PAREN);
+ Token token = Lexer_peek(lexer);
+ while (token.type != TOKEN_CLOSE_PAREN) {
+ token = assert_token(Lexer_next(lexer), TOKEN_IDENTIFIER);
+ // TODO: Check for shadowing with globals
+ assert_token(Lexer_next(lexer), TOKEN_COLON);
+ Type type = parse_type(lexer);
+
+ i64 new_count = func->func.num_args + 1;
+ func->func.args = realloc(func->func.args, sizeof(Variable) * new_count);
+ Variable *var = &func->func.args[func->func.num_args++];
+ var->name = token.value.as_string;
+ var->type = type;
+
+ token = Lexer_peek(lexer);
+ if (token.type == TOKEN_COMMA) {
+ Lexer_next(lexer);
+ token = Lexer_peek(lexer);
+ }
+ }
+ assert_token(Lexer_next(lexer), TOKEN_CLOSE_PAREN);
+
+ // Set the offsets for the arguments
+
+ // IMPORTANT: We want to skip the saved ret_addr+old_rbp that we
+ // pushed on the stack. Each of these is 8 bytes.
+ int offset = -16;
+ for (int i = 0; i < func->func.num_args; i++) {
+ Variable *var = &func->func.args[i];
+ var->offset = offset;
+ // TODO: Compute this for different types
+ int var_size = 8;
+ offset -= var_size;
+ }
+}
+
Node *parse_func(Lexer *lexer)
{
Token token;
token = assert_token(Lexer_next(lexer), TOKEN_FN);
Node *func = Node_new(AST_FUNC);
- current_function = func;
+ push_new_function(func);
token = assert_token(Lexer_next(lexer), TOKEN_IDENTIFIER);
-
func->func.name = token.value.as_string;
- assert_token(Lexer_next(lexer), TOKEN_OPEN_PAREN);
- // TODO: Parse parameters
- assert_token(Lexer_next(lexer), TOKEN_CLOSE_PAREN);
+ parse_func_args(lexer, func);
token = Lexer_peek(lexer);
if (token.type == TOKEN_COLON) {
@@ -426,9 +534,11 @@ Node *parse_func(Lexer *lexer)
func->func.return_type = (Type){.type = TYPE_NONE};
}
- // Make sure there's no funny business with the stack offet
+ // Make sure there's no funny business with the stack offset
assert(cur_stack_offset == 0);
+ assert(block_stack_count == 0);
func->func.body = parse_block(lexer);
+ assert(block_stack_count == 0);
assert(cur_stack_offset == 0);
return func;
diff --git a/examples/functions.cup b/examples/functions.cup
new file mode 100644
index 0000000..89858c6
--- /dev/null
+++ b/examples/functions.cup
@@ -0,0 +1,13 @@
+fn rec_sum(n: int, accum: int): int {
+ if (n == 0)
+ return accum;
+ return rec_sum(n - 1, accum + n);
+}
+
+fn sum(n: int): int {
+ return rec_sum(n, 0);
+}
+
+fn main() {
+ return sum(10);
+} \ No newline at end of file
diff --git a/tests/functions.sh b/tests/functions.sh
new file mode 100755
index 0000000..3e87b37
--- /dev/null
+++ b/tests/functions.sh
@@ -0,0 +1,63 @@
+#!/bin/bash
+
+# Test compound scopes
+
+. tests/common.sh
+
+set -e
+
+echo -n "- Different argument counts: "
+assert_exit_status_stdin 5 <<EOF
+fn test() { return 5; }
+fn main() { return test(); }
+EOF
+
+assert_exit_status_stdin 5 <<EOF
+fn test(a: int) { return a; }
+fn main() { return test(5); }
+EOF
+
+assert_exit_status_stdin 5 <<EOF
+fn test(a: int, b: int) { return a+b; }
+fn main() { return test(2, 3); }
+EOF
+
+assert_exit_status_stdin 5 <<EOF
+fn test(a: int, b: int, c: int, d: int, e: int) { return a+b+c+d+e; }
+fn main() { return test(1,1,1,1,1); }
+EOF
+
+assert_compile_failure_stdin <<EOF
+fn test() { return 5; }
+fn main() { return test(5); }
+EOF
+
+assert_compile_failure_stdin <<EOF
+fn test(a: int, b: int, c: int) { return 5; }
+fn main() { return test(5); }
+EOF
+
+assert_compile_failure_stdin <<EOF
+fn test(a: int, b: int, c: int) { return 5; }
+fn main() { return test(5, 6, 5, 8); }
+EOF
+
+echo " OK"
+
+echo -n "- Recursion: "
+assert_exit_status_stdin 3 <<EOF
+fn test(n: int) { return n == 0 ? 0 : 1 + test(n-1); }
+fn main() { return test(3); }
+EOF
+
+assert_exit_status_stdin 55 <<EOF
+fn test(n: int) { return n == 0 ? 0 : n + test(n-1); }
+fn main() { return test(10); }
+EOF
+
+assert_exit_status_stdin 55 <<EOF
+fn test(n: int, accum: int) { return n == 0 ? accum : test(n-1, n+accum); }
+fn main() { return test(10,0); }
+EOF
+
+echo " OK" \ No newline at end of file