aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/parser.c19
-rw-r--r--src/tokens.h1
-rw-r--r--src/types.c11
-rw-r--r--src/types.h1
4 files changed, 19 insertions, 13 deletions
diff --git a/src/parser.c b/src/parser.c
index 7c01c75..a9c49f9 100644
--- a/src/parser.c
+++ b/src/parser.c
@@ -813,14 +813,15 @@ Lexer *remove_lexer()
return lexer_stack[lexer_stack_count - 1];
}
-Type *parse_struct_declaration(Lexer *lexer, bool is_global) {
+Type *parse_struct_union_declaration(Lexer *lexer, bool is_global) {
i64 prev_struct_count = defined_structs_count;
- assert_token(Lexer_next(lexer), TOKEN_STRUCT);
-
- Type *struct_type = type_new(TYPE_STRUCT);
+ Type *struct_type;
+ Token token = Lexer_next(lexer);
+ assert(token.type == TOKEN_STRUCT || token.type == TOKEN_UNION);
+ struct_type = type_new(token.type == TOKEN_STRUCT ? TYPE_STRUCT : TYPE_UNION);
- Token token = Lexer_peek(lexer);
+ token = Lexer_peek(lexer);
// For nested temporary structs we don't need a name
if (token.type != TOKEN_IDENTIFIER && is_global)
die_location(token.loc, "You need to specify a name for the struct defined globally.");
@@ -845,8 +846,8 @@ Type *parse_struct_declaration(Lexer *lexer, bool is_global) {
// We want to allow nested temporary structs.
Type *type;
Token next = Lexer_peek(lexer);
- if (next.type == TOKEN_STRUCT) {
- type = parse_struct_declaration(lexer, false);
+ if (next.type == TOKEN_STRUCT || next.type == TOKEN_UNION) {
+ type = parse_struct_union_declaration(lexer, false);
} else {
type = parse_type(lexer);
}
@@ -879,8 +880,8 @@ Node *parse_program(Lexer *lexer)
} else if (token.type == TOKEN_LET) {
Node *var_decl = parse_var_declaration(lexer);
Node_add_child(program, var_decl);
- } else if (token.type == TOKEN_STRUCT) {
- parse_struct_declaration(lexer, true);
+ } else if (token.type == TOKEN_STRUCT || token.type == TOKEN_UNION) {
+ parse_struct_union_declaration(lexer, true);
} else if (token.type == TOKEN_IMPORT) {
// TODO: Handle circular imports
// TODO: Handle complex import graphs (#pragma once)
diff --git a/src/tokens.h b/src/tokens.h
index 6eaad8e..455d269 100644
--- a/src/tokens.h
+++ b/src/tokens.h
@@ -57,6 +57,7 @@
F(TOKEN_LET, "let") \
F(TOKEN_RETURN, "return") \
F(TOKEN_STRUCT, "struct") \
+ F(TOKEN_UNION, "union") \
F(TOKEN_WHILE, "while") \
F(TOKEN_IMPORT, "import") \
diff --git a/src/types.c b/src/types.c
index 5138164..008af10 100644
--- a/src/types.c
+++ b/src/types.c
@@ -107,6 +107,7 @@ static char *data_type_to_str(Type *type)
case TYPE_CHAR: return "char";
case TYPE_ANY: return "<@>";
case TYPE_STRUCT: return type->struct_name;
+ case TYPE_UNION: return type->struct_name;
default: assert(false && "Unreachable");
}
}
@@ -133,19 +134,21 @@ char *type_to_str(Type *type)
i64 push_field(Type *type, char *field_name, Type *field_type)
{
- assert(type->type == TYPE_STRUCT);
+ assert(type->type == TYPE_STRUCT || type->type == TYPE_UNION);
+ bool is_union = type->type == TYPE_UNION;
+
type->fields.type = realloc(type->fields.type, sizeof(Type *) * (type->fields.num_fields + 1));
type->fields.offset = realloc(type->fields.offset, sizeof(i64) * (type->fields.num_fields + 1));
type->fields.name = realloc(type->fields.name, sizeof(char *) * (type->fields.num_fields + 1));
i64 field_size = size_for_type(field_type);
i64 offset_factor = i64min(field_size, 8);
- i64 offset = align_up(type->fields.size, offset_factor);
+ i64 offset = is_union ? 0 : align_up(type->fields.size, offset_factor);
type->fields.type[type->fields.num_fields] = field_type;
type->fields.offset[type->fields.num_fields] = offset;
type->fields.name[type->fields.num_fields] = field_name;
- type->fields.size = offset + field_size;
+ type->fields.size = is_union ? i64max(field_size, type->fields.size) : offset + field_size;
type->fields.num_fields++;
return offset;
@@ -153,7 +156,7 @@ i64 push_field(Type *type, char *field_name, Type *field_type)
i64 find_field_index(Type *type, char *field_name)
{
- assert(type->type == TYPE_STRUCT);
+ assert(type->type == TYPE_STRUCT || type->type == TYPE_UNION);
for (int i = 0; i < type->fields.num_fields; i++) {
if (strcmp(type->fields.name[i], field_name) == 0)
return i;
diff --git a/src/types.h b/src/types.h
index e8b9488..22c0d5f 100644
--- a/src/types.h
+++ b/src/types.h
@@ -11,6 +11,7 @@ typedef enum {
TYPE_PTR,
TYPE_ARRAY,
TYPE_STRUCT,
+ TYPE_UNION,
} DataType;
typedef struct data_type_node {