aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorMustafa Quraish <[email protected]>2022-02-02 07:20:53 -0500
committerMustafa Quraish <[email protected]>2022-02-02 07:37:39 -0500
commit1a8f96c65f94227faa9747ef876a60f3c313c6f1 (patch)
treed80a396958ff2fc752b620cc5314e27e40b58ecb /src
parentUse `type*` instead of `type&` to denote a pointer type (for now) (diff)
downloadcup-1a8f96c65f94227faa9747ef876a60f3c313c6f1.tar.xz
cup-1a8f96c65f94227faa9747ef876a60f3c313c6f1.zip
Type checking of expressions / functions!
This is a bit of a chonky commit, but it adds in the basics of checking the types of expressions / function calls / return types. There's still a lot of work to be done, including: (1) Adding new core types, and casting between allowed types automatically (2) Picking the corrent output type based on input types (for instance float+int == float) (3) We need much better error reporting, the error messages are really vague and unhelpful as-is (4) We also need to work to ensure that a function with a return type actually returns (5) Possible re-factoring to make stuff less hacky when we have more types / structs / arrays / etc.
Diffstat (limited to 'src')
-rw-r--r--src/ast.c29
-rw-r--r--src/ast.h5
-rw-r--r--src/generator.c12
-rw-r--r--src/parser.c168
-rw-r--r--src/utils.c7
-rw-r--r--src/utils.h3
6 files changed, 210 insertions, 14 deletions
diff --git a/src/ast.c b/src/ast.c
index 4d4eb9b..7d9f053 100644
--- a/src/ast.c
+++ b/src/ast.c
@@ -55,6 +55,15 @@ char *data_type_to_str(DataType type)
}
}
+bool type_equals(Type *a, Type *b)
+{
+ if (a == NULL && b == NULL)
+ return true;
+ if (a == NULL || b == NULL)
+ return false;
+ return a->type == b->type && type_equals(a->ptr, b->ptr);
+}
+
i64 size_for_type(Type *type)
{
switch (type->type)
@@ -64,12 +73,24 @@ i64 size_for_type(Type *type)
default: assert(false && "Unreachable type");
}
}
-
Type *type_new(DataType type)
{
- Type *t = calloc(sizeof(Type), 1);
- t->type = type;
- return t;
+ // For the core types, we don't need to allocate any memory, just
+ // return a pointer to a static instance.
+ static Type type_int = {.type = TYPE_INT, .ptr = NULL};
+ if (type == TYPE_INT) return &type_int;
+
+ Type *self = calloc(sizeof(Type), 1);
+ self->type = type;
+ return self;
+}
+
+Node *Node_from_int_literal(i64 value)
+{
+ Node *self = Node_new(AST_LITERAL);
+ self->literal.type = self->expr_type = type_new(TYPE_INT);
+ self->literal.as_int = value;
+ return self;
}
void print_type_to_file(FILE *out, Type *type)
diff --git a/src/ast.h b/src/ast.h
index 0966550..6773c0b 100644
--- a/src/ast.h
+++ b/src/ast.h
@@ -73,6 +73,8 @@ typedef struct data_type_node {
Type *type_new(DataType type);
i64 size_for_type(Type *type);
+bool type_equals(Type *a, Type *b);
+void print_type_to_file(FILE *out, Type *type);
typedef struct {
char *name;
@@ -83,6 +85,7 @@ typedef struct {
typedef struct ast_node Node;
typedef struct ast_node {
NodeType type;
+ Type *expr_type;
union {
// Binary expr
@@ -163,4 +166,6 @@ typedef struct ast_node {
void Node_add_child(Node *parent, Node *child);
Node *Node_new(NodeType type);
+Node *Node_from_int_literal(i64 value);
+
void print_ast(Node *node); \ No newline at end of file
diff --git a/src/generator.c b/src/generator.c
index 8663116..f9aa787 100644
--- a/src/generator.c
+++ b/src/generator.c
@@ -24,6 +24,16 @@ void make_syscall(i64 syscall_no, FILE *out) {
fprintf(out, " syscall\n");
}
+char *specifier_for_type(Type *type) {
+ switch (size_for_type(type)) {
+ case 1: return "byte";
+ case 2: return "word";
+ case 4: return "dword";
+ case 8: return "qword";
+ default: assert(false && "Unreachable");
+ }
+}
+
void generate_expr_into_rax(Node *expr, FILE *out);
void generate_lvalue_into_rax(Node *node, FILE *out)
@@ -85,7 +95,7 @@ void generate_expr_into_rax(Node *expr, FILE *out)
fprintf(out, " push rax\n");
generate_expr_into_rax(expr->assign.value, out);
fprintf(out, " pop rbx\n");
- fprintf(out, " mov [rbx], rax\n");
+ fprintf(out, " mov %s [rbx], rax\n", specifier_for_type(var->expr_type));
} else if (expr->type == OP_NEG) {
generate_expr_into_rax(expr->unary_expr, out);
diff --git a/src/parser.c b/src/parser.c
index 85c4d6a..e391500 100644
--- a/src/parser.c
+++ b/src/parser.c
@@ -144,7 +144,7 @@ Node *find_function_definition(Token *token)
void add_global_variable(Variable *var)
{
var->offset = global_vars_offset;
- int var_size = size_for_type(var->type);
+ int var_size = align_up(size_for_type(var->type), 8);
global_vars_offset += var_size;
global_vars[global_vars_count++] = var;
}
@@ -157,7 +157,7 @@ void add_variable_to_current_block(Variable *var)
int new_len = (cur_block->block.num_locals + 1);
// TODO: Align the stack to a certain size?
- int var_size = size_for_type(var->type);
+ int var_size = align_up(size_for_type(var->type), 8);
// Add to the block
// FIXME: Use a map here
@@ -205,6 +205,7 @@ Node *parse_literal(Lexer *lexer)
Node *node = Node_new(AST_LITERAL);
Token token = assert_token(Lexer_next(lexer), TOKEN_INTLIT);
node->literal.type = type_new(TYPE_INT);
+ node->expr_type = node->literal.type;
node->literal.as_int = token.value.as_int;
return node;
}
@@ -243,6 +244,10 @@ Node *parse_var_declaration(Lexer *lexer)
if (is_global)
die_location(token.loc, "Cannot initialize global variable `%s` outside function", node->var_decl.var.name);
node->var_decl.value = parse_expression(lexer);
+
+ if (!type_equals(node->var_decl.var.type, node->var_decl.value->expr_type))
+ die_location(token.loc, "Type mismatch for variable declaration `%s` initalizer", node->var_decl.var.name);
+
assert_token(Lexer_next(lexer), TOKEN_SEMICOLON);
} else {
assert_token(token, TOKEN_SEMICOLON);
@@ -266,9 +271,6 @@ Node *parse_function_call_args(Lexer *lexer, Node *func)
call->call.args = realloc(call->call.args, sizeof(Node *) * new_size);
call->call.args[call->call.num_args++] = arg;
- if (new_size > func->func.num_args)
- die_location(identifier.loc, "Too many arguments to function `%s`", func->func.name);
-
token = Lexer_peek(lexer);
if (token.type == TOKEN_COMMA) {
Lexer_next(lexer);
@@ -277,8 +279,14 @@ Node *parse_function_call_args(Lexer *lexer, Node *func)
}
if (call->call.num_args != func->func.num_args)
- die_location(identifier.loc, "Too few arguments to function `%s`", func->func.name);
+ die_location(identifier.loc, "Function `%s` expects %d arguments, got %d", func->func.name, func->func.num_args, call->call.num_args);
+ for (int i = 0; i < call->call.num_args; i++) {
+ if (!type_equals(func->func.args[i].type, call->call.args[i]->expr_type)) {
+ die_location(identifier.loc, "Type mismatch for argument %d in function call `%s`", i, func->func.name);
+ }
+ }
+ call->expr_type = call->call.func->func.return_type;
assert_token(Lexer_next(lexer), TOKEN_CLOSE_PAREN);
return call;
}
@@ -294,6 +302,7 @@ Node *parse_identifier(Lexer *lexer)
Lexer_next(lexer);
expr = Node_new(AST_LOCAL_VAR);
expr->variable = var;
+ expr->expr_type = var->type;
return expr;
}
@@ -302,6 +311,7 @@ Node *parse_identifier(Lexer *lexer)
Lexer_next(lexer);
expr = Node_new(AST_GLOBAL_VAR);
expr->variable = gvar;
+ expr->expr_type = gvar->type;
return expr;
}
@@ -319,23 +329,145 @@ Node *parse_identifier(Lexer *lexer)
return NULL;
}
+Node *handle_unary_expr_types(Node *node, Token *token)
+{
+ Type *old_type = node->unary_expr->expr_type;
+
+ if (node->type == OP_NOT) {
+ node->expr_type = type_new(TYPE_INT);
+ } else if (node->type == OP_ADDROF) {
+ Type *ptr = type_new(TYPE_PTR);
+ ptr->ptr = old_type;
+ node->expr_type = ptr;
+ } else if (node->type == OP_DEREF) {
+ if (old_type->type != TYPE_PTR)
+ die_location(token->loc, "Cannot dereference non-pointer type");
+ node->expr_type = old_type->ptr;
+ } else if (node->type == OP_NEG) {
+ if (old_type->type != TYPE_INT)
+ die_location(token->loc, "Cannot negate non-integer type");
+ node->expr_type = type_new(TYPE_INT);
+ } else {
+ // Default to not changing the type
+ node->expr_type = old_type;
+ }
+ // die_location(token->loc, "Unknown unary expression type in handle_unary_expr_types\n");
+ return node;
+}
+
+Node *handle_binary_expr_types(Node *node, Token *token)
+{
+ Type *left = node->binary.left->expr_type;
+ Type *right = node->binary.right->expr_type;
+
+ switch (node->type)
+ {
+ case OP_PLUS: {
+ if (left->type == TYPE_INT && right->type == TYPE_INT) {
+ node->expr_type = type_new(TYPE_INT);
+ } else if (left->type == TYPE_PTR && right->type == TYPE_INT) {
+ node->expr_type = left;
+ // Pointer arithmetic!
+ Node *mul = Node_new(OP_MUL);
+ mul->binary.left = node->binary.right;
+ mul->binary.right = Node_new(AST_LITERAL);
+ mul->binary.right->literal.type = type_new(TYPE_INT);
+ mul->binary.right->literal.as_int = size_for_type(left->ptr);
+ node->binary.right = mul;
+ } else if (left->type == TYPE_INT && right->type == TYPE_PTR) {
+ node->expr_type = right;
+ // Pointer arithmetic!
+ Node *mul = Node_new(OP_MUL);
+ mul->binary.left = node->binary.left;
+ mul->binary.right = Node_new(AST_LITERAL);
+ mul->binary.right->literal.type = type_new(TYPE_INT);
+ mul->binary.right->literal.as_int = size_for_type(right->ptr);
+ node->binary.left = mul;
+ } else {
+ die_location(token->loc, "Cannot add non-integer types");
+ }
+ } break;
+
+ case OP_MINUS: {
+ if (left->type == TYPE_INT && right->type == TYPE_INT) {
+ node->expr_type = type_new(TYPE_INT);
+ } else if (left->type == TYPE_PTR && right->type == TYPE_INT) {
+ node->expr_type = left;
+ // Pointer arithmetic!
+ Node *mul = Node_new(OP_MUL);
+ mul->binary.left = node->binary.right;
+ mul->binary.right = Node_from_int_literal(size_for_type(left->ptr));
+ node->binary.right = mul;
+ } else if (left->type == TYPE_INT && right->type == TYPE_PTR) {
+ node->expr_type = right;
+ // Pointer arithmetic!
+ Node *mul = Node_new(OP_MUL);
+ mul->binary.left = node->binary.left;
+ mul->binary.right = Node_from_int_literal(size_for_type(right->ptr));
+ node->binary.left = mul;
+ } else if (left->type == TYPE_PTR && right->type == TYPE_PTR) {
+ // TODO: Check for different pointer types
+ node->expr_type = type_new(TYPE_INT);
+ // Divide by size of pointer
+ Node *div = Node_new(OP_DIV);
+ div->binary.left = node;
+ div->binary.right = Node_from_int_literal(size_for_type(left->ptr));
+ div->expr_type = node->expr_type;
+ node = div;
+ } else {
+ die_location(token->loc, "Cannot subtract non-integer types");
+ }
+ } break;
+
+ case OP_DIV:
+ case OP_MOD:
+ case OP_MUL: {
+ if (left->type == TYPE_INT && right->type == TYPE_INT) {
+ node->expr_type = left;
+ } else {
+ die_location(token->loc, "Cannot do operation `%s` non-integer types", node_type_to_str(node->type));
+ }
+ } break;
+
+ case OP_EQ:
+ case OP_NEQ:
+ case OP_LT:
+ case OP_GT:
+ case OP_LEQ:
+ case OP_GEQ:
+ case OP_AND:
+ case OP_OR: {
+ node->expr_type = type_new(TYPE_INT);
+ } break;
+
+ default:
+ die_location(token->loc, "Unknown binary expression type in handle_binary_expr_types\n");
+ }
+ return node;
+}
+
Node *parse_factor(Lexer *lexer)
{
- // TODO: Parse more complicated things
+ // TODO: We need to properly handle type conversions / operations with different types
+ // where we need to cast one of the operands / etc. Perhaps have a separate
+ // type-checking / adding casts/conversions pass?
Token token = Lexer_peek(lexer);
Node *expr = NULL;
if (token.type == TOKEN_MINUS) {
Lexer_next(lexer);
expr = Node_new(OP_NEG);
expr->unary_expr = parse_factor(lexer);
+ expr = handle_unary_expr_types(expr, &token);
} else if (token.type == TOKEN_TILDE) {
Lexer_next(lexer);
expr = Node_new(OP_BWINV);
expr->unary_expr = parse_factor(lexer);
+ expr = handle_unary_expr_types(expr, &token);
} else if (token.type == TOKEN_EXCLAMATION) {
Lexer_next(lexer);
expr = Node_new(OP_NOT);
expr->unary_expr = parse_factor(lexer);
+ expr = handle_unary_expr_types(expr, &token);
} else if (token.type == TOKEN_OPEN_PAREN) {
Lexer_next(lexer);
expr = parse_expression(lexer);
@@ -350,6 +482,7 @@ Node *parse_factor(Lexer *lexer)
expr->unary_expr = parse_factor(lexer);
if (!is_lvalue(expr->unary_expr->type))
die_location(token.loc, "Cannot take address of non-lvalue");
+ expr = handle_unary_expr_types(expr, &token);
} else if (token.type == TOKEN_STAR) {
Lexer_next(lexer);
expr = Node_new(OP_DEREF);
@@ -357,6 +490,7 @@ Node *parse_factor(Lexer *lexer)
// to work, we need to to be able to evaluate the type for complex expressions,
// which we do not support as of now.
expr->unary_expr = parse_factor(lexer);
+ expr = handle_unary_expr_types(expr, &token);
} else {
die_location(token.loc, ": Expected token found in parse_factor: `%s`", token_type_to_str(token.type));
}
@@ -372,6 +506,7 @@ Node *parse_factor(Lexer *lexer)
Node *right = next_parser(lexer); \
op->binary.left = expr; \
op->binary.right = right; \
+ op = handle_binary_expr_types(op, &token); \
expr = op; \
token = Lexer_peek(lexer); \
} \
@@ -408,7 +543,13 @@ Node *parse_conditional_exp(Lexer *lexer)
conditional->conditional.cond = expr;
conditional->conditional.do_then = then_expr;
conditional->conditional.do_else = else_expr;
+
+ if (!type_equals(then_expr->expr_type, else_expr->expr_type))
+ die_location(token.loc, "Type mismatch in conditional expression cases");
+
+ conditional->expr_type = then_expr->expr_type;
expr = conditional;
+ expr->expr_type = then_expr->expr_type;
}
return expr;
}
@@ -425,7 +566,12 @@ Node *parse_expression(Lexer *lexer)
Node *assign = Node_new(OP_ASSIGN);
assign->assign.var = node;
assign->assign.value = parse_expression(lexer);
+
+ if (!type_equals(node->expr_type, assign->assign.value->expr_type))
+ die_location(token.loc, "Type mismatch in assignment expression");
+
node = assign;
+ node->expr_type = node->assign.var->expr_type;
}
}
return node;
@@ -442,6 +588,10 @@ Node *parse_statement(Lexer *lexer)
assert_token(Lexer_next(lexer), TOKEN_RETURN);
node = Node_new(AST_RETURN);
node->unary_expr = parse_expression(lexer);
+
+ if (!type_equals(node->unary_expr->expr_type, current_function->func.return_type))
+ die_location(token.loc, "Return expression does not match function's return type");
+
assert_token(Lexer_next(lexer), TOKEN_SEMICOLON);
} else if (token.type == TOKEN_IF) {
Lexer_next(lexer);
@@ -569,7 +719,9 @@ void parse_func_args(Lexer *lexer, Node *func)
Variable *var = &func->func.args[i];
var->offset = offset;
// TODO: Do we need to align the stack here?
- int var_size = size_for_type(var->type);
+ // TODO: (Here and other uses of `size_for_type`):
+ // Should we only align to max(8, type->size) instead?
+ int var_size = align_up(size_for_type(var->type), 8);
offset -= var_size;
}
}
diff --git a/src/utils.c b/src/utils.c
index 2335981..939021a 100644
--- a/src/utils.c
+++ b/src/utils.c
@@ -27,4 +27,9 @@ void _die_location(char *file, int line, Location loc, const char *fmt, ...)
}
i64 i64max(i64 a, i64 b) { return a > b ? a : b; }
-i64 i64min(i64 a, i64 b) { return a < b ? a : b; } \ No newline at end of file
+i64 i64min(i64 a, i64 b) { return a < b ? a : b; }
+
+i64 align_up(i64 val, i64 align)
+{
+ return (val + align - 1) & ~(align - 1);
+} \ No newline at end of file
diff --git a/src/utils.h b/src/utils.h
index dfc018c..7718c1c 100644
--- a/src/utils.h
+++ b/src/utils.h
@@ -9,4 +9,7 @@ void _die_location(char *file, int line, Location loc, const char *fmt, ...);
i64 i64max(i64 a, i64 b);
i64 i64min(i64 a, i64 b);
+// Assumes alignment is a power of 2
+i64 align_up(i64 val, i64 align);
+
#define die_location(loc, ...) _die_location(__FILE__, __LINE__, loc, __VA_ARGS__) \ No newline at end of file