aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--cup/ast.h10
-rw-r--r--cup/generator.c46
-rw-r--r--cup/lexer.c17
-rw-r--r--cup/parser.c29
-rw-r--r--cup/tokens.h2
-rw-r--r--examples/loops.cup22
-rwxr-xr-xtests/loops.sh118
7 files changed, 230 insertions, 14 deletions
diff --git a/cup/ast.h b/cup/ast.h
index 46557da..22bedcc 100644
--- a/cup/ast.h
+++ b/cup/ast.h
@@ -26,6 +26,8 @@
F(AST_LITERAL, "literal") \
F(AST_CONDITIONAL, "conditional expression") \
F(AST_IF, "if statement") \
+ F(AST_WHILE, "while statement") \
+ F(AST_FOR, "for statement") \
F(AST_VARDECL, "variable decl") \
F(AST_VAR, "variable") \
F(AST_RETURN, "return") \
@@ -123,6 +125,14 @@ typedef struct ast_node {
Node *do_else;
} conditional;
+ // Used for all loops
+ struct {
+ Node *cond;
+ Node *init;
+ Node *step;
+ Node *body;
+ } loop;
+
Variable *variable;
};
} Node;
diff --git a/cup/generator.c b/cup/generator.c
index 9e93164..d7b252d 100644
--- a/cup/generator.c
+++ b/cup/generator.c
@@ -211,24 +211,54 @@ void generate_statement(Node *stmt, FILE *out)
} else if (stmt->type == AST_IF) {
assert(stmt->conditional.cond);
assert(stmt->conditional.do_then);
- generate_expr_into_rax(stmt->conditional.cond, out);
+ int cur_label = label_counter++;
+ generate_expr_into_rax(stmt->conditional.cond, out);
// If we don't have an `else` clause, we can simplify
if (!stmt->conditional.do_else) {
fprintf(out, " cmp rax, 0\n");
- fprintf(out, " je .if_end_%d\n", label_counter);
+ fprintf(out, " je .if_end_%d\n", cur_label);
generate_statement(stmt->conditional.do_then, out);
- fprintf(out, ".if_end_%d:\n", label_counter);
+ fprintf(out, ".if_end_%d:\n", cur_label);
} else {
fprintf(out, " cmp rax, 0\n");
- fprintf(out, " je .if_else_%d\n", label_counter);
+ fprintf(out, " je .if_else_%d\n", cur_label);
generate_statement(stmt->conditional.do_then, out);
- fprintf(out, " jmp .if_end_%d\n", label_counter);
- fprintf(out, ".if_else_%d:\n", label_counter);
+ fprintf(out, " jmp .if_end_%d\n", cur_label);
+ fprintf(out, ".if_else_%d:\n", cur_label);
generate_statement(stmt->conditional.do_else, out);
- fprintf(out, ".if_end_%d:\n", label_counter);
+ fprintf(out, ".if_end_%d:\n", cur_label);
}
- label_counter++;
+ } else if (stmt->type == AST_WHILE) {
+ int cur_label = label_counter++;
+ fprintf(out, ".loop_start_%d:\n", cur_label);
+ fprintf(out, ".loop_continue_%d:\n", cur_label);
+ generate_expr_into_rax(stmt->loop.cond, out);
+ fprintf(out, " cmp rax, 0\n");
+ fprintf(out, " je .loop_end_%d\n", cur_label);
+ generate_statement(stmt->loop.body, out);
+ fprintf(out, " jmp .loop_start_%d\n", cur_label);
+ fprintf(out, ".loop_end_%d:\n", cur_label);
+
+ } else if (stmt->type == AST_FOR) {
+ int cur_label = label_counter++;
+ if (stmt->loop.init) {
+ generate_statement(stmt->loop.init, out);
+ }
+ fprintf(out, ".loop_start_%d:\n", cur_label);
+ if (stmt->loop.cond) {
+ generate_expr_into_rax(stmt->loop.cond, out);
+ fprintf(out, " cmp rax, 0\n");
+ fprintf(out, " je .loop_end_%d\n", cur_label);
+ }
+ generate_statement(stmt->loop.body, out);
+ fprintf(out, ".loop_continue_%d:\n", cur_label);
+ if (stmt->loop.step) {
+ generate_expr_into_rax(stmt->loop.step, out);
+ }
+ fprintf(out, " jmp .loop_start_%d\n", cur_label);
+ fprintf(out, ".loop_end_%d:\n", cur_label);
+
} else if (stmt->type == AST_BLOCK) {
generate_block(stmt, out);
} else {
diff --git a/cup/lexer.c b/cup/lexer.c
index 32138cc..ea095ba 100644
--- a/cup/lexer.c
+++ b/cup/lexer.c
@@ -79,6 +79,9 @@ static Token Lexer_make_token(Lexer *lexer, TokenType type, int inc_amount)
return token;
}
+#define LEX_KEYWORD(str, token_type) \
+ if (Lexer_starts_with(lexer, str)) return Lexer_make_token(lexer, token_type, strlen(str));
+
Token Lexer_next(Lexer *lexer)
{
while (lexer->pos < lexer->len) {
@@ -165,12 +168,14 @@ Token Lexer_next(Lexer *lexer)
default: {
// Handle keywords explicitly
- if (Lexer_starts_with(lexer, "fn")) return Lexer_make_token(lexer, TOKEN_FN, 2);
- if (Lexer_starts_with(lexer, "if")) return Lexer_make_token(lexer, TOKEN_IF, 2);
- if (Lexer_starts_with(lexer, "else")) return Lexer_make_token(lexer, TOKEN_ELSE, 4);
- if (Lexer_starts_with(lexer, "return")) return Lexer_make_token(lexer, TOKEN_RETURN, 6);
- if (Lexer_starts_with(lexer, "int")) return Lexer_make_token(lexer, TOKEN_INT, 3);
- if (Lexer_starts_with(lexer, "let")) return Lexer_make_token(lexer, TOKEN_LET, 3);
+ LEX_KEYWORD("fn", TOKEN_FN);
+ LEX_KEYWORD("if", TOKEN_IF);
+ LEX_KEYWORD("int", TOKEN_INT);
+ LEX_KEYWORD("let", TOKEN_LET);
+ LEX_KEYWORD("for", TOKEN_FOR);
+ LEX_KEYWORD("else", TOKEN_ELSE);
+ LEX_KEYWORD("while", TOKEN_WHILE);
+ LEX_KEYWORD("return", TOKEN_RETURN);
if (isdigit(lexer->src[lexer->pos])) {
// TODO: Parse hex and octal numbers
diff --git a/cup/parser.c b/cup/parser.c
index b3113ea..698002f 100644
--- a/cup/parser.c
+++ b/cup/parser.c
@@ -334,6 +334,35 @@ Node *parse_statement(Lexer *lexer)
Lexer_next(lexer);
node->conditional.do_else = parse_statement(lexer);
}
+ } else if (token.type == TOKEN_WHILE) {
+ Lexer_next(lexer);
+ node = Node_new(AST_WHILE);
+ assert_token(Lexer_next(lexer), TOKEN_OPEN_PAREN);
+ node->loop.cond = parse_expression(lexer);
+ assert_token(Lexer_next(lexer), TOKEN_CLOSE_PAREN);
+ node->loop.body = parse_statement(lexer);
+ } else if (token.type == TOKEN_FOR) {
+ Lexer_next(lexer);
+ node = Node_new(AST_FOR);
+ assert_token(Lexer_next(lexer), TOKEN_OPEN_PAREN);
+
+ // All of the expressions in the for loop are optional
+
+ // TODO: Allow this to be a declaration, need to inject
+ // the variable into the symbol table for the block
+ if (Lexer_peek(lexer).type != TOKEN_SEMICOLON)
+ node->loop.init = parse_expression(lexer);
+ assert_token(Lexer_next(lexer), TOKEN_SEMICOLON);
+
+ if (Lexer_peek(lexer).type != TOKEN_SEMICOLON)
+ node->loop.cond = parse_expression(lexer);
+ assert_token(Lexer_next(lexer), TOKEN_SEMICOLON);
+
+ if (Lexer_peek(lexer).type != TOKEN_CLOSE_PAREN)
+ node->loop.step = parse_expression(lexer);
+ assert_token(Lexer_next(lexer), TOKEN_CLOSE_PAREN);
+
+ node->loop.body = parse_statement(lexer);
} else if (token.type == TOKEN_OPEN_BRACE) {
node = parse_block(lexer);
} else {
diff --git a/cup/tokens.h b/cup/tokens.h
index 84aacec..44f5e6e 100644
--- a/cup/tokens.h
+++ b/cup/tokens.h
@@ -17,6 +17,7 @@
F(TOKEN_EQ, "==") \
F(TOKEN_EXCLAMATION, "!") \
F(TOKEN_FN, "fn") \
+ F(TOKEN_FOR, "for") \
F(TOKEN_GEQ, ">=") \
F(TOKEN_GT, ">") \
F(TOKEN_IDENTIFIER, "identifier") \
@@ -46,6 +47,7 @@
F(TOKEN_STAR, "*") \
F(TOKEN_STRINGLIT, "string literal") \
F(TOKEN_TILDE, "~") \
+ F(TOKEN_WHILE, "while") \
F(TOKEN_XOR, "^")
typedef enum {
diff --git a/examples/loops.cup b/examples/loops.cup
new file mode 100644
index 0000000..6c4f3a7
--- /dev/null
+++ b/examples/loops.cup
@@ -0,0 +1,22 @@
+fn main(): int {
+ let sum1: int = 0;
+ let sum2: int = 0;
+
+ let N: int = 10;
+ let i: int = 0;
+
+ for (i = 0; i <= N; i = i + 1) {
+ sum1 = sum1 + i;
+ }
+
+ i = 0;
+ while (i <= N) {
+ sum2 = sum2 + i;
+ i = i + 1;
+ }
+
+ if (sum1 == sum2 && sum1 == 55) {
+ return 0;
+ }
+ return 1;
+} \ No newline at end of file
diff --git a/tests/loops.sh b/tests/loops.sh
new file mode 100755
index 0000000..3fea60f
--- /dev/null
+++ b/tests/loops.sh
@@ -0,0 +1,118 @@
+#!/bin/bash
+
+. tests/common.sh
+
+set -e
+
+echo -n "- While loops: "
+assert_exit_status_stdin 5 <<EOF
+fn main() {
+ while (1) {
+ return 5;
+ }
+ return 3;
+}
+EOF
+
+assert_exit_status_stdin 3 <<EOF
+fn main() {
+ while (0) {
+ return 5;
+ }
+ return 3;
+}
+EOF
+
+assert_exit_status_stdin 10 <<EOF
+fn main() {
+ let sum: int = 0;
+ while (sum < 10) {
+ sum = sum + 1;
+ }
+ return sum;
+}
+EOF
+
+assert_exit_status_stdin 55 <<EOF
+fn main() {
+ let sum: int = 0;
+ let N: int = 10;
+ let i: int = 0;
+ while (i <= N) {
+ sum = sum + i;
+ i = i + 1;
+ }
+ return sum;
+}
+EOF
+echo " OK"
+
+echo -n "- For loops: "
+assert_exit_status_stdin 5 <<EOF
+fn main() {
+ for (;;) {
+ return 5;
+ }
+ return 3;
+}
+EOF
+
+assert_exit_status_stdin 3 <<EOF
+fn main() {
+ for (;0;) {
+ return 5;
+ }
+ return 3;
+}
+EOF
+
+assert_exit_status_stdin 55 <<EOF
+fn main() {
+ let sum: int = 0;
+ let i: int;
+ for (i = 0; i <= 10; i = i + 1) {
+ sum = sum + i;
+ }
+ return sum;
+}
+EOF
+
+assert_exit_status_stdin 55 <<EOF
+fn main() {
+ let sum: int = 0;
+ let i: int = 0;
+ for (; i <= 10; i = i + 1) {
+ sum = sum + i;
+ }
+ return sum;
+}
+EOF
+
+assert_exit_status_stdin 45 <<EOF
+fn main() {
+ let sum: int = 0;
+ let i: int = 0;
+ for (;i < 10;) {
+ sum = sum + i;
+ i = i + 1;
+ }
+ return sum;
+}
+EOF
+
+assert_exit_status_stdin 55 <<EOF
+fn main() {
+ let sum: int = 0;
+ let i: int = 0;
+ for (;;) {
+ sum = sum + i;
+ i = i + 1;
+ if (i == 11) {
+ return sum;
+ }
+ }
+ // unreachable, but we don't catch this error yet
+ return -1;
+}
+EOF
+echo " OK"