package monitor import ( "bufio" "context" "database/sql" "fmt" "net" "strings" "time" "github.com/Fuwn/kaze/internal/config" ) // DatabaseMonitor monitors database connections type DatabaseMonitor struct { id string name string group string target string // Connection string or host:port dbType string // postgres, mysql, redis, mongodb, sqlite interval time.Duration timeout time.Duration retries int roundResponseTime bool roundUptime bool } // NewDatabaseMonitor creates a new database monitor func NewDatabaseMonitor(cfg config.MonitorConfig) (*DatabaseMonitor, error) { dbType := strings.ToLower(cfg.DBType) if dbType == "" { return nil, fmt.Errorf("db_type is required for database monitors") } // Validate supported database types switch dbType { case "postgres", "postgresql": dbType = "postgres" case "mysql", "mariadb": dbType = "mysql" case "redis": // Redis uses simple protocol check case "memcached": // Memcached uses simple protocol check case "mongodb", "mongo": dbType = "mongodb" default: return nil, fmt.Errorf("unsupported database type: %s (supported: postgres, mysql, redis, memcached, mongodb)", cfg.DBType) } return &DatabaseMonitor{ id: cfg.ID(), name: cfg.Name, group: cfg.Group, target: cfg.Target, dbType: dbType, interval: cfg.Interval.Duration, timeout: cfg.Timeout.Duration, retries: cfg.Retries, roundResponseTime: cfg.RoundResponseTime, roundUptime: cfg.RoundUptime, }, nil } // ID returns the unique identifier for this monitor func (m *DatabaseMonitor) ID() string { return m.id } // Name returns the monitor's name func (m *DatabaseMonitor) Name() string { return m.name } // Group returns the group this monitor belongs to func (m *DatabaseMonitor) Group() string { return m.group } // Type returns the monitor type func (m *DatabaseMonitor) Type() string { return "database" } // Target returns the monitor target func (m *DatabaseMonitor) Target() string { return m.target } // Interval returns the check interval func (m *DatabaseMonitor) Interval() time.Duration { return m.interval } // Retries returns the number of retry attempts func (m *DatabaseMonitor) Retries() int { return m.retries } // HideSSLDays returns whether to hide SSL days from display func (m *DatabaseMonitor) HideSSLDays() bool { return true // Databases don't expose SSL info this way } // RoundResponseTime returns whether to round response time func (m *DatabaseMonitor) RoundResponseTime() bool { return m.roundResponseTime } // RoundUptime returns whether to round uptime percentage func (m *DatabaseMonitor) RoundUptime() bool { return m.roundUptime } // Check performs the database connection check func (m *DatabaseMonitor) Check(ctx context.Context) *Result { result := &Result{ MonitorName: m.id, Timestamp: time.Now(), } start := time.Now() var err error switch m.dbType { case "redis": err = m.checkRedis(ctx) case "memcached": err = m.checkMemcached(ctx) case "postgres": err = m.checkPostgres(ctx) case "mysql": err = m.checkMySQL(ctx) case "mongodb": err = m.checkMongoDB(ctx) default: err = fmt.Errorf("unsupported database type: %s", m.dbType) } result.ResponseTime = time.Since(start) if err != nil { result.Status = StatusDown result.Error = err return result } result.Status = StatusUp // Check for slow response (degraded if > 1 second) if result.ResponseTime > 1*time.Second { result.Status = StatusDegraded result.Error = fmt.Errorf("slow response: %v", result.ResponseTime) } return result } // checkRedis performs a Redis PING check using the RESP protocol func (m *DatabaseMonitor) checkRedis(ctx context.Context) error { // Parse target - could be host:port or redis://... target := m.target if strings.HasPrefix(target, "redis://") { target = strings.TrimPrefix(target, "redis://") // Remove any auth/db parts for simple connection if idx := strings.Index(target, "@"); idx != -1 { target = target[idx+1:] } if idx := strings.Index(target, "/"); idx != -1 { target = target[:idx] } } // Default port if not specified if _, _, err := net.SplitHostPort(target); err != nil { target = target + ":6379" } // Connect with timeout dialer := &net.Dialer{Timeout: m.timeout} conn, err := dialer.DialContext(ctx, "tcp", target) if err != nil { return fmt.Errorf("connection failed: %w", err) } defer conn.Close() // Set read/write deadline if err := conn.SetDeadline(time.Now().Add(m.timeout)); err != nil { return fmt.Errorf("failed to set deadline: %w", err) } // Send PING command (RESP protocol) _, err = conn.Write([]byte("*1\r\n$4\r\nPING\r\n")) if err != nil { return fmt.Errorf("failed to send PING: %w", err) } // Read response reader := bufio.NewReader(conn) response, err := reader.ReadString('\n') if err != nil { return fmt.Errorf("failed to read response: %w", err) } // Check for PONG response (simple string +PONG or bulk string) response = strings.TrimSpace(response) if response != "+PONG" && !strings.HasPrefix(response, "$") { // If it's an error response if strings.HasPrefix(response, "-") { return fmt.Errorf("redis error: %s", strings.TrimPrefix(response, "-")) } return fmt.Errorf("unexpected response: %s", response) } return nil } // checkMemcached performs a Memcached version check func (m *DatabaseMonitor) checkMemcached(ctx context.Context) error { target := m.target // Default port if not specified if _, _, err := net.SplitHostPort(target); err != nil { target = target + ":11211" } // Connect with timeout dialer := &net.Dialer{Timeout: m.timeout} conn, err := dialer.DialContext(ctx, "tcp", target) if err != nil { return fmt.Errorf("connection failed: %w", err) } defer conn.Close() // Set read/write deadline if err := conn.SetDeadline(time.Now().Add(m.timeout)); err != nil { return fmt.Errorf("failed to set deadline: %w", err) } // Send version command _, err = conn.Write([]byte("version\r\n")) if err != nil { return fmt.Errorf("failed to send version command: %w", err) } // Read response reader := bufio.NewReader(conn) response, err := reader.ReadString('\n') if err != nil { return fmt.Errorf("failed to read response: %w", err) } // Check for VERSION response if !strings.HasPrefix(response, "VERSION") { return fmt.Errorf("unexpected response: %s", strings.TrimSpace(response)) } return nil } // checkPostgres performs a PostgreSQL connection check func (m *DatabaseMonitor) checkPostgres(ctx context.Context) error { // Build connection string connStr := m.target if !strings.Contains(connStr, "://") && !strings.Contains(connStr, "=") { // Assume it's host:port format, build a connection string host, port, err := net.SplitHostPort(connStr) if err != nil { host = connStr port = "5432" } connStr = fmt.Sprintf("host=%s port=%s sslmode=disable connect_timeout=%d", host, port, int(m.timeout.Seconds())) } // For PostgreSQL, we need the driver // Since we want to avoid heavy dependencies, we'll do a simple TCP + startup message check return m.checkPostgresProtocol(ctx) } // checkPostgresProtocol performs a basic PostgreSQL protocol check func (m *DatabaseMonitor) checkPostgresProtocol(ctx context.Context) error { target := m.target // Parse connection string or host:port if strings.HasPrefix(target, "postgres://") || strings.HasPrefix(target, "postgresql://") { // Extract host:port from URL target = strings.TrimPrefix(target, "postgres://") target = strings.TrimPrefix(target, "postgresql://") if idx := strings.Index(target, "@"); idx != -1 { target = target[idx+1:] } if idx := strings.Index(target, "/"); idx != -1 { target = target[:idx] } if idx := strings.Index(target, "?"); idx != -1 { target = target[:idx] } } // Default port if not specified if _, _, err := net.SplitHostPort(target); err != nil { target = target + ":5432" } // Connect with timeout dialer := &net.Dialer{Timeout: m.timeout} conn, err := dialer.DialContext(ctx, "tcp", target) if err != nil { return fmt.Errorf("connection failed: %w", err) } defer conn.Close() // Set read deadline if err := conn.SetDeadline(time.Now().Add(m.timeout)); err != nil { return fmt.Errorf("failed to set deadline: %w", err) } // Send a minimal startup message (protocol version 3.0) // This will prompt the server to respond, even if auth fails // Message format: length (4 bytes) + protocol version (4 bytes) + parameters startupMsg := []byte{ 0, 0, 0, 8, // Length: 8 bytes 0, 3, 0, 0, // Protocol version 3.0 } _, err = conn.Write(startupMsg) if err != nil { return fmt.Errorf("failed to send startup message: %w", err) } // Read response - we expect either 'R' (auth request) or 'E' (error) response := make([]byte, 1) _, err = conn.Read(response) if err != nil { return fmt.Errorf("failed to read response: %w", err) } // 'R' = AuthenticationRequest, 'E' = ErrorResponse, 'N' = NoticeResponse // Any of these means PostgreSQL is running switch response[0] { case 'R', 'E', 'N', 'S': return nil // Server is responding default: return fmt.Errorf("unexpected response type: %c", response[0]) } } // checkMySQL performs a MySQL connection check using the protocol func (m *DatabaseMonitor) checkMySQL(ctx context.Context) error { target := m.target // Parse connection string or host:port if strings.HasPrefix(target, "mysql://") { target = strings.TrimPrefix(target, "mysql://") if idx := strings.Index(target, "@"); idx != -1 { target = target[idx+1:] } if idx := strings.Index(target, "/"); idx != -1 { target = target[:idx] } } // Handle tcp(host:port) format if strings.HasPrefix(target, "tcp(") { target = strings.TrimPrefix(target, "tcp(") target = strings.TrimSuffix(target, ")") } // Default port if not specified if _, _, err := net.SplitHostPort(target); err != nil { target = target + ":3306" } // Connect with timeout dialer := &net.Dialer{Timeout: m.timeout} conn, err := dialer.DialContext(ctx, "tcp", target) if err != nil { return fmt.Errorf("connection failed: %w", err) } defer conn.Close() // Set read deadline if err := conn.SetDeadline(time.Now().Add(m.timeout)); err != nil { return fmt.Errorf("failed to set deadline: %w", err) } // MySQL sends a handshake packet upon connection // Read the initial handshake packet header := make([]byte, 4) _, err = conn.Read(header) if err != nil { return fmt.Errorf("failed to read handshake header: %w", err) } // Packet length is first 3 bytes (little-endian) packetLen := int(header[0]) | int(header[1])<<8 | int(header[2])<<16 if packetLen <= 0 || packetLen > 65535 { return fmt.Errorf("invalid packet length: %d", packetLen) } // Read the rest of the handshake packet packet := make([]byte, packetLen) _, err = conn.Read(packet) if err != nil { return fmt.Errorf("failed to read handshake packet: %w", err) } // First byte of packet is protocol version (should be 10 for MySQL 5.x+) // or 0xFF for error packet if len(packet) > 0 { if packet[0] == 0xFF { // Error packet if len(packet) > 3 { errMsg := string(packet[3:]) return fmt.Errorf("mysql error: %s", errMsg) } return fmt.Errorf("mysql error response") } // Protocol version 10 or 9 means MySQL is running if packet[0] == 10 || packet[0] == 9 { return nil } } return fmt.Errorf("unexpected handshake response") } // checkMongoDB performs a MongoDB connection check func (m *DatabaseMonitor) checkMongoDB(ctx context.Context) error { target := m.target // Parse connection string or host:port if strings.HasPrefix(target, "mongodb://") || strings.HasPrefix(target, "mongodb+srv://") { target = strings.TrimPrefix(target, "mongodb://") target = strings.TrimPrefix(target, "mongodb+srv://") if idx := strings.Index(target, "@"); idx != -1 { target = target[idx+1:] } if idx := strings.Index(target, "/"); idx != -1 { target = target[:idx] } if idx := strings.Index(target, "?"); idx != -1 { target = target[:idx] } } // Default port if not specified if _, _, err := net.SplitHostPort(target); err != nil { target = target + ":27017" } // For MongoDB, we'll send an isMaster command using the wire protocol // This is a simple way to check if MongoDB is responding // Connect with timeout dialer := &net.Dialer{Timeout: m.timeout} conn, err := dialer.DialContext(ctx, "tcp", target) if err != nil { return fmt.Errorf("connection failed: %w", err) } defer conn.Close() // Set deadline if err := conn.SetDeadline(time.Now().Add(m.timeout)); err != nil { return fmt.Errorf("failed to set deadline: %w", err) } // Build isMaster command using OP_MSG (MongoDB 3.6+) // This is a simplified ping - we just check if we can establish connection // and the server responds to our message // For simplicity, we'll use the legacy OP_QUERY with isMaster // Document: { isMaster: 1 } // BSON: \x13\x00\x00\x00 (length 19) \x10 (int32) isMaster\x00 \x01\x00\x00\x00 \x00 isMasterDoc := []byte{ 0x13, 0x00, 0x00, 0x00, // document length (19 bytes) 0x10, // type: int32 'i', 's', 'M', 'a', 's', 't', 'e', 'r', 0x00, // "isMaster\0" 0x01, 0x00, 0x00, 0x00, // value: 1 0x00, // document terminator } // OP_QUERY message // Header: length (4) + requestID (4) + responseTo (4) + opCode (4) // Body: flags (4) + fullCollectionName + numberToSkip (4) + numberToReturn (4) + query collName := []byte("admin.$cmd\x00") msgLen := 16 + 4 + len(collName) + 4 + 4 + len(isMasterDoc) msg := make([]byte, 0, msgLen) // Header msg = appendInt32(msg, int32(msgLen)) // length msg = appendInt32(msg, 1) // requestID msg = appendInt32(msg, 0) // responseTo msg = appendInt32(msg, 2004) // opCode: OP_QUERY // Body msg = appendInt32(msg, 0) // flags msg = append(msg, collName...) // collection name msg = appendInt32(msg, 0) // numberToSkip msg = appendInt32(msg, 1) // numberToReturn msg = append(msg, isMasterDoc...) // query document _, err = conn.Write(msg) if err != nil { return fmt.Errorf("failed to send isMaster: %w", err) } // Read response header header := make([]byte, 16) _, err = conn.Read(header) if err != nil { return fmt.Errorf("failed to read response header: %w", err) } // Check opCode in response (bytes 12-15, little-endian) opCode := int32(header[12]) | int32(header[13])<<8 | int32(header[14])<<16 | int32(header[15])<<24 if opCode == 1 { // OP_REPLY return nil // MongoDB responded } return fmt.Errorf("unexpected response opCode: %d", opCode) } // appendInt32 appends a little-endian int32 to the byte slice func appendInt32(b []byte, v int32) []byte { return append(b, byte(v), byte(v>>8), byte(v>>16), byte(v>>24)) } // openDatabase opens a database connection using database/sql // This requires the appropriate driver to be imported func openDatabase(driverName, dataSourceName string, timeout time.Duration) (*sql.DB, error) { db, err := sql.Open(driverName, dataSourceName) if err != nil { return nil, err } // Set connection pool settings db.SetMaxOpenConns(1) db.SetMaxIdleConns(0) db.SetConnMaxLifetime(timeout) return db, nil }