package monitor import ( "context" "fmt" "net" "strings" "time" "github.com/Fuwn/kaze/internal/config" ) // DNSMonitor monitors DNS resolution type DNSMonitor struct { id string name string group string target string // Domain to resolve interval time.Duration timeout time.Duration retries int roundResponseTime bool roundUptime bool dnsServer string // Optional DNS server expectedIPs []string // Expected IP addresses expectedCNAME string // Expected CNAME recordType string // DNS record type (A, AAAA, CNAME, MX, TXT, etc.) } // NewDNSMonitor creates a new DNS monitor func NewDNSMonitor(cfg config.MonitorConfig) (*DNSMonitor, error) { // Default to A record if not specified recordType := "A" if cfg.RecordType != "" { recordType = strings.ToUpper(cfg.RecordType) } return &DNSMonitor{ id: cfg.ID(), name: cfg.Name, group: cfg.Group, target: cfg.Target, interval: cfg.Interval.Duration, timeout: cfg.Timeout.Duration, retries: cfg.Retries, roundResponseTime: cfg.RoundResponseTime, roundUptime: cfg.RoundUptime, dnsServer: cfg.DNSServer, expectedIPs: cfg.ExpectedIPs, expectedCNAME: cfg.ExpectedCNAME, recordType: recordType, }, nil } // ID returns the unique identifier for this monitor func (m *DNSMonitor) ID() string { return m.id } // Name returns the monitor's name func (m *DNSMonitor) Name() string { return m.name } // Group returns the group this monitor belongs to func (m *DNSMonitor) Group() string { return m.group } // Type returns the monitor type func (m *DNSMonitor) Type() string { return "dns" } // Target returns the monitor target func (m *DNSMonitor) Target() string { return m.target } // Interval returns the check interval func (m *DNSMonitor) Interval() time.Duration { return m.interval } // Retries returns the number of retry attempts func (m *DNSMonitor) Retries() int { return m.retries } // HideSSLDays returns whether to hide SSL days from display func (m *DNSMonitor) HideSSLDays() bool { return false // DNS doesn't use SSL } // RoundResponseTime returns whether to round response time func (m *DNSMonitor) RoundResponseTime() bool { return m.roundResponseTime } // RoundUptime returns whether to round uptime percentage func (m *DNSMonitor) RoundUptime() bool { return m.roundUptime } // Check performs the DNS resolution check func (m *DNSMonitor) Check(ctx context.Context) *Result { result := &Result{ MonitorName: m.id, Timestamp: time.Now(), } // Create resolver resolver := &net.Resolver{} if m.dnsServer != "" { // Use custom DNS server resolver = &net.Resolver{ PreferGo: true, Dial: func(ctx context.Context, network, address string) (net.Conn, error) { d := net.Dialer{Timeout: m.timeout} return d.DialContext(ctx, "udp", m.dnsServer) }, } } // Create timeout context timeoutCtx, cancel := context.WithTimeout(ctx, m.timeout) defer cancel() start := time.Now() switch m.recordType { case "A", "AAAA": ips, err := resolver.LookupIP(timeoutCtx, "ip", m.target) result.ResponseTime = time.Since(start) if err != nil { result.Status = StatusDown result.Error = fmt.Errorf("DNS lookup failed: %w", err) return result } if len(ips) == 0 { result.Status = StatusDown result.Error = fmt.Errorf("no %s records found", m.recordType) return result } // If expected IPs are specified, verify them if len(m.expectedIPs) > 0 { found := false for _, ip := range ips { for _, expected := range m.expectedIPs { if ip.String() == expected { found = true break } } if found { break } } if !found { result.Status = StatusDegraded result.Error = fmt.Errorf("expected IPs not found in response") return result } } result.Status = StatusUp case "CNAME": cname, err := resolver.LookupCNAME(timeoutCtx, m.target) result.ResponseTime = time.Since(start) if err != nil { result.Status = StatusDown result.Error = fmt.Errorf("CNAME lookup failed: %w", err) return result } // If expected CNAME is specified, verify it if m.expectedCNAME != "" && cname != m.expectedCNAME { result.Status = StatusDegraded result.Error = fmt.Errorf("CNAME mismatch: got %s, expected %s", cname, m.expectedCNAME) return result } result.Status = StatusUp case "MX": mxs, err := resolver.LookupMX(timeoutCtx, m.target) result.ResponseTime = time.Since(start) if err != nil { result.Status = StatusDown result.Error = fmt.Errorf("MX lookup failed: %w", err) return result } if len(mxs) == 0 { result.Status = StatusDown result.Error = fmt.Errorf("no MX records found") return result } result.Status = StatusUp case "TXT": txts, err := resolver.LookupTXT(timeoutCtx, m.target) result.ResponseTime = time.Since(start) if err != nil { result.Status = StatusDown result.Error = fmt.Errorf("TXT lookup failed: %w", err) return result } if len(txts) == 0 { result.Status = StatusDown result.Error = fmt.Errorf("no TXT records found") return result } result.Status = StatusUp default: result.ResponseTime = time.Since(start) result.Status = StatusDown result.Error = fmt.Errorf("unsupported record type: %s", m.recordType) } return result }