fix old link removal on inbound tag change #633

This commit is contained in:
Alireza Ahmadi
2025-06-13 00:57:32 +02:00
parent 4dabe656c9
commit 92c742987e
6 changed files with 119 additions and 123 deletions
+9 -1
View File
@@ -1,12 +1,14 @@
package cronjob package cronjob
import ( import (
"s-ui/database"
"s-ui/logger" "s-ui/logger"
"s-ui/service" "s-ui/service"
) )
type DepleteJob struct { type DepleteJob struct {
service.ClientService service.ClientService
service.InboundService
} }
func NewDepleteJob() *DepleteJob { func NewDepleteJob() *DepleteJob {
@@ -14,9 +16,15 @@ func NewDepleteJob() *DepleteJob {
} }
func (s *DepleteJob) Run() { func (s *DepleteJob) Run() {
err := s.ClientService.DepleteClients() inboundIds, err := s.ClientService.DepleteClients()
if err != nil { if err != nil {
logger.Warning("Disable depleted users failed: ", err) logger.Warning("Disable depleted users failed: ", err)
return return
} }
if len(inboundIds) > 0 {
err := s.InboundService.RestartInbounds(database.GetDB(), inboundIds)
if err != nil {
logger.Error("unable to restart inbounds: ", err)
}
}
} }
+10 -22
View File
@@ -13,9 +13,7 @@ import (
"gorm.io/gorm" "gorm.io/gorm"
) )
type ClientService struct { type ClientService struct{}
InboundService
}
func (s *ClientService) Get(id string) (*[]model.Client, error) { func (s *ClientService) Get(id string) (*[]model.Client, error) {
if id == "" { if id == "" {
@@ -248,13 +246,9 @@ func (s *ClientService) UpdateClientsOnInboundDelete(tx *gorm.DB, id uint, tag s
return nil return nil
} }
func (s *ClientService) UpdateLinksByInboundChange(tx *gorm.DB, inbounIds []uint, hostname string) error { func (s *ClientService) UpdateLinksByInboundChange(tx *gorm.DB, inbounds *[]model.Inbound, hostname string, oldTag string) error {
var inbounds []model.Inbound var err error
err := tx.Model(model.Inbound{}).Preload("Tls").Where("id in ? and type in ?", inbounIds, util.InboundTypeWithLink).Find(&inbounds).Error for _, inbound := range *inbounds {
if err != nil && database.IsNotFound(err) {
return err
}
for _, inbound := range inbounds {
var clients []model.Client var clients []model.Client
err = tx.Table("clients"). err = tx.Table("clients").
Where("EXISTS (SELECT 1 FROM json_each(clients.inbounds) WHERE json_each.value = ?)", inbound.Id). Where("EXISTS (SELECT 1 FROM json_each(clients.inbounds) WHERE json_each.value = ?)", inbound.Id).
@@ -274,7 +268,7 @@ func (s *ClientService) UpdateLinksByInboundChange(tx *gorm.DB, inbounIds []uint
}) })
} }
for _, clientLink := range clientLinks { for _, clientLink := range clientLinks {
if clientLink["remark"] != inbound.Tag { if clientLink["remark"] != inbound.Tag && clientLink["remark"] != oldTag {
newClientLinks = append(newClientLinks, clientLink) newClientLinks = append(newClientLinks, clientLink)
} }
} }
@@ -292,7 +286,7 @@ func (s *ClientService) UpdateLinksByInboundChange(tx *gorm.DB, inbounIds []uint
return nil return nil
} }
func (s *ClientService) DepleteClients() error { func (s *ClientService) DepleteClients() ([]uint, error) {
var err error var err error
var clients []model.Client var clients []model.Client
var changes []model.Changes var changes []model.Changes
@@ -306,12 +300,6 @@ func (s *ClientService) DepleteClients() error {
defer func() { defer func() {
if err == nil { if err == nil {
tx.Commit() tx.Commit()
if len(inboundIds) > 0 && corePtr.IsRunning() {
err1 := s.InboundService.RestartInbounds(db, inboundIds)
if err1 != nil {
logger.Error("unable to restart inbounds: ", err1)
}
}
} else { } else {
tx.Rollback() tx.Rollback()
} }
@@ -319,7 +307,7 @@ func (s *ClientService) DepleteClients() error {
err = tx.Model(model.Client{}).Where("enable = true AND ((volume >0 AND up+down > volume) OR (expiry > 0 AND expiry < ?))", now).Scan(&clients).Error err = tx.Model(model.Client{}).Where("enable = true AND ((volume >0 AND up+down > volume) OR (expiry > 0 AND expiry < ?))", now).Scan(&clients).Error
if err != nil { if err != nil {
return err return nil, err
} }
dt := time.Now().Unix() dt := time.Now().Unix()
@@ -342,16 +330,16 @@ func (s *ClientService) DepleteClients() error {
if len(changes) > 0 { if len(changes) > 0 {
err = tx.Model(model.Client{}).Where("enable = true AND ((volume >0 AND up+down > volume) OR (expiry > 0 AND expiry < ?))", now).Update("enable", false).Error err = tx.Model(model.Client{}).Where("enable = true AND ((volume >0 AND up+down > volume) OR (expiry > 0 AND expiry < ?))", now).Update("enable", false).Error
if err != nil { if err != nil {
return err return nil, err
} }
err = tx.Model(model.Changes{}).Create(&changes).Error err = tx.Model(model.Changes{}).Create(&changes).Error
if err != nil { if err != nil {
return err return nil, err
} }
LastUpdate = dt LastUpdate = dt
} }
return nil return inboundIds, nil
} }
// avoid duplicate inboundIds // avoid duplicate inboundIds
+10 -61
View File
@@ -124,9 +124,6 @@ func (s *ConfigService) StopCore() error {
func (s *ConfigService) Save(obj string, act string, data json.RawMessage, initUsers string, loginUser string, hostname string) ([]string, error) { func (s *ConfigService) Save(obj string, act string, data json.RawMessage, initUsers string, loginUser string, hostname string) ([]string, error) {
var err error var err error
var inboundIds []uint
var serviceIds []uint
var inboundId uint
var objs []string = []string{obj} var objs []string = []string{obj}
db := database.GetDB() db := database.GetDB()
@@ -134,18 +131,6 @@ func (s *ConfigService) Save(obj string, act string, data json.RawMessage, initU
defer func() { defer func() {
if err == nil { if err == nil {
tx.Commit() tx.Commit()
if len(inboundIds) > 0 && corePtr.IsRunning() {
err1 := s.InboundService.RestartInbounds(db, inboundIds)
if err1 != nil {
logger.Error("unable to restart inbounds: ", err1)
}
}
if len(serviceIds) > 0 && corePtr.IsRunning() {
err1 := s.ServicesService.RestartServices(db, serviceIds)
if err1 != nil {
logger.Error("unable to restart services: ", err1)
}
}
// Try to start core if it is not running // Try to start core if it is not running
if !corePtr.IsRunning() { if !corePtr.IsRunning() {
s.StartCore("") s.StartCore("")
@@ -157,12 +142,17 @@ func (s *ConfigService) Save(obj string, act string, data json.RawMessage, initU
switch obj { switch obj {
case "clients": case "clients":
inboundIds, err = s.ClientService.Save(tx, act, data, hostname) inboundIds, err := s.ClientService.Save(tx, act, data, hostname)
objs = append(objs, "inbounds") if err == nil && len(inboundIds) > 0 {
objs = append(objs, "inbounds")
err = s.InboundService.RestartInbounds(tx, inboundIds)
}
case "tls": case "tls":
serviceIds, inboundIds, err = s.TlsService.Save(tx, act, data) err = s.TlsService.Save(tx, act, data, hostname)
objs = append(objs, "clients", "inbounds")
case "inbounds": case "inbounds":
inboundId, err = s.InboundService.Save(tx, act, data, initUsers, hostname) err = s.InboundService.Save(tx, act, data, initUsers, hostname)
objs = append(objs, "clients")
case "outbounds": case "outbounds":
err = s.OutboundService.Save(tx, act, data) err = s.OutboundService.Save(tx, act, data)
case "services": case "services":
@@ -195,49 +185,8 @@ func (s *ConfigService) Save(obj string, act string, data json.RawMessage, initU
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Commit changes so far
tx.Commit()
LastUpdate = time.Now().Unix() LastUpdate = time.Now().Unix()
tx = db.Begin()
// Update side changes
// Update client links
if obj == "tls" && len(inboundIds) > 0 {
err = s.ClientService.UpdateLinksByInboundChange(tx, inboundIds, hostname)
if err != nil {
return nil, err
}
objs = append(objs, "clients")
}
if obj == "inbounds" {
switch act {
case "new":
err = s.ClientService.UpdateClientsOnInboundAdd(tx, initUsers, inboundId, hostname)
case "edit":
err = s.ClientService.UpdateLinksByInboundChange(tx, []uint{inboundId}, hostname)
case "del":
var tag string
err = json.Unmarshal(data, &tag)
if err != nil {
return nil, err
}
err = s.ClientService.UpdateClientsOnInboundDelete(tx, inboundId, tag)
}
if err != nil {
return nil, err
}
objs = append(objs, "clients")
}
// Update out_json of inbounds when tls is changed
if obj == "tls" && len(inboundIds) > 0 {
err = s.InboundService.UpdateOutJsons(tx, inboundIds, hostname)
if err != nil {
return nil, common.NewError("unable to update out_json of inbounds: ", err.Error())
}
objs = append(objs, "inbounds")
}
return objs, nil return objs, nil
} }
+42 -23
View File
@@ -13,7 +13,9 @@ import (
"gorm.io/gorm" "gorm.io/gorm"
) )
type InboundService struct{} type InboundService struct {
ClientService
}
func (s *InboundService) Get(ids string) (*[]map[string]interface{}, error) { func (s *InboundService) Get(ids string) (*[]map[string]interface{}, error) {
if ids == "" { if ids == "" {
@@ -97,40 +99,41 @@ func (s *InboundService) FromIds(ids []uint) ([]*model.Inbound, error) {
return inbounds, nil return inbounds, nil
} }
func (s *InboundService) Save(tx *gorm.DB, act string, data json.RawMessage, initUserIds string, hostname string) (uint, error) { func (s *InboundService) Save(tx *gorm.DB, act string, data json.RawMessage, initUserIds string, hostname string) error {
var err error var err error
var id uint
switch act { switch act {
case "new", "edit": case "new", "edit":
var inbound model.Inbound var inbound model.Inbound
err = inbound.UnmarshalJSON(data) err = inbound.UnmarshalJSON(data)
if err != nil { if err != nil {
return 0, err return err
} }
if inbound.TlsId > 0 { if inbound.TlsId > 0 {
err = tx.Model(model.Tls{}).Where("id = ?", inbound.TlsId).Find(&inbound.Tls).Error err = tx.Model(model.Tls{}).Where("id = ?", inbound.TlsId).Find(&inbound.Tls).Error
if err != nil { if err != nil {
return 0, err return err
}
}
var oldTag string
if act == "edit" {
err = tx.Model(model.Inbound{}).Select("tag").Where("id = ?", inbound.Id).Find(&oldTag).Error
if err != nil {
return err
} }
} }
if corePtr.IsRunning() { if corePtr.IsRunning() {
if act == "edit" { if act == "edit" {
var oldTag string
err = tx.Model(model.Inbound{}).Select("tag").Where("id = ?", inbound.Id).Find(&oldTag).Error
if err != nil {
return 0, err
}
err = corePtr.RemoveInbound(oldTag) err = corePtr.RemoveInbound(oldTag)
if err != nil && err != os.ErrInvalid { if err != nil && err != os.ErrInvalid {
return 0, err return err
} }
} }
inboundConfig, err := inbound.MarshalJSON() inboundConfig, err := inbound.MarshalJSON()
if err != nil { if err != nil {
return 0, err return err
} }
if act == "edit" { if act == "edit" {
@@ -139,49 +142,62 @@ func (s *InboundService) Save(tx *gorm.DB, act string, data json.RawMessage, ini
inboundConfig, err = s.initUsers(tx, inboundConfig, initUserIds, inbound.Type) inboundConfig, err = s.initUsers(tx, inboundConfig, initUserIds, inbound.Type)
} }
if err != nil { if err != nil {
return 0, err return err
} }
err = corePtr.AddInbound(inboundConfig) err = corePtr.AddInbound(inboundConfig)
if err != nil { if err != nil {
return 0, err return err
} }
} }
err = util.FillOutJson(&inbound, hostname) err = util.FillOutJson(&inbound, hostname)
if err != nil { if err != nil {
return 0, err return err
} }
err = tx.Save(&inbound).Error err = tx.Save(&inbound).Error
if err != nil { if err != nil {
return 0, err return err
}
switch act {
case "new":
err = s.ClientService.UpdateClientsOnInboundAdd(tx, initUserIds, inbound.Id, hostname)
case "edit":
err = s.ClientService.UpdateLinksByInboundChange(tx, &[]model.Inbound{inbound}, hostname, oldTag)
}
if err != nil {
return err
} }
id = inbound.Id
case "del": case "del":
var tag string var tag string
err = json.Unmarshal(data, &tag) err = json.Unmarshal(data, &tag)
if err != nil { if err != nil {
return 0, err return err
} }
if corePtr.IsRunning() { if corePtr.IsRunning() {
err = corePtr.RemoveInbound(tag) err = corePtr.RemoveInbound(tag)
if err != nil && err != os.ErrInvalid { if err != nil && err != os.ErrInvalid {
return 0, err return err
} }
} }
var id uint
err = tx.Model(model.Inbound{}).Select("id").Where("tag = ?", tag).Scan(&id).Error err = tx.Model(model.Inbound{}).Select("id").Where("tag = ?", tag).Scan(&id).Error
if err != nil { if err != nil {
return 0, err return err
}
err = s.ClientService.UpdateClientsOnInboundDelete(tx, id, tag)
if err != nil {
return err
} }
err = tx.Where("tag = ?", tag).Delete(model.Inbound{}).Error err = tx.Where("tag = ?", tag).Delete(model.Inbound{}).Error
if err != nil { if err != nil {
return 0, err return err
} }
default: default:
return 0, common.NewErrorf("unknown action: %s", act) return common.NewErrorf("unknown action: %s", act)
} }
return id, nil return nil
} }
func (s *InboundService) UpdateOutJsons(tx *gorm.DB, inboundIds []uint, hostname string) error { func (s *InboundService) UpdateOutJsons(tx *gorm.DB, inboundIds []uint, hostname string) error {
@@ -312,6 +328,9 @@ func (s *InboundService) initUsers(db *gorm.DB, inboundJson []byte, clientIds st
} }
func (s *InboundService) RestartInbounds(tx *gorm.DB, ids []uint) error { func (s *InboundService) RestartInbounds(tx *gorm.DB, ids []uint) error {
if !corePtr.IsRunning() {
return nil
}
var inbounds []*model.Inbound var inbounds []*model.Inbound
err := tx.Model(model.Inbound{}).Preload("Tls").Where("id in ?", ids).Find(&inbounds).Error err := tx.Model(model.Inbound{}).Preload("Tls").Where("id in ?", ids).Find(&inbounds).Error
if err != nil { if err != nil {
+3
View File
@@ -126,6 +126,9 @@ func (s *ServicesService) Save(tx *gorm.DB, act string, data json.RawMessage) er
} }
func (s *ServicesService) RestartServices(tx *gorm.DB, ids []uint) error { func (s *ServicesService) RestartServices(tx *gorm.DB, ids []uint) error {
if !corePtr.IsRunning() {
return nil
}
var services []*model.Service var services []*model.Service
err := tx.Model(model.Service{}).Preload("Tls").Where("id in ?", ids).Find(&services).Error err := tx.Model(model.Service{}).Preload("Tls").Where("id in ?", ids).Find(&services).Error
if err != nil { if err != nil {
+45 -16
View File
@@ -11,6 +11,7 @@ import (
type TlsService struct { type TlsService struct {
InboundService InboundService
ServicesService
} }
func (s *TlsService) GetAll() ([]model.Tls, error) { func (s *TlsService) GetAll() ([]model.Tls, error) {
@@ -24,52 +25,80 @@ func (s *TlsService) GetAll() ([]model.Tls, error) {
return tlsConfig, nil return tlsConfig, nil
} }
func (s *TlsService) Save(tx *gorm.DB, action string, data json.RawMessage) ([]uint, []uint, error) { func (s *TlsService) Save(tx *gorm.DB, action string, data json.RawMessage, hostname string) error {
var err error var err error
var inboundIds []uint
var serviceIds []uint
switch action { switch action {
case "new", "edit": case "new", "edit":
var tls model.Tls var tls model.Tls
err = json.Unmarshal(data, &tls) err = json.Unmarshal(data, &tls)
if err != nil { if err != nil {
return nil, nil, err return err
} }
err = tx.Save(&tls).Error err = tx.Save(&tls).Error
if err != nil { if err != nil {
return nil, nil, err return err
} }
err = tx.Model(model.Inbound{}).Select("id").Where("tls_id = ?", tls.Id).Scan(&inboundIds).Error if action == "edit" {
if err != nil { var inbounds []model.Inbound
return nil, nil, err err = tx.Model(model.Inbound{}).Preload("Tls").Where("tls_id = ?", tls.Id).Find(&inbounds).Error
if err != nil {
return err
}
if len(inbounds) > 0 {
err = s.ClientService.UpdateLinksByInboundChange(tx, &inbounds, hostname, "")
if err != nil {
return err
}
var inboundIds []uint
for _, inbound := range inbounds {
inboundIds = append(inboundIds, inbound.Id)
}
err = s.InboundService.UpdateOutJsons(tx, inboundIds, hostname)
if err != nil {
return common.NewError("unable to update out_json of inbounds: ", err.Error())
}
err = s.InboundService.RestartInbounds(tx, inboundIds)
if err != nil {
return err
}
}
var serviceIds []uint
err = tx.Model(model.Service{}).Where("tls_id = ?", tls.Id).Scan(&serviceIds).Error
if err != nil {
return err
}
if len(serviceIds) > 0 {
err = s.ServicesService.RestartServices(tx, serviceIds)
if err != nil {
return err
}
}
} }
err = tx.Model(model.Service{}).Where("tls_id = ?", tls.Id).Scan(&serviceIds).Error
return serviceIds, inboundIds, nil
case "del": case "del":
var id uint var id uint
err = json.Unmarshal(data, &id) err = json.Unmarshal(data, &id)
if err != nil { if err != nil {
return nil, nil, err return err
} }
var inboundCount int64 var inboundCount int64
err = tx.Model(model.Inbound{}).Where("tls_id = ?", id).Count(&inboundCount).Error err = tx.Model(model.Inbound{}).Where("tls_id = ?", id).Count(&inboundCount).Error
if err != nil { if err != nil {
return nil, nil, err return err
} }
var serviceCount int64 var serviceCount int64
err = tx.Model(model.Service{}).Where("tls_id = ?", id).Count(&serviceCount).Error err = tx.Model(model.Service{}).Where("tls_id = ?", id).Count(&serviceCount).Error
if err != nil { if err != nil {
return nil, nil, err return err
} }
if inboundCount > 0 || serviceCount > 0 { if inboundCount > 0 || serviceCount > 0 {
return nil, nil, common.NewError("tls in use") return common.NewError("tls in use")
} }
err = tx.Where("id = ?", id).Delete(model.Tls{}).Error err = tx.Where("id = ?", id).Delete(model.Tls{}).Error
if err != nil { if err != nil {
return nil, nil, err return err
} }
} }
return nil, nil, nil return nil
} }