From 92c742987e8125f233d6352547de72537ba83940 Mon Sep 17 00:00:00 2001 From: Alireza Ahmadi Date: Fri, 13 Jun 2025 00:57:32 +0200 Subject: [PATCH] fix old link removal on inbound tag change #633 --- cronjob/depleteJob.go | 10 +++++- service/client.go | 32 ++++++------------- service/config.go | 71 ++++++------------------------------------- service/inbounds.go | 65 +++++++++++++++++++++++++-------------- service/services.go | 3 ++ service/tls.go | 61 +++++++++++++++++++++++++++---------- 6 files changed, 119 insertions(+), 123 deletions(-) diff --git a/cronjob/depleteJob.go b/cronjob/depleteJob.go index d7adcffa..2a83b7d 100644 --- a/cronjob/depleteJob.go +++ b/cronjob/depleteJob.go @@ -1,12 +1,14 @@ package cronjob import ( + "s-ui/database" "s-ui/logger" "s-ui/service" ) type DepleteJob struct { service.ClientService + service.InboundService } func NewDepleteJob() *DepleteJob { @@ -14,9 +16,15 @@ func NewDepleteJob() *DepleteJob { } func (s *DepleteJob) Run() { - err := s.ClientService.DepleteClients() + inboundIds, err := s.ClientService.DepleteClients() if err != nil { logger.Warning("Disable depleted users failed: ", err) return } + if len(inboundIds) > 0 { + err := s.InboundService.RestartInbounds(database.GetDB(), inboundIds) + if err != nil { + logger.Error("unable to restart inbounds: ", err) + } + } } diff --git a/service/client.go b/service/client.go index 82cdbe1..1480661 100644 --- a/service/client.go +++ b/service/client.go @@ -13,9 +13,7 @@ import ( "gorm.io/gorm" ) -type ClientService struct { - InboundService -} +type ClientService struct{} func (s *ClientService) Get(id string) (*[]model.Client, error) { if id == "" { @@ -248,13 +246,9 @@ func (s *ClientService) UpdateClientsOnInboundDelete(tx *gorm.DB, id uint, tag s return nil } -func (s *ClientService) UpdateLinksByInboundChange(tx *gorm.DB, inbounIds []uint, hostname string) error { - var inbounds []model.Inbound - err := tx.Model(model.Inbound{}).Preload("Tls").Where("id in ? and type in ?", inbounIds, util.InboundTypeWithLink).Find(&inbounds).Error - if err != nil && database.IsNotFound(err) { - return err - } - for _, inbound := range inbounds { +func (s *ClientService) UpdateLinksByInboundChange(tx *gorm.DB, inbounds *[]model.Inbound, hostname string, oldTag string) error { + var err error + for _, inbound := range *inbounds { var clients []model.Client err = tx.Table("clients"). Where("EXISTS (SELECT 1 FROM json_each(clients.inbounds) WHERE json_each.value = ?)", inbound.Id). @@ -274,7 +268,7 @@ func (s *ClientService) UpdateLinksByInboundChange(tx *gorm.DB, inbounIds []uint }) } for _, clientLink := range clientLinks { - if clientLink["remark"] != inbound.Tag { + if clientLink["remark"] != inbound.Tag && clientLink["remark"] != oldTag { newClientLinks = append(newClientLinks, clientLink) } } @@ -292,7 +286,7 @@ func (s *ClientService) UpdateLinksByInboundChange(tx *gorm.DB, inbounIds []uint return nil } -func (s *ClientService) DepleteClients() error { +func (s *ClientService) DepleteClients() ([]uint, error) { var err error var clients []model.Client var changes []model.Changes @@ -306,12 +300,6 @@ func (s *ClientService) DepleteClients() error { defer func() { if err == nil { tx.Commit() - if len(inboundIds) > 0 && corePtr.IsRunning() { - err1 := s.InboundService.RestartInbounds(db, inboundIds) - if err1 != nil { - logger.Error("unable to restart inbounds: ", err1) - } - } } else { tx.Rollback() } @@ -319,7 +307,7 @@ func (s *ClientService) DepleteClients() error { err = tx.Model(model.Client{}).Where("enable = true AND ((volume >0 AND up+down > volume) OR (expiry > 0 AND expiry < ?))", now).Scan(&clients).Error if err != nil { - return err + return nil, err } dt := time.Now().Unix() @@ -342,16 +330,16 @@ func (s *ClientService) DepleteClients() error { if len(changes) > 0 { err = tx.Model(model.Client{}).Where("enable = true AND ((volume >0 AND up+down > volume) OR (expiry > 0 AND expiry < ?))", now).Update("enable", false).Error if err != nil { - return err + return nil, err } err = tx.Model(model.Changes{}).Create(&changes).Error if err != nil { - return err + return nil, err } LastUpdate = dt } - return nil + return inboundIds, nil } // avoid duplicate inboundIds diff --git a/service/config.go b/service/config.go index 47f9c40..80bc5f8 100644 --- a/service/config.go +++ b/service/config.go @@ -124,9 +124,6 @@ func (s *ConfigService) StopCore() error { func (s *ConfigService) Save(obj string, act string, data json.RawMessage, initUsers string, loginUser string, hostname string) ([]string, error) { var err error - var inboundIds []uint - var serviceIds []uint - var inboundId uint var objs []string = []string{obj} db := database.GetDB() @@ -134,18 +131,6 @@ func (s *ConfigService) Save(obj string, act string, data json.RawMessage, initU defer func() { if err == nil { tx.Commit() - if len(inboundIds) > 0 && corePtr.IsRunning() { - err1 := s.InboundService.RestartInbounds(db, inboundIds) - if err1 != nil { - logger.Error("unable to restart inbounds: ", err1) - } - } - if len(serviceIds) > 0 && corePtr.IsRunning() { - err1 := s.ServicesService.RestartServices(db, serviceIds) - if err1 != nil { - logger.Error("unable to restart services: ", err1) - } - } // Try to start core if it is not running if !corePtr.IsRunning() { s.StartCore("") @@ -157,12 +142,17 @@ func (s *ConfigService) Save(obj string, act string, data json.RawMessage, initU switch obj { case "clients": - inboundIds, err = s.ClientService.Save(tx, act, data, hostname) - objs = append(objs, "inbounds") + inboundIds, err := s.ClientService.Save(tx, act, data, hostname) + if err == nil && len(inboundIds) > 0 { + objs = append(objs, "inbounds") + err = s.InboundService.RestartInbounds(tx, inboundIds) + } case "tls": - serviceIds, inboundIds, err = s.TlsService.Save(tx, act, data) + err = s.TlsService.Save(tx, act, data, hostname) + objs = append(objs, "clients", "inbounds") case "inbounds": - inboundId, err = s.InboundService.Save(tx, act, data, initUsers, hostname) + err = s.InboundService.Save(tx, act, data, initUsers, hostname) + objs = append(objs, "clients") case "outbounds": err = s.OutboundService.Save(tx, act, data) case "services": @@ -195,49 +185,8 @@ func (s *ConfigService) Save(obj string, act string, data json.RawMessage, initU if err != nil { return nil, err } - // Commit changes so far - tx.Commit() + LastUpdate = time.Now().Unix() - tx = db.Begin() - - // Update side changes - - // Update client links - if obj == "tls" && len(inboundIds) > 0 { - err = s.ClientService.UpdateLinksByInboundChange(tx, inboundIds, hostname) - if err != nil { - return nil, err - } - objs = append(objs, "clients") - } - if obj == "inbounds" { - switch act { - case "new": - err = s.ClientService.UpdateClientsOnInboundAdd(tx, initUsers, inboundId, hostname) - case "edit": - err = s.ClientService.UpdateLinksByInboundChange(tx, []uint{inboundId}, hostname) - case "del": - var tag string - err = json.Unmarshal(data, &tag) - if err != nil { - return nil, err - } - err = s.ClientService.UpdateClientsOnInboundDelete(tx, inboundId, tag) - } - if err != nil { - return nil, err - } - objs = append(objs, "clients") - } - - // Update out_json of inbounds when tls is changed - if obj == "tls" && len(inboundIds) > 0 { - err = s.InboundService.UpdateOutJsons(tx, inboundIds, hostname) - if err != nil { - return nil, common.NewError("unable to update out_json of inbounds: ", err.Error()) - } - objs = append(objs, "inbounds") - } return objs, nil } diff --git a/service/inbounds.go b/service/inbounds.go index 0d96a8f..a0963c6 100644 --- a/service/inbounds.go +++ b/service/inbounds.go @@ -13,7 +13,9 @@ import ( "gorm.io/gorm" ) -type InboundService struct{} +type InboundService struct { + ClientService +} func (s *InboundService) Get(ids string) (*[]map[string]interface{}, error) { if ids == "" { @@ -97,40 +99,41 @@ func (s *InboundService) FromIds(ids []uint) ([]*model.Inbound, error) { return inbounds, nil } -func (s *InboundService) Save(tx *gorm.DB, act string, data json.RawMessage, initUserIds string, hostname string) (uint, error) { +func (s *InboundService) Save(tx *gorm.DB, act string, data json.RawMessage, initUserIds string, hostname string) error { var err error - var id uint switch act { case "new", "edit": var inbound model.Inbound err = inbound.UnmarshalJSON(data) if err != nil { - return 0, err + return err } if inbound.TlsId > 0 { err = tx.Model(model.Tls{}).Where("id = ?", inbound.TlsId).Find(&inbound.Tls).Error if err != nil { - return 0, err + return err + } + } + var oldTag string + if act == "edit" { + err = tx.Model(model.Inbound{}).Select("tag").Where("id = ?", inbound.Id).Find(&oldTag).Error + if err != nil { + return err } } if corePtr.IsRunning() { if act == "edit" { - var oldTag string - err = tx.Model(model.Inbound{}).Select("tag").Where("id = ?", inbound.Id).Find(&oldTag).Error - if err != nil { - return 0, err - } err = corePtr.RemoveInbound(oldTag) if err != nil && err != os.ErrInvalid { - return 0, err + return err } } inboundConfig, err := inbound.MarshalJSON() if err != nil { - return 0, err + return err } if act == "edit" { @@ -139,49 +142,62 @@ func (s *InboundService) Save(tx *gorm.DB, act string, data json.RawMessage, ini inboundConfig, err = s.initUsers(tx, inboundConfig, initUserIds, inbound.Type) } if err != nil { - return 0, err + return err } err = corePtr.AddInbound(inboundConfig) if err != nil { - return 0, err + return err } } err = util.FillOutJson(&inbound, hostname) if err != nil { - return 0, err + return err } err = tx.Save(&inbound).Error if err != nil { - return 0, err + return err + } + switch act { + case "new": + err = s.ClientService.UpdateClientsOnInboundAdd(tx, initUserIds, inbound.Id, hostname) + case "edit": + err = s.ClientService.UpdateLinksByInboundChange(tx, &[]model.Inbound{inbound}, hostname, oldTag) + } + if err != nil { + return err } - id = inbound.Id case "del": var tag string err = json.Unmarshal(data, &tag) if err != nil { - return 0, err + return err } if corePtr.IsRunning() { err = corePtr.RemoveInbound(tag) if err != nil && err != os.ErrInvalid { - return 0, err + return err } } + var id uint err = tx.Model(model.Inbound{}).Select("id").Where("tag = ?", tag).Scan(&id).Error if err != nil { - return 0, err + return err + } + err = s.ClientService.UpdateClientsOnInboundDelete(tx, id, tag) + if err != nil { + return err } err = tx.Where("tag = ?", tag).Delete(model.Inbound{}).Error if err != nil { - return 0, err + return err } default: - return 0, common.NewErrorf("unknown action: %s", act) + return common.NewErrorf("unknown action: %s", act) } - return id, nil + return nil } func (s *InboundService) UpdateOutJsons(tx *gorm.DB, inboundIds []uint, hostname string) error { @@ -312,6 +328,9 @@ func (s *InboundService) initUsers(db *gorm.DB, inboundJson []byte, clientIds st } func (s *InboundService) RestartInbounds(tx *gorm.DB, ids []uint) error { + if !corePtr.IsRunning() { + return nil + } var inbounds []*model.Inbound err := tx.Model(model.Inbound{}).Preload("Tls").Where("id in ?", ids).Find(&inbounds).Error if err != nil { diff --git a/service/services.go b/service/services.go index 4ac2aaf..30a4621 100644 --- a/service/services.go +++ b/service/services.go @@ -126,6 +126,9 @@ func (s *ServicesService) Save(tx *gorm.DB, act string, data json.RawMessage) er } func (s *ServicesService) RestartServices(tx *gorm.DB, ids []uint) error { + if !corePtr.IsRunning() { + return nil + } var services []*model.Service err := tx.Model(model.Service{}).Preload("Tls").Where("id in ?", ids).Find(&services).Error if err != nil { diff --git a/service/tls.go b/service/tls.go index bb24202..684d0bf 100644 --- a/service/tls.go +++ b/service/tls.go @@ -11,6 +11,7 @@ import ( type TlsService struct { InboundService + ServicesService } func (s *TlsService) GetAll() ([]model.Tls, error) { @@ -24,52 +25,80 @@ func (s *TlsService) GetAll() ([]model.Tls, error) { return tlsConfig, nil } -func (s *TlsService) Save(tx *gorm.DB, action string, data json.RawMessage) ([]uint, []uint, error) { +func (s *TlsService) Save(tx *gorm.DB, action string, data json.RawMessage, hostname string) error { var err error - var inboundIds []uint - var serviceIds []uint switch action { case "new", "edit": var tls model.Tls err = json.Unmarshal(data, &tls) if err != nil { - return nil, nil, err + return err } err = tx.Save(&tls).Error if err != nil { - return nil, nil, err + return err } - err = tx.Model(model.Inbound{}).Select("id").Where("tls_id = ?", tls.Id).Scan(&inboundIds).Error - if err != nil { - return nil, nil, err + if action == "edit" { + var inbounds []model.Inbound + err = tx.Model(model.Inbound{}).Preload("Tls").Where("tls_id = ?", tls.Id).Find(&inbounds).Error + if err != nil { + return err + } + if len(inbounds) > 0 { + err = s.ClientService.UpdateLinksByInboundChange(tx, &inbounds, hostname, "") + if err != nil { + return err + } + var inboundIds []uint + for _, inbound := range inbounds { + inboundIds = append(inboundIds, inbound.Id) + } + err = s.InboundService.UpdateOutJsons(tx, inboundIds, hostname) + if err != nil { + return common.NewError("unable to update out_json of inbounds: ", err.Error()) + } + err = s.InboundService.RestartInbounds(tx, inboundIds) + if err != nil { + return err + } + } + var serviceIds []uint + err = tx.Model(model.Service{}).Where("tls_id = ?", tls.Id).Scan(&serviceIds).Error + if err != nil { + return err + } + if len(serviceIds) > 0 { + err = s.ServicesService.RestartServices(tx, serviceIds) + if err != nil { + return err + } + } } - err = tx.Model(model.Service{}).Where("tls_id = ?", tls.Id).Scan(&serviceIds).Error - return serviceIds, inboundIds, nil case "del": var id uint err = json.Unmarshal(data, &id) if err != nil { - return nil, nil, err + return err } var inboundCount int64 err = tx.Model(model.Inbound{}).Where("tls_id = ?", id).Count(&inboundCount).Error if err != nil { - return nil, nil, err + return err } var serviceCount int64 err = tx.Model(model.Service{}).Where("tls_id = ?", id).Count(&serviceCount).Error if err != nil { - return nil, nil, err + return err } if inboundCount > 0 || serviceCount > 0 { - return nil, nil, common.NewError("tls in use") + return common.NewError("tls in use") } err = tx.Where("id = ?", id).Delete(model.Tls{}).Error if err != nil { - return nil, nil, err + return err } } - return nil, nil, nil + return nil }