package storage import ( "fmt" "os" "regexp" "strconv" "strings" "time" "github.com/Fuwn/kaze/internal/config" ) type MaintenanceState struct { lastCheck time.Time lastDaily time.Time nextCronTime time.Time totalChecks int64 cfg config.MaintenanceConfig dbPath string sizeBytes int64 cronFields []cronField } type cronField struct { values map[int]bool any bool } func NewMaintenanceState(dbPath string, cfg config.MaintenanceConfig) (*MaintenanceState, error) { m := &MaintenanceState{ cfg: cfg, dbPath: dbPath, lastCheck: time.Now(), lastDaily: time.Now(), } if cfg.Triggers.Size != "" { bytes, err := parseSize(cfg.Triggers.Size) if err != nil { return nil, fmt.Errorf("invalid size %q: %w", cfg.Triggers.Size, err) } m.sizeBytes = bytes } if cfg.Triggers.Cron != "" { fields, err := parseCron(cfg.Triggers.Cron) if err != nil { return nil, fmt.Errorf("invalid cron %q: %w", cfg.Triggers.Cron, err) } m.cronFields = fields m.nextCronTime = m.calculateNextCron(time.Now()) } return m, nil } func (m *MaintenanceState) IncrementChecks() { m.totalChecks++ } func (m *MaintenanceState) ShouldRun() bool { if m.cfg.Mode == "" || m.cfg.Mode == "never" { return false } if m.checkSizeTrigger() { return true } if m.checkChecksTrigger() { return true } if m.checkCronTrigger() { return true } if m.checkDailyTrigger() { return true } return false } func (m *MaintenanceState) checkSizeTrigger() bool { if m.sizeBytes == 0 { return false } info, err := os.Stat(m.dbPath) if err != nil { return false } return info.Size() >= m.sizeBytes } func (m *MaintenanceState) checkChecksTrigger() bool { if m.cfg.Triggers.Checks == 0 { return false } return m.totalChecks >= m.cfg.Triggers.Checks } func (m *MaintenanceState) checkCronTrigger() bool { if len(m.cronFields) == 0 { return false } now := time.Now() if now.After(m.nextCronTime) || now.Equal(m.nextCronTime) { return true } return false } func (m *MaintenanceState) checkDailyTrigger() bool { if m.cfg.Triggers.Daily == "" { return false } parts := strings.Split(m.cfg.Triggers.Daily, ":") if len(parts) != 2 { return false } hour, err1 := strconv.Atoi(parts[0]) minute, err2 := strconv.Atoi(parts[1]) if err1 != nil || err2 != nil { return false } now := time.Now() today := time.Date(now.Year(), now.Month(), now.Day(), hour, minute, 0, 0, now.Location()) if now.After(today) && m.lastDaily.Before(today) { return true } return false } func (m *MaintenanceState) Execute() error { switch m.cfg.Mode { case "backup": return m.executeBackup() case "reset": return m.executeReset() default: return nil } } func (m *MaintenanceState) executeBackup() error { epoch := time.Now().Unix() backupPath := fmt.Sprintf("%s.%d", m.dbPath, epoch) if err := os.Rename(m.dbPath, backupPath); err != nil { return fmt.Errorf("failed to backup database: %w", err) } m.resetState() return nil } func (m *MaintenanceState) executeReset() error { if err := os.Remove(m.dbPath); err != nil && !os.IsNotExist(err) { return fmt.Errorf("failed to remove database: %w", err) } walPath := m.dbPath + "-wal" if err := os.Remove(walPath); err != nil && !os.IsNotExist(err) { return fmt.Errorf("failed to remove WAL file: %w", err) } shmPath := m.dbPath + "-shm" if err := os.Remove(shmPath); err != nil && !os.IsNotExist(err) { return fmt.Errorf("failed to remove SHM file: %w", err) } m.resetState() return nil } func (m *MaintenanceState) resetState() { m.totalChecks = 0 m.lastDaily = time.Now() if len(m.cronFields) > 0 { m.nextCronTime = m.calculateNextCron(time.Now()) } } func (m *MaintenanceState) calculateNextCron(from time.Time) time.Time { t := from.Add(time.Minute).Truncate(time.Minute) for i := 0; i < 366*24*60; i++ { if m.cronMatches(t) { return t } t = t.Add(time.Minute) } return from.Add(24 * time.Hour) } func (m *MaintenanceState) cronMatches(t time.Time) bool { if len(m.cronFields) != 5 { return false } minute := t.Minute() hour := t.Hour() dayOfMonth := t.Day() month := int(t.Month()) dayOfWeek := int(t.Weekday()) return (m.cronFields[0].any || m.cronFields[0].values[minute]) && (m.cronFields[1].any || m.cronFields[1].values[hour]) && (m.cronFields[2].any || m.cronFields[2].values[dayOfMonth]) && (m.cronFields[3].any || m.cronFields[3].values[month]) && (m.cronFields[4].any || m.cronFields[4].values[dayOfWeek]) } func parseSize(s string) (int64, error) { s = strings.TrimSpace(strings.ToUpper(s)) re := regexp.MustCompile(`^(\d+(?:\.\d+)?)\s*(B|KB|MB|GB|TB)?$`) matches := re.FindStringSubmatch(s) if matches == nil { return 0, fmt.Errorf("invalid size format") } value, err := strconv.ParseFloat(matches[1], 64) if err != nil { return 0, err } unit := matches[2] if unit == "" { unit = "B" } multipliers := map[string]float64{ "B": 1, "KB": 1024, "MB": 1024 * 1024, "GB": 1024 * 1024 * 1024, "TB": 1024 * 1024 * 1024 * 1024, } return int64(value * multipliers[unit]), nil } func parseCron(expr string) ([]cronField, error) { parts := strings.Fields(expr) if len(parts) != 5 { return nil, fmt.Errorf("cron must have 5 fields (minute hour day month weekday)") } limits := []struct{ min, max int }{ {0, 59}, {0, 23}, {1, 31}, {1, 12}, {0, 6}, } fields := make([]cronField, 5) for i, part := range parts { f, err := parseCronField(part, limits[i].min, limits[i].max) if err != nil { return nil, fmt.Errorf("field %d: %w", i+1, err) } fields[i] = f } return fields, nil } func parseCronField(field string, min, max int) (cronField, error) { cf := cronField{values: make(map[int]bool)} if field == "*" { cf.any = true return cf, nil } for _, part := range strings.Split(field, ",") { if strings.Contains(part, "/") { stepParts := strings.Split(part, "/") if len(stepParts) != 2 { return cf, fmt.Errorf("invalid step %q", part) } step, err := strconv.Atoi(stepParts[1]) if err != nil || step <= 0 { return cf, fmt.Errorf("invalid step value %q", stepParts[1]) } start := min end := max if stepParts[0] != "*" { rangeParts := strings.Split(stepParts[0], "-") start, _ = strconv.Atoi(rangeParts[0]) if len(rangeParts) == 2 { end, _ = strconv.Atoi(rangeParts[1]) } else { end = max } } for i := start; i <= end; i += step { cf.values[i] = true } } else if strings.Contains(part, "-") { rangeParts := strings.Split(part, "-") if len(rangeParts) != 2 { return cf, fmt.Errorf("invalid range %q", part) } start, err1 := strconv.Atoi(rangeParts[0]) end, err2 := strconv.Atoi(rangeParts[1]) if err1 != nil || err2 != nil { return cf, fmt.Errorf("invalid range values") } for i := start; i <= end; i++ { cf.values[i] = true } } else { val, err := strconv.Atoi(part) if err != nil { return cf, fmt.Errorf("invalid value %q", part) } cf.values[val] = true } } return cf, nil }