diff --git a/database/backup.go b/database/backup.go index 53846a0..9ca3bd4 100644 --- a/database/backup.go +++ b/database/backup.go @@ -40,6 +40,7 @@ func GetDb(exclude string) ([]byte, error) { if err != nil { return nil, err } + defer os.Remove(dbPath) err = backupDb.AutoMigrate( &model.Setting{}, @@ -69,29 +70,50 @@ func GetDb(exclude string) ([]byte, error) { // Perform scans and handle errors if err := db.Model(&model.Setting{}).Scan(&settings).Error; err != nil { return nil, err + } else if len(settings) > 0 { + if err := backupDb.Save(settings).Error; err != nil { + return nil, err + } } if err := db.Model(&model.Tls{}).Scan(&tls).Error; err != nil { return nil, err + } else if len(tls) > 0 { + if err := backupDb.Save(tls).Error; err != nil { + return nil, err + } } if err := db.Model(&model.Inbound{}).Scan(&inbound).Error; err != nil { return nil, err + } else if len(inbound) > 0 { + if err := backupDb.Save(inbound).Error; err != nil { + return nil, err + } } if err := db.Model(&model.Outbound{}).Scan(&outbound).Error; err != nil { return nil, err + } else if len(outbound) > 0 { + if err := backupDb.Save(outbound).Error; err != nil { + return nil, err + } } if err := db.Model(&model.Endpoint{}).Scan(&endpoint).Error; err != nil { return nil, err + } else if len(endpoint) > 0 { + if err := backupDb.Save(endpoint).Error; err != nil { + return nil, err + } } if err := db.Model(&model.User{}).Scan(&users).Error; err != nil { return nil, err + } else if len(users) > 0 { + if err := backupDb.Save(users).Error; err != nil { + return nil, err + } } if err := db.Model(&model.Client{}).Scan(&clients).Error; err != nil { return nil, err - } - - // Save each model - for _, mdl := range []interface{}{settings, tls, inbound, outbound, endpoint, users, clients} { - if err := backupDb.Save(mdl).Error; err != nil { + } else if len(clients) > 0 { + if err := backupDb.Save(clients).Error; err != nil { return nil, err } } @@ -132,7 +154,6 @@ func GetDb(exclude string) ([]byte, error) { return nil, err } defer file.Close() - defer os.Remove(dbPath) // Read the file contents fileContents, err := io.ReadAll(file)