aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/ast.c29
-rw-r--r--src/ast.h5
-rw-r--r--src/generator.c12
-rw-r--r--src/parser.c168
-rw-r--r--src/utils.c7
-rw-r--r--src/utils.h3
-rwxr-xr-xtests/conditions.sh18
-rwxr-xr-xtests/core.sh116
-rwxr-xr-xtests/functions.sh44
-rwxr-xr-xtests/loops.sh20
-rwxr-xr-xtests/variables.sh36
11 files changed, 327 insertions, 131 deletions
diff --git a/src/ast.c b/src/ast.c
index 4d4eb9b..7d9f053 100644
--- a/src/ast.c
+++ b/src/ast.c
@@ -55,6 +55,15 @@ char *data_type_to_str(DataType type)
}
}
+bool type_equals(Type *a, Type *b)
+{
+ if (a == NULL && b == NULL)
+ return true;
+ if (a == NULL || b == NULL)
+ return false;
+ return a->type == b->type && type_equals(a->ptr, b->ptr);
+}
+
i64 size_for_type(Type *type)
{
switch (type->type)
@@ -64,12 +73,24 @@ i64 size_for_type(Type *type)
default: assert(false && "Unreachable type");
}
}
-
Type *type_new(DataType type)
{
- Type *t = calloc(sizeof(Type), 1);
- t->type = type;
- return t;
+ // For the core types, we don't need to allocate any memory, just
+ // return a pointer to a static instance.
+ static Type type_int = {.type = TYPE_INT, .ptr = NULL};
+ if (type == TYPE_INT) return &type_int;
+
+ Type *self = calloc(sizeof(Type), 1);
+ self->type = type;
+ return self;
+}
+
+Node *Node_from_int_literal(i64 value)
+{
+ Node *self = Node_new(AST_LITERAL);
+ self->literal.type = self->expr_type = type_new(TYPE_INT);
+ self->literal.as_int = value;
+ return self;
}
void print_type_to_file(FILE *out, Type *type)
diff --git a/src/ast.h b/src/ast.h
index 0966550..6773c0b 100644
--- a/src/ast.h
+++ b/src/ast.h
@@ -73,6 +73,8 @@ typedef struct data_type_node {
Type *type_new(DataType type);
i64 size_for_type(Type *type);
+bool type_equals(Type *a, Type *b);
+void print_type_to_file(FILE *out, Type *type);
typedef struct {
char *name;
@@ -83,6 +85,7 @@ typedef struct {
typedef struct ast_node Node;
typedef struct ast_node {
NodeType type;
+ Type *expr_type;
union {
// Binary expr
@@ -163,4 +166,6 @@ typedef struct ast_node {
void Node_add_child(Node *parent, Node *child);
Node *Node_new(NodeType type);
+Node *Node_from_int_literal(i64 value);
+
void print_ast(Node *node); \ No newline at end of file
diff --git a/src/generator.c b/src/generator.c
index 8663116..f9aa787 100644
--- a/src/generator.c
+++ b/src/generator.c
@@ -24,6 +24,16 @@ void make_syscall(i64 syscall_no, FILE *out) {
fprintf(out, " syscall\n");
}
+char *specifier_for_type(Type *type) {
+ switch (size_for_type(type)) {
+ case 1: return "byte";
+ case 2: return "word";
+ case 4: return "dword";
+ case 8: return "qword";
+ default: assert(false && "Unreachable");
+ }
+}
+
void generate_expr_into_rax(Node *expr, FILE *out);
void generate_lvalue_into_rax(Node *node, FILE *out)
@@ -85,7 +95,7 @@ void generate_expr_into_rax(Node *expr, FILE *out)
fprintf(out, " push rax\n");
generate_expr_into_rax(expr->assign.value, out);
fprintf(out, " pop rbx\n");
- fprintf(out, " mov [rbx], rax\n");
+ fprintf(out, " mov %s [rbx], rax\n", specifier_for_type(var->expr_type));
} else if (expr->type == OP_NEG) {
generate_expr_into_rax(expr->unary_expr, out);
diff --git a/src/parser.c b/src/parser.c
index 85c4d6a..e391500 100644
--- a/src/parser.c
+++ b/src/parser.c
@@ -144,7 +144,7 @@ Node *find_function_definition(Token *token)
void add_global_variable(Variable *var)
{
var->offset = global_vars_offset;
- int var_size = size_for_type(var->type);
+ int var_size = align_up(size_for_type(var->type), 8);
global_vars_offset += var_size;
global_vars[global_vars_count++] = var;
}
@@ -157,7 +157,7 @@ void add_variable_to_current_block(Variable *var)
int new_len = (cur_block->block.num_locals + 1);
// TODO: Align the stack to a certain size?
- int var_size = size_for_type(var->type);
+ int var_size = align_up(size_for_type(var->type), 8);
// Add to the block
// FIXME: Use a map here
@@ -205,6 +205,7 @@ Node *parse_literal(Lexer *lexer)
Node *node = Node_new(AST_LITERAL);
Token token = assert_token(Lexer_next(lexer), TOKEN_INTLIT);
node->literal.type = type_new(TYPE_INT);
+ node->expr_type = node->literal.type;
node->literal.as_int = token.value.as_int;
return node;
}
@@ -243,6 +244,10 @@ Node *parse_var_declaration(Lexer *lexer)
if (is_global)
die_location(token.loc, "Cannot initialize global variable `%s` outside function", node->var_decl.var.name);
node->var_decl.value = parse_expression(lexer);
+
+ if (!type_equals(node->var_decl.var.type, node->var_decl.value->expr_type))
+ die_location(token.loc, "Type mismatch for variable declaration `%s` initalizer", node->var_decl.var.name);
+
assert_token(Lexer_next(lexer), TOKEN_SEMICOLON);
} else {
assert_token(token, TOKEN_SEMICOLON);
@@ -266,9 +271,6 @@ Node *parse_function_call_args(Lexer *lexer, Node *func)
call->call.args = realloc(call->call.args, sizeof(Node *) * new_size);
call->call.args[call->call.num_args++] = arg;
- if (new_size > func->func.num_args)
- die_location(identifier.loc, "Too many arguments to function `%s`", func->func.name);
-
token = Lexer_peek(lexer);
if (token.type == TOKEN_COMMA) {
Lexer_next(lexer);
@@ -277,8 +279,14 @@ Node *parse_function_call_args(Lexer *lexer, Node *func)
}
if (call->call.num_args != func->func.num_args)
- die_location(identifier.loc, "Too few arguments to function `%s`", func->func.name);
+ die_location(identifier.loc, "Function `%s` expects %d arguments, got %d", func->func.name, func->func.num_args, call->call.num_args);
+ for (int i = 0; i < call->call.num_args; i++) {
+ if (!type_equals(func->func.args[i].type, call->call.args[i]->expr_type)) {
+ die_location(identifier.loc, "Type mismatch for argument %d in function call `%s`", i, func->func.name);
+ }
+ }
+ call->expr_type = call->call.func->func.return_type;
assert_token(Lexer_next(lexer), TOKEN_CLOSE_PAREN);
return call;
}
@@ -294,6 +302,7 @@ Node *parse_identifier(Lexer *lexer)
Lexer_next(lexer);
expr = Node_new(AST_LOCAL_VAR);
expr->variable = var;
+ expr->expr_type = var->type;
return expr;
}
@@ -302,6 +311,7 @@ Node *parse_identifier(Lexer *lexer)
Lexer_next(lexer);
expr = Node_new(AST_GLOBAL_VAR);
expr->variable = gvar;
+ expr->expr_type = gvar->type;
return expr;
}
@@ -319,23 +329,145 @@ Node *parse_identifier(Lexer *lexer)
return NULL;
}
+Node *handle_unary_expr_types(Node *node, Token *token)
+{
+ Type *old_type = node->unary_expr->expr_type;
+
+ if (node->type == OP_NOT) {
+ node->expr_type = type_new(TYPE_INT);
+ } else if (node->type == OP_ADDROF) {
+ Type *ptr = type_new(TYPE_PTR);
+ ptr->ptr = old_type;
+ node->expr_type = ptr;
+ } else if (node->type == OP_DEREF) {
+ if (old_type->type != TYPE_PTR)
+ die_location(token->loc, "Cannot dereference non-pointer type");
+ node->expr_type = old_type->ptr;
+ } else if (node->type == OP_NEG) {
+ if (old_type->type != TYPE_INT)
+ die_location(token->loc, "Cannot negate non-integer type");
+ node->expr_type = type_new(TYPE_INT);
+ } else {
+ // Default to not changing the type
+ node->expr_type = old_type;
+ }
+ // die_location(token->loc, "Unknown unary expression type in handle_unary_expr_types\n");
+ return node;
+}
+
+Node *handle_binary_expr_types(Node *node, Token *token)
+{
+ Type *left = node->binary.left->expr_type;
+ Type *right = node->binary.right->expr_type;
+
+ switch (node->type)
+ {
+ case OP_PLUS: {
+ if (left->type == TYPE_INT && right->type == TYPE_INT) {
+ node->expr_type = type_new(TYPE_INT);
+ } else if (left->type == TYPE_PTR && right->type == TYPE_INT) {
+ node->expr_type = left;
+ // Pointer arithmetic!
+ Node *mul = Node_new(OP_MUL);
+ mul->binary.left = node->binary.right;
+ mul->binary.right = Node_new(AST_LITERAL);
+ mul->binary.right->literal.type = type_new(TYPE_INT);
+ mul->binary.right->literal.as_int = size_for_type(left->ptr);
+ node->binary.right = mul;
+ } else if (left->type == TYPE_INT && right->type == TYPE_PTR) {
+ node->expr_type = right;
+ // Pointer arithmetic!
+ Node *mul = Node_new(OP_MUL);
+ mul->binary.left = node->binary.left;
+ mul->binary.right = Node_new(AST_LITERAL);
+ mul->binary.right->literal.type = type_new(TYPE_INT);
+ mul->binary.right->literal.as_int = size_for_type(right->ptr);
+ node->binary.left = mul;
+ } else {
+ die_location(token->loc, "Cannot add non-integer types");
+ }
+ } break;
+
+ case OP_MINUS: {
+ if (left->type == TYPE_INT && right->type == TYPE_INT) {
+ node->expr_type = type_new(TYPE_INT);
+ } else if (left->type == TYPE_PTR && right->type == TYPE_INT) {
+ node->expr_type = left;
+ // Pointer arithmetic!
+ Node *mul = Node_new(OP_MUL);
+ mul->binary.left = node->binary.right;
+ mul->binary.right = Node_from_int_literal(size_for_type(left->ptr));
+ node->binary.right = mul;
+ } else if (left->type == TYPE_INT && right->type == TYPE_PTR) {
+ node->expr_type = right;
+ // Pointer arithmetic!
+ Node *mul = Node_new(OP_MUL);
+ mul->binary.left = node->binary.left;
+ mul->binary.right = Node_from_int_literal(size_for_type(right->ptr));
+ node->binary.left = mul;
+ } else if (left->type == TYPE_PTR && right->type == TYPE_PTR) {
+ // TODO: Check for different pointer types
+ node->expr_type = type_new(TYPE_INT);
+ // Divide by size of pointer
+ Node *div = Node_new(OP_DIV);
+ div->binary.left = node;
+ div->binary.right = Node_from_int_literal(size_for_type(left->ptr));
+ div->expr_type = node->expr_type;
+ node = div;
+ } else {
+ die_location(token->loc, "Cannot subtract non-integer types");
+ }
+ } break;
+
+ case OP_DIV:
+ case OP_MOD:
+ case OP_MUL: {
+ if (left->type == TYPE_INT && right->type == TYPE_INT) {
+ node->expr_type = left;
+ } else {
+ die_location(token->loc, "Cannot do operation `%s` non-integer types", node_type_to_str(node->type));
+ }
+ } break;
+
+ case OP_EQ:
+ case OP_NEQ:
+ case OP_LT:
+ case OP_GT:
+ case OP_LEQ:
+ case OP_GEQ:
+ case OP_AND:
+ case OP_OR: {
+ node->expr_type = type_new(TYPE_INT);
+ } break;
+
+ default:
+ die_location(token->loc, "Unknown binary expression type in handle_binary_expr_types\n");
+ }
+ return node;
+}
+
Node *parse_factor(Lexer *lexer)
{
- // TODO: Parse more complicated things
+ // TODO: We need to properly handle type conversions / operations with different types
+ // where we need to cast one of the operands / etc. Perhaps have a separate
+ // type-checking / adding casts/conversions pass?
Token token = Lexer_peek(lexer);
Node *expr = NULL;
if (token.type == TOKEN_MINUS) {
Lexer_next(lexer);
expr = Node_new(OP_NEG);
expr->unary_expr = parse_factor(lexer);
+ expr = handle_unary_expr_types(expr, &token);
} else if (token.type == TOKEN_TILDE) {
Lexer_next(lexer);
expr = Node_new(OP_BWINV);
expr->unary_expr = parse_factor(lexer);
+ expr = handle_unary_expr_types(expr, &token);
} else if (token.type == TOKEN_EXCLAMATION) {
Lexer_next(lexer);
expr = Node_new(OP_NOT);
expr->unary_expr = parse_factor(lexer);
+ expr = handle_unary_expr_types(expr, &token);
} else if (token.type == TOKEN_OPEN_PAREN) {
Lexer_next(lexer);
expr = parse_expression(lexer);
@@ -350,6 +482,7 @@ Node *parse_factor(Lexer *lexer)
expr->unary_expr = parse_factor(lexer);
if (!is_lvalue(expr->unary_expr->type))
die_location(token.loc, "Cannot take address of non-lvalue");
+ expr = handle_unary_expr_types(expr, &token);
} else if (token.type == TOKEN_STAR) {
Lexer_next(lexer);
expr = Node_new(OP_DEREF);
@@ -357,6 +490,7 @@ Node *parse_factor(Lexer *lexer)
// to work, we need to to be able to evaluate the type for complex expressions,
// which we do not support as of now.
expr->unary_expr = parse_factor(lexer);
+ expr = handle_unary_expr_types(expr, &token);
} else {
die_location(token.loc, ": Expected token found in parse_factor: `%s`", token_type_to_str(token.type));
}
@@ -372,6 +506,7 @@ Node *parse_factor(Lexer *lexer)
Node *right = next_parser(lexer); \
op->binary.left = expr; \
op->binary.right = right; \
+ op = handle_binary_expr_types(op, &token); \
expr = op; \
token = Lexer_peek(lexer); \
} \
@@ -408,7 +543,13 @@ Node *parse_conditional_exp(Lexer *lexer)
conditional->conditional.cond = expr;
conditional->conditional.do_then = then_expr;
conditional->conditional.do_else = else_expr;
+
+ if (!type_equals(then_expr->expr_type, else_expr->expr_type))
+ die_location(token.loc, "Type mismatch in conditional expression cases");
+
+ conditional->expr_type = then_expr->expr_type;
expr = conditional;
+ expr->expr_type = then_expr->expr_type;
}
return expr;
}
@@ -425,7 +566,12 @@ Node *parse_expression(Lexer *lexer)
Node *assign = Node_new(OP_ASSIGN);
assign->assign.var = node;
assign->assign.value = parse_expression(lexer);
+
+ if (!type_equals(node->expr_type, assign->assign.value->expr_type))
+ die_location(token.loc, "Type mismatch in assignment expression");
+
node = assign;
+ node->expr_type = node->assign.var->expr_type;
}
}
return node;
@@ -442,6 +588,10 @@ Node *parse_statement(Lexer *lexer)
assert_token(Lexer_next(lexer), TOKEN_RETURN);
node = Node_new(AST_RETURN);
node->unary_expr = parse_expression(lexer);
+
+ if (!type_equals(node->unary_expr->expr_type, current_function->func.return_type))
+ die_location(token.loc, "Return expression does not match function's return type");
+
assert_token(Lexer_next(lexer), TOKEN_SEMICOLON);
} else if (token.type == TOKEN_IF) {
Lexer_next(lexer);
@@ -569,7 +719,9 @@ void parse_func_args(Lexer *lexer, Node *func)
Variable *var = &func->func.args[i];
var->offset = offset;
// TODO: Do we need to align the stack here?
- int var_size = size_for_type(var->type);
+ // TODO: (Here and other uses of `size_for_type`):
+ // Should we only align to max(8, type->size) instead?
+ int var_size = align_up(size_for_type(var->type), 8);
offset -= var_size;
}
}
diff --git a/src/utils.c b/src/utils.c
index 2335981..939021a 100644
--- a/src/utils.c
+++ b/src/utils.c
@@ -27,4 +27,9 @@ void _die_location(char *file, int line, Location loc, const char *fmt, ...)
}
i64 i64max(i64 a, i64 b) { return a > b ? a : b; }
-i64 i64min(i64 a, i64 b) { return a < b ? a : b; } \ No newline at end of file
+i64 i64min(i64 a, i64 b) { return a < b ? a : b; }
+
+i64 align_up(i64 val, i64 align)
+{
+ return (val + align - 1) & ~(align - 1);
+} \ No newline at end of file
diff --git a/src/utils.h b/src/utils.h
index dfc018c..7718c1c 100644
--- a/src/utils.h
+++ b/src/utils.h
@@ -9,4 +9,7 @@ void _die_location(char *file, int line, Location loc, const char *fmt, ...);
i64 i64max(i64 a, i64 b);
i64 i64min(i64 a, i64 b);
+// Assumes alignment is a power of 2
+i64 align_up(i64 val, i64 align);
+
#define die_location(loc, ...) _die_location(__FILE__, __LINE__, loc, __VA_ARGS__) \ No newline at end of file
diff --git a/tests/conditions.sh b/tests/conditions.sh
index be80d96..b10948e 100755
--- a/tests/conditions.sh
+++ b/tests/conditions.sh
@@ -5,12 +5,12 @@
set -e
echo -n "- Conditionals: "
-assert_exit_status 'fn main() { return 1 ? 5 : 10; }' 5
-assert_exit_status 'fn main() { return 0 ? 5 : 10; }' 10
-assert_exit_status 'fn main() { return 1 < 2 ? 10 : 20; }' 10
+assert_exit_status 'fn main(): int { return 1 ? 5 : 10; }' 5
+assert_exit_status 'fn main(): int { return 0 ? 5 : 10; }' 10
+assert_exit_status 'fn main(): int { return 1 < 2 ? 10 : 20; }' 10
assert_exit_status_stdin 5 <<EOF
-fn main() {
+fn main(): int {
let flag: int = 1;
let a: int;
flag ? a = 5 : a = 10;
@@ -19,7 +19,7 @@ fn main() {
EOF
assert_exit_status_stdin 10 <<EOF
-fn main() {
+fn main(): int {
let flag: int = 0;
let a: int;
flag ? a = 5 : a = 10;
@@ -30,21 +30,21 @@ echo " OK"
echo -n "- If statement: "
assert_exit_status_stdin 10 <<EOF
-fn main() {
+fn main(): int {
if (5 < 20) return 10;
return 3;
}
EOF
assert_exit_status_stdin 3 <<EOF
-fn main() {
+fn main(): int {
if (5 > 20) return 10;
return 3;
}
EOF
assert_exit_status_stdin 20 <<EOF
-fn main() {
+fn main(): int {
let x: int;
if (0)
x = 3;
@@ -55,7 +55,7 @@ fn main() {
EOF
assert_exit_status_stdin 3 <<EOF
-fn main() {
+fn main(): int {
let x: int;
if (1)
x = 3;
diff --git a/tests/core.sh b/tests/core.sh
index 71c539d..7b62208 100755
--- a/tests/core.sh
+++ b/tests/core.sh
@@ -5,75 +5,75 @@
set -e
echo -n "- Basic return: "
-assert_exit_status 'fn main() { return 0; }' 0
-assert_exit_status 'fn main() { return 1; }' 1
-assert_exit_status 'fn main() { return 100; }' 100
+assert_exit_status 'fn main(): int { return 0; }' 0
+assert_exit_status 'fn main(): int { return 1; }' 1
+assert_exit_status 'fn main(): int { return 100; }' 100
echo " OK"
echo -n "- Unary ops: "
-assert_exit_status 'fn main() { return -1; }' 255
-assert_exit_status 'fn main() { return -100; }' 156
-assert_exit_status 'fn main() { return !0; }' 1
-assert_exit_status 'fn main() { return !1; }' 0
-assert_exit_status 'fn main() { return !34; }' 0
-assert_exit_status 'fn main() { return !-1; }' 0
-assert_exit_status 'fn main() { return ~34; }' 221
+assert_exit_status 'fn main(): int { return -1; }' 255
+assert_exit_status 'fn main(): int { return -100; }' 156
+assert_exit_status 'fn main(): int { return !0; }' 1
+assert_exit_status 'fn main(): int { return !1; }' 0
+assert_exit_status 'fn main(): int { return !34; }' 0
+assert_exit_status 'fn main(): int { return !-1; }' 0
+assert_exit_status 'fn main(): int { return ~34; }' 221
echo " OK"
echo -n "- Arith Binary ops: "
-assert_exit_status 'fn main() { return 1 + 1; }' 2
-assert_exit_status 'fn main() { return 1 + 100; }' 101
-assert_exit_status 'fn main() { return 100 + 1; }' 101
-assert_exit_status 'fn main() { return 1 - 1; }' 0
-assert_exit_status 'fn main() { return 1 - 100; }' 157
-assert_exit_status 'fn main() { return 100 - 1; }' 99
-assert_exit_status 'fn main() { return 1 * 1; }' 1
-assert_exit_status 'fn main() { return 1 * 100; }' 100
-assert_exit_status 'fn main() { return 100 * 1; }' 100
-assert_exit_status 'fn main() { return 7 * 3; }' 21
-assert_exit_status 'fn main() { return 1 / 1; }' 1
-assert_exit_status 'fn main() { return 100 / 1; }' 100
-assert_exit_status 'fn main() { return 100 / 7; }' 14
-assert_exit_status 'fn main() { return 100 / 100; }' 1
-assert_exit_status 'fn main() { return 100 / -1; }' 156
+assert_exit_status 'fn main(): int { return 1 + 1; }' 2
+assert_exit_status 'fn main(): int { return 1 + 100; }' 101
+assert_exit_status 'fn main(): int { return 100 + 1; }' 101
+assert_exit_status 'fn main(): int { return 1 - 1; }' 0
+assert_exit_status 'fn main(): int { return 1 - 100; }' 157
+assert_exit_status 'fn main(): int { return 100 - 1; }' 99
+assert_exit_status 'fn main(): int { return 1 * 1; }' 1
+assert_exit_status 'fn main(): int { return 1 * 100; }' 100
+assert_exit_status 'fn main(): int { return 100 * 1; }' 100
+assert_exit_status 'fn main(): int { return 7 * 3; }' 21
+assert_exit_status 'fn main(): int { return 1 / 1; }' 1
+assert_exit_status 'fn main(): int { return 100 / 1; }' 100
+assert_exit_status 'fn main(): int { return 100 / 7; }' 14
+assert_exit_status 'fn main(): int { return 100 / 100; }' 1
+assert_exit_status 'fn main(): int { return 100 / -1; }' 156
echo " OK"
echo -n "- Relational ops: "
-assert_exit_status 'fn main() { return 1 == 1; }' 1
-assert_exit_status 'fn main() { return 1 == 2; }' 0
-assert_exit_status 'fn main() { return 1 != 1; }' 0
-assert_exit_status 'fn main() { return 1 != 2; }' 1
+assert_exit_status 'fn main(): int { return 1 == 1; }' 1
+assert_exit_status 'fn main(): int { return 1 == 2; }' 0
+assert_exit_status 'fn main(): int { return 1 != 1; }' 0
+assert_exit_status 'fn main(): int { return 1 != 2; }' 1
-assert_exit_status 'fn main() { return 1 < 2; }' 1
-assert_exit_status 'fn main() { return 2 < 2; }' 0
+assert_exit_status 'fn main(): int { return 1 < 2; }' 1
+assert_exit_status 'fn main(): int { return 2 < 2; }' 0
-assert_exit_status 'fn main() { return 1 <= 2; }' 1
-assert_exit_status 'fn main() { return 2 <= 2; }' 1
-assert_exit_status 'fn main() { return 3 <= 2; }' 0
+assert_exit_status 'fn main(): int { return 1 <= 2; }' 1
+assert_exit_status 'fn main(): int { return 2 <= 2; }' 1
+assert_exit_status 'fn main(): int { return 3 <= 2; }' 0
-assert_exit_status 'fn main() { return 2 > 2; }' 0
-assert_exit_status 'fn main() { return 3 > 2; }' 1
+assert_exit_status 'fn main(): int { return 2 > 2; }' 0
+assert_exit_status 'fn main(): int { return 3 > 2; }' 1
-assert_exit_status 'fn main() { return 1 >= 2; }' 0
-assert_exit_status 'fn main() { return 2 >= 2; }' 1
-assert_exit_status 'fn main() { return 3 >= 2; }' 1
+assert_exit_status 'fn main(): int { return 1 >= 2; }' 0
+assert_exit_status 'fn main(): int { return 2 >= 2; }' 1
+assert_exit_status 'fn main(): int { return 3 >= 2; }' 1
echo " OK"
echo -n "- Simple logical ops: "
-assert_exit_status 'fn main() { return 0 && 0; }' 0
-assert_exit_status 'fn main() { return 0 && 5; }' 0
-assert_exit_status 'fn main() { return 5 && 0; }' 0
-assert_exit_status 'fn main() { return 5 && 1; }' 1
-
-assert_exit_status 'fn main() { return 0 || 0; }' 0
-assert_exit_status 'fn main() { return 5 || 0; }' 1
-assert_exit_status 'fn main() { return 0 || 3; }' 1
-assert_exit_status 'fn main() { return 2 || 1; }' 1
+assert_exit_status 'fn main(): int { return 0 && 0; }' 0
+assert_exit_status 'fn main(): int { return 0 && 5; }' 0
+assert_exit_status 'fn main(): int { return 5 && 0; }' 0
+assert_exit_status 'fn main(): int { return 5 && 1; }' 1
+
+assert_exit_status 'fn main(): int { return 0 || 0; }' 0
+assert_exit_status 'fn main(): int { return 5 || 0; }' 1
+assert_exit_status 'fn main(): int { return 0 || 3; }' 1
+assert_exit_status 'fn main(): int { return 2 || 1; }' 1
echo " OK"
echo -n "- Short-circuiting: "
assert_exit_status_stdin 5 <<EOF
-fn main() {
+fn main(): int {
let x: int = 5;
let y: int = (1 || (x = 10));
return x;
@@ -81,7 +81,7 @@ fn main() {
EOF
assert_exit_status_stdin 10 <<EOF
-fn main() {
+fn main(): int {
let x: int = 5;
let y: int = (0 || (x = 10));
return x;
@@ -89,7 +89,7 @@ fn main() {
EOF
assert_exit_status_stdin 5 <<EOF
-fn main() {
+fn main(): int {
let x: int = 5;
let y: int = (0 && (x = 10));
return x;
@@ -97,7 +97,7 @@ fn main() {
EOF
assert_exit_status_stdin 10 <<EOF
-fn main() {
+fn main(): int {
let x: int = 5;
let y: int = (1 && (x = 10));
return x;
@@ -109,7 +109,7 @@ echo -n "- Importing file: "
assert_exit_status_stdin 10 <<EOF
import "std/math.cup"
-fn main() {
+fn main(): int {
let x: int = abs(-5);
let y: int = factorial(3);
return x + y - 1;
@@ -117,7 +117,7 @@ fn main() {
EOF
assert_compile_failure_stdin <<EOF
-fn main() {
+fn main(): int {
let x: int = abs(-5);
let y: int = factorial(3);
return x + y - 1;
@@ -127,7 +127,7 @@ echo " OK"
echo -n "- Defer: "
assert_stdout_text \
-"fn main() {
+"fn main(): int {
defer print(5);
print(4);
}" \
@@ -135,7 +135,7 @@ assert_stdout_text \
5"
assert_stdout_text \
-"fn main() {
+"fn main(): int {
defer print(1);
{
defer print(2);
@@ -162,7 +162,7 @@ assert_stdout_text \
}
print(3);
}
-fn main() {
+fn main(): int {
defer print(4);
defer test();
print(10);
@@ -180,7 +180,7 @@ fn test(): int {
defer g = 10;
return g;
}
-fn main() {
+fn main(): int {
print(test());
print(g);
}" \
diff --git a/tests/functions.sh b/tests/functions.sh
index e8c79f2..c188bd1 100755
--- a/tests/functions.sh
+++ b/tests/functions.sh
@@ -8,61 +8,61 @@ set -e
echo -n "- Different argument counts: "
assert_exit_status_stdin 5 <<EOF
-fn test() { return 5; }
-fn main() { return test(); }
+fn test(): int { return 5; }
+fn main(): int { return test(); }
EOF
assert_exit_status_stdin 5 <<EOF
-fn test(a: int) { return a; }
-fn main() { return test(5); }
+fn test(a: int): int { return a; }
+fn main(): int { return test(5); }
EOF
assert_exit_status_stdin 5 <<EOF
-fn test(a: int, b: int) { return a+b; }
-fn main() { return test(2, 3); }
+fn test(a: int, b: int): int { return a+b; }
+fn main(): int { return test(2, 3); }
EOF
assert_exit_status_stdin 5 <<EOF
-fn test(a: int, b: int) { let n: int = a + b; return n; }
-fn main() { return test(2, 3); }
+fn test(a: int, b: int): int { let n: int = a + b; return n; }
+fn main(): int { return test(2, 3); }
EOF
assert_exit_status_stdin 5 <<EOF
-fn test(a: int, b: int, c: int, d: int, e: int) { return a+b+c+d+e; }
-fn main() { return test(1,1,1,1,1); }
+fn test(a: int, b: int, c: int, d: int, e: int): int { return a+b+c+d+e; }
+fn main(): int { return test(1,1,1,1,1); }
EOF
assert_compile_failure_stdin <<EOF
-fn test() { return 5; }
-fn main() { return test(5); }
+fn test(): int { return 5; }
+fn main(): int { return test(5); }
EOF
assert_compile_failure_stdin <<EOF
-fn test(a: int, b: int, c: int) { return 5; }
-fn main() { return test(5); }
+fn test(a: int, b: int, c: int): int { return 5; }
+fn main(): int { return test(5); }
EOF
assert_compile_failure_stdin <<EOF
-fn test(a: int, b: int, c: int) { return 5; }
-fn main() { return test(5, 6, 5, 8); }
+fn test(a: int, b: int, c: int): int { return 5; }
+fn main(): int { return test(5, 6, 5, 8); }
EOF
echo " OK"
echo -n "- Recursion: "
assert_exit_status_stdin 3 <<EOF
-fn test(n: int) { return n == 0 ? 0 : 1 + test(n-1); }
-fn main() { return test(3); }
+fn test(n: int): int { return n == 0 ? 0 : 1 + test(n-1); }
+fn main(): int { return test(3); }
EOF
assert_exit_status_stdin 55 <<EOF
-fn test(n: int) { return n == 0 ? 0 : n + test(n-1); }
-fn main() { return test(10); }
+fn test(n: int): int { return n == 0 ? 0 : n + test(n-1); }
+fn main(): int { return test(10); }
EOF
assert_exit_status_stdin 55 <<EOF
-fn test(n: int, accum: int) { return n == 0 ? accum : test(n-1, n+accum); }
-fn main() { return test(10,0); }
+fn test(n: int, accum: int): int { return n == 0 ? accum : test(n-1, n+accum); }
+fn main(): int { return test(10,0); }
EOF
echo " OK" \ No newline at end of file
diff --git a/tests/loops.sh b/tests/loops.sh
index 3fea60f..ab0a949 100755
--- a/tests/loops.sh
+++ b/tests/loops.sh
@@ -6,7 +6,7 @@ set -e
echo -n "- While loops: "
assert_exit_status_stdin 5 <<EOF
-fn main() {
+fn main(): int {
while (1) {
return 5;
}
@@ -15,7 +15,7 @@ fn main() {
EOF
assert_exit_status_stdin 3 <<EOF
-fn main() {
+fn main(): int {
while (0) {
return 5;
}
@@ -24,7 +24,7 @@ fn main() {
EOF
assert_exit_status_stdin 10 <<EOF
-fn main() {
+fn main(): int {
let sum: int = 0;
while (sum < 10) {
sum = sum + 1;
@@ -34,7 +34,7 @@ fn main() {
EOF
assert_exit_status_stdin 55 <<EOF
-fn main() {
+fn main(): int {
let sum: int = 0;
let N: int = 10;
let i: int = 0;
@@ -49,7 +49,7 @@ echo " OK"
echo -n "- For loops: "
assert_exit_status_stdin 5 <<EOF
-fn main() {
+fn main(): int {
for (;;) {
return 5;
}
@@ -58,7 +58,7 @@ fn main() {
EOF
assert_exit_status_stdin 3 <<EOF
-fn main() {
+fn main(): int {
for (;0;) {
return 5;
}
@@ -67,7 +67,7 @@ fn main() {
EOF
assert_exit_status_stdin 55 <<EOF
-fn main() {
+fn main(): int {
let sum: int = 0;
let i: int;
for (i = 0; i <= 10; i = i + 1) {
@@ -78,7 +78,7 @@ fn main() {
EOF
assert_exit_status_stdin 55 <<EOF
-fn main() {
+fn main(): int {
let sum: int = 0;
let i: int = 0;
for (; i <= 10; i = i + 1) {
@@ -89,7 +89,7 @@ fn main() {
EOF
assert_exit_status_stdin 45 <<EOF
-fn main() {
+fn main(): int {
let sum: int = 0;
let i: int = 0;
for (;i < 10;) {
@@ -101,7 +101,7 @@ fn main() {
EOF
assert_exit_status_stdin 55 <<EOF
-fn main() {
+fn main(): int {
let sum: int = 0;
let i: int = 0;
for (;;) {
diff --git a/tests/variables.sh b/tests/variables.sh
index ffa66c3..93e9bd8 100755
--- a/tests/variables.sh
+++ b/tests/variables.sh
@@ -5,12 +5,12 @@
set -e
echo -n "- One variable: "
-assert_exit_status 'fn main() { let x: int; x = 45; return x; }' 45
-assert_exit_status 'fn main() { let x: int = 45; return x; }' 45
-assert_exit_status 'fn main() { let x: int = 45; return x+x; }' 90
+assert_exit_status 'fn main(): int { let x: int; x = 45; return x; }' 45
+assert_exit_status 'fn main(): int { let x: int = 45; return x; }' 45
+assert_exit_status 'fn main(): int { let x: int = 45; return x+x; }' 90
assert_exit_status_stdin 5 <<EOF
-fn main() {
+fn main(): int {
let x: int;
x = 3;
x = 5;
@@ -19,7 +19,7 @@ fn main() {
EOF
assert_exit_status_stdin 5 <<EOF
-fn main() {
+fn main(): int {
let x: int = 3;
x = x + x - 1;
return x;
@@ -30,7 +30,7 @@ echo " OK"
echo -n "- Multiple variable: "
assert_exit_status_stdin 2 <<EOF
-fn main() {
+fn main(): int {
let x: int = 1;
let y: int = x + x;
return y;
@@ -38,7 +38,7 @@ fn main() {
EOF
assert_exit_status_stdin 23 <<EOF
-fn main() {
+fn main(): int {
let x: int = 1;
let y: int = x + x;
let z: int = y + y;
@@ -49,7 +49,7 @@ fn main() {
EOF
assert_exit_status_stdin 2 <<EOF
-fn main() {
+fn main(): int {
let x: int = 1;
let y: int = x + x;
y = y + x;
@@ -59,7 +59,7 @@ fn main() {
EOF
assert_exit_status_stdin 18 <<EOF
-fn main() {
+fn main(): int {
let x: int = 5;
let y: int;
let z: int = (y = x + 3) + 2;
@@ -71,7 +71,7 @@ echo " OK"
echo -n "- Global variables: "
assert_exit_status_stdin 18 <<EOF
let g: int;
-fn main() {
+fn main(): int {
g = 18;
return g;
}
@@ -80,7 +80,7 @@ EOF
assert_exit_status_stdin 18 <<EOF
let g: int;
let h: int;
-fn main() {
+fn main(): int {
g = 18;
h = g + g;
return h - g;
@@ -96,7 +96,7 @@ fn test() {
h = g + g;
}
-fn main() {
+fn main(): int {
test();
return h - g;
}
@@ -105,7 +105,7 @@ EOF
assert_compile_failure_stdin <<EOF
let g: int = 0;
-fn main() {
+fn main(): int {
return g;
}
EOF
@@ -113,7 +113,7 @@ echo " OK"
echo -n "- Nested Blocks: "
assert_exit_status_stdin 3 <<EOF
-fn main() {
+fn main(): int {
let x: int = 1;
{
let y: int = 3;
@@ -216,7 +216,7 @@ echo " OK"
echo -n "- Conditionals w/ blocks: "
assert_exit_status_stdin 3 <<EOF
-fn main() {
+fn main(): int {
let x: int = 1;
if (x == 1) {
let y: int = 3;
@@ -227,7 +227,7 @@ fn main() {
EOF
assert_exit_status_stdin 1 <<EOF
-fn main() {
+fn main(): int {
let x: int = 1;
if (x != 1) {
let y: int = 3;
@@ -238,7 +238,7 @@ fn main() {
EOF
assert_exit_status_stdin 5 <<EOF
-fn main() {
+fn main(): int {
let x: int = 1;
if (x != 1) {
let y: int = 3;
@@ -252,7 +252,7 @@ fn main() {
EOF
assert_compile_failure_stdin <<EOF
-fn main() {
+fn main(): int {
let x: int = 1;
if (x != 1) {
let y: int = 3;