improve client's inbound changes

This commit is contained in:
Alireza Ahmadi
2025-07-13 12:29:21 +02:00
parent f239574e41
commit d2827d013b
12 changed files with 848 additions and 21 deletions
+182
View File
@@ -0,0 +1,182 @@
package hysteria
import (
"context"
"net"
"time"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/inbound"
"github.com/sagernet/sing-box/common/listener"
"github.com/sagernet/sing-box/common/tls"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-quic/hysteria"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/auth"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
func RegisterInbound(registry *inbound.Registry) {
inbound.Register[option.HysteriaInboundOptions](registry, C.TypeHysteria, NewInbound)
}
type Inbound struct {
inbound.Adapter
router adapter.Router
logger log.ContextLogger
listener *listener.Listener
tlsConfig tls.ServerConfig
service *hysteria.Service[int]
userNameList []string
}
func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.HysteriaInboundOptions) (adapter.Inbound, error) {
options.UDPFragmentDefault = true
if options.TLS == nil || !options.TLS.Enabled {
return nil, C.ErrTLSRequired
}
tlsConfig, err := tls.NewServer(ctx, logger, common.PtrValueOrDefault(options.TLS))
if err != nil {
return nil, err
}
inbound := &Inbound{
Adapter: inbound.NewAdapter(C.TypeHysteria, tag),
router: router,
logger: logger,
listener: listener.New(listener.Options{
Context: ctx,
Logger: logger,
Listen: options.ListenOptions,
}),
tlsConfig: tlsConfig,
}
var sendBps, receiveBps uint64
if options.Up.Value() > 0 {
sendBps = options.Up.Value()
} else {
sendBps = uint64(options.UpMbps) * hysteria.MbpsToBps
}
if options.Down.Value() > 0 {
receiveBps = options.Down.Value()
} else {
receiveBps = uint64(options.DownMbps) * hysteria.MbpsToBps
}
var udpTimeout time.Duration
if options.UDPTimeout != 0 {
udpTimeout = time.Duration(options.UDPTimeout)
} else {
udpTimeout = C.UDPTimeout
}
service, err := hysteria.NewService[int](hysteria.ServiceOptions{
Context: ctx,
Logger: logger,
SendBPS: sendBps,
ReceiveBPS: receiveBps,
XPlusPassword: options.Obfs,
TLSConfig: tlsConfig,
UDPTimeout: udpTimeout,
Handler: inbound,
// Legacy options
ConnReceiveWindow: options.ReceiveWindowConn,
StreamReceiveWindow: options.ReceiveWindowClient,
MaxIncomingStreams: int64(options.MaxConnClient),
DisableMTUDiscovery: options.DisableMTUDiscovery,
})
if err != nil {
return nil, err
}
userList := make([]int, 0, len(options.Users))
userNameList := make([]string, 0, len(options.Users))
userPasswordList := make([]string, 0, len(options.Users))
for index, user := range options.Users {
userList = append(userList, index)
userNameList = append(userNameList, user.Name)
var password string
if user.AuthString != "" {
password = user.AuthString
} else {
password = string(user.Auth)
}
userPasswordList = append(userPasswordList, password)
}
service.UpdateUsers(userList, userPasswordList)
inbound.service = service
inbound.userNameList = userNameList
return inbound, nil
}
func (h *Inbound) NewConnectionEx(ctx context.Context, conn net.Conn, source M.Socksaddr, destination M.Socksaddr, onClose N.CloseHandlerFunc) {
ctx = log.ContextWithNewID(ctx)
var metadata adapter.InboundContext
metadata.Inbound = h.Tag()
metadata.InboundType = h.Type()
//nolint:staticcheck
metadata.InboundDetour = h.listener.ListenOptions().Detour
//nolint:staticcheck
metadata.InboundOptions = h.listener.ListenOptions().InboundOptions
metadata.OriginDestination = h.listener.UDPAddr()
metadata.Source = source
metadata.Destination = destination
h.logger.InfoContext(ctx, "inbound connection from ", metadata.Source)
userID, _ := auth.UserFromContext[int](ctx)
if userName := h.userNameList[userID]; userName != "" {
metadata.User = userName
h.logger.InfoContext(ctx, "[", userName, "] inbound connection to ", metadata.Destination)
} else {
h.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination)
}
h.router.RouteConnectionEx(ctx, conn, metadata, onClose)
}
func (h *Inbound) NewPacketConnectionEx(ctx context.Context, conn N.PacketConn, source M.Socksaddr, destination M.Socksaddr, onClose N.CloseHandlerFunc) {
ctx = log.ContextWithNewID(ctx)
var metadata adapter.InboundContext
metadata.Inbound = h.Tag()
metadata.InboundType = h.Type()
//nolint:staticcheck
metadata.InboundDetour = h.listener.ListenOptions().Detour
//nolint:staticcheck
metadata.InboundOptions = h.listener.ListenOptions().InboundOptions
metadata.OriginDestination = h.listener.UDPAddr()
metadata.Source = source
metadata.Destination = destination
h.logger.InfoContext(ctx, "inbound packet connection from ", metadata.Source)
userID, _ := auth.UserFromContext[int](ctx)
if userName := h.userNameList[userID]; userName != "" {
metadata.User = userName
h.logger.InfoContext(ctx, "[", userName, "] inbound packet connection to ", metadata.Destination)
} else {
h.logger.InfoContext(ctx, "inbound packet connection to ", metadata.Destination)
}
h.router.RoutePacketConnectionEx(ctx, conn, metadata, onClose)
}
func (h *Inbound) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
if h.tlsConfig != nil {
err := h.tlsConfig.Start()
if err != nil {
return err
}
}
packetConn, err := h.listener.ListenUDP()
if err != nil {
return err
}
return h.service.Start(packetConn)
}
func (h *Inbound) Close() error {
return common.Close(
h.listener,
h.tlsConfig,
common.PtrOrNil(h.service),
)
}
+126
View File
@@ -0,0 +1,126 @@
package hysteria
import (
"context"
"net"
"os"
"time"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/outbound"
"github.com/sagernet/sing-box/common/dialer"
"github.com/sagernet/sing-box/common/tls"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-box/protocol/tuic"
"github.com/sagernet/sing-quic/hysteria"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
func RegisterOutbound(registry *outbound.Registry) {
outbound.Register[option.HysteriaOutboundOptions](registry, C.TypeHysteria, NewOutbound)
}
var (
_ adapter.Outbound = (*tuic.Outbound)(nil)
_ adapter.InterfaceUpdateListener = (*tuic.Outbound)(nil)
)
type Outbound struct {
outbound.Adapter
logger logger.ContextLogger
client *hysteria.Client
}
func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.HysteriaOutboundOptions) (adapter.Outbound, error) {
options.UDPFragmentDefault = true
if options.TLS == nil || !options.TLS.Enabled {
return nil, C.ErrTLSRequired
}
tlsConfig, err := tls.NewClient(ctx, options.Server, common.PtrValueOrDefault(options.TLS))
if err != nil {
return nil, err
}
outboundDialer, err := dialer.New(ctx, options.DialerOptions, options.ServerIsDomain())
if err != nil {
return nil, err
}
networkList := options.Network.Build()
var password string
if options.AuthString != "" {
password = options.AuthString
} else {
password = string(options.Auth)
}
var sendBps, receiveBps uint64
if options.Up.Value() > 0 {
sendBps = options.Up.Value()
} else {
sendBps = uint64(options.UpMbps) * hysteria.MbpsToBps
}
if options.Down.Value() > 0 {
receiveBps = options.Down.Value()
} else {
receiveBps = uint64(options.DownMbps) * hysteria.MbpsToBps
}
client, err := hysteria.NewClient(hysteria.ClientOptions{
Context: ctx,
Dialer: outboundDialer,
Logger: logger,
ServerAddress: options.ServerOptions.Build(),
ServerPorts: options.ServerPorts,
HopInterval: time.Duration(options.HopInterval),
SendBPS: sendBps,
ReceiveBPS: receiveBps,
XPlusPassword: options.Obfs,
Password: password,
TLSConfig: tlsConfig,
UDPDisabled: !common.Contains(networkList, N.NetworkUDP),
ConnReceiveWindow: options.ReceiveWindowConn,
StreamReceiveWindow: options.ReceiveWindow,
DisableMTUDiscovery: options.DisableMTUDiscovery,
})
if err != nil {
return nil, err
}
return &Outbound{
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeHysteria, tag, networkList, options.DialerOptions),
logger: logger,
client: client,
}, nil
}
func (h *Outbound) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
switch N.NetworkName(network) {
case N.NetworkTCP:
h.logger.InfoContext(ctx, "outbound connection to ", destination)
return h.client.DialConn(ctx, destination)
case N.NetworkUDP:
conn, err := h.ListenPacket(ctx, destination)
if err != nil {
return nil, err
}
return bufio.NewBindPacketConn(conn, destination), nil
default:
return nil, E.New("unsupported network: ", network)
}
}
func (h *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
h.logger.InfoContext(ctx, "outbound packet connection to ", destination)
return h.client.ListenPacket(ctx, destination)
}
func (h *Outbound) InterfaceUpdated() {
h.client.CloseWithError(E.New("network changed"))
}
func (h *Outbound) Close() error {
return h.client.CloseWithError(os.ErrClosed)
}
+28
View File
@@ -0,0 +1,28 @@
package hysteria
import (
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/option"
)
func (h *Inbound) UpdateUsers(users []option.HysteriaUser) error {
h.Close()
userList := make([]int, 0, len(users))
userNameList := make([]string, 0, len(users))
userPasswordList := make([]string, 0, len(users))
for index, user := range users {
userList = append(userList, index)
userNameList = append(userNameList, user.Name)
var password string
if user.AuthString != "" {
password = user.AuthString
} else {
password = string(user.Auth)
}
userPasswordList = append(userPasswordList, password)
}
h.service.UpdateUsers(userList, userPasswordList)
h.userNameList = userNameList
h.Start(adapter.StartStateStart)
return nil
}
+215
View File
@@ -0,0 +1,215 @@
package hysteria2
import (
"context"
"net"
"net/http"
"net/http/httputil"
"net/url"
"time"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/inbound"
"github.com/sagernet/sing-box/common/listener"
"github.com/sagernet/sing-box/common/tls"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-quic/hysteria"
"github.com/sagernet/sing-quic/hysteria2"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/auth"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
func RegisterInbound(registry *inbound.Registry) {
inbound.Register[option.Hysteria2InboundOptions](registry, C.TypeHysteria2, NewInbound)
}
type Inbound struct {
inbound.Adapter
router adapter.Router
logger log.ContextLogger
listener *listener.Listener
tlsConfig tls.ServerConfig
service *hysteria2.Service[int]
userNameList []string
}
func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.Hysteria2InboundOptions) (adapter.Inbound, error) {
options.UDPFragmentDefault = true
if options.TLS == nil || !options.TLS.Enabled {
return nil, C.ErrTLSRequired
}
tlsConfig, err := tls.NewServer(ctx, logger, common.PtrValueOrDefault(options.TLS))
if err != nil {
return nil, err
}
var salamanderPassword string
if options.Obfs != nil {
if options.Obfs.Password == "" {
return nil, E.New("missing obfs password")
}
switch options.Obfs.Type {
case hysteria2.ObfsTypeSalamander:
salamanderPassword = options.Obfs.Password
default:
return nil, E.New("unknown obfs type: ", options.Obfs.Type)
}
}
var masqueradeHandler http.Handler
if options.Masquerade != nil && options.Masquerade.Type != "" {
switch options.Masquerade.Type {
case C.Hysterai2MasqueradeTypeFile:
masqueradeHandler = http.FileServer(http.Dir(options.Masquerade.FileOptions.Directory))
case C.Hysterai2MasqueradeTypeProxy:
masqueradeURL, err := url.Parse(options.Masquerade.ProxyOptions.URL)
if err != nil {
return nil, E.Cause(err, "parse masquerade URL")
}
masqueradeHandler = &httputil.ReverseProxy{
Rewrite: func(r *httputil.ProxyRequest) {
r.SetURL(masqueradeURL)
if !options.Masquerade.ProxyOptions.RewriteHost {
r.Out.Host = r.In.Host
}
},
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
w.WriteHeader(http.StatusBadGateway)
},
}
case C.Hysterai2MasqueradeTypeString:
masqueradeHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if options.Masquerade.StringOptions.StatusCode != 0 {
w.WriteHeader(options.Masquerade.StringOptions.StatusCode)
}
for key, values := range options.Masquerade.StringOptions.Headers {
for _, value := range values {
w.Header().Add(key, value)
}
}
w.Write([]byte(options.Masquerade.StringOptions.Content))
})
default:
return nil, E.New("unknown masquerade type: ", options.Masquerade.Type)
}
}
inbound := &Inbound{
Adapter: inbound.NewAdapter(C.TypeHysteria2, tag),
router: router,
logger: logger,
listener: listener.New(listener.Options{
Context: ctx,
Logger: logger,
Listen: options.ListenOptions,
}),
tlsConfig: tlsConfig,
}
var udpTimeout time.Duration
if options.UDPTimeout != 0 {
udpTimeout = time.Duration(options.UDPTimeout)
} else {
udpTimeout = C.UDPTimeout
}
service, err := hysteria2.NewService[int](hysteria2.ServiceOptions{
Context: ctx,
Logger: logger,
BrutalDebug: options.BrutalDebug,
SendBPS: uint64(options.UpMbps * hysteria.MbpsToBps),
ReceiveBPS: uint64(options.DownMbps * hysteria.MbpsToBps),
SalamanderPassword: salamanderPassword,
TLSConfig: tlsConfig,
IgnoreClientBandwidth: options.IgnoreClientBandwidth,
UDPTimeout: udpTimeout,
Handler: inbound,
MasqueradeHandler: masqueradeHandler,
})
if err != nil {
return nil, err
}
userList := make([]int, 0, len(options.Users))
userNameList := make([]string, 0, len(options.Users))
userPasswordList := make([]string, 0, len(options.Users))
for index, user := range options.Users {
userList = append(userList, index)
userNameList = append(userNameList, user.Name)
userPasswordList = append(userPasswordList, user.Password)
}
service.UpdateUsers(userList, userPasswordList)
inbound.service = service
inbound.userNameList = userNameList
return inbound, nil
}
func (h *Inbound) NewConnectionEx(ctx context.Context, conn net.Conn, source M.Socksaddr, destination M.Socksaddr, onClose N.CloseHandlerFunc) {
ctx = log.ContextWithNewID(ctx)
var metadata adapter.InboundContext
metadata.Inbound = h.Tag()
metadata.InboundType = h.Type()
//nolint:staticcheck
metadata.InboundDetour = h.listener.ListenOptions().Detour
//nolint:staticcheck
metadata.InboundOptions = h.listener.ListenOptions().InboundOptions
metadata.OriginDestination = h.listener.UDPAddr()
metadata.Source = source
metadata.Destination = destination
h.logger.InfoContext(ctx, "inbound connection from ", metadata.Source)
userID, _ := auth.UserFromContext[int](ctx)
if userName := h.userNameList[userID]; userName != "" {
metadata.User = userName
h.logger.InfoContext(ctx, "[", userName, "] inbound connection to ", metadata.Destination)
} else {
h.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination)
}
h.router.RouteConnectionEx(ctx, conn, metadata, onClose)
}
func (h *Inbound) NewPacketConnectionEx(ctx context.Context, conn N.PacketConn, source M.Socksaddr, destination M.Socksaddr, onClose N.CloseHandlerFunc) {
ctx = log.ContextWithNewID(ctx)
var metadata adapter.InboundContext
metadata.Inbound = h.Tag()
metadata.InboundType = h.Type()
//nolint:staticcheck
metadata.InboundDetour = h.listener.ListenOptions().Detour
//nolint:staticcheck
metadata.InboundOptions = h.listener.ListenOptions().InboundOptions
metadata.OriginDestination = h.listener.UDPAddr()
metadata.Source = source
metadata.Destination = destination
h.logger.InfoContext(ctx, "inbound packet connection from ", metadata.Source)
userID, _ := auth.UserFromContext[int](ctx)
if userName := h.userNameList[userID]; userName != "" {
metadata.User = userName
h.logger.InfoContext(ctx, "[", userName, "] inbound packet connection to ", metadata.Destination)
} else {
h.logger.InfoContext(ctx, "inbound packet connection to ", metadata.Destination)
}
h.router.RoutePacketConnectionEx(ctx, conn, metadata, onClose)
}
func (h *Inbound) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
if h.tlsConfig != nil {
err := h.tlsConfig.Start()
if err != nil {
return err
}
}
packetConn, err := h.listener.ListenUDP()
if err != nil {
return err
}
return h.service.Start(packetConn)
}
func (h *Inbound) Close() error {
return common.Close(
h.listener,
h.tlsConfig,
common.PtrOrNil(h.service),
)
}
+120
View File
@@ -0,0 +1,120 @@
package hysteria2
import (
"context"
"net"
"os"
"time"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/outbound"
"github.com/sagernet/sing-box/common/dialer"
"github.com/sagernet/sing-box/common/tls"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-box/protocol/tuic"
"github.com/sagernet/sing-quic/hysteria"
"github.com/sagernet/sing-quic/hysteria2"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
func RegisterOutbound(registry *outbound.Registry) {
outbound.Register[option.Hysteria2OutboundOptions](registry, C.TypeHysteria2, NewOutbound)
}
var (
_ adapter.Outbound = (*tuic.Outbound)(nil)
_ adapter.InterfaceUpdateListener = (*tuic.Outbound)(nil)
)
type Outbound struct {
outbound.Adapter
logger logger.ContextLogger
client *hysteria2.Client
}
func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.Hysteria2OutboundOptions) (adapter.Outbound, error) {
options.UDPFragmentDefault = true
if options.TLS == nil || !options.TLS.Enabled {
return nil, C.ErrTLSRequired
}
tlsConfig, err := tls.NewClient(ctx, options.Server, common.PtrValueOrDefault(options.TLS))
if err != nil {
return nil, err
}
var salamanderPassword string
if options.Obfs != nil {
if options.Obfs.Password == "" {
return nil, E.New("missing obfs password")
}
switch options.Obfs.Type {
case hysteria2.ObfsTypeSalamander:
salamanderPassword = options.Obfs.Password
default:
return nil, E.New("unknown obfs type: ", options.Obfs.Type)
}
}
outboundDialer, err := dialer.New(ctx, options.DialerOptions, options.ServerIsDomain())
if err != nil {
return nil, err
}
networkList := options.Network.Build()
client, err := hysteria2.NewClient(hysteria2.ClientOptions{
Context: ctx,
Dialer: outboundDialer,
Logger: logger,
BrutalDebug: options.BrutalDebug,
ServerAddress: options.ServerOptions.Build(),
ServerPorts: options.ServerPorts,
HopInterval: time.Duration(options.HopInterval),
SendBPS: uint64(options.UpMbps * hysteria.MbpsToBps),
ReceiveBPS: uint64(options.DownMbps * hysteria.MbpsToBps),
SalamanderPassword: salamanderPassword,
Password: options.Password,
TLSConfig: tlsConfig,
UDPDisabled: !common.Contains(networkList, N.NetworkUDP),
})
if err != nil {
return nil, err
}
return &Outbound{
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeHysteria2, tag, networkList, options.DialerOptions),
logger: logger,
client: client,
}, nil
}
func (h *Outbound) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
switch N.NetworkName(network) {
case N.NetworkTCP:
h.logger.InfoContext(ctx, "outbound connection to ", destination)
return h.client.DialConn(ctx, destination)
case N.NetworkUDP:
conn, err := h.ListenPacket(ctx, destination)
if err != nil {
return nil, err
}
return bufio.NewBindPacketConn(conn, destination), nil
default:
return nil, E.New("unsupported network: ", network)
}
}
func (h *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
h.logger.InfoContext(ctx, "outbound packet connection to ", destination)
return h.client.ListenPacket(ctx)
}
func (h *Outbound) InterfaceUpdated() {
h.client.CloseWithError(E.New("network changed"))
}
func (h *Outbound) Close() error {
return h.client.CloseWithError(os.ErrClosed)
}
+22
View File
@@ -0,0 +1,22 @@
package hysteria2
import (
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/option"
)
func (h *Inbound) UpdateUsers(users []option.Hysteria2User) error {
h.Close()
userList := make([]int, 0, len(users))
userNameList := make([]string, 0, len(users))
userPasswordList := make([]string, 0, len(users))
for index, user := range users {
userList = append(userList, index)
userNameList = append(userNameList, user.Name)
userPasswordList = append(userPasswordList, user.Password)
}
h.service.UpdateUsers(userList, userPasswordList)
h.userNameList = userNameList
h.Start(adapter.StartStateStart)
return nil
}
+3 -2
View File
@@ -1,6 +1,9 @@
package core
import (
"s-ui/core/protocol/hysteria"
"s-ui/core/protocol/hysteria2"
"github.com/sagernet/sing-box/adapter/endpoint"
"github.com/sagernet/sing-box/adapter/inbound"
"github.com/sagernet/sing-box/adapter/outbound"
@@ -18,8 +21,6 @@ import (
protocolDNS "github.com/sagernet/sing-box/protocol/dns"
"github.com/sagernet/sing-box/protocol/group"
"github.com/sagernet/sing-box/protocol/http"
"github.com/sagernet/sing-box/protocol/hysteria"
"github.com/sagernet/sing-box/protocol/hysteria2"
"github.com/sagernet/sing-box/protocol/mixed"
"github.com/sagernet/sing-box/protocol/naive"
_ "github.com/sagernet/sing-box/protocol/naive/quic"
+1 -1
View File
@@ -22,7 +22,7 @@ func (s *DepleteJob) Run() {
return
}
if len(inboundIds) > 0 {
err := s.InboundService.RestartInbounds(database.GetDB(), inboundIds)
err := s.InboundService.UpdateUsers(database.GetDB(), inboundIds)
if err != nil {
logger.Error("unable to restart inbounds: ", err)
}
+41 -17
View File
@@ -1,6 +1,7 @@
package service
import (
"bytes"
"encoding/json"
"s-ui/database"
"s-ui/database/model"
@@ -54,14 +55,22 @@ func (s *ClientService) Save(tx *gorm.DB, act string, data json.RawMessage, host
if err != nil {
return nil, err
}
err = json.Unmarshal(client.Inbounds, &inboundIds)
if err != nil {
return nil, err
}
err = s.updateLinksWithFixedInbounds(tx, []*model.Client{&client}, inboundIds, hostname)
if err != nil {
return nil, err
}
if act == "edit" {
// Find changed inbounds
inboundIds, err = s.findInboundsChanges(tx, client)
if err != nil {
return nil, err
}
} else {
err = json.Unmarshal(client.Inbounds, &inboundIds)
if err != nil {
return nil, err
}
}
err = tx.Save(&client).Error
if err != nil {
return nil, err
@@ -140,7 +149,7 @@ func (s *ClientService) updateLinksWithFixedInbounds(tx *gorm.DB, clients []*mod
}
}
// Add no local links
// Add non local links
for _, clientLink := range clientLinks {
if clientLink["type"] != "local" {
newClientLinks = append(newClientLinks, clientLink)
@@ -316,7 +325,8 @@ func (s *ClientService) DepleteClients() ([]uint, error) {
users = append(users, client.Name)
var userInbounds []uint
json.Unmarshal(client.Inbounds, &userInbounds)
inboundIds = s.uniqueAppendInboundIds(inboundIds, userInbounds)
// Find changed inbounds
inboundIds = common.UnionUintArray(inboundIds, userInbounds)
changes = append(changes, model.Changes{
DateTime: dt,
Actor: "DepleteJob",
@@ -342,18 +352,32 @@ func (s *ClientService) DepleteClients() ([]uint, error) {
return inboundIds, nil
}
// avoid duplicate inboundIds
func (s *ClientService) uniqueAppendInboundIds(a []uint, b []uint) []uint {
m := make(map[uint]bool)
for _, v := range a {
m[v] = true
func (s *ClientService) findInboundsChanges(tx *gorm.DB, client model.Client) ([]uint, error) {
var err error
var oldClient model.Client
var oldInboundIds, newInboundIds []uint
err = tx.Model(model.Client{}).Where("id = ?", client.Id).First(&oldClient).Error
if err != nil {
return nil, err
}
for _, v := range b {
m[v] = true
err = json.Unmarshal(oldClient.Inbounds, &oldInboundIds)
if err != nil {
return nil, err
}
var res []uint
for k := range m {
res = append(res, k)
err = json.Unmarshal(client.Inbounds, &newInboundIds)
if err != nil {
return nil, err
}
return res
// Check client.Config changes
if !bytes.Equal(oldClient.Config, client.Config) ||
oldClient.Name != client.Name ||
oldClient.Enable != client.Enable {
return common.UnionUintArray(oldInboundIds, newInboundIds), nil
}
// Check client.Inbounds changes
diffInbounds := common.DiffUintArray(oldInboundIds, newInboundIds)
return diffInbounds, nil
}
+4 -1
View File
@@ -145,7 +145,10 @@ func (s *ConfigService) Save(obj string, act string, data json.RawMessage, initU
inboundIds, err := s.ClientService.Save(tx, act, data, hostname)
if err == nil && len(inboundIds) > 0 {
objs = append(objs, "inbounds")
err = s.InboundService.RestartInbounds(tx, inboundIds)
err = s.InboundService.UpdateUsers(tx, inboundIds)
if err != nil {
return nil, common.NewErrorf("failed to update users for inbounds: %v", err)
}
}
case "tls":
err = s.TlsService.Save(tx, act, data, hostname)
+62
View File
@@ -4,12 +4,16 @@ import (
"encoding/json"
"fmt"
"os"
"s-ui/core/protocol/hysteria"
"s-ui/core/protocol/hysteria2"
"s-ui/database"
"s-ui/database/model"
"s-ui/logger"
"s-ui/util"
"s-ui/util/common"
"strings"
"github.com/sagernet/sing-box/option"
"gorm.io/gorm"
)
@@ -327,6 +331,64 @@ func (s *InboundService) initUsers(db *gorm.DB, inboundJson []byte, clientIds st
return json.Marshal(inbound)
}
func (s *InboundService) UpdateUsers(tx *gorm.DB, ids []uint) error {
var inbounds []model.Inbound
err := tx.Model(model.Inbound{}).Preload("Tls").Where("id in ?", ids).Find(&inbounds).Error
if err != nil {
return err
}
for _, inbound := range inbounds {
inboundConfig, err := inbound.MarshalJSON()
if err != nil {
return err
}
inboundConfig, err = s.addUsers(tx, inboundConfig, inbound.Id, inbound.Type)
if err != nil {
return err
}
inb, ok := corePtr.GetInstance().Inbound().Get(inbound.Tag)
if !ok {
return common.NewErrorf("inbound %s not found", inbound.Tag)
}
switch inbound.Type {
case "hysteria":
var hysteriaOptions option.HysteriaInboundOptions
err = json.Unmarshal(inboundConfig, &hysteriaOptions)
if err != nil {
return common.NewErrorf("failed to unmarshal hysteria options for inbound %s: %v", inbound.Tag, err)
}
err = inb.(*hysteria.Inbound).UpdateUsers(hysteriaOptions.Users)
if err != nil {
return common.NewErrorf("failed to update users for hysteria inbound %s: %v", inbound.Tag, err)
}
logger.Info("Updated users for hysteria inbound:", inbound.Tag)
case "hysteria2":
var hy2Options option.Hysteria2InboundOptions
err = json.Unmarshal(inboundConfig, &hy2Options)
if err != nil {
return common.NewErrorf("failed to unmarshal hysteria2 options for inbound %s: %v", inbound.Tag, err)
}
err = inb.(*hysteria2.Inbound).UpdateUsers(hy2Options.Users)
if err != nil {
return common.NewErrorf("failed to update users for hysteria2 inbound %s: %v", inbound.Tag, err)
}
logger.Info("Updated users for hysteria2 inbound:", inbound.Tag)
default:
err = corePtr.RemoveInbound(inbound.Tag)
if err != nil && err != os.ErrInvalid {
return err
}
err = corePtr.AddInbound(inboundConfig)
if err != nil {
return err
}
}
}
return nil
}
func (s *InboundService) RestartInbounds(tx *gorm.DB, ids []uint) error {
if !corePtr.IsRunning() {
return nil
+44
View File
@@ -0,0 +1,44 @@
package common
func UnionUintArray(a []uint, b []uint) []uint {
m := make(map[uint]bool)
for _, v := range a {
m[v] = true
}
for _, v := range b {
m[v] = true
}
var res []uint
for k := range m {
res = append(res, k)
}
return res
}
// Find different elements in two slices
// Returns elements in 'a' that are not in 'b' and elements in 'b' that are not in 'a'
func DiffUintArray(a []uint, b []uint) []uint {
different := []uint{}
set := make(map[uint]bool)
for _, item := range a {
set[item] = true
}
for _, item := range b {
if !set[item] {
different = append(different, item)
}
}
set = make(map[uint]bool)
for _, item := range b {
set[item] = true
}
for _, item := range a {
if !set[item] {
different = append(different, item)
}
}
return different
}