aboutsummaryrefslogtreecommitdiff
path: root/backend
diff options
context:
space:
mode:
Diffstat (limited to 'backend')
-rw-r--r--backend/api/api.go7
-rw-r--r--backend/db/mongo.go14
2 files changed, 12 insertions, 9 deletions
diff --git a/backend/api/api.go b/backend/api/api.go
index 9dd68a9..183c090 100644
--- a/backend/api/api.go
+++ b/backend/api/api.go
@@ -1,6 +1,7 @@
package api
import (
+ "context"
"net/http"
"os"
"os/signal"
@@ -10,9 +11,15 @@ import (
mux "github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
+ "github.com/jackyzha0/ctrl-v/db"
)
func cleanup() {
+ ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+ if err := db.Client.Disconnect(ctx); err != nil {
+ panic(err)
+ }
log.Print("Shutting down server...")
}
diff --git a/backend/db/mongo.go b/backend/db/mongo.go
index dc1dd25..f7870ac 100644
--- a/backend/db/mongo.go
+++ b/backend/db/mongo.go
@@ -13,6 +13,7 @@ import (
"go.mongodb.org/mongo-driver/mongo/readpref"
)
+var Client *mongo.Client
var Session *mongo.Session
var pastes *mongo.Collection
@@ -25,21 +26,16 @@ func initSessions(user, pass, ip string) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
- client, err := mongo.Connect(ctx, options.Client().ApplyURI(mongoURI).SetTLSConfig(&tls.Config{}))
+ c, err := mongo.Connect(ctx, options.Client().ApplyURI(mongoURI).SetTLSConfig(&tls.Config{}))
+ Client = c
if err != nil {
log.Fatalf("error establishing connection to mongo: %s", err.Error())
}
- err = client.Ping(ctx, readpref.Primary())
+ err = Client.Ping(ctx, readpref.Primary())
if err != nil {
log.Fatalf("error pinging mongo: %s", err.Error())
}
- defer func() {
- if err = client.Disconnect(ctx); err != nil {
- panic(err)
- }
- }()
-
// ensure expiry check
expiryIndex := options.Index().SetExpireAfterSeconds(0)
sessionTTL := mongo.IndexModel{
@@ -55,7 +51,7 @@ func initSessions(user, pass, ip string) {
}
// Define connection to Databases
- pastes = client.Database("main").Collection("pastes")
+ pastes = Client.Database("main").Collection("pastes")
_, _ = pastes.Indexes().CreateOne(ctx, sessionTTL)
_, _ = pastes.Indexes().CreateOne(ctx, uniqueHashes)
}