diff options
| author | Mustafa Quraish <[email protected]> | 2022-02-03 21:01:03 -0500 |
|---|---|---|
| committer | Mustafa Quraish <[email protected]> | 2022-02-03 21:01:03 -0500 |
| commit | 9f3edcbeea51dd8841b9d89fa7ef98f11682a1c7 (patch) | |
| tree | b2a4b6fa51ad558b36facde2cf952a6af5a25669 /src | |
| parent | Add automatic type inference for initialized variable declarations (diff) | |
| download | cup-9f3edcbeea51dd8841b9d89fa7ef98f11682a1c7.tar.xz cup-9f3edcbeea51dd8841b9d89fa7ef98f11682a1c7.zip | |
Add support for basic structs
Structs for now (and probably for the near future) are not allowed
to be passed by value, and instead you just pass a pointer to it.
Nested structs can also be defined, and they can be either anonymous,
or named (in which case only the members can access the type).
Diffstat (limited to 'src')
| -rw-r--r-- | src/ast.c | 1 | ||||
| -rw-r--r-- | src/ast.h | 7 | ||||
| -rw-r--r-- | src/generator.c | 10 | ||||
| -rw-r--r-- | src/lexer.c | 1 | ||||
| -rw-r--r-- | src/parser.c | 109 | ||||
| -rw-r--r-- | src/tokens.h | 2 | ||||
| -rw-r--r-- | src/types.c | 54 | ||||
| -rw-r--r-- | src/types.h | 14 |
8 files changed, 192 insertions, 6 deletions
@@ -112,6 +112,7 @@ bool is_lvalue(NodeType type) { case AST_LOCAL_VAR: case AST_GLOBAL_VAR: + case OP_MEMBER: case OP_DEREF: // FIXME: Should this be the case? return true; default: return false; @@ -26,6 +26,7 @@ F(OP_GT, ">") \ F(OP_GEQ, ">=") \ F(OP_ASSIGN, "=") \ + F(OP_MEMBER, ".") \ F(AST_LITERAL, "literal") \ F(AST_FUNCCALL, "Function call") \ F(AST_CONDITIONAL, "conditional expression") \ @@ -143,6 +144,12 @@ typedef struct ast_node { Node **args; int num_args; } call; + + struct { + i64 offset; + Node *expr; + bool is_ptr; + } member; }; } Node; diff --git a/src/generator.c b/src/generator.c index 6fcccbf..2fb89b7 100644 --- a/src/generator.c +++ b/src/generator.c @@ -50,13 +50,21 @@ void generate_expr_into_rax(Node *expr, FILE *out); void generate_lvalue_into_rax(Node *node, FILE *out) { assert(is_lvalue(node->type)); - i64 offset = node->variable->offset; if (node->type == AST_LOCAL_VAR) { + i64 offset = node->variable->offset; fprintf(out, " mov rax, rbp\n"); fprintf(out, " sub rax, %lld\n", offset); } else if (node->type == AST_GLOBAL_VAR) { + i64 offset = node->variable->offset; fprintf(out, " mov rax, global_vars\n"); fprintf(out, " add rax, %lld\n", offset); + } else if (node->type == OP_MEMBER) { + i64 offset = node->member.offset; + if (node->member.is_ptr) + generate_expr_into_rax(node->member.expr, out); + else + generate_lvalue_into_rax(node->member.expr, out); + fprintf(out, " add rax, %lld\n", offset); } else if (node->type == OP_DEREF) { generate_expr_into_rax(node->unary_expr, out); } else { diff --git a/src/lexer.c b/src/lexer.c index d1201b2..f02ef94 100644 --- a/src/lexer.c +++ b/src/lexer.c @@ -128,6 +128,7 @@ Token Lexer_next(Lexer *lexer) case ':': return Lexer_make_token(lexer, TOKEN_COLON, 1); case '~': return Lexer_make_token(lexer, TOKEN_TILDE, 1); case '?': return Lexer_make_token(lexer, TOKEN_QUESTION, 1); + case '.': return Lexer_make_token(lexer, TOKEN_DOT, 1); case ',': return Lexer_make_token(lexer, TOKEN_COMMA, 1); case '*': return Lexer_make_token(lexer, TOKEN_STAR, 1); case '%': return Lexer_make_token(lexer, TOKEN_PERCENT, 1); diff --git a/src/parser.c b/src/parser.c index 6bf2f1d..7c01c75 100644 --- a/src/parser.c +++ b/src/parser.c @@ -26,6 +26,11 @@ static i64 global_vars_offset = 0; static Lexer *lexer_stack[LEXER_STACK_SIZE]; static i64 lexer_stack_count = 0; +#define DEFINED_STRUCT_SIZE 128 +static Type *defined_structs[DEFINED_STRUCT_SIZE]; +static i64 defined_structs_count = 0; + + Token do_assert_token(Token token, TokenType type, char *filename, int line) { if (token.type != type) { @@ -43,6 +48,22 @@ Token do_assert_token(Token token, TokenType type, char *filename, int line) * Some helpers */ +void push_struct_definition(Type *type) +{ + assert(defined_structs_count < DEFINED_STRUCT_SIZE); + defined_structs[defined_structs_count++] = type; +} + +Type *find_custom_type_definition(Token *token) +{ + for (i64 i = 0; i < defined_structs_count; i++) { + if (strcmp(defined_structs[i]->struct_name, token->value.as_string) == 0) { + return defined_structs[i]; + } + } + return NULL; +} + void block_stack_push(Node *block) { assert(block_stack_count < BLOCK_STACK_SIZE); @@ -164,7 +185,13 @@ Type *parse_type(Lexer *lexer) Lexer_next(lexer); type = type_new(TYPE_CHAR); } else { - die_location(token.loc, "Unexpected type found: %s", token_type_to_str(token.type)); + assert_token(token, TOKEN_IDENTIFIER); + // TODO: Don't allow a type to contain itself. + // TODO: Don't allow a type to contain an array of itself. + type = find_custom_type_definition(&token); + if (!type) + die_location(token.loc, "Could not find what type `%s` is referencing", token.value.as_string); + Lexer_next(lexer); } for (;;) { @@ -370,7 +397,7 @@ Node *parse_factor(Lexer *lexer) expr = Node_new(OP_BWINV); expr->unary_expr = parse_factor(lexer); expr = handle_unary_expr_types(expr, &token); - + // ++x is changed to (x = x + 1) } else if (token.type == TOKEN_PLUSPLUS) { Lexer_next(lexer); @@ -454,6 +481,26 @@ Node *parse_factor(Lexer *lexer) die_location(token.loc, "Post-incrementing is not supported\n"); } else if (token.type == TOKEN_MINUSMINUS) { die_location(token.loc, "Post-decrementing is not supported\n"); + } else if (token.type == TOKEN_DOT) { + // TODO: Pointer to struct + if (!is_struct_or_struct_ptr(expr->expr_type)) + die_location(token.loc, "Cannot access member of non-struct type"); + + bool is_ptr = expr->expr_type->type == TYPE_PTR; + Type *struct_type = is_ptr ? expr->expr_type->ptr : expr->expr_type; + Lexer_next(lexer); + Token field_token = assert_token(Lexer_next(lexer), TOKEN_IDENTIFIER); + i64 index = find_field_index(struct_type, field_token.value.as_string); + if (index == -1) + die_location(field_token.loc, "Struct `%s` does not have a field named `%s`", type_to_str(struct_type), field_token.value.as_string); + + Node *member = Node_new(OP_MEMBER); + member->expr_type = struct_type->fields.type[index]; + member->member.expr = expr; + member->member.offset = struct_type->fields.offset[index]; + member->member.is_ptr = (expr->expr_type->type == TYPE_PTR); + expr = member; + } else { break; } @@ -683,6 +730,9 @@ void parse_func_args(Lexer *lexer, Node *func) assert_token(Lexer_next(lexer), TOKEN_COLON); Type *type = parse_type(lexer); + if (type->type == TYPE_STRUCT) + die_location(token.loc, "Structs cannot be passed as arguments, maybe pass a pointer?"); + i64 new_count = func->func.num_args + 1; func->func.args = realloc(func->func.args, sizeof(Variable) * new_count); Variable *var = &func->func.args[func->func.num_args++]; @@ -763,6 +813,57 @@ Lexer *remove_lexer() return lexer_stack[lexer_stack_count - 1]; } +Type *parse_struct_declaration(Lexer *lexer, bool is_global) { + i64 prev_struct_count = defined_structs_count; + + assert_token(Lexer_next(lexer), TOKEN_STRUCT); + + Type *struct_type = type_new(TYPE_STRUCT); + + Token token = Lexer_peek(lexer); + // For nested temporary structs we don't need a name + if (token.type != TOKEN_IDENTIFIER && is_global) + die_location(token.loc, "You need to specify a name for the struct defined globally."); + + // But if they do provide one, we'll add it to the list of defined structs so they + // it can referenced internally. + bool has_name = false; + if (token.type == TOKEN_IDENTIFIER) { + struct_type->struct_name = token.value.as_string; + push_struct_definition(struct_type); + Lexer_next(lexer); + has_name = true; + } + + assert_token(Lexer_next(lexer), TOKEN_OPEN_BRACE); + + token = Lexer_peek(lexer); + while (token.type != TOKEN_CLOSE_BRACE) { + token = assert_token(Lexer_next(lexer), TOKEN_IDENTIFIER); + assert_token(Lexer_next(lexer), TOKEN_COLON); + + // We want to allow nested temporary structs. + Type *type; + Token next = Lexer_peek(lexer); + if (next.type == TOKEN_STRUCT) { + type = parse_struct_declaration(lexer, false); + } else { + type = parse_type(lexer); + } + + push_field(struct_type, token.value.as_string, type); + assert_token(Lexer_next(lexer), TOKEN_SEMICOLON); + token = Lexer_peek(lexer); + } + assert_token(Lexer_next(lexer), TOKEN_CLOSE_BRACE); + + // If this is not being defined globally, we want to remove it from the namespace. + if (!is_global) + defined_structs_count = prev_struct_count; + + return struct_type; +} + Node *parse_program(Lexer *lexer) { initialize_builtins(); @@ -778,6 +879,8 @@ Node *parse_program(Lexer *lexer) } else if (token.type == TOKEN_LET) { Node *var_decl = parse_var_declaration(lexer); Node_add_child(program, var_decl); + } else if (token.type == TOKEN_STRUCT) { + parse_struct_declaration(lexer, true); } else if (token.type == TOKEN_IMPORT) { // TODO: Handle circular imports // TODO: Handle complex import graphs (#pragma once) @@ -788,6 +891,8 @@ Node *parse_program(Lexer *lexer) char *filename = token.value.as_string; lexer = Lexer_new_open_file(filename); push_new_lexer(lexer); + } else if (token.type == TOKEN_SEMICOLON) { + Lexer_next(lexer); } else { die_location(token.loc, "Unexpected token in parse_program: `%s`\n", token_type_to_str(token.type)); exit(1); diff --git a/src/tokens.h b/src/tokens.h index f02cfb0..6eaad8e 100644 --- a/src/tokens.h +++ b/src/tokens.h @@ -14,6 +14,7 @@ F(TOKEN_CLOSE_PAREN, ")") \ F(TOKEN_COLON, ":") \ F(TOKEN_COMMA, ",") \ + F(TOKEN_DOT, ".") \ F(TOKEN_EOF, "EOF") \ F(TOKEN_EQ, "==") \ F(TOKEN_EXCLAMATION, "!") \ @@ -55,6 +56,7 @@ F(TOKEN_INT, "int") \ F(TOKEN_LET, "let") \ F(TOKEN_RETURN, "return") \ + F(TOKEN_STRUCT, "struct") \ F(TOKEN_WHILE, "while") \ F(TOKEN_IMPORT, "import") \ diff --git a/src/types.c b/src/types.c index abbba86..5138164 100644 --- a/src/types.c +++ b/src/types.c @@ -36,6 +36,7 @@ i64 size_for_type(Type *type) case TYPE_PTR: return 8; case TYPE_CHAR: return 1; case TYPE_ARRAY: return type->array_size * size_for_type(type->ptr); + case TYPE_STRUCT: return type->fields.size; default: { printf("Unknown type: %d\n", type->type); assert(false && "Unreachable type"); @@ -86,9 +87,18 @@ bool is_int_type(Type *type) } } -static char *data_type_to_str(DataType type) +bool is_struct_or_struct_ptr(Type *type) { - switch (type) + if (type->type == TYPE_STRUCT) + return true; + if (type->type == TYPE_PTR && type->ptr->type == TYPE_STRUCT) + return true; + return false; +} + +static char *data_type_to_str(Type *type) +{ + switch (type->type) { case TYPE_NONE: return "void"; case TYPE_INT: return "int"; @@ -96,6 +106,7 @@ static char *data_type_to_str(DataType type) case TYPE_ARRAY: return "array"; case TYPE_CHAR: return "char"; case TYPE_ANY: return "<@>"; + case TYPE_STRUCT: return type->struct_name; default: assert(false && "Unreachable"); } } @@ -114,16 +125,50 @@ char *type_to_str(Type *type) // FIXME: This is inefficient as all hell but this will only really be // used for error reporting. - strcat(str, data_type_to_str(type->type)); + strcat(str, data_type_to_str(type)); for (int i = 0; i < ptr_count; i++) strcat(str, "*"); return str; } +i64 push_field(Type *type, char *field_name, Type *field_type) +{ + assert(type->type == TYPE_STRUCT); + type->fields.type = realloc(type->fields.type, sizeof(Type *) * (type->fields.num_fields + 1)); + type->fields.offset = realloc(type->fields.offset, sizeof(i64) * (type->fields.num_fields + 1)); + type->fields.name = realloc(type->fields.name, sizeof(char *) * (type->fields.num_fields + 1)); + + i64 field_size = size_for_type(field_type); + i64 offset_factor = i64min(field_size, 8); + i64 offset = align_up(type->fields.size, offset_factor); + + type->fields.type[type->fields.num_fields] = field_type; + type->fields.offset[type->fields.num_fields] = offset; + type->fields.name[type->fields.num_fields] = field_name; + type->fields.size = offset + field_size; + type->fields.num_fields++; + + return offset; +} + +i64 find_field_index(Type *type, char *field_name) +{ + assert(type->type == TYPE_STRUCT); + for (int i = 0; i < type->fields.num_fields; i++) { + if (strcmp(type->fields.name[i], field_name) == 0) + return i; + } + return -1; +} + + Node *handle_unary_expr_types(Node *node, Token *token) { Type *old_type = node->unary_expr->expr_type; + if (node->type != OP_ADDROF && old_type->type == TYPE_STRUCT) + die_location(token->loc, "Performing invalid unary operation on struct type"); + if (node->type == OP_NOT) { node->expr_type = type_new(TYPE_INT); } else if (node->type == OP_ADDROF) { @@ -154,6 +199,9 @@ Node *handle_binary_expr_types(Node *node, Token *token) { Type *left = node->binary.left->expr_type; Type *right = node->binary.right->expr_type; + + if (left->type == TYPE_STRUCT || right->type == TYPE_STRUCT) + die_location(token->loc, "Performing invalid binary operation on struct type"); switch (node->type) { diff --git a/src/types.h b/src/types.h index 73b5b63..e8b9488 100644 --- a/src/types.h +++ b/src/types.h @@ -10,12 +10,21 @@ typedef enum { TYPE_CHAR, TYPE_PTR, TYPE_ARRAY, + TYPE_STRUCT, } DataType; typedef struct data_type_node { DataType type; struct data_type_node *ptr; i64 array_size; + char *struct_name; + struct { + char **name; + struct data_type_node **type; + i64 *offset; + i64 num_fields; + i64 size; + } fields; } Type; Type *type_new(DataType type); @@ -28,6 +37,11 @@ bool type_equals(Type *a, Type *b); bool is_int_type(Type *type); bool is_string_type(Type *type); bool is_convertible(Type *from, Type *to); +bool is_struct_or_struct_ptr(Type *type); + +// Returns the offset of the field in the struct. +i64 push_field(Type *type, char *field_name, Type *field_type); +i64 find_field_index(Type *type, char *field_name); // Type checking / casting expressions to right types typedef struct ast_node Node; |