diff options
| author | Mustafa Quraish <[email protected]> | 2022-02-02 07:20:53 -0500 |
|---|---|---|
| committer | Mustafa Quraish <[email protected]> | 2022-02-02 07:37:39 -0500 |
| commit | 1a8f96c65f94227faa9747ef876a60f3c313c6f1 (patch) | |
| tree | d80a396958ff2fc752b620cc5314e27e40b58ecb /src/parser.c | |
| parent | Use `type*` instead of `type&` to denote a pointer type (for now) (diff) | |
| download | cup-1a8f96c65f94227faa9747ef876a60f3c313c6f1.tar.xz cup-1a8f96c65f94227faa9747ef876a60f3c313c6f1.zip | |
Type checking of expressions / functions!
This is a bit of a chonky commit, but it adds in the basics of checking
the types of expressions / function calls / return types. There's still
a lot of work to be done, including:
(1) Adding new core types, and casting between allowed types
automatically
(2) Picking the corrent output type based on input types (for instance
float+int == float)
(3) We need much better error reporting, the error messages are really
vague and unhelpful as-is
(4) We also need to work to ensure that a function with a return type
actually returns
(5) Possible re-factoring to make stuff less hacky when we have more
types / structs / arrays / etc.
Diffstat (limited to 'src/parser.c')
| -rw-r--r-- | src/parser.c | 168 |
1 files changed, 160 insertions, 8 deletions
diff --git a/src/parser.c b/src/parser.c index 85c4d6a..e391500 100644 --- a/src/parser.c +++ b/src/parser.c @@ -144,7 +144,7 @@ Node *find_function_definition(Token *token) void add_global_variable(Variable *var) { var->offset = global_vars_offset; - int var_size = size_for_type(var->type); + int var_size = align_up(size_for_type(var->type), 8); global_vars_offset += var_size; global_vars[global_vars_count++] = var; } @@ -157,7 +157,7 @@ void add_variable_to_current_block(Variable *var) int new_len = (cur_block->block.num_locals + 1); // TODO: Align the stack to a certain size? - int var_size = size_for_type(var->type); + int var_size = align_up(size_for_type(var->type), 8); // Add to the block // FIXME: Use a map here @@ -205,6 +205,7 @@ Node *parse_literal(Lexer *lexer) Node *node = Node_new(AST_LITERAL); Token token = assert_token(Lexer_next(lexer), TOKEN_INTLIT); node->literal.type = type_new(TYPE_INT); + node->expr_type = node->literal.type; node->literal.as_int = token.value.as_int; return node; } @@ -243,6 +244,10 @@ Node *parse_var_declaration(Lexer *lexer) if (is_global) die_location(token.loc, "Cannot initialize global variable `%s` outside function", node->var_decl.var.name); node->var_decl.value = parse_expression(lexer); + + if (!type_equals(node->var_decl.var.type, node->var_decl.value->expr_type)) + die_location(token.loc, "Type mismatch for variable declaration `%s` initalizer", node->var_decl.var.name); + assert_token(Lexer_next(lexer), TOKEN_SEMICOLON); } else { assert_token(token, TOKEN_SEMICOLON); @@ -266,9 +271,6 @@ Node *parse_function_call_args(Lexer *lexer, Node *func) call->call.args = realloc(call->call.args, sizeof(Node *) * new_size); call->call.args[call->call.num_args++] = arg; - if (new_size > func->func.num_args) - die_location(identifier.loc, "Too many arguments to function `%s`", func->func.name); - token = Lexer_peek(lexer); if (token.type == TOKEN_COMMA) { Lexer_next(lexer); @@ -277,8 +279,14 @@ Node *parse_function_call_args(Lexer *lexer, Node *func) } if (call->call.num_args != func->func.num_args) - die_location(identifier.loc, "Too few arguments to function `%s`", func->func.name); + die_location(identifier.loc, "Function `%s` expects %d arguments, got %d", func->func.name, func->func.num_args, call->call.num_args); + for (int i = 0; i < call->call.num_args; i++) { + if (!type_equals(func->func.args[i].type, call->call.args[i]->expr_type)) { + die_location(identifier.loc, "Type mismatch for argument %d in function call `%s`", i, func->func.name); + } + } + call->expr_type = call->call.func->func.return_type; assert_token(Lexer_next(lexer), TOKEN_CLOSE_PAREN); return call; } @@ -294,6 +302,7 @@ Node *parse_identifier(Lexer *lexer) Lexer_next(lexer); expr = Node_new(AST_LOCAL_VAR); expr->variable = var; + expr->expr_type = var->type; return expr; } @@ -302,6 +311,7 @@ Node *parse_identifier(Lexer *lexer) Lexer_next(lexer); expr = Node_new(AST_GLOBAL_VAR); expr->variable = gvar; + expr->expr_type = gvar->type; return expr; } @@ -319,23 +329,145 @@ Node *parse_identifier(Lexer *lexer) return NULL; } +Node *handle_unary_expr_types(Node *node, Token *token) +{ + Type *old_type = node->unary_expr->expr_type; + + if (node->type == OP_NOT) { + node->expr_type = type_new(TYPE_INT); + } else if (node->type == OP_ADDROF) { + Type *ptr = type_new(TYPE_PTR); + ptr->ptr = old_type; + node->expr_type = ptr; + } else if (node->type == OP_DEREF) { + if (old_type->type != TYPE_PTR) + die_location(token->loc, "Cannot dereference non-pointer type"); + node->expr_type = old_type->ptr; + } else if (node->type == OP_NEG) { + if (old_type->type != TYPE_INT) + die_location(token->loc, "Cannot negate non-integer type"); + node->expr_type = type_new(TYPE_INT); + } else { + // Default to not changing the type + node->expr_type = old_type; + } + // die_location(token->loc, "Unknown unary expression type in handle_unary_expr_types\n"); + return node; +} + +Node *handle_binary_expr_types(Node *node, Token *token) +{ + Type *left = node->binary.left->expr_type; + Type *right = node->binary.right->expr_type; + + switch (node->type) + { + case OP_PLUS: { + if (left->type == TYPE_INT && right->type == TYPE_INT) { + node->expr_type = type_new(TYPE_INT); + } else if (left->type == TYPE_PTR && right->type == TYPE_INT) { + node->expr_type = left; + // Pointer arithmetic! + Node *mul = Node_new(OP_MUL); + mul->binary.left = node->binary.right; + mul->binary.right = Node_new(AST_LITERAL); + mul->binary.right->literal.type = type_new(TYPE_INT); + mul->binary.right->literal.as_int = size_for_type(left->ptr); + node->binary.right = mul; + } else if (left->type == TYPE_INT && right->type == TYPE_PTR) { + node->expr_type = right; + // Pointer arithmetic! + Node *mul = Node_new(OP_MUL); + mul->binary.left = node->binary.left; + mul->binary.right = Node_new(AST_LITERAL); + mul->binary.right->literal.type = type_new(TYPE_INT); + mul->binary.right->literal.as_int = size_for_type(right->ptr); + node->binary.left = mul; + } else { + die_location(token->loc, "Cannot add non-integer types"); + } + } break; + + case OP_MINUS: { + if (left->type == TYPE_INT && right->type == TYPE_INT) { + node->expr_type = type_new(TYPE_INT); + } else if (left->type == TYPE_PTR && right->type == TYPE_INT) { + node->expr_type = left; + // Pointer arithmetic! + Node *mul = Node_new(OP_MUL); + mul->binary.left = node->binary.right; + mul->binary.right = Node_from_int_literal(size_for_type(left->ptr)); + node->binary.right = mul; + } else if (left->type == TYPE_INT && right->type == TYPE_PTR) { + node->expr_type = right; + // Pointer arithmetic! + Node *mul = Node_new(OP_MUL); + mul->binary.left = node->binary.left; + mul->binary.right = Node_from_int_literal(size_for_type(right->ptr)); + node->binary.left = mul; + } else if (left->type == TYPE_PTR && right->type == TYPE_PTR) { + // TODO: Check for different pointer types + node->expr_type = type_new(TYPE_INT); + // Divide by size of pointer + Node *div = Node_new(OP_DIV); + div->binary.left = node; + div->binary.right = Node_from_int_literal(size_for_type(left->ptr)); + div->expr_type = node->expr_type; + node = div; + } else { + die_location(token->loc, "Cannot subtract non-integer types"); + } + } break; + + case OP_DIV: + case OP_MOD: + case OP_MUL: { + if (left->type == TYPE_INT && right->type == TYPE_INT) { + node->expr_type = left; + } else { + die_location(token->loc, "Cannot do operation `%s` non-integer types", node_type_to_str(node->type)); + } + } break; + + case OP_EQ: + case OP_NEQ: + case OP_LT: + case OP_GT: + case OP_LEQ: + case OP_GEQ: + case OP_AND: + case OP_OR: { + node->expr_type = type_new(TYPE_INT); + } break; + + default: + die_location(token->loc, "Unknown binary expression type in handle_binary_expr_types\n"); + } + return node; +} + Node *parse_factor(Lexer *lexer) { - // TODO: Parse more complicated things + // TODO: We need to properly handle type conversions / operations with different types + // where we need to cast one of the operands / etc. Perhaps have a separate + // type-checking / adding casts/conversions pass? Token token = Lexer_peek(lexer); Node *expr = NULL; if (token.type == TOKEN_MINUS) { Lexer_next(lexer); expr = Node_new(OP_NEG); expr->unary_expr = parse_factor(lexer); + expr = handle_unary_expr_types(expr, &token); } else if (token.type == TOKEN_TILDE) { Lexer_next(lexer); expr = Node_new(OP_BWINV); expr->unary_expr = parse_factor(lexer); + expr = handle_unary_expr_types(expr, &token); } else if (token.type == TOKEN_EXCLAMATION) { Lexer_next(lexer); expr = Node_new(OP_NOT); expr->unary_expr = parse_factor(lexer); + expr = handle_unary_expr_types(expr, &token); } else if (token.type == TOKEN_OPEN_PAREN) { Lexer_next(lexer); expr = parse_expression(lexer); @@ -350,6 +482,7 @@ Node *parse_factor(Lexer *lexer) expr->unary_expr = parse_factor(lexer); if (!is_lvalue(expr->unary_expr->type)) die_location(token.loc, "Cannot take address of non-lvalue"); + expr = handle_unary_expr_types(expr, &token); } else if (token.type == TOKEN_STAR) { Lexer_next(lexer); expr = Node_new(OP_DEREF); @@ -357,6 +490,7 @@ Node *parse_factor(Lexer *lexer) // to work, we need to to be able to evaluate the type for complex expressions, // which we do not support as of now. expr->unary_expr = parse_factor(lexer); + expr = handle_unary_expr_types(expr, &token); } else { die_location(token.loc, ": Expected token found in parse_factor: `%s`", token_type_to_str(token.type)); } @@ -372,6 +506,7 @@ Node *parse_factor(Lexer *lexer) Node *right = next_parser(lexer); \ op->binary.left = expr; \ op->binary.right = right; \ + op = handle_binary_expr_types(op, &token); \ expr = op; \ token = Lexer_peek(lexer); \ } \ @@ -408,7 +543,13 @@ Node *parse_conditional_exp(Lexer *lexer) conditional->conditional.cond = expr; conditional->conditional.do_then = then_expr; conditional->conditional.do_else = else_expr; + + if (!type_equals(then_expr->expr_type, else_expr->expr_type)) + die_location(token.loc, "Type mismatch in conditional expression cases"); + + conditional->expr_type = then_expr->expr_type; expr = conditional; + expr->expr_type = then_expr->expr_type; } return expr; } @@ -425,7 +566,12 @@ Node *parse_expression(Lexer *lexer) Node *assign = Node_new(OP_ASSIGN); assign->assign.var = node; assign->assign.value = parse_expression(lexer); + + if (!type_equals(node->expr_type, assign->assign.value->expr_type)) + die_location(token.loc, "Type mismatch in assignment expression"); + node = assign; + node->expr_type = node->assign.var->expr_type; } } return node; @@ -442,6 +588,10 @@ Node *parse_statement(Lexer *lexer) assert_token(Lexer_next(lexer), TOKEN_RETURN); node = Node_new(AST_RETURN); node->unary_expr = parse_expression(lexer); + + if (!type_equals(node->unary_expr->expr_type, current_function->func.return_type)) + die_location(token.loc, "Return expression does not match function's return type"); + assert_token(Lexer_next(lexer), TOKEN_SEMICOLON); } else if (token.type == TOKEN_IF) { Lexer_next(lexer); @@ -569,7 +719,9 @@ void parse_func_args(Lexer *lexer, Node *func) Variable *var = &func->func.args[i]; var->offset = offset; // TODO: Do we need to align the stack here? - int var_size = size_for_type(var->type); + // TODO: (Here and other uses of `size_for_type`): + // Should we only align to max(8, type->size) instead? + int var_size = align_up(size_for_type(var->type), 8); offset -= var_size; } } |