diff options
Diffstat (limited to 'backend/db/mongo.go')
| -rw-r--r-- | backend/db/mongo.go | 64 |
1 files changed, 34 insertions, 30 deletions
diff --git a/backend/db/mongo.go b/backend/db/mongo.go index 4c8a739..b29f5f2 100644 --- a/backend/db/mongo.go +++ b/backend/db/mongo.go @@ -1,17 +1,18 @@ package db import ( - "crypto/tls" + "context" "fmt" - "net" + "time" - "github.com/globalsign/mgo" - "github.com/globalsign/mgo/bson" log "github.com/sirupsen/logrus" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" ) -var Session *mgo.Session -var pastes *mgo.Collection +var Session *mongo.Session +var pastes *mongo.Collection func initSessions(user, pass, ip string) { log.Infof("attempting connection to %s", ip) @@ -19,49 +20,52 @@ func initSessions(user, pass, ip string) { // build uri string URIfmt := "mongodb://%s:%s@%s:27017" mongoURI := fmt.Sprintf(URIfmt, user, pass, ip) - dialInfo, err := mgo.ParseURL(mongoURI) - if err != nil { - log.Fatalf("error parsing uri: %s", err.Error()) - } - - tlsConfig := &tls.Config{} - dialInfo.DialServer = func(addr *mgo.ServerAddr) (net.Conn, error) { - conn, err := tls.Dial("tcp", addr.String(), tlsConfig) - return conn, err - } - Session, err = mgo.DialWithInfo(dialInfo) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + client, err := mongo.Connect(ctx, options.Client().ApplyURI(mongoURI)) if err != nil { log.Fatalf("error establishing connection to mongo: %s", err.Error()) } // ensure expiry check - sessionTTL := mgo.Index{ - Key: []string{"expiry"}, - ExpireAfter: 0, + expiryIndex := options.Index().SetExpireAfterSeconds(0) + sessionTTL := mongo.IndexModel{ + Keys: []string{"expiry"}, + Options: expiryIndex, } // ensure hashes are unique - uniqueHashes := mgo.Index{ - Key: []string{"hash"}, - Unique: true, + uniqueIndex := options.Index().SetUnique(true) + uniqueHashes := mongo.IndexModel{ + Keys: []string{"hash"}, + Options: uniqueIndex, } - _ = Session.DB("main").C("pastes").EnsureIndex(sessionTTL) - _ = Session.DB("main").C("pastes").EnsureIndex(uniqueHashes) + _, _ = client.Database("main").Collection("pastes").Indexes().CreateOne(ctx, sessionTTL) + _, _ = client.Database("main").Collection("pastes").Indexes().CreateOne(ctx, uniqueHashes) // Define connection to Databases - pastes = Session.DB("main").C("pastes") + pastes = client.Database("main").Collection("pastes") } func insert(new Paste) error { - return pastes.Insert(new) + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + _, err := pastes.InsertOne(ctx, new) + return err } func fetch(hash string) (Paste, error) { p := Paste{} - q := bson.M{"hash": hash} - err := pastes.Find(q).One(&p) - return p, err + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + result := pastes.FindOne(ctx, q) + if (result.Err() != nil) { + return p, result.Err() + } else { + result.Decode(&p) + return p, nil + } } |