diff options
| -rw-r--r-- | config.example.yaml | 33 | ||||
| -rw-r--r-- | internal/config/config.go | 2 | ||||
| -rw-r--r-- | internal/monitor/database.go | 529 | ||||
| -rw-r--r-- | internal/monitor/monitor.go | 4 |
4 files changed, 562 insertions, 6 deletions
diff --git a/config.example.yaml b/config.example.yaml index bc6a772..50f6692 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -98,15 +98,38 @@ groups: # default_collapsed: false # show_group_uptime: true monitors: - - name: "Database" - type: tcp - target: "localhost:5432" + - name: "PostgreSQL" + type: database + db_type: postgres + target: "localhost:5432" # or postgres://user:pass@host:5432/db + interval: 30s + timeout: 5s + + - name: "MySQL" + type: database + db_type: mysql + target: "localhost:3306" # or mysql://user:pass@host:3306/db interval: 30s timeout: 5s - name: "Redis" - type: tcp - target: "localhost:6379" + type: database + db_type: redis + target: "localhost:6379" # or redis://host:6379 + interval: 30s + timeout: 5s + + - name: "Memcached" + type: database + db_type: memcached + target: "localhost:11211" + interval: 30s + timeout: 5s + + - name: "MongoDB" + type: database + db_type: mongodb + target: "localhost:27017" # or mongodb://user:pass@host:27017 interval: 30s timeout: 5s diff --git a/internal/config/config.go b/internal/config/config.go index 076a203..5b9d468 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -105,6 +105,8 @@ type MonitorConfig struct { // GraphQL specific fields GraphQLQuery string `yaml:"graphql_query,omitempty"` // GraphQL query to execute GraphQLVariables map[string]string `yaml:"graphql_variables,omitempty"` // GraphQL query variables + // Database specific fields + DBType string `yaml:"db_type,omitempty"` // Database type: postgres, mysql, redis, memcached, mongodb } // IncidentConfig represents an incident or maintenance diff --git a/internal/monitor/database.go b/internal/monitor/database.go new file mode 100644 index 0000000..0b4d2a9 --- /dev/null +++ b/internal/monitor/database.go @@ -0,0 +1,529 @@ +package monitor + +import ( + "bufio" + "context" + "database/sql" + "fmt" + "net" + "strings" + "time" + + "github.com/Fuwn/kaze/internal/config" +) + +// DatabaseMonitor monitors database connections +type DatabaseMonitor struct { + name string + target string // Connection string or host:port + dbType string // postgres, mysql, redis, mongodb, sqlite + interval time.Duration + timeout time.Duration + retries int + roundResponseTime bool + roundUptime bool +} + +// NewDatabaseMonitor creates a new database monitor +func NewDatabaseMonitor(cfg config.MonitorConfig) (*DatabaseMonitor, error) { + dbType := strings.ToLower(cfg.DBType) + if dbType == "" { + return nil, fmt.Errorf("db_type is required for database monitors") + } + + // Validate supported database types + switch dbType { + case "postgres", "postgresql": + dbType = "postgres" + case "mysql", "mariadb": + dbType = "mysql" + case "redis": + // Redis uses simple protocol check + case "memcached": + // Memcached uses simple protocol check + case "mongodb", "mongo": + dbType = "mongodb" + default: + return nil, fmt.Errorf("unsupported database type: %s (supported: postgres, mysql, redis, memcached, mongodb)", cfg.DBType) + } + + return &DatabaseMonitor{ + name: cfg.Name, + target: cfg.Target, + dbType: dbType, + interval: cfg.Interval.Duration, + timeout: cfg.Timeout.Duration, + retries: cfg.Retries, + roundResponseTime: cfg.RoundResponseTime, + roundUptime: cfg.RoundUptime, + }, nil +} + +// Name returns the monitor's name +func (m *DatabaseMonitor) Name() string { + return m.name +} + +// Type returns the monitor type +func (m *DatabaseMonitor) Type() string { + return "database" +} + +// Target returns the monitor target +func (m *DatabaseMonitor) Target() string { + return m.target +} + +// Interval returns the check interval +func (m *DatabaseMonitor) Interval() time.Duration { + return m.interval +} + +// Retries returns the number of retry attempts +func (m *DatabaseMonitor) Retries() int { + return m.retries +} + +// HideSSLDays returns whether to hide SSL days from display +func (m *DatabaseMonitor) HideSSLDays() bool { + return true // Databases don't expose SSL info this way +} + +// RoundResponseTime returns whether to round response time +func (m *DatabaseMonitor) RoundResponseTime() bool { + return m.roundResponseTime +} + +// RoundUptime returns whether to round uptime percentage +func (m *DatabaseMonitor) RoundUptime() bool { + return m.roundUptime +} + +// Check performs the database connection check +func (m *DatabaseMonitor) Check(ctx context.Context) *Result { + result := &Result{ + MonitorName: m.name, + Timestamp: time.Now(), + } + + start := time.Now() + var err error + + switch m.dbType { + case "redis": + err = m.checkRedis(ctx) + case "memcached": + err = m.checkMemcached(ctx) + case "postgres": + err = m.checkPostgres(ctx) + case "mysql": + err = m.checkMySQL(ctx) + case "mongodb": + err = m.checkMongoDB(ctx) + default: + err = fmt.Errorf("unsupported database type: %s", m.dbType) + } + + result.ResponseTime = time.Since(start) + + if err != nil { + result.Status = StatusDown + result.Error = err + return result + } + + result.Status = StatusUp + + // Check for slow response (degraded if > 1 second) + if result.ResponseTime > 1*time.Second { + result.Status = StatusDegraded + result.Error = fmt.Errorf("slow response: %v", result.ResponseTime) + } + + return result +} + +// checkRedis performs a Redis PING check using the RESP protocol +func (m *DatabaseMonitor) checkRedis(ctx context.Context) error { + // Parse target - could be host:port or redis://... + target := m.target + if strings.HasPrefix(target, "redis://") { + target = strings.TrimPrefix(target, "redis://") + // Remove any auth/db parts for simple connection + if idx := strings.Index(target, "@"); idx != -1 { + target = target[idx+1:] + } + if idx := strings.Index(target, "/"); idx != -1 { + target = target[:idx] + } + } + + // Default port if not specified + if _, _, err := net.SplitHostPort(target); err != nil { + target = target + ":6379" + } + + // Connect with timeout + dialer := &net.Dialer{Timeout: m.timeout} + conn, err := dialer.DialContext(ctx, "tcp", target) + if err != nil { + return fmt.Errorf("connection failed: %w", err) + } + defer conn.Close() + + // Set read/write deadline + if err := conn.SetDeadline(time.Now().Add(m.timeout)); err != nil { + return fmt.Errorf("failed to set deadline: %w", err) + } + + // Send PING command (RESP protocol) + _, err = conn.Write([]byte("*1\r\n$4\r\nPING\r\n")) + if err != nil { + return fmt.Errorf("failed to send PING: %w", err) + } + + // Read response + reader := bufio.NewReader(conn) + response, err := reader.ReadString('\n') + if err != nil { + return fmt.Errorf("failed to read response: %w", err) + } + + // Check for PONG response (simple string +PONG or bulk string) + response = strings.TrimSpace(response) + if response != "+PONG" && !strings.HasPrefix(response, "$") { + // If it's an error response + if strings.HasPrefix(response, "-") { + return fmt.Errorf("redis error: %s", strings.TrimPrefix(response, "-")) + } + return fmt.Errorf("unexpected response: %s", response) + } + + return nil +} + +// checkMemcached performs a Memcached version check +func (m *DatabaseMonitor) checkMemcached(ctx context.Context) error { + target := m.target + + // Default port if not specified + if _, _, err := net.SplitHostPort(target); err != nil { + target = target + ":11211" + } + + // Connect with timeout + dialer := &net.Dialer{Timeout: m.timeout} + conn, err := dialer.DialContext(ctx, "tcp", target) + if err != nil { + return fmt.Errorf("connection failed: %w", err) + } + defer conn.Close() + + // Set read/write deadline + if err := conn.SetDeadline(time.Now().Add(m.timeout)); err != nil { + return fmt.Errorf("failed to set deadline: %w", err) + } + + // Send version command + _, err = conn.Write([]byte("version\r\n")) + if err != nil { + return fmt.Errorf("failed to send version command: %w", err) + } + + // Read response + reader := bufio.NewReader(conn) + response, err := reader.ReadString('\n') + if err != nil { + return fmt.Errorf("failed to read response: %w", err) + } + + // Check for VERSION response + if !strings.HasPrefix(response, "VERSION") { + return fmt.Errorf("unexpected response: %s", strings.TrimSpace(response)) + } + + return nil +} + +// checkPostgres performs a PostgreSQL connection check +func (m *DatabaseMonitor) checkPostgres(ctx context.Context) error { + // Build connection string + connStr := m.target + if !strings.Contains(connStr, "://") && !strings.Contains(connStr, "=") { + // Assume it's host:port format, build a connection string + host, port, err := net.SplitHostPort(connStr) + if err != nil { + host = connStr + port = "5432" + } + connStr = fmt.Sprintf("host=%s port=%s sslmode=disable connect_timeout=%d", host, port, int(m.timeout.Seconds())) + } + + // For PostgreSQL, we need the driver + // Since we want to avoid heavy dependencies, we'll do a simple TCP + startup message check + return m.checkPostgresProtocol(ctx) +} + +// checkPostgresProtocol performs a basic PostgreSQL protocol check +func (m *DatabaseMonitor) checkPostgresProtocol(ctx context.Context) error { + target := m.target + + // Parse connection string or host:port + if strings.HasPrefix(target, "postgres://") || strings.HasPrefix(target, "postgresql://") { + // Extract host:port from URL + target = strings.TrimPrefix(target, "postgres://") + target = strings.TrimPrefix(target, "postgresql://") + if idx := strings.Index(target, "@"); idx != -1 { + target = target[idx+1:] + } + if idx := strings.Index(target, "/"); idx != -1 { + target = target[:idx] + } + if idx := strings.Index(target, "?"); idx != -1 { + target = target[:idx] + } + } + + // Default port if not specified + if _, _, err := net.SplitHostPort(target); err != nil { + target = target + ":5432" + } + + // Connect with timeout + dialer := &net.Dialer{Timeout: m.timeout} + conn, err := dialer.DialContext(ctx, "tcp", target) + if err != nil { + return fmt.Errorf("connection failed: %w", err) + } + defer conn.Close() + + // Set read deadline + if err := conn.SetDeadline(time.Now().Add(m.timeout)); err != nil { + return fmt.Errorf("failed to set deadline: %w", err) + } + + // Send a minimal startup message (protocol version 3.0) + // This will prompt the server to respond, even if auth fails + // Message format: length (4 bytes) + protocol version (4 bytes) + parameters + startupMsg := []byte{ + 0, 0, 0, 8, // Length: 8 bytes + 0, 3, 0, 0, // Protocol version 3.0 + } + + _, err = conn.Write(startupMsg) + if err != nil { + return fmt.Errorf("failed to send startup message: %w", err) + } + + // Read response - we expect either 'R' (auth request) or 'E' (error) + response := make([]byte, 1) + _, err = conn.Read(response) + if err != nil { + return fmt.Errorf("failed to read response: %w", err) + } + + // 'R' = AuthenticationRequest, 'E' = ErrorResponse, 'N' = NoticeResponse + // Any of these means PostgreSQL is running + switch response[0] { + case 'R', 'E', 'N', 'S': + return nil // Server is responding + default: + return fmt.Errorf("unexpected response type: %c", response[0]) + } +} + +// checkMySQL performs a MySQL connection check using the protocol +func (m *DatabaseMonitor) checkMySQL(ctx context.Context) error { + target := m.target + + // Parse connection string or host:port + if strings.HasPrefix(target, "mysql://") { + target = strings.TrimPrefix(target, "mysql://") + if idx := strings.Index(target, "@"); idx != -1 { + target = target[idx+1:] + } + if idx := strings.Index(target, "/"); idx != -1 { + target = target[:idx] + } + } + + // Handle tcp(host:port) format + if strings.HasPrefix(target, "tcp(") { + target = strings.TrimPrefix(target, "tcp(") + target = strings.TrimSuffix(target, ")") + } + + // Default port if not specified + if _, _, err := net.SplitHostPort(target); err != nil { + target = target + ":3306" + } + + // Connect with timeout + dialer := &net.Dialer{Timeout: m.timeout} + conn, err := dialer.DialContext(ctx, "tcp", target) + if err != nil { + return fmt.Errorf("connection failed: %w", err) + } + defer conn.Close() + + // Set read deadline + if err := conn.SetDeadline(time.Now().Add(m.timeout)); err != nil { + return fmt.Errorf("failed to set deadline: %w", err) + } + + // MySQL sends a handshake packet upon connection + // Read the initial handshake packet + header := make([]byte, 4) + _, err = conn.Read(header) + if err != nil { + return fmt.Errorf("failed to read handshake header: %w", err) + } + + // Packet length is first 3 bytes (little-endian) + packetLen := int(header[0]) | int(header[1])<<8 | int(header[2])<<16 + if packetLen <= 0 || packetLen > 65535 { + return fmt.Errorf("invalid packet length: %d", packetLen) + } + + // Read the rest of the handshake packet + packet := make([]byte, packetLen) + _, err = conn.Read(packet) + if err != nil { + return fmt.Errorf("failed to read handshake packet: %w", err) + } + + // First byte of packet is protocol version (should be 10 for MySQL 5.x+) + // or 0xFF for error packet + if len(packet) > 0 { + if packet[0] == 0xFF { + // Error packet + if len(packet) > 3 { + errMsg := string(packet[3:]) + return fmt.Errorf("mysql error: %s", errMsg) + } + return fmt.Errorf("mysql error response") + } + // Protocol version 10 or 9 means MySQL is running + if packet[0] == 10 || packet[0] == 9 { + return nil + } + } + + return fmt.Errorf("unexpected handshake response") +} + +// checkMongoDB performs a MongoDB connection check +func (m *DatabaseMonitor) checkMongoDB(ctx context.Context) error { + target := m.target + + // Parse connection string or host:port + if strings.HasPrefix(target, "mongodb://") || strings.HasPrefix(target, "mongodb+srv://") { + target = strings.TrimPrefix(target, "mongodb://") + target = strings.TrimPrefix(target, "mongodb+srv://") + if idx := strings.Index(target, "@"); idx != -1 { + target = target[idx+1:] + } + if idx := strings.Index(target, "/"); idx != -1 { + target = target[:idx] + } + if idx := strings.Index(target, "?"); idx != -1 { + target = target[:idx] + } + } + + // Default port if not specified + if _, _, err := net.SplitHostPort(target); err != nil { + target = target + ":27017" + } + + // For MongoDB, we'll send an isMaster command using the wire protocol + // This is a simple way to check if MongoDB is responding + + // Connect with timeout + dialer := &net.Dialer{Timeout: m.timeout} + conn, err := dialer.DialContext(ctx, "tcp", target) + if err != nil { + return fmt.Errorf("connection failed: %w", err) + } + defer conn.Close() + + // Set deadline + if err := conn.SetDeadline(time.Now().Add(m.timeout)); err != nil { + return fmt.Errorf("failed to set deadline: %w", err) + } + + // Build isMaster command using OP_MSG (MongoDB 3.6+) + // This is a simplified ping - we just check if we can establish connection + // and the server responds to our message + + // For simplicity, we'll use the legacy OP_QUERY with isMaster + // Document: { isMaster: 1 } + // BSON: \x13\x00\x00\x00 (length 19) \x10 (int32) isMaster\x00 \x01\x00\x00\x00 \x00 + isMasterDoc := []byte{ + 0x13, 0x00, 0x00, 0x00, // document length (19 bytes) + 0x10, // type: int32 + 'i', 's', 'M', 'a', 's', 't', 'e', 'r', 0x00, // "isMaster\0" + 0x01, 0x00, 0x00, 0x00, // value: 1 + 0x00, // document terminator + } + + // OP_QUERY message + // Header: length (4) + requestID (4) + responseTo (4) + opCode (4) + // Body: flags (4) + fullCollectionName + numberToSkip (4) + numberToReturn (4) + query + collName := []byte("admin.$cmd\x00") + msgLen := 16 + 4 + len(collName) + 4 + 4 + len(isMasterDoc) + + msg := make([]byte, 0, msgLen) + // Header + msg = appendInt32(msg, int32(msgLen)) // length + msg = appendInt32(msg, 1) // requestID + msg = appendInt32(msg, 0) // responseTo + msg = appendInt32(msg, 2004) // opCode: OP_QUERY + // Body + msg = appendInt32(msg, 0) // flags + msg = append(msg, collName...) // collection name + msg = appendInt32(msg, 0) // numberToSkip + msg = appendInt32(msg, 1) // numberToReturn + msg = append(msg, isMasterDoc...) // query document + + _, err = conn.Write(msg) + if err != nil { + return fmt.Errorf("failed to send isMaster: %w", err) + } + + // Read response header + header := make([]byte, 16) + _, err = conn.Read(header) + if err != nil { + return fmt.Errorf("failed to read response header: %w", err) + } + + // Check opCode in response (bytes 12-15, little-endian) + opCode := int32(header[12]) | int32(header[13])<<8 | int32(header[14])<<16 | int32(header[15])<<24 + if opCode == 1 { // OP_REPLY + return nil // MongoDB responded + } + + return fmt.Errorf("unexpected response opCode: %d", opCode) +} + +// appendInt32 appends a little-endian int32 to the byte slice +func appendInt32(b []byte, v int32) []byte { + return append(b, byte(v), byte(v>>8), byte(v>>16), byte(v>>24)) +} + +// openDatabase opens a database connection using database/sql +// This requires the appropriate driver to be imported +func openDatabase(driverName, dataSourceName string, timeout time.Duration) (*sql.DB, error) { + db, err := sql.Open(driverName, dataSourceName) + if err != nil { + return nil, err + } + + // Set connection pool settings + db.SetMaxOpenConns(1) + db.SetMaxIdleConns(0) + db.SetConnMaxLifetime(timeout) + + return db, nil +} diff --git a/internal/monitor/monitor.go b/internal/monitor/monitor.go index 5ec283a..9a1ec15 100644 --- a/internal/monitor/monitor.go +++ b/internal/monitor/monitor.go @@ -34,7 +34,7 @@ type Monitor interface { // Name returns the monitor's name Name() string - // Type returns the monitor type (http, https, tcp, gemini, icmp, dns, graphql) + // Type returns the monitor type (http, https, tcp, gemini, icmp, dns, graphql, database) Type() string // Target returns the monitor target (URL or host:port) @@ -74,6 +74,8 @@ func New(cfg config.MonitorConfig) (Monitor, error) { return NewDNSMonitor(cfg) case "graphql": return NewGraphQLMonitor(cfg) + case "database", "db": + return NewDatabaseMonitor(cfg) default: return nil, &UnsupportedTypeError{Type: cfg.Type} } |