From ed48cdca33729a10c556efa658d7b8eaf44cedd6 Mon Sep 17 00:00:00 2001 From: Alireza Ahmadi Date: Wed, 25 Dec 2024 10:57:17 +0100 Subject: [PATCH] load from database --- backend/api/api.go | 6 +- backend/app/app.go | 21 +-- backend/config/config.go | 7 +- backend/core/main.go | 21 +-- backend/database/model/endpoints.go | 1 - backend/database/model/inbounds.go | 5 +- backend/database/model/outbounds.go | 1 - backend/go.mod | 4 +- backend/go.sum | 4 + backend/service/config.go | 198 +++++++++++++++------------- backend/service/endpoints.go | 48 +++++++ backend/service/inbounds.go | 56 +++++++- backend/service/outbounds.go | 48 +++++++ 13 files changed, 290 insertions(+), 130 deletions(-) create mode 100644 backend/service/endpoints.go create mode 100644 backend/service/outbounds.go diff --git a/backend/api/api.go b/backend/api/api.go index 9e77827..dd596ed 100644 --- a/backend/api/api.go +++ b/backend/api/api.go @@ -16,7 +16,7 @@ type APIHandler struct { service.ConfigService service.ClientService service.TlsService - service.InDataService + service.InboundService service.PanelService service.StatsService service.ServerService @@ -207,7 +207,7 @@ func (a *APIHandler) loadData(c *gin.Context) (interface{}, error) { if err != nil { return "", err } - inData, err := a.InDataService.GetAll() + inbounds, err := a.InboundService.GetAll() if err != nil { return "", err } @@ -218,7 +218,7 @@ func (a *APIHandler) loadData(c *gin.Context) (interface{}, error) { data["config"] = *config data["clients"] = clients data["tls"] = tlsConfigs - data["inData"] = inData + data["inbounds"] = inbounds data["subURI"] = subURI data["onlines"] = onlines } else { diff --git a/backend/app/app.go b/backend/app/app.go index df926ba..c9762a0 100644 --- a/backend/app/app.go +++ b/backend/app/app.go @@ -16,11 +16,12 @@ import ( type APP struct { service.SettingService - webServer *web.Server - subServer *sub.Server - cronJob *cronjob.CronJob - logger *logging.Logger - core *core.Core + configService *service.ConfigService + webServer *web.Server + subServer *sub.Server + cronJob *cronjob.CronJob + logger *logging.Logger + core *core.Core } func NewApp() *APP { @@ -46,8 +47,8 @@ func (a *APP) Init() error { a.webServer = web.NewServer() a.subServer = sub.NewServer() - configService := service.NewConfigService(a.core) - err = configService.InitConfig() + a.configService = service.NewConfigService(a.core) + err = a.configService.InitConfig() if err != nil { return err } @@ -80,7 +81,7 @@ func (a *APP) Start() error { return err } - err = a.core.Start() + err = a.configService.StartCore() if err != nil { logger.Error(err) } @@ -98,6 +99,10 @@ func (a *APP) Stop() { if err != nil { logger.Warning("stop Web Server err:", err) } + err = a.configService.StopCore() + if err != nil { + logger.Warning("stop Core err:", err) + } } func (a *APP) initLog() { diff --git a/backend/config/config.go b/backend/config/config.go index 8917bd3..12349b3 100644 --- a/backend/config/config.go +++ b/backend/config/config.go @@ -4,6 +4,7 @@ import ( _ "embed" "fmt" "os" + "path/filepath" "strings" ) @@ -59,7 +60,11 @@ func GetBinFolderPath() string { func GetDBFolderPath() string { dbFolderPath := os.Getenv("SUI_DB_FOLDER") if dbFolderPath == "" { - dbFolderPath = "/usr/local/s-ui/db" + dir, err := filepath.Abs(filepath.Dir(os.Args[0])) + if err != nil { + dbFolderPath = "/usr/local/s-ui/db" + } + dbFolderPath = dir + "/db" } return dbFolderPath } diff --git a/backend/core/main.go b/backend/core/main.go index ac62a43..ef3ae29 100644 --- a/backend/core/main.go +++ b/backend/core/main.go @@ -2,8 +2,6 @@ package core import ( "context" - "os" - "s-ui/config" "s-ui/logger" sb "github.com/sagernet/sing-box" @@ -49,14 +47,9 @@ func (c *Core) GetInstance() *Box { return c.instance } -func (c *Core) Start() error { - filepath := config.GetBinFolderPath() + "/config.json" - configFile, err := os.ReadFile(filepath) - if err != nil { - return err - } +func (c *Core) Start(sbConfig []byte) error { var opt option.Options - err = opt.UnmarshalJSONContext(globalCtx, configFile) + err := opt.UnmarshalJSONContext(globalCtx, sbConfig) if err != nil { logger.Error("Unmarshal config err:", err.Error()) } @@ -92,16 +85,6 @@ func (c *Core) Stop() error { return nil } -func (c *Core) Restart() error { - err := c.Stop() - if err != nil { - logger.Error("stop sing-box err:", err.Error()) - return err - } - logger.Info("sing-box stopped") - return c.Start() -} - func (c *Core) IsRunning() bool { return c.isRunning } diff --git a/backend/database/model/endpoints.go b/backend/database/model/endpoints.go index 9c92ade..d6eae27 100644 --- a/backend/database/model/endpoints.go +++ b/backend/database/model/endpoints.go @@ -35,7 +35,6 @@ func (o *Endpoint) UnmarshalJSON(data []byte) error { func (o Endpoint) MarshalJSON() ([]byte, error) { // Combine fixed fields and dynamic fields into one map combined := make(map[string]interface{}) - combined["id"] = o.Id combined["type"] = o.Type combined["tag"] = o.Tag diff --git a/backend/database/model/inbounds.go b/backend/database/model/inbounds.go index 7d2e3b0..56c0c7c 100644 --- a/backend/database/model/inbounds.go +++ b/backend/database/model/inbounds.go @@ -1,6 +1,8 @@ package model -import "encoding/json" +import ( + "encoding/json" +) type Inbound struct { Id uint `json:"id" form:"id" gorm:"primaryKey;autoIncrement"` @@ -58,7 +60,6 @@ func (i *Inbound) UnmarshalJSON(data []byte) error { func (i Inbound) MarshalJSON() ([]byte, error) { // Combine fixed fields and dynamic fields into one map combined := make(map[string]interface{}) - combined["id"] = i.Id combined["type"] = i.Type combined["tag"] = i.Tag if i.Tls != nil { diff --git a/backend/database/model/outbounds.go b/backend/database/model/outbounds.go index 830c9ad..0ae84ab 100644 --- a/backend/database/model/outbounds.go +++ b/backend/database/model/outbounds.go @@ -35,7 +35,6 @@ func (o *Outbound) UnmarshalJSON(data []byte) error { func (o Outbound) MarshalJSON() ([]byte, error) { // Combine fixed fields and dynamic fields into one map combined := make(map[string]interface{}) - combined["id"] = o.Id combined["type"] = o.Type combined["tag"] = o.Tag diff --git a/backend/go.mod b/backend/go.mod index d99f888..ad867e6 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -12,7 +12,7 @@ require ( github.com/sagernet/sing-dns v0.4.0-beta.1 github.com/shirou/gopsutil/v3 v3.24.5 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 - gorm.io/driver/sqlite v1.5.6 + gorm.io/driver/sqlite v1.5.7 gorm.io/gorm v1.25.12 ) @@ -64,7 +64,7 @@ require ( github.com/logrusorgru/aurora v2.0.3+incompatible // indirect github.com/lufia/plan9stats v0.0.0-20240909124753-873cd0166683 // indirect github.com/mattn/go-isatty v0.0.20 // indirect - github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect + github.com/mattn/go-sqlite3 v1.14.24 // indirect github.com/mdlayher/netlink v1.7.2 // indirect github.com/mdlayher/socket v0.4.1 // indirect github.com/metacubex/tfo-go v0.0.0-20241006021335-daedaf0ca7aa // indirect diff --git a/backend/go.sum b/backend/go.sum index 8083f22..2bc2b2a 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -110,6 +110,8 @@ github.com/lufia/plan9stats v0.0.0-20240909124753-873cd0166683 h1:7UMa6KCCMjZEMD github.com/lufia/plan9stats v0.0.0-20240909124753-873cd0166683/go.mod h1:ilwx/Dta8jXAgpFYFvSWEMwxmbWXyiUHkd5FwyKhb5k= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM= +github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U= github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g= @@ -302,6 +304,8 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gorm.io/driver/sqlite v1.5.6 h1:fO/X46qn5NUEEOZtnjJRWRzZMe8nqJiQ9E+0hi+hKQE= gorm.io/driver/sqlite v1.5.6/go.mod h1:U+J8craQU6Fzkcvu8oLeAQmi50TkwPEhHDEjQZXDah4= +gorm.io/driver/sqlite v1.5.7 h1:8NvsrhP0ifM7LX9G4zPB97NwovUakUxc+2V2uuf3Z1I= +gorm.io/driver/sqlite v1.5.7/go.mod h1:U+J8craQU6Fzkcvu8oLeAQmi50TkwPEhHDEjQZXDah4= gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8= gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ= lukechampine.com/blake3 v1.3.0 h1:sJ3XhFINmHSrYCgl958hscfIa3bw8x4DqMP3u1YvoYE= diff --git a/backend/service/config.go b/backend/service/config.go index e816f77..754de3d 100644 --- a/backend/service/config.go +++ b/backend/service/config.go @@ -21,8 +21,10 @@ var ( type ConfigService struct { ClientService TlsService - InDataService SettingService + InboundService + OutboundService + EndpointService } type SingBoxConfig struct { @@ -31,6 +33,7 @@ type SingBoxConfig struct { Ntp json.RawMessage `json:"ntp"` Inbounds []json.RawMessage `json:"inbounds"` Outbounds []json.RawMessage `json:"outbounds"` + Endpoints []json.RawMessage `json:"endpoints"` Route json.RawMessage `json:"route"` Experimental json.RawMessage `json:"experimental"` } @@ -42,47 +45,70 @@ func NewConfigService(core *core.Core) *ConfigService { func (s *ConfigService) InitConfig() error { IsSystemd = config.IsSystemd() - configPath := config.GetBinFolderPath() - data, err := os.ReadFile(configPath + "/config.json") - if err != nil { - if os.IsNotExist(err) { - defaultConfig := []byte(config.GetDefaultConfig()) - err = os.MkdirAll(configPath, 01764) - if err != nil { - return err - } - err = os.WriteFile(configPath+"/config.json", defaultConfig, 0764) - if err != nil { - return err - } - data = defaultConfig - } else { - return err - } - } - var singboxConfig SingBoxConfig - err = json.Unmarshal(data, &singboxConfig) - if err != nil { - return err - } - return nil } func (s *ConfigService) GetConfig() (*SingBoxConfig, error) { - configPath := config.GetBinFolderPath() - data, err := os.ReadFile(configPath + "/config.json") + data, err := s.SettingService.GetConfig() if err != nil { return nil, err } singboxConfig := SingBoxConfig{} - err = json.Unmarshal(data, &singboxConfig) + err = json.Unmarshal([]byte(data), &singboxConfig) + if err != nil { + return nil, err + } + + singboxConfig.Inbounds, err = s.InboundService.GetAllConfig(database.GetDB()) + if err != nil { + return nil, err + } + singboxConfig.Outbounds, err = s.OutboundService.GetAllConfig(database.GetDB()) + if err != nil { + return nil, err + } + singboxConfig.Endpoints, err = s.EndpointService.GetAllConfig(database.GetDB()) if err != nil { return nil, err } return &singboxConfig, nil } +func (s *ConfigService) StartCore() error { + singboxConfig, err := s.GetConfig() + if err != nil { + return err + } + rawConfig, err := json.MarshalIndent(singboxConfig, "", " ") + if err != nil { + return err + } + err = corePtr.Start(rawConfig) + if err != nil { + logger.Error("start sing-box err:", err.Error()) + return err + } + logger.Info("sing-box started") + return nil +} + +func (s *ConfigService) RestartCore() error { + err := s.StartCore() + if err != nil { + return err + } + return s.StartCore() +} + +func (s *ConfigService) StopCore() error { + err := corePtr.Stop() + if err != nil { + return err + } + logger.Info("sing-box stopped") + return nil +} + func (s *ConfigService) SaveChanges(changes map[string]string, loginUser string) error { var err error var clientChanges, tlsChanges, inChanges, settingChanges, configChanges []model.Changes @@ -139,12 +165,12 @@ func (s *ConfigService) SaveChanges(changes map[string]string, loginUser string) return err } } - if len(inChanges) > 0 { - err = s.InDataService.Save(tx, inChanges) - if err != nil { - return err - } - } + // if len(inChanges) > 0 { + // err = s.InDataService.Save(tx, inChanges) + // if err != nil { + // return err + // } + // } if len(settingChanges) > 0 { err = s.SettingService.Save(tx, settingChanges) if err != nil { @@ -342,7 +368,7 @@ func (s *ConfigService) Save(singboxConfig *SingBoxConfig, needRestart bool) err } if needRestart { - err = corePtr.Restart() + err = s.RestartCore() if err != nil { return err } @@ -353,61 +379,59 @@ func (s *ConfigService) Save(singboxConfig *SingBoxConfig, needRestart bool) err } func (s *ConfigService) DepleteClients() error { - users, inbounds, err := s.ClientService.DepleteClients() - if err != nil || len(users) == 0 || len(inbounds) == 0 { + users, inboundIds, err := s.ClientService.DepleteClients() + if err != nil || len(users) == 0 || len(inboundIds) == 0 { return err } - singboxConfig, err := s.GetConfig() - if err != nil { - return err - } - for inbound_index, inbound := range singboxConfig.Inbounds { - var inboundJson map[string]interface{} - json.Unmarshal(inbound, &inboundJson) - if s.contains(inbounds, inboundJson["tag"].(string)) { - inbound_users, ok := inboundJson["users"].([]interface{}) - if ok { - var updatedUsers []interface{} - for _, user := range inbound_users { - userMap, ok := user.(map[string]interface{}) - if ok { - name, exists := userMap["name"].(string) - if exists && s.contains(users, name) { - // Skip the user exists - continue - } - username, exists := userMap["username"].(string) - if exists && s.contains(users, username) { - // Skip the username exists - continue - } - } - updatedUsers = append(updatedUsers, user) - } - // Exception for Naive and ShadowTLSv3 - if len(updatedUsers) == 0 { - if inboundJson["type"].(string) == "naive" || - (inboundJson["type"].(string) == "shadowtls" && - inboundJson["version"].(float64) == 3) { - updatedUsers = append(updatedUsers, make(map[string]interface{})) - } - } + // inbounds, err := s.InboundService.FromIds(inboundIds) + // if err != nil { + // return err + // } + // for inbound_index, inbound := range inbounds { + // var inboundJson map[string]interface{} + // json.Unmarshal(inbound.Options, &inboundJson) + // inbound_users, ok := inboundJson["users"].([]interface{}) + // if ok { + // var updatedUsers []interface{} + // for _, user := range inbound_users { + // userMap, ok := user.(map[string]interface{}) + // if ok { + // name, exists := userMap["name"].(string) + // if exists && s.contains(users, name) { + // // Skip the user exists + // continue + // } + // username, exists := userMap["username"].(string) + // if exists && s.contains(users, username) { + // // Skip the username exists + // continue + // } + // } + // updatedUsers = append(updatedUsers, user) + // } + // // Exception for Naive and ShadowTLSv3 + // if len(updatedUsers) == 0 { + // if inboundJson["type"].(string) == "naive" || + // (inboundJson["type"].(string) == "shadowtls" && + // inboundJson["version"].(float64) == 3) { + // updatedUsers = append(updatedUsers, make(map[string]interface{})) + // } + // } - inboundJson["users"] = updatedUsers - } - } - modifiedInbound, err := json.MarshalIndent(inboundJson, "", " ") - if err != nil { - return err - } - singboxConfig.Inbounds[inbound_index] = modifiedInbound - } + // inboundJson["users"] = updatedUsers + // } + // modifiedInbound, err := json.MarshalIndent(inboundJson, "", " ") + // if err != nil { + // return err + // } + // inbounds[inbound_index] = modifiedInbound + // } - err = s.Save(singboxConfig, true) - if err != nil { - return err - } + // err = s.Save(singboxConfig, true) + // if err != nil { + // return err + // } return nil } @@ -437,7 +461,3 @@ func (s *ConfigService) GetChanges(actor string, chngKey string, count string) [ } return chngs } - -func (s *ConfigService) RestartCore() error { - return corePtr.Restart() -} diff --git a/backend/service/endpoints.go b/backend/service/endpoints.go new file mode 100644 index 0000000..704b48b --- /dev/null +++ b/backend/service/endpoints.go @@ -0,0 +1,48 @@ +package service + +import ( + "encoding/json" + "s-ui/database" + "s-ui/database/model" + + "gorm.io/gorm" +) + +type EndpointService struct{} + +func (o *EndpointService) GetAll() ([]*model.Endpoint, error) { + db := database.GetDB() + endpoints := []*model.Endpoint{} + err := db.Model(model.Endpoint{}).Scan(&endpoints).Error + if err != nil { + return nil, err + } + return endpoints, nil +} + +func (o *EndpointService) Get(id uint) (*model.Endpoint, error) { + db := database.GetDB() + endpoint := &model.Endpoint{} + err := db.First(endpoint, id).Error + if err != nil { + return nil, err + } + return endpoint, nil +} + +func (o *EndpointService) GetAllConfig(db *gorm.DB) ([]json.RawMessage, error) { + var endpointsJson []json.RawMessage + var endpoints []*model.Endpoint + err := db.Model(model.Endpoint{}).Scan(&endpoints).Error + if err != nil { + return nil, err + } + for _, endpoint := range endpoints { + endpointJson, err := endpoint.MarshalJSON() + if err != nil { + return nil, err + } + endpointsJson = append(endpointsJson, endpointJson) + } + return endpointsJson, nil +} diff --git a/backend/service/inbounds.go b/backend/service/inbounds.go index bd51898..de4c706 100644 --- a/backend/service/inbounds.go +++ b/backend/service/inbounds.go @@ -1,6 +1,7 @@ package service import ( + "encoding/json" "s-ui/database" "s-ui/database/model" @@ -9,14 +10,14 @@ import ( type InboundService struct{} -func (s *InboundService) GetAll() ([]model.Inbound, error) { +func (s *InboundService) GetAll() (*[]map[string]interface{}, error) { db := database.GetDB() - inbounds := []model.Inbound{} - err := db.Model(model.Inbound{}).Scan(&inbounds).Error + inbounds := []map[string]interface{}{} + err := db.Model(model.Inbound{}).Select("id, tag, type, address, port, tls_id , count(users) as ucount").Scan(&inbounds).Error if err != nil { return nil, err } - return inbounds, nil + return &inbounds, nil } func (s *InboundService) FromIds(ids []uint) ([]*model.Inbound, error) { @@ -32,3 +33,50 @@ func (s *InboundService) FromIds(ids []uint) ([]*model.Inbound, error) { func (s *InboundService) Save(db *gorm.DB, inbounds []*model.Inbound) error { return db.Save(inbounds).Error } + +func (s *InboundService) GetAllConfig(db *gorm.DB) ([]json.RawMessage, error) { + var inboundsJson []json.RawMessage + var inbounds []*model.Inbound + err := db.Model(model.Inbound{}).Preload("Tls").Find(&inbounds).Error + if err != nil { + return nil, err + } + for _, inbound := range inbounds { + inboundJson, err := inbound.MarshalJSON() + if err != nil { + return nil, err + } + switch inbound.Type { + case "mixed", "socks", "http", "shadowsocks", "vmess", "trojan", "naive", "hysteria", "shadowtls", "tuic", "hysteria2", "vless": + inboundJson, err = s.addUsers(db, inboundJson, inbound.Id, inbound.Type) + if err != nil { + return nil, err + } + } + inboundsJson = append(inboundsJson, inboundJson) + } + return inboundsJson, nil +} + +func (s *InboundService) addUsers(db *gorm.DB, inboundJson []byte, inboundId uint, inboundType string) ([]byte, error) { + var inbound map[string]interface{} + err := json.Unmarshal(inboundJson, &inbound) + if err != nil { + return nil, err + } + var users []string + err = db.Raw(`SELECT json_extract(clients.config, ?) + FROM clients, json_each(clients.inbounds) as je + WHERE clients.enable = true AND je.value = ?;`, + "$."+inboundType, inboundId).Scan(&users).Error + if err != nil { + return nil, err + } + var usersJson []json.RawMessage + for _, user := range users { + usersJson = append(usersJson, json.RawMessage(user)) + } + + inbound["users"] = usersJson + return json.Marshal(inbound) +} diff --git a/backend/service/outbounds.go b/backend/service/outbounds.go new file mode 100644 index 0000000..00419da --- /dev/null +++ b/backend/service/outbounds.go @@ -0,0 +1,48 @@ +package service + +import ( + "encoding/json" + "s-ui/database" + "s-ui/database/model" + + "gorm.io/gorm" +) + +type OutboundService struct{} + +func (o *OutboundService) GetAll() ([]*model.Outbound, error) { + db := database.GetDB() + outbounds := []*model.Outbound{} + err := db.Model(model.Outbound{}).Scan(&outbounds).Error + if err != nil { + return nil, err + } + return outbounds, nil +} + +func (o *OutboundService) Get(id uint) (*model.Outbound, error) { + db := database.GetDB() + outbound := &model.Outbound{} + err := db.First(outbound, id).Error + if err != nil { + return nil, err + } + return outbound, nil +} + +func (o *OutboundService) GetAllConfig(db *gorm.DB) ([]json.RawMessage, error) { + var outboundsJson []json.RawMessage + var outbounds []*model.Outbound + err := db.Model(model.Outbound{}).Scan(&outbounds).Error + if err != nil { + return nil, err + } + for _, outbound := range outbounds { + outboundJson, err := outbound.MarshalJSON() + if err != nil { + return nil, err + } + outboundsJson = append(outboundsJson, outboundJson) + } + return outboundsJson, nil +}