migrate database

This commit is contained in:
Alireza Ahmadi
2024-12-22 11:53:44 +01:00
parent ecd9348a0f
commit 7a047daf6f
14 changed files with 567 additions and 113 deletions
+3
View File
@@ -37,6 +37,9 @@ func (a *APP) Init() error {
return err return err
} }
// Init Setting
a.SettingService.GetAllSetting()
a.core = core.NewCore() a.core = core.NewCore()
a.cronJob = cronjob.NewCronJob(a.core) a.cronJob = cronjob.NewCronJob(a.core)
+2 -1
View File
@@ -4,6 +4,7 @@ import (
"flag" "flag"
"fmt" "fmt"
"os" "os"
"s-ui/cmd/migration"
"s-ui/config" "s-ui/config"
) )
@@ -72,7 +73,7 @@ func ParseCmd() {
} }
case "migrate": case "migrate":
migrateDb() migration.MigrateDb()
case "setting": case "setting":
err := settingCmd.Parse(os.Args[2:]) err := settingCmd.Parse(os.Args[2:])
@@ -1,50 +1,14 @@
package cmd package migration
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"log"
"os"
"s-ui/config"
"s-ui/database/model" "s-ui/database/model"
"strings" "strings"
"gorm.io/driver/sqlite"
"gorm.io/gorm" "gorm.io/gorm"
) )
func migrateDb() {
// void running on first install
path := config.GetDBPath()
_, err := os.Stat(path)
if err != nil {
return
}
db, err := gorm.Open(sqlite.Open(path))
if err != nil {
log.Fatal(err)
}
tx := db.Begin()
defer func() {
if err == nil {
tx.Commit()
} else {
tx.Rollback()
}
}()
fmt.Println("Start migrating database...")
err = migrateClientSchema(tx)
if err != nil {
log.Fatal(err)
}
err = changesObj(tx)
if err != nil {
log.Fatal(err)
}
fmt.Println("Migration done!")
}
func migrateClientSchema(db *gorm.DB) error { func migrateClientSchema(db *gorm.DB) error {
rows, err := db.Raw("PRAGMA table_info(clients)").Rows() rows, err := db.Raw("PRAGMA table_info(clients)").Rows()
if err != nil { if err != nil {
@@ -95,10 +59,21 @@ func migrateClientSchema(db *gorm.DB) error {
} }
} }
} }
db.AutoMigrate(model.Client{})
return nil return nil
} }
func changesObj(db *gorm.DB) error { func changesObj(db *gorm.DB) error {
return db.Exec("UPDATE changes SET obj = CAST('\"' || CAST(obj AS TEXT) || '\"' AS BLOB) WHERE actor = ? and obj not like ?", "DepleteJob", "\"%\"").Error return db.Exec("UPDATE changes SET obj = CAST('\"' || CAST(obj AS TEXT) || '\"' AS BLOB) WHERE actor = ? and obj not like ?", "DepleteJob", "\"%\"").Error
} }
func to1_1(db *gorm.DB) error {
err := migrateClientSchema(db)
if err != nil {
return err
}
err = changesObj(db)
if err != nil {
return err
}
return nil
}
+199
View File
@@ -0,0 +1,199 @@
package migration
import (
"encoding/json"
"errors"
"os"
"path/filepath"
"s-ui/database/model"
"gorm.io/gorm"
)
type InboundData struct {
Id uint
Tag string
Addrs json.RawMessage
OutJson json.RawMessage
}
func moveJsonToDb(db *gorm.DB) error {
binFolderPath := os.Getenv("SUI_BIN_FOLDER")
if binFolderPath == "" {
binFolderPath = "bin"
}
dir, err := filepath.Abs(filepath.Dir(os.Args[0]))
if err != nil {
return err
}
configPath := dir + "/" + binFolderPath + "/config.json"
if _, err := os.Stat(configPath); errors.Is(err, os.ErrNotExist) {
return nil
}
data, err := os.ReadFile(configPath)
if err != nil {
return err
}
var oldConfig map[string]interface{}
err = json.Unmarshal(data, &oldConfig)
if err != nil {
return err
}
oldInbounds := oldConfig["inbounds"].([]interface{})
db.Migrator().DropTable(&model.Inbound{})
db.AutoMigrate(&model.Inbound{})
for _, inbound := range oldInbounds {
inbObj, _ := inbound.(map[string]interface{})
tag, _ := inbObj["tag"].(string)
if tlsObj, ok := inbObj["tls"]; ok {
var tls_id uint
err = db.Raw("SELECT id FROM tls WHERE inbounds like ?", `%"`+tag+`"%`).Find(&tls_id).Error
if err != nil {
return err
}
// Bind or Create tls_id
if tls_id > 0 {
inbObj["tls_id"] = tls_id
} else {
tls_server, _ := json.MarshalIndent(tlsObj, "", " ")
if len(tls_server) > 5 {
newTls := &model.Tls{
Name: tag,
Server: tls_server,
}
err = db.Create(newTls).Error
if err != nil {
return err
}
inbObj["tls_id"] = newTls.Id
}
}
}
var inbData InboundData
db.Raw("select id,addrs,out_json from inbound_data where tag = ?", tag).Find(&inbData)
if inbData.Id > 0 {
inbObj["outJson"] = inbData.OutJson
inbObj["addrs"] = inbData.Addrs
} else {
inbObj["outJson"] = json.RawMessage("{}")
inbObj["addrs"] = json.RawMessage("[]")
}
inbJson, _ := json.Marshal(inbObj)
var newInbound model.Inbound
err = newInbound.UnmarshalJSON(inbJson)
if err != nil {
return err
}
err = db.Create(&newInbound).Error
if err != nil {
return err
}
}
delete(oldConfig, "inbounds")
oldOutbounds := oldConfig["outbounds"].([]interface{})
db.Migrator().DropTable(&model.Outbound{}, &model.Endpoint{})
db.AutoMigrate(&model.Outbound{}, &model.Endpoint{})
for _, outbound := range oldOutbounds {
outType, _ := outbound.(map[string]interface{})["type"].(string)
outboundRaw, _ := json.MarshalIndent(outbound, "", " ")
if outType == "wireguard" { // Check if it is Entrypoint
var newEntrypoint model.Endpoint
err = newEntrypoint.UnmarshalJSON(outboundRaw)
if err != nil {
return err
}
err = db.Create(&newEntrypoint).Error
if err != nil {
return err
}
} else { // It is Outbound
var newOutbound model.Outbound
err = newOutbound.UnmarshalJSON(outboundRaw)
if err != nil {
return err
}
err = db.Create(&newOutbound).Error
if err != nil {
return err
}
}
}
delete(oldConfig, "outbounds")
// Remove v2rayapi and clashapi from experimental config
experimental := oldConfig["experimental"].(map[string]interface{})
delete(experimental, "v2ray_api")
delete(experimental, "clash_api")
oldConfig["experimental"] = experimental
// Save the other configs
var otherConfigs json.RawMessage
otherConfigs, err = json.MarshalIndent(oldConfig, "", " ")
if err != nil {
return err
}
return db.Save(&model.Setting{
Key: "config",
Value: string(otherConfigs),
}).Error
}
func migrateTls(db *gorm.DB) error {
if !db.Migrator().HasColumn(&model.Tls{}, "inbounds") {
return nil
}
return db.Migrator().DropColumn(&model.Tls{}, "inbounds")
}
func dropInboundData(db *gorm.DB) error {
if !db.Migrator().HasTable(&InboundData{}) {
return nil
}
return db.Migrator().DropTable(&InboundData{})
}
func migrateClients(db *gorm.DB) error {
var oldClients []model.Client
err := db.Model(model.Client{}).Scan(&oldClients).Error
if err != nil {
return err
}
for index, oldClient := range oldClients {
var old_inbounds []string
err = json.Unmarshal(oldClient.Inbounds, &old_inbounds)
if err != nil {
return err
}
var inbound_ids []uint
err = db.Raw("SELECT id FROM inbounds WHERE tag in ?", old_inbounds).Find(&inbound_ids).Error
if err != nil {
return err
}
oldClients[index].Inbounds, _ = json.Marshal(inbound_ids)
}
return db.Save(oldClients).Error
}
func to1_2(db *gorm.DB) error {
err := moveJsonToDb(db)
if err != nil {
return err
}
err = migrateTls(db)
if err != nil {
return err
}
err = dropInboundData(db)
if err != nil {
return err
}
return migrateClients(db)
}
+68
View File
@@ -0,0 +1,68 @@
package migration
import (
"fmt"
"log"
"os"
"s-ui/config"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
func MigrateDb() {
// void running on first install
path := config.GetDBPath()
_, err := os.Stat(path)
if err != nil {
println("Database not found")
return
}
db, err := gorm.Open(sqlite.Open(path))
if err != nil {
log.Fatal(err)
return
}
tx := db.Begin()
defer func() {
if err == nil {
tx.Commit()
} else {
tx.Rollback()
}
}()
currentVersion := config.GetVersion()
dbVersion := ""
tx.Raw("SELECT value FROM settings WHERE key = ?", "version").Find(&dbVersion)
fmt.Println("Current version:", currentVersion, "\nDatabase version:", dbVersion)
if currentVersion == dbVersion {
fmt.Println("Database is up to date, no need to migrate")
return
}
fmt.Println("Start migrating database...")
// Before 1.2
if dbVersion == "" {
err = to1_1(tx)
if err != nil {
log.Fatal("Migration to 1.1 failed: ", err)
return
}
err = to1_2(tx)
if err != nil {
log.Fatal("Migration to 1.2 failed: ", err)
return
}
}
// Set version
err = tx.Raw("UPDATE settings SET value = ? WHERE key = ?", currentVersion, "version").Error
if err != nil {
log.Fatal("Update version failed: ", err)
return
}
fmt.Println("Migration done!")
}
+6 -15
View File
@@ -17,22 +17,13 @@
"route": { "route": {
"rules": [ "rules": [
{ {
"protocol": "dns", "protocol": [
"outbound": "dns-out" "dns"
],
"outbound": "dns-out",
"action": "route"
} }
] ]
}, },
"experimental": { "experimental": {}
"v2ray_api": {
"listen": "127.0.0.1:1080",
"stats": {
"enabled": true,
"inbounds": [],
"outbounds": [
"direct"
],
"users": []
}
}
}
} }
+18 -1
View File
@@ -1,6 +1,7 @@
package database package database
import ( import (
"encoding/json"
"os" "os"
"path" "path"
"s-ui/config" "s-ui/config"
@@ -48,6 +49,10 @@ func OpenDB(dbPath string) error {
Logger: gormLogger, Logger: gormLogger,
} }
db, err = gorm.Open(sqlite.Open(dbPath), c) db, err = gorm.Open(sqlite.Open(dbPath), c)
if config.IsDebug() {
db = db.Debug()
}
return err return err
} }
@@ -57,10 +62,22 @@ func InitDB(dbPath string) error {
return err return err
} }
// Default Outbounds
if !db.Migrator().HasTable(&model.Outbound{}) {
db.Migrator().CreateTable(&model.Outbound{})
defaultOutbound := []model.Outbound{
{Type: "direct", Tag: "direct", Options: json.RawMessage(`{}`)},
{Type: "dns", Tag: "dns-out", Options: json.RawMessage(`{}`)},
}
db.Create(&defaultOutbound)
}
err = db.AutoMigrate( err = db.AutoMigrate(
&model.Setting{}, &model.Setting{},
&model.Tls{}, &model.Tls{},
&model.InboundData{}, &model.Inbound{},
&model.Outbound{},
&model.Endpoint{},
&model.User{}, &model.User{},
&model.Stats{}, &model.Stats{},
&model.Client{}, &model.Client{},
+54
View File
@@ -0,0 +1,54 @@
package model
import "encoding/json"
type Endpoint struct {
Id uint `json:"id" form:"id" gorm:"primaryKey;autoIncrement"`
Type string `json:"type" form:"type"`
Tag string `json:"tag" form:"tag"`
Options json.RawMessage `json:"-" form:"-"`
}
func (o *Endpoint) UnmarshalJSON(data []byte) error {
var err error
var raw map[string]interface{}
if err = json.Unmarshal(data, &raw); err != nil {
return err
}
// Extract fixed fields and store the rest in Options
if val, exists := raw["id"]; exists {
o.Id = val.(uint)
delete(raw, "id")
}
o.Type, _ = raw["type"].(string)
delete(raw, "type")
o.Tag = raw["tag"].(string)
delete(raw, "tag")
// Remaining fields
o.Options, err = json.Marshal(raw)
return err
}
// MarshalJSON customizes marshalling
func (o Endpoint) MarshalJSON() ([]byte, error) {
// Combine fixed fields and dynamic fields into one map
combined := make(map[string]interface{})
combined["id"] = o.Id
combined["type"] = o.Type
combined["tag"] = o.Tag
if o.Options != nil {
var restFields map[string]json.RawMessage
if err := json.Unmarshal(o.Options, &restFields); err != nil {
return nil, err
}
for k, v := range restFields {
combined[k] = v
}
}
return json.Marshal(combined)
}
+80
View File
@@ -0,0 +1,80 @@
package model
import "encoding/json"
type Inbound struct {
Id uint `json:"id" form:"id" gorm:"primaryKey;autoIncrement"`
Type string `json:"type" form:"type"`
Tag string `json:"tag" form:"tag"`
// Foreign key to tls table
TlsId uint `json:"tls_id" form:"tls_id"`
Tls *Tls `json:"tls" form:"tls" gorm:"foreignKey:TlsId;references:Id"`
Addrs json.RawMessage `json:"addrs" form:"addrs"`
OutJson json.RawMessage `json:"outJson" form:"outJson"`
Options json.RawMessage `json:"-" form:"-"`
}
func (i *Inbound) UnmarshalJSON(data []byte) error {
var err error
var raw map[string]interface{}
if err = json.Unmarshal(data, &raw); err != nil {
return err
}
// Extract fixed fields and store the rest in Options
if val, exists := raw["id"].(uint); exists {
i.Id = val
delete(raw, "id")
}
i.Type, _ = raw["type"].(string)
delete(raw, "type")
i.Tag, _ = raw["tag"].(string)
delete(raw, "tag")
// TlsId
if val, exists := raw["tls_id"].(float64); exists {
i.TlsId = uint(val)
}
delete(raw, "tls_id")
delete(raw, "tls")
delete(raw, "users")
// Addrs
i.Addrs, _ = json.MarshalIndent(raw["addrs"], "", " ")
delete(raw, "addrs")
// OutJson
i.OutJson, _ = json.MarshalIndent(raw["outJson"], "", " ")
delete(raw, "outJson")
// Remaining fields
i.Options, err = json.MarshalIndent(raw, "", " ")
return err
}
// MarshalJSON customizes marshalling
func (i Inbound) MarshalJSON() ([]byte, error) {
// Combine fixed fields and dynamic fields into one map
combined := make(map[string]interface{})
combined["id"] = i.Id
combined["type"] = i.Type
combined["tag"] = i.Tag
if i.Tls != nil {
combined["tls"] = i.Tls.Server
}
if i.Options != nil {
var restFields map[string]json.RawMessage
if err := json.Unmarshal(i.Options, &restFields); err != nil {
return nil, err
}
for k, v := range restFields {
combined[k] = v
}
}
return json.Marshal(combined)
}
+4 -12
View File
@@ -9,18 +9,10 @@ type Setting struct {
} }
type Tls struct { type Tls struct {
Id uint `json:"id" form:"id" gorm:"primaryKey;autoIncrement"` Id uint `json:"id" form:"id" gorm:"primaryKey;autoIncrement"`
Name string `json:"name" form:"name"` Name string `json:"name" form:"name"`
Inbounds json.RawMessage `json:"inbounds" form:"inbounds"` Server json.RawMessage `json:"server" form:"server"`
Server json.RawMessage `json:"server" form:"server"` Client json.RawMessage `json:"client" form:"client"`
Client json.RawMessage `json:"client" form:"client"`
}
type InboundData struct {
Id uint `json:"id" form:"id" gorm:"primaryKey;autoIncrement"`
Tag string `json:"tag" form:"tag"`
Addrs json.RawMessage `json:"addrs" form:"addrs"`
OutJson json.RawMessage `json:"outJson" form:"outJson"`
} }
type User struct { type User struct {
+54
View File
@@ -0,0 +1,54 @@
package model
import "encoding/json"
type Outbound struct {
Id uint `json:"id" form:"id" gorm:"primaryKey;autoIncrement"`
Type string `json:"type" form:"type"`
Tag string `json:"tag" form:"tag"`
Options json.RawMessage `json:"-" form:"-"`
}
func (o *Outbound) UnmarshalJSON(data []byte) error {
var err error
var raw map[string]interface{}
if err = json.Unmarshal(data, &raw); err != nil {
return err
}
// Extract fixed fields and store the rest in Options
if val, exists := raw["id"]; exists {
o.Id = val.(uint)
delete(raw, "id")
}
o.Type, _ = raw["type"].(string)
delete(raw, "type")
o.Tag = raw["tag"].(string)
delete(raw, "tag")
// Remaining fields
o.Options, err = json.Marshal(raw)
return err
}
// MarshalJSON customizes marshalling
func (o Outbound) MarshalJSON() ([]byte, error) {
// Combine fixed fields and dynamic fields into one map
combined := make(map[string]interface{})
combined["id"] = o.Id
combined["type"] = o.Type
combined["tag"] = o.Tag
if o.Options != nil {
var restFields map[string]json.RawMessage
if err := json.Unmarshal(o.Options, &restFields); err != nil {
return nil, err
}
for k, v := range restFields {
combined[k] = v
}
}
return json.Marshal(combined)
}
-46
View File
@@ -1,46 +0,0 @@
package service
import (
"encoding/json"
"s-ui/database"
"s-ui/database/model"
"gorm.io/gorm"
)
type InDataService struct {
}
func (s *InDataService) GetAll() ([]model.InboundData, error) {
db := database.GetDB()
inData := []model.InboundData{}
err := db.Model(model.InboundData{}).Scan(&inData).Error
if err != nil {
return nil, err
}
return inData, nil
}
func (s *InDataService) Save(tx *gorm.DB, changes []model.Changes) error {
var err error
for _, change := range changes {
inData := model.InboundData{}
err = json.Unmarshal(change.Obj, &inData)
if err != nil {
return err
}
switch change.Action {
case "new":
err = tx.Create(&inData).Error
case "del":
err = tx.Where("id = ?", change.Index).Delete(model.InboundData{}).Error
default:
err = tx.Save(inData).Error
}
if err != nil {
return err
}
}
return err
}
+34
View File
@@ -0,0 +1,34 @@
package service
import (
"s-ui/database"
"s-ui/database/model"
"gorm.io/gorm"
)
type InboundService struct{}
func (s *InboundService) GetAll() ([]model.Inbound, error) {
db := database.GetDB()
inbounds := []model.Inbound{}
err := db.Model(model.Inbound{}).Scan(&inbounds).Error
if err != nil {
return nil, err
}
return inbounds, nil
}
func (s *InboundService) FromIds(ids []uint) ([]*model.Inbound, error) {
db := database.GetDB()
inbounds := []*model.Inbound{}
err := db.Model(model.Inbound{}).Where("id in ?", ids).Scan(&inbounds).Error
if err != nil {
return nil, err
}
return inbounds, nil
}
func (s *InboundService) Save(db *gorm.DB, inbounds []*model.Inbound) error {
return db.Save(inbounds).Error
}
+32
View File
@@ -3,6 +3,7 @@ package service
import ( import (
"encoding/json" "encoding/json"
"os" "os"
"s-ui/config"
"s-ui/database" "s-ui/database"
"s-ui/database/model" "s-ui/database/model"
"s-ui/logger" "s-ui/logger"
@@ -14,6 +15,25 @@ import (
"gorm.io/gorm" "gorm.io/gorm"
) )
var defaultConfig = `{
"log": {
"level": "info"
},
"dns": {},
"route": {
"rules": [
{
"protocol": [
"dns"
],
"outbound": "dns-out",
"action": "route"
}
]
},
"experimental": {}
}`
var defaultValueMap = map[string]string{ var defaultValueMap = map[string]string{
"webListen": "", "webListen": "",
"webDomain": "", "webDomain": "",
@@ -37,6 +57,8 @@ var defaultValueMap = map[string]string{
"subShowInfo": "false", "subShowInfo": "false",
"subURI": "", "subURI": "",
"subJsonExt": "", "subJsonExt": "",
"config": defaultConfig,
"version": config.GetVersion(),
} }
type SettingService struct { type SettingService struct {
@@ -67,6 +89,8 @@ func (s *SettingService) GetAllSetting() (*map[string]string, error) {
// Due to security principles // Due to security principles
delete(allSetting, "secret") delete(allSetting, "secret")
delete(allSetting, "config")
delete(allSetting, "version")
return &allSetting, nil return &allSetting, nil
} }
@@ -311,6 +335,14 @@ func (s *SettingService) GetFinalSubURI(host string) (string, error) {
return protocol + "://" + host + port + (*allSetting)["subPath"], nil return protocol + "://" + host + port + (*allSetting)["subPath"], nil
} }
func (s *SettingService) GetConfig() (string, error) {
return s.getString("config")
}
func (s *SettingService) SetConfig(config string) error {
return s.setString("config", config)
}
func (s *SettingService) Save(tx *gorm.DB, changes []model.Changes) error { func (s *SettingService) Save(tx *gorm.DB, changes []model.Changes) error {
var err error var err error
for _, change := range changes { for _, change := range changes {