diff options
Diffstat (limited to 'backend')
| -rw-r--r-- | backend/api/api.go | 7 | ||||
| -rw-r--r-- | backend/db/mongo.go | 14 |
2 files changed, 12 insertions, 9 deletions
diff --git a/backend/api/api.go b/backend/api/api.go index 9dd68a9..183c090 100644 --- a/backend/api/api.go +++ b/backend/api/api.go @@ -1,6 +1,7 @@ package api import ( + "context" "net/http" "os" "os/signal" @@ -10,9 +11,15 @@ import ( mux "github.com/gorilla/mux" log "github.com/sirupsen/logrus" + "github.com/jackyzha0/ctrl-v/db" ) func cleanup() { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + if err := db.Client.Disconnect(ctx); err != nil { + panic(err) + } log.Print("Shutting down server...") } diff --git a/backend/db/mongo.go b/backend/db/mongo.go index dc1dd25..f7870ac 100644 --- a/backend/db/mongo.go +++ b/backend/db/mongo.go @@ -13,6 +13,7 @@ import ( "go.mongodb.org/mongo-driver/mongo/readpref" ) +var Client *mongo.Client var Session *mongo.Session var pastes *mongo.Collection @@ -25,21 +26,16 @@ func initSessions(user, pass, ip string) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - client, err := mongo.Connect(ctx, options.Client().ApplyURI(mongoURI).SetTLSConfig(&tls.Config{})) + c, err := mongo.Connect(ctx, options.Client().ApplyURI(mongoURI).SetTLSConfig(&tls.Config{})) + Client = c if err != nil { log.Fatalf("error establishing connection to mongo: %s", err.Error()) } - err = client.Ping(ctx, readpref.Primary()) + err = Client.Ping(ctx, readpref.Primary()) if err != nil { log.Fatalf("error pinging mongo: %s", err.Error()) } - defer func() { - if err = client.Disconnect(ctx); err != nil { - panic(err) - } - }() - // ensure expiry check expiryIndex := options.Index().SetExpireAfterSeconds(0) sessionTTL := mongo.IndexModel{ @@ -55,7 +51,7 @@ func initSessions(user, pass, ip string) { } // Define connection to Databases - pastes = client.Database("main").Collection("pastes") + pastes = Client.Database("main").Collection("pastes") _, _ = pastes.Indexes().CreateOne(ctx, sessionTTL) _, _ = pastes.Indexes().CreateOne(ctx, uniqueHashes) } |