diff options
| author | Mustafa Quraish <[email protected]> | 2022-02-03 21:01:03 -0500 |
|---|---|---|
| committer | Mustafa Quraish <[email protected]> | 2022-02-03 21:01:03 -0500 |
| commit | 9f3edcbeea51dd8841b9d89fa7ef98f11682a1c7 (patch) | |
| tree | b2a4b6fa51ad558b36facde2cf952a6af5a25669 | |
| parent | Add automatic type inference for initialized variable declarations (diff) | |
| download | cup-9f3edcbeea51dd8841b9d89fa7ef98f11682a1c7.tar.xz cup-9f3edcbeea51dd8841b9d89fa7ef98f11682a1c7.zip | |
Add support for basic structs
Structs for now (and probably for the near future) are not allowed
to be passed by value, and instead you just pass a pointer to it.
Nested structs can also be defined, and they can be either anonymous,
or named (in which case only the members can access the type).
| -rw-r--r-- | src/ast.c | 1 | ||||
| -rw-r--r-- | src/ast.h | 7 | ||||
| -rw-r--r-- | src/generator.c | 10 | ||||
| -rw-r--r-- | src/lexer.c | 1 | ||||
| -rw-r--r-- | src/parser.c | 109 | ||||
| -rw-r--r-- | src/tokens.h | 2 | ||||
| -rw-r--r-- | src/types.c | 54 | ||||
| -rw-r--r-- | src/types.h | 14 |
8 files changed, 192 insertions, 6 deletions
@@ -112,6 +112,7 @@ bool is_lvalue(NodeType type) { case AST_LOCAL_VAR: case AST_GLOBAL_VAR: + case OP_MEMBER: case OP_DEREF: // FIXME: Should this be the case? return true; default: return false; @@ -26,6 +26,7 @@ F(OP_GT, ">") \ F(OP_GEQ, ">=") \ F(OP_ASSIGN, "=") \ + F(OP_MEMBER, ".") \ F(AST_LITERAL, "literal") \ F(AST_FUNCCALL, "Function call") \ F(AST_CONDITIONAL, "conditional expression") \ @@ -143,6 +144,12 @@ typedef struct ast_node { Node **args; int num_args; } call; + + struct { + i64 offset; + Node *expr; + bool is_ptr; + } member; }; } Node; diff --git a/src/generator.c b/src/generator.c index 6fcccbf..2fb89b7 100644 --- a/src/generator.c +++ b/src/generator.c @@ -50,13 +50,21 @@ void generate_expr_into_rax(Node *expr, FILE *out); void generate_lvalue_into_rax(Node *node, FILE *out) { assert(is_lvalue(node->type)); - i64 offset = node->variable->offset; if (node->type == AST_LOCAL_VAR) { + i64 offset = node->variable->offset; fprintf(out, " mov rax, rbp\n"); fprintf(out, " sub rax, %lld\n", offset); } else if (node->type == AST_GLOBAL_VAR) { + i64 offset = node->variable->offset; fprintf(out, " mov rax, global_vars\n"); fprintf(out, " add rax, %lld\n", offset); + } else if (node->type == OP_MEMBER) { + i64 offset = node->member.offset; + if (node->member.is_ptr) + generate_expr_into_rax(node->member.expr, out); + else + generate_lvalue_into_rax(node->member.expr, out); + fprintf(out, " add rax, %lld\n", offset); } else if (node->type == OP_DEREF) { generate_expr_into_rax(node->unary_expr, out); } else { diff --git a/src/lexer.c b/src/lexer.c index d1201b2..f02ef94 100644 --- a/src/lexer.c +++ b/src/lexer.c @@ -128,6 +128,7 @@ Token Lexer_next(Lexer *lexer) case ':': return Lexer_make_token(lexer, TOKEN_COLON, 1); case '~': return Lexer_make_token(lexer, TOKEN_TILDE, 1); case '?': return Lexer_make_token(lexer, TOKEN_QUESTION, 1); + case '.': return Lexer_make_token(lexer, TOKEN_DOT, 1); case ',': return Lexer_make_token(lexer, TOKEN_COMMA, 1); case '*': return Lexer_make_token(lexer, TOKEN_STAR, 1); case '%': return Lexer_make_token(lexer, TOKEN_PERCENT, 1); diff --git a/src/parser.c b/src/parser.c index 6bf2f1d..7c01c75 100644 --- a/src/parser.c +++ b/src/parser.c @@ -26,6 +26,11 @@ static i64 global_vars_offset = 0; static Lexer *lexer_stack[LEXER_STACK_SIZE]; static i64 lexer_stack_count = 0; +#define DEFINED_STRUCT_SIZE 128 +static Type *defined_structs[DEFINED_STRUCT_SIZE]; +static i64 defined_structs_count = 0; + + Token do_assert_token(Token token, TokenType type, char *filename, int line) { if (token.type != type) { @@ -43,6 +48,22 @@ Token do_assert_token(Token token, TokenType type, char *filename, int line) * Some helpers */ +void push_struct_definition(Type *type) +{ + assert(defined_structs_count < DEFINED_STRUCT_SIZE); + defined_structs[defined_structs_count++] = type; +} + +Type *find_custom_type_definition(Token *token) +{ + for (i64 i = 0; i < defined_structs_count; i++) { + if (strcmp(defined_structs[i]->struct_name, token->value.as_string) == 0) { + return defined_structs[i]; + } + } + return NULL; +} + void block_stack_push(Node *block) { assert(block_stack_count < BLOCK_STACK_SIZE); @@ -164,7 +185,13 @@ Type *parse_type(Lexer *lexer) Lexer_next(lexer); type = type_new(TYPE_CHAR); } else { - die_location(token.loc, "Unexpected type found: %s", token_type_to_str(token.type)); + assert_token(token, TOKEN_IDENTIFIER); + // TODO: Don't allow a type to contain itself. + // TODO: Don't allow a type to contain an array of itself. + type = find_custom_type_definition(&token); + if (!type) + die_location(token.loc, "Could not find what type `%s` is referencing", token.value.as_string); + Lexer_next(lexer); } for (;;) { @@ -370,7 +397,7 @@ Node *parse_factor(Lexer *lexer) expr = Node_new(OP_BWINV); expr->unary_expr = parse_factor(lexer); expr = handle_unary_expr_types(expr, &token); - + // ++x is changed to (x = x + 1) } else if (token.type == TOKEN_PLUSPLUS) { Lexer_next(lexer); @@ -454,6 +481,26 @@ Node *parse_factor(Lexer *lexer) die_location(token.loc, "Post-incrementing is not supported\n"); } else if (token.type == TOKEN_MINUSMINUS) { die_location(token.loc, "Post-decrementing is not supported\n"); + } else if (token.type == TOKEN_DOT) { + // TODO: Pointer to struct + if (!is_struct_or_struct_ptr(expr->expr_type)) + die_location(token.loc, "Cannot access member of non-struct type"); + + bool is_ptr = expr->expr_type->type == TYPE_PTR; + Type *struct_type = is_ptr ? expr->expr_type->ptr : expr->expr_type; + Lexer_next(lexer); + Token field_token = assert_token(Lexer_next(lexer), TOKEN_IDENTIFIER); + i64 index = find_field_index(struct_type, field_token.value.as_string); + if (index == -1) + die_location(field_token.loc, "Struct `%s` does not have a field named `%s`", type_to_str(struct_type), field_token.value.as_string); + + Node *member = Node_new(OP_MEMBER); + member->expr_type = struct_type->fields.type[index]; + member->member.expr = expr; + member->member.offset = struct_type->fields.offset[index]; + member->member.is_ptr = (expr->expr_type->type == TYPE_PTR); + expr = member; + } else { break; } @@ -683,6 +730,9 @@ void parse_func_args(Lexer *lexer, Node *func) assert_token(Lexer_next(lexer), TOKEN_COLON); Type *type = parse_type(lexer); + if (type->type == TYPE_STRUCT) + die_location(token.loc, "Structs cannot be passed as arguments, maybe pass a pointer?"); + i64 new_count = func->func.num_args + 1; func->func.args = realloc(func->func.args, sizeof(Variable) * new_count); Variable *var = &func->func.args[func->func.num_args++]; @@ -763,6 +813,57 @@ Lexer *remove_lexer() return lexer_stack[lexer_stack_count - 1]; } +Type *parse_struct_declaration(Lexer *lexer, bool is_global) { + i64 prev_struct_count = defined_structs_count; + + assert_token(Lexer_next(lexer), TOKEN_STRUCT); + + Type *struct_type = type_new(TYPE_STRUCT); + + Token token = Lexer_peek(lexer); + // For nested temporary structs we don't need a name + if (token.type != TOKEN_IDENTIFIER && is_global) + die_location(token.loc, "You need to specify a name for the struct defined globally."); + + // But if they do provide one, we'll add it to the list of defined structs so they + // it can referenced internally. + bool has_name = false; + if (token.type == TOKEN_IDENTIFIER) { + struct_type->struct_name = token.value.as_string; + push_struct_definition(struct_type); + Lexer_next(lexer); + has_name = true; + } + + assert_token(Lexer_next(lexer), TOKEN_OPEN_BRACE); + + token = Lexer_peek(lexer); + while (token.type != TOKEN_CLOSE_BRACE) { + token = assert_token(Lexer_next(lexer), TOKEN_IDENTIFIER); + assert_token(Lexer_next(lexer), TOKEN_COLON); + + // We want to allow nested temporary structs. + Type *type; + Token next = Lexer_peek(lexer); + if (next.type == TOKEN_STRUCT) { + type = parse_struct_declaration(lexer, false); + } else { + type = parse_type(lexer); + } + + push_field(struct_type, token.value.as_string, type); + assert_token(Lexer_next(lexer), TOKEN_SEMICOLON); + token = Lexer_peek(lexer); + } + assert_token(Lexer_next(lexer), TOKEN_CLOSE_BRACE); + + // If this is not being defined globally, we want to remove it from the namespace. + if (!is_global) + defined_structs_count = prev_struct_count; + + return struct_type; +} + Node *parse_program(Lexer *lexer) { initialize_builtins(); @@ -778,6 +879,8 @@ Node *parse_program(Lexer *lexer) } else if (token.type == TOKEN_LET) { Node *var_decl = parse_var_declaration(lexer); Node_add_child(program, var_decl); + } else if (token.type == TOKEN_STRUCT) { + parse_struct_declaration(lexer, true); } else if (token.type == TOKEN_IMPORT) { // TODO: Handle circular imports // TODO: Handle complex import graphs (#pragma once) @@ -788,6 +891,8 @@ Node *parse_program(Lexer *lexer) char *filename = token.value.as_string; lexer = Lexer_new_open_file(filename); push_new_lexer(lexer); + } else if (token.type == TOKEN_SEMICOLON) { + Lexer_next(lexer); } else { die_location(token.loc, "Unexpected token in parse_program: `%s`\n", token_type_to_str(token.type)); exit(1); diff --git a/src/tokens.h b/src/tokens.h index f02cfb0..6eaad8e 100644 --- a/src/tokens.h +++ b/src/tokens.h @@ -14,6 +14,7 @@ F(TOKEN_CLOSE_PAREN, ")") \ F(TOKEN_COLON, ":") \ F(TOKEN_COMMA, ",") \ + F(TOKEN_DOT, ".") \ F(TOKEN_EOF, "EOF") \ F(TOKEN_EQ, "==") \ F(TOKEN_EXCLAMATION, "!") \ @@ -55,6 +56,7 @@ F(TOKEN_INT, "int") \ F(TOKEN_LET, "let") \ F(TOKEN_RETURN, "return") \ + F(TOKEN_STRUCT, "struct") \ F(TOKEN_WHILE, "while") \ F(TOKEN_IMPORT, "import") \ diff --git a/src/types.c b/src/types.c index abbba86..5138164 100644 --- a/src/types.c +++ b/src/types.c @@ -36,6 +36,7 @@ i64 size_for_type(Type *type) case TYPE_PTR: return 8; case TYPE_CHAR: return 1; case TYPE_ARRAY: return type->array_size * size_for_type(type->ptr); + case TYPE_STRUCT: return type->fields.size; default: { printf("Unknown type: %d\n", type->type); assert(false && "Unreachable type"); @@ -86,9 +87,18 @@ bool is_int_type(Type *type) } } -static char *data_type_to_str(DataType type) +bool is_struct_or_struct_ptr(Type *type) { - switch (type) + if (type->type == TYPE_STRUCT) + return true; + if (type->type == TYPE_PTR && type->ptr->type == TYPE_STRUCT) + return true; + return false; +} + +static char *data_type_to_str(Type *type) +{ + switch (type->type) { case TYPE_NONE: return "void"; case TYPE_INT: return "int"; @@ -96,6 +106,7 @@ static char *data_type_to_str(DataType type) case TYPE_ARRAY: return "array"; case TYPE_CHAR: return "char"; case TYPE_ANY: return "<@>"; + case TYPE_STRUCT: return type->struct_name; default: assert(false && "Unreachable"); } } @@ -114,16 +125,50 @@ char *type_to_str(Type *type) // FIXME: This is inefficient as all hell but this will only really be // used for error reporting. - strcat(str, data_type_to_str(type->type)); + strcat(str, data_type_to_str(type)); for (int i = 0; i < ptr_count; i++) strcat(str, "*"); return str; } +i64 push_field(Type *type, char *field_name, Type *field_type) +{ + assert(type->type == TYPE_STRUCT); + type->fields.type = realloc(type->fields.type, sizeof(Type *) * (type->fields.num_fields + 1)); + type->fields.offset = realloc(type->fields.offset, sizeof(i64) * (type->fields.num_fields + 1)); + type->fields.name = realloc(type->fields.name, sizeof(char *) * (type->fields.num_fields + 1)); + + i64 field_size = size_for_type(field_type); + i64 offset_factor = i64min(field_size, 8); + i64 offset = align_up(type->fields.size, offset_factor); + + type->fields.type[type->fields.num_fields] = field_type; + type->fields.offset[type->fields.num_fields] = offset; + type->fields.name[type->fields.num_fields] = field_name; + type->fields.size = offset + field_size; + type->fields.num_fields++; + + return offset; +} + +i64 find_field_index(Type *type, char *field_name) +{ + assert(type->type == TYPE_STRUCT); + for (int i = 0; i < type->fields.num_fields; i++) { + if (strcmp(type->fields.name[i], field_name) == 0) + return i; + } + return -1; +} + + Node *handle_unary_expr_types(Node *node, Token *token) { Type *old_type = node->unary_expr->expr_type; + if (node->type != OP_ADDROF && old_type->type == TYPE_STRUCT) + die_location(token->loc, "Performing invalid unary operation on struct type"); + if (node->type == OP_NOT) { node->expr_type = type_new(TYPE_INT); } else if (node->type == OP_ADDROF) { @@ -154,6 +199,9 @@ Node *handle_binary_expr_types(Node *node, Token *token) { Type *left = node->binary.left->expr_type; Type *right = node->binary.right->expr_type; + + if (left->type == TYPE_STRUCT || right->type == TYPE_STRUCT) + die_location(token->loc, "Performing invalid binary operation on struct type"); switch (node->type) { diff --git a/src/types.h b/src/types.h index 73b5b63..e8b9488 100644 --- a/src/types.h +++ b/src/types.h @@ -10,12 +10,21 @@ typedef enum { TYPE_CHAR, TYPE_PTR, TYPE_ARRAY, + TYPE_STRUCT, } DataType; typedef struct data_type_node { DataType type; struct data_type_node *ptr; i64 array_size; + char *struct_name; + struct { + char **name; + struct data_type_node **type; + i64 *offset; + i64 num_fields; + i64 size; + } fields; } Type; Type *type_new(DataType type); @@ -28,6 +37,11 @@ bool type_equals(Type *a, Type *b); bool is_int_type(Type *type); bool is_string_type(Type *type); bool is_convertible(Type *from, Type *to); +bool is_struct_or_struct_ptr(Type *type); + +// Returns the offset of the field in the struct. +i64 push_field(Type *type, char *field_name, Type *field_type); +i64 find_field_index(Type *type, char *field_name); // Type checking / casting expressions to right types typedef struct ast_node Node; |