diff --git a/core/protocol/hysteria/inbound.go b/core/protocol/hysteria/inbound.go new file mode 100644 index 0000000..5afc440 --- /dev/null +++ b/core/protocol/hysteria/inbound.go @@ -0,0 +1,182 @@ +package hysteria + +import ( + "context" + "net" + "time" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/adapter/inbound" + "github.com/sagernet/sing-box/common/listener" + "github.com/sagernet/sing-box/common/tls" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing-quic/hysteria" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/auth" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +func RegisterInbound(registry *inbound.Registry) { + inbound.Register[option.HysteriaInboundOptions](registry, C.TypeHysteria, NewInbound) +} + +type Inbound struct { + inbound.Adapter + router adapter.Router + logger log.ContextLogger + listener *listener.Listener + tlsConfig tls.ServerConfig + service *hysteria.Service[int] + userNameList []string +} + +func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.HysteriaInboundOptions) (adapter.Inbound, error) { + options.UDPFragmentDefault = true + if options.TLS == nil || !options.TLS.Enabled { + return nil, C.ErrTLSRequired + } + tlsConfig, err := tls.NewServer(ctx, logger, common.PtrValueOrDefault(options.TLS)) + if err != nil { + return nil, err + } + inbound := &Inbound{ + Adapter: inbound.NewAdapter(C.TypeHysteria, tag), + router: router, + logger: logger, + listener: listener.New(listener.Options{ + Context: ctx, + Logger: logger, + Listen: options.ListenOptions, + }), + tlsConfig: tlsConfig, + } + var sendBps, receiveBps uint64 + if options.Up.Value() > 0 { + sendBps = options.Up.Value() + } else { + sendBps = uint64(options.UpMbps) * hysteria.MbpsToBps + } + if options.Down.Value() > 0 { + receiveBps = options.Down.Value() + } else { + receiveBps = uint64(options.DownMbps) * hysteria.MbpsToBps + } + var udpTimeout time.Duration + if options.UDPTimeout != 0 { + udpTimeout = time.Duration(options.UDPTimeout) + } else { + udpTimeout = C.UDPTimeout + } + service, err := hysteria.NewService[int](hysteria.ServiceOptions{ + Context: ctx, + Logger: logger, + SendBPS: sendBps, + ReceiveBPS: receiveBps, + XPlusPassword: options.Obfs, + TLSConfig: tlsConfig, + UDPTimeout: udpTimeout, + Handler: inbound, + + // Legacy options + + ConnReceiveWindow: options.ReceiveWindowConn, + StreamReceiveWindow: options.ReceiveWindowClient, + MaxIncomingStreams: int64(options.MaxConnClient), + DisableMTUDiscovery: options.DisableMTUDiscovery, + }) + if err != nil { + return nil, err + } + userList := make([]int, 0, len(options.Users)) + userNameList := make([]string, 0, len(options.Users)) + userPasswordList := make([]string, 0, len(options.Users)) + for index, user := range options.Users { + userList = append(userList, index) + userNameList = append(userNameList, user.Name) + var password string + if user.AuthString != "" { + password = user.AuthString + } else { + password = string(user.Auth) + } + userPasswordList = append(userPasswordList, password) + } + service.UpdateUsers(userList, userPasswordList) + inbound.service = service + inbound.userNameList = userNameList + return inbound, nil +} + +func (h *Inbound) NewConnectionEx(ctx context.Context, conn net.Conn, source M.Socksaddr, destination M.Socksaddr, onClose N.CloseHandlerFunc) { + ctx = log.ContextWithNewID(ctx) + var metadata adapter.InboundContext + metadata.Inbound = h.Tag() + metadata.InboundType = h.Type() + //nolint:staticcheck + metadata.InboundDetour = h.listener.ListenOptions().Detour + //nolint:staticcheck + metadata.InboundOptions = h.listener.ListenOptions().InboundOptions + metadata.OriginDestination = h.listener.UDPAddr() + metadata.Source = source + metadata.Destination = destination + h.logger.InfoContext(ctx, "inbound connection from ", metadata.Source) + userID, _ := auth.UserFromContext[int](ctx) + if userName := h.userNameList[userID]; userName != "" { + metadata.User = userName + h.logger.InfoContext(ctx, "[", userName, "] inbound connection to ", metadata.Destination) + } else { + h.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination) + } + h.router.RouteConnectionEx(ctx, conn, metadata, onClose) +} + +func (h *Inbound) NewPacketConnectionEx(ctx context.Context, conn N.PacketConn, source M.Socksaddr, destination M.Socksaddr, onClose N.CloseHandlerFunc) { + ctx = log.ContextWithNewID(ctx) + var metadata adapter.InboundContext + metadata.Inbound = h.Tag() + metadata.InboundType = h.Type() + //nolint:staticcheck + metadata.InboundDetour = h.listener.ListenOptions().Detour + //nolint:staticcheck + metadata.InboundOptions = h.listener.ListenOptions().InboundOptions + metadata.OriginDestination = h.listener.UDPAddr() + metadata.Source = source + metadata.Destination = destination + h.logger.InfoContext(ctx, "inbound packet connection from ", metadata.Source) + userID, _ := auth.UserFromContext[int](ctx) + if userName := h.userNameList[userID]; userName != "" { + metadata.User = userName + h.logger.InfoContext(ctx, "[", userName, "] inbound packet connection to ", metadata.Destination) + } else { + h.logger.InfoContext(ctx, "inbound packet connection to ", metadata.Destination) + } + h.router.RoutePacketConnectionEx(ctx, conn, metadata, onClose) +} + +func (h *Inbound) Start(stage adapter.StartStage) error { + if stage != adapter.StartStateStart { + return nil + } + if h.tlsConfig != nil { + err := h.tlsConfig.Start() + if err != nil { + return err + } + } + packetConn, err := h.listener.ListenUDP() + if err != nil { + return err + } + return h.service.Start(packetConn) +} + +func (h *Inbound) Close() error { + return common.Close( + h.listener, + h.tlsConfig, + common.PtrOrNil(h.service), + ) +} diff --git a/core/protocol/hysteria/outbound.go b/core/protocol/hysteria/outbound.go new file mode 100644 index 0000000..42a37ee --- /dev/null +++ b/core/protocol/hysteria/outbound.go @@ -0,0 +1,126 @@ +package hysteria + +import ( + "context" + "net" + "os" + "time" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/adapter/outbound" + "github.com/sagernet/sing-box/common/dialer" + "github.com/sagernet/sing-box/common/tls" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing-box/protocol/tuic" + "github.com/sagernet/sing-quic/hysteria" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/bufio" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +func RegisterOutbound(registry *outbound.Registry) { + outbound.Register[option.HysteriaOutboundOptions](registry, C.TypeHysteria, NewOutbound) +} + +var ( + _ adapter.Outbound = (*tuic.Outbound)(nil) + _ adapter.InterfaceUpdateListener = (*tuic.Outbound)(nil) +) + +type Outbound struct { + outbound.Adapter + logger logger.ContextLogger + client *hysteria.Client +} + +func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.HysteriaOutboundOptions) (adapter.Outbound, error) { + options.UDPFragmentDefault = true + if options.TLS == nil || !options.TLS.Enabled { + return nil, C.ErrTLSRequired + } + tlsConfig, err := tls.NewClient(ctx, options.Server, common.PtrValueOrDefault(options.TLS)) + if err != nil { + return nil, err + } + outboundDialer, err := dialer.New(ctx, options.DialerOptions, options.ServerIsDomain()) + if err != nil { + return nil, err + } + networkList := options.Network.Build() + var password string + if options.AuthString != "" { + password = options.AuthString + } else { + password = string(options.Auth) + } + var sendBps, receiveBps uint64 + if options.Up.Value() > 0 { + sendBps = options.Up.Value() + } else { + sendBps = uint64(options.UpMbps) * hysteria.MbpsToBps + } + if options.Down.Value() > 0 { + receiveBps = options.Down.Value() + } else { + receiveBps = uint64(options.DownMbps) * hysteria.MbpsToBps + } + client, err := hysteria.NewClient(hysteria.ClientOptions{ + Context: ctx, + Dialer: outboundDialer, + Logger: logger, + ServerAddress: options.ServerOptions.Build(), + ServerPorts: options.ServerPorts, + HopInterval: time.Duration(options.HopInterval), + SendBPS: sendBps, + ReceiveBPS: receiveBps, + XPlusPassword: options.Obfs, + Password: password, + TLSConfig: tlsConfig, + UDPDisabled: !common.Contains(networkList, N.NetworkUDP), + ConnReceiveWindow: options.ReceiveWindowConn, + StreamReceiveWindow: options.ReceiveWindow, + DisableMTUDiscovery: options.DisableMTUDiscovery, + }) + if err != nil { + return nil, err + } + return &Outbound{ + Adapter: outbound.NewAdapterWithDialerOptions(C.TypeHysteria, tag, networkList, options.DialerOptions), + logger: logger, + client: client, + }, nil +} + +func (h *Outbound) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { + switch N.NetworkName(network) { + case N.NetworkTCP: + h.logger.InfoContext(ctx, "outbound connection to ", destination) + return h.client.DialConn(ctx, destination) + case N.NetworkUDP: + conn, err := h.ListenPacket(ctx, destination) + if err != nil { + return nil, err + } + return bufio.NewBindPacketConn(conn, destination), nil + default: + return nil, E.New("unsupported network: ", network) + } +} + +func (h *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { + h.logger.InfoContext(ctx, "outbound packet connection to ", destination) + return h.client.ListenPacket(ctx, destination) +} + +func (h *Outbound) InterfaceUpdated() { + h.client.CloseWithError(E.New("network changed")) +} + +func (h *Outbound) Close() error { + return h.client.CloseWithError(os.ErrClosed) +} diff --git a/core/protocol/hysteria/users.go b/core/protocol/hysteria/users.go new file mode 100644 index 0000000..91d1a90 --- /dev/null +++ b/core/protocol/hysteria/users.go @@ -0,0 +1,28 @@ +package hysteria + +import ( + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/option" +) + +func (h *Inbound) UpdateUsers(users []option.HysteriaUser) error { + h.Close() + userList := make([]int, 0, len(users)) + userNameList := make([]string, 0, len(users)) + userPasswordList := make([]string, 0, len(users)) + for index, user := range users { + userList = append(userList, index) + userNameList = append(userNameList, user.Name) + var password string + if user.AuthString != "" { + password = user.AuthString + } else { + password = string(user.Auth) + } + userPasswordList = append(userPasswordList, password) + } + h.service.UpdateUsers(userList, userPasswordList) + h.userNameList = userNameList + h.Start(adapter.StartStateStart) + return nil +} diff --git a/core/protocol/hysteria2/inbound.go b/core/protocol/hysteria2/inbound.go new file mode 100644 index 0000000..f55b6ae --- /dev/null +++ b/core/protocol/hysteria2/inbound.go @@ -0,0 +1,215 @@ +package hysteria2 + +import ( + "context" + "net" + "net/http" + "net/http/httputil" + "net/url" + "time" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/adapter/inbound" + "github.com/sagernet/sing-box/common/listener" + "github.com/sagernet/sing-box/common/tls" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing-quic/hysteria" + "github.com/sagernet/sing-quic/hysteria2" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/auth" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +func RegisterInbound(registry *inbound.Registry) { + inbound.Register[option.Hysteria2InboundOptions](registry, C.TypeHysteria2, NewInbound) +} + +type Inbound struct { + inbound.Adapter + router adapter.Router + logger log.ContextLogger + listener *listener.Listener + tlsConfig tls.ServerConfig + service *hysteria2.Service[int] + userNameList []string +} + +func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.Hysteria2InboundOptions) (adapter.Inbound, error) { + options.UDPFragmentDefault = true + if options.TLS == nil || !options.TLS.Enabled { + return nil, C.ErrTLSRequired + } + tlsConfig, err := tls.NewServer(ctx, logger, common.PtrValueOrDefault(options.TLS)) + if err != nil { + return nil, err + } + var salamanderPassword string + if options.Obfs != nil { + if options.Obfs.Password == "" { + return nil, E.New("missing obfs password") + } + switch options.Obfs.Type { + case hysteria2.ObfsTypeSalamander: + salamanderPassword = options.Obfs.Password + default: + return nil, E.New("unknown obfs type: ", options.Obfs.Type) + } + } + var masqueradeHandler http.Handler + if options.Masquerade != nil && options.Masquerade.Type != "" { + switch options.Masquerade.Type { + case C.Hysterai2MasqueradeTypeFile: + masqueradeHandler = http.FileServer(http.Dir(options.Masquerade.FileOptions.Directory)) + case C.Hysterai2MasqueradeTypeProxy: + masqueradeURL, err := url.Parse(options.Masquerade.ProxyOptions.URL) + if err != nil { + return nil, E.Cause(err, "parse masquerade URL") + } + masqueradeHandler = &httputil.ReverseProxy{ + Rewrite: func(r *httputil.ProxyRequest) { + r.SetURL(masqueradeURL) + if !options.Masquerade.ProxyOptions.RewriteHost { + r.Out.Host = r.In.Host + } + }, + ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) { + w.WriteHeader(http.StatusBadGateway) + }, + } + case C.Hysterai2MasqueradeTypeString: + masqueradeHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if options.Masquerade.StringOptions.StatusCode != 0 { + w.WriteHeader(options.Masquerade.StringOptions.StatusCode) + } + for key, values := range options.Masquerade.StringOptions.Headers { + for _, value := range values { + w.Header().Add(key, value) + } + } + w.Write([]byte(options.Masquerade.StringOptions.Content)) + }) + default: + return nil, E.New("unknown masquerade type: ", options.Masquerade.Type) + } + } + inbound := &Inbound{ + Adapter: inbound.NewAdapter(C.TypeHysteria2, tag), + router: router, + logger: logger, + listener: listener.New(listener.Options{ + Context: ctx, + Logger: logger, + Listen: options.ListenOptions, + }), + tlsConfig: tlsConfig, + } + var udpTimeout time.Duration + if options.UDPTimeout != 0 { + udpTimeout = time.Duration(options.UDPTimeout) + } else { + udpTimeout = C.UDPTimeout + } + service, err := hysteria2.NewService[int](hysteria2.ServiceOptions{ + Context: ctx, + Logger: logger, + BrutalDebug: options.BrutalDebug, + SendBPS: uint64(options.UpMbps * hysteria.MbpsToBps), + ReceiveBPS: uint64(options.DownMbps * hysteria.MbpsToBps), + SalamanderPassword: salamanderPassword, + TLSConfig: tlsConfig, + IgnoreClientBandwidth: options.IgnoreClientBandwidth, + UDPTimeout: udpTimeout, + Handler: inbound, + MasqueradeHandler: masqueradeHandler, + }) + if err != nil { + return nil, err + } + userList := make([]int, 0, len(options.Users)) + userNameList := make([]string, 0, len(options.Users)) + userPasswordList := make([]string, 0, len(options.Users)) + for index, user := range options.Users { + userList = append(userList, index) + userNameList = append(userNameList, user.Name) + userPasswordList = append(userPasswordList, user.Password) + } + service.UpdateUsers(userList, userPasswordList) + inbound.service = service + inbound.userNameList = userNameList + return inbound, nil +} + +func (h *Inbound) NewConnectionEx(ctx context.Context, conn net.Conn, source M.Socksaddr, destination M.Socksaddr, onClose N.CloseHandlerFunc) { + ctx = log.ContextWithNewID(ctx) + var metadata adapter.InboundContext + metadata.Inbound = h.Tag() + metadata.InboundType = h.Type() + //nolint:staticcheck + metadata.InboundDetour = h.listener.ListenOptions().Detour + //nolint:staticcheck + metadata.InboundOptions = h.listener.ListenOptions().InboundOptions + metadata.OriginDestination = h.listener.UDPAddr() + metadata.Source = source + metadata.Destination = destination + h.logger.InfoContext(ctx, "inbound connection from ", metadata.Source) + userID, _ := auth.UserFromContext[int](ctx) + if userName := h.userNameList[userID]; userName != "" { + metadata.User = userName + h.logger.InfoContext(ctx, "[", userName, "] inbound connection to ", metadata.Destination) + } else { + h.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination) + } + h.router.RouteConnectionEx(ctx, conn, metadata, onClose) +} + +func (h *Inbound) NewPacketConnectionEx(ctx context.Context, conn N.PacketConn, source M.Socksaddr, destination M.Socksaddr, onClose N.CloseHandlerFunc) { + ctx = log.ContextWithNewID(ctx) + var metadata adapter.InboundContext + metadata.Inbound = h.Tag() + metadata.InboundType = h.Type() + //nolint:staticcheck + metadata.InboundDetour = h.listener.ListenOptions().Detour + //nolint:staticcheck + metadata.InboundOptions = h.listener.ListenOptions().InboundOptions + metadata.OriginDestination = h.listener.UDPAddr() + metadata.Source = source + metadata.Destination = destination + h.logger.InfoContext(ctx, "inbound packet connection from ", metadata.Source) + userID, _ := auth.UserFromContext[int](ctx) + if userName := h.userNameList[userID]; userName != "" { + metadata.User = userName + h.logger.InfoContext(ctx, "[", userName, "] inbound packet connection to ", metadata.Destination) + } else { + h.logger.InfoContext(ctx, "inbound packet connection to ", metadata.Destination) + } + h.router.RoutePacketConnectionEx(ctx, conn, metadata, onClose) +} + +func (h *Inbound) Start(stage adapter.StartStage) error { + if stage != adapter.StartStateStart { + return nil + } + if h.tlsConfig != nil { + err := h.tlsConfig.Start() + if err != nil { + return err + } + } + packetConn, err := h.listener.ListenUDP() + if err != nil { + return err + } + return h.service.Start(packetConn) +} + +func (h *Inbound) Close() error { + return common.Close( + h.listener, + h.tlsConfig, + common.PtrOrNil(h.service), + ) +} diff --git a/core/protocol/hysteria2/outbound.go b/core/protocol/hysteria2/outbound.go new file mode 100644 index 0000000..c805f07 --- /dev/null +++ b/core/protocol/hysteria2/outbound.go @@ -0,0 +1,120 @@ +package hysteria2 + +import ( + "context" + "net" + "os" + "time" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/adapter/outbound" + "github.com/sagernet/sing-box/common/dialer" + "github.com/sagernet/sing-box/common/tls" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing-box/protocol/tuic" + "github.com/sagernet/sing-quic/hysteria" + "github.com/sagernet/sing-quic/hysteria2" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/bufio" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +func RegisterOutbound(registry *outbound.Registry) { + outbound.Register[option.Hysteria2OutboundOptions](registry, C.TypeHysteria2, NewOutbound) +} + +var ( + _ adapter.Outbound = (*tuic.Outbound)(nil) + _ adapter.InterfaceUpdateListener = (*tuic.Outbound)(nil) +) + +type Outbound struct { + outbound.Adapter + logger logger.ContextLogger + client *hysteria2.Client +} + +func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.Hysteria2OutboundOptions) (adapter.Outbound, error) { + options.UDPFragmentDefault = true + if options.TLS == nil || !options.TLS.Enabled { + return nil, C.ErrTLSRequired + } + tlsConfig, err := tls.NewClient(ctx, options.Server, common.PtrValueOrDefault(options.TLS)) + if err != nil { + return nil, err + } + var salamanderPassword string + if options.Obfs != nil { + if options.Obfs.Password == "" { + return nil, E.New("missing obfs password") + } + switch options.Obfs.Type { + case hysteria2.ObfsTypeSalamander: + salamanderPassword = options.Obfs.Password + default: + return nil, E.New("unknown obfs type: ", options.Obfs.Type) + } + } + outboundDialer, err := dialer.New(ctx, options.DialerOptions, options.ServerIsDomain()) + if err != nil { + return nil, err + } + networkList := options.Network.Build() + client, err := hysteria2.NewClient(hysteria2.ClientOptions{ + Context: ctx, + Dialer: outboundDialer, + Logger: logger, + BrutalDebug: options.BrutalDebug, + ServerAddress: options.ServerOptions.Build(), + ServerPorts: options.ServerPorts, + HopInterval: time.Duration(options.HopInterval), + SendBPS: uint64(options.UpMbps * hysteria.MbpsToBps), + ReceiveBPS: uint64(options.DownMbps * hysteria.MbpsToBps), + SalamanderPassword: salamanderPassword, + Password: options.Password, + TLSConfig: tlsConfig, + UDPDisabled: !common.Contains(networkList, N.NetworkUDP), + }) + if err != nil { + return nil, err + } + return &Outbound{ + Adapter: outbound.NewAdapterWithDialerOptions(C.TypeHysteria2, tag, networkList, options.DialerOptions), + logger: logger, + client: client, + }, nil +} + +func (h *Outbound) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { + switch N.NetworkName(network) { + case N.NetworkTCP: + h.logger.InfoContext(ctx, "outbound connection to ", destination) + return h.client.DialConn(ctx, destination) + case N.NetworkUDP: + conn, err := h.ListenPacket(ctx, destination) + if err != nil { + return nil, err + } + return bufio.NewBindPacketConn(conn, destination), nil + default: + return nil, E.New("unsupported network: ", network) + } +} + +func (h *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { + h.logger.InfoContext(ctx, "outbound packet connection to ", destination) + return h.client.ListenPacket(ctx) +} + +func (h *Outbound) InterfaceUpdated() { + h.client.CloseWithError(E.New("network changed")) +} + +func (h *Outbound) Close() error { + return h.client.CloseWithError(os.ErrClosed) +} diff --git a/core/protocol/hysteria2/users.go b/core/protocol/hysteria2/users.go new file mode 100644 index 0000000..7d29f63 --- /dev/null +++ b/core/protocol/hysteria2/users.go @@ -0,0 +1,22 @@ +package hysteria2 + +import ( + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/option" +) + +func (h *Inbound) UpdateUsers(users []option.Hysteria2User) error { + h.Close() + userList := make([]int, 0, len(users)) + userNameList := make([]string, 0, len(users)) + userPasswordList := make([]string, 0, len(users)) + for index, user := range users { + userList = append(userList, index) + userNameList = append(userNameList, user.Name) + userPasswordList = append(userPasswordList, user.Password) + } + h.service.UpdateUsers(userList, userPasswordList) + h.userNameList = userNameList + h.Start(adapter.StartStateStart) + return nil +} diff --git a/core/register.go b/core/register.go index e4e8d81..4c7a7b7 100644 --- a/core/register.go +++ b/core/register.go @@ -1,6 +1,9 @@ package core import ( + "s-ui/core/protocol/hysteria" + "s-ui/core/protocol/hysteria2" + "github.com/sagernet/sing-box/adapter/endpoint" "github.com/sagernet/sing-box/adapter/inbound" "github.com/sagernet/sing-box/adapter/outbound" @@ -18,8 +21,6 @@ import ( protocolDNS "github.com/sagernet/sing-box/protocol/dns" "github.com/sagernet/sing-box/protocol/group" "github.com/sagernet/sing-box/protocol/http" - "github.com/sagernet/sing-box/protocol/hysteria" - "github.com/sagernet/sing-box/protocol/hysteria2" "github.com/sagernet/sing-box/protocol/mixed" "github.com/sagernet/sing-box/protocol/naive" _ "github.com/sagernet/sing-box/protocol/naive/quic" diff --git a/cronjob/depleteJob.go b/cronjob/depleteJob.go index 2a83b7d..273039a 100644 --- a/cronjob/depleteJob.go +++ b/cronjob/depleteJob.go @@ -22,7 +22,7 @@ func (s *DepleteJob) Run() { return } if len(inboundIds) > 0 { - err := s.InboundService.RestartInbounds(database.GetDB(), inboundIds) + err := s.InboundService.UpdateUsers(database.GetDB(), inboundIds) if err != nil { logger.Error("unable to restart inbounds: ", err) } diff --git a/service/client.go b/service/client.go index 1480661..90c2256 100644 --- a/service/client.go +++ b/service/client.go @@ -1,6 +1,7 @@ package service import ( + "bytes" "encoding/json" "s-ui/database" "s-ui/database/model" @@ -54,14 +55,22 @@ func (s *ClientService) Save(tx *gorm.DB, act string, data json.RawMessage, host if err != nil { return nil, err } - err = json.Unmarshal(client.Inbounds, &inboundIds) - if err != nil { - return nil, err - } err = s.updateLinksWithFixedInbounds(tx, []*model.Client{&client}, inboundIds, hostname) if err != nil { return nil, err } + if act == "edit" { + // Find changed inbounds + inboundIds, err = s.findInboundsChanges(tx, client) + if err != nil { + return nil, err + } + } else { + err = json.Unmarshal(client.Inbounds, &inboundIds) + if err != nil { + return nil, err + } + } err = tx.Save(&client).Error if err != nil { return nil, err @@ -140,7 +149,7 @@ func (s *ClientService) updateLinksWithFixedInbounds(tx *gorm.DB, clients []*mod } } - // Add no local links + // Add non local links for _, clientLink := range clientLinks { if clientLink["type"] != "local" { newClientLinks = append(newClientLinks, clientLink) @@ -316,7 +325,8 @@ func (s *ClientService) DepleteClients() ([]uint, error) { users = append(users, client.Name) var userInbounds []uint json.Unmarshal(client.Inbounds, &userInbounds) - inboundIds = s.uniqueAppendInboundIds(inboundIds, userInbounds) + // Find changed inbounds + inboundIds = common.UnionUintArray(inboundIds, userInbounds) changes = append(changes, model.Changes{ DateTime: dt, Actor: "DepleteJob", @@ -342,18 +352,32 @@ func (s *ClientService) DepleteClients() ([]uint, error) { return inboundIds, nil } -// avoid duplicate inboundIds -func (s *ClientService) uniqueAppendInboundIds(a []uint, b []uint) []uint { - m := make(map[uint]bool) - for _, v := range a { - m[v] = true +func (s *ClientService) findInboundsChanges(tx *gorm.DB, client model.Client) ([]uint, error) { + var err error + var oldClient model.Client + var oldInboundIds, newInboundIds []uint + err = tx.Model(model.Client{}).Where("id = ?", client.Id).First(&oldClient).Error + if err != nil { + return nil, err } - for _, v := range b { - m[v] = true + err = json.Unmarshal(oldClient.Inbounds, &oldInboundIds) + if err != nil { + return nil, err } - var res []uint - for k := range m { - res = append(res, k) + err = json.Unmarshal(client.Inbounds, &newInboundIds) + if err != nil { + return nil, err } - return res + + // Check client.Config changes + if !bytes.Equal(oldClient.Config, client.Config) || + oldClient.Name != client.Name || + oldClient.Enable != client.Enable { + return common.UnionUintArray(oldInboundIds, newInboundIds), nil + } + + // Check client.Inbounds changes + diffInbounds := common.DiffUintArray(oldInboundIds, newInboundIds) + + return diffInbounds, nil } diff --git a/service/config.go b/service/config.go index 80bc5f8..f86e174 100644 --- a/service/config.go +++ b/service/config.go @@ -145,7 +145,10 @@ func (s *ConfigService) Save(obj string, act string, data json.RawMessage, initU inboundIds, err := s.ClientService.Save(tx, act, data, hostname) if err == nil && len(inboundIds) > 0 { objs = append(objs, "inbounds") - err = s.InboundService.RestartInbounds(tx, inboundIds) + err = s.InboundService.UpdateUsers(tx, inboundIds) + if err != nil { + return nil, common.NewErrorf("failed to update users for inbounds: %v", err) + } } case "tls": err = s.TlsService.Save(tx, act, data, hostname) diff --git a/service/inbounds.go b/service/inbounds.go index a0963c6..8bc6517 100644 --- a/service/inbounds.go +++ b/service/inbounds.go @@ -4,12 +4,16 @@ import ( "encoding/json" "fmt" "os" + "s-ui/core/protocol/hysteria" + "s-ui/core/protocol/hysteria2" "s-ui/database" "s-ui/database/model" + "s-ui/logger" "s-ui/util" "s-ui/util/common" "strings" + "github.com/sagernet/sing-box/option" "gorm.io/gorm" ) @@ -327,6 +331,64 @@ func (s *InboundService) initUsers(db *gorm.DB, inboundJson []byte, clientIds st return json.Marshal(inbound) } +func (s *InboundService) UpdateUsers(tx *gorm.DB, ids []uint) error { + var inbounds []model.Inbound + err := tx.Model(model.Inbound{}).Preload("Tls").Where("id in ?", ids).Find(&inbounds).Error + if err != nil { + return err + } + for _, inbound := range inbounds { + inboundConfig, err := inbound.MarshalJSON() + if err != nil { + return err + } + inboundConfig, err = s.addUsers(tx, inboundConfig, inbound.Id, inbound.Type) + if err != nil { + return err + } + inb, ok := corePtr.GetInstance().Inbound().Get(inbound.Tag) + if !ok { + return common.NewErrorf("inbound %s not found", inbound.Tag) + } + switch inbound.Type { + case "hysteria": + var hysteriaOptions option.HysteriaInboundOptions + err = json.Unmarshal(inboundConfig, &hysteriaOptions) + if err != nil { + return common.NewErrorf("failed to unmarshal hysteria options for inbound %s: %v", inbound.Tag, err) + } + err = inb.(*hysteria.Inbound).UpdateUsers(hysteriaOptions.Users) + if err != nil { + return common.NewErrorf("failed to update users for hysteria inbound %s: %v", inbound.Tag, err) + } + logger.Info("Updated users for hysteria inbound:", inbound.Tag) + case "hysteria2": + var hy2Options option.Hysteria2InboundOptions + err = json.Unmarshal(inboundConfig, &hy2Options) + if err != nil { + return common.NewErrorf("failed to unmarshal hysteria2 options for inbound %s: %v", inbound.Tag, err) + } + err = inb.(*hysteria2.Inbound).UpdateUsers(hy2Options.Users) + if err != nil { + return common.NewErrorf("failed to update users for hysteria2 inbound %s: %v", inbound.Tag, err) + } + logger.Info("Updated users for hysteria2 inbound:", inbound.Tag) + default: + err = corePtr.RemoveInbound(inbound.Tag) + if err != nil && err != os.ErrInvalid { + return err + } + + err = corePtr.AddInbound(inboundConfig) + if err != nil { + return err + } + } + + } + return nil +} + func (s *InboundService) RestartInbounds(tx *gorm.DB, ids []uint) error { if !corePtr.IsRunning() { return nil diff --git a/util/common/array.go b/util/common/array.go new file mode 100644 index 0000000..6a69c1b --- /dev/null +++ b/util/common/array.go @@ -0,0 +1,44 @@ +package common + +func UnionUintArray(a []uint, b []uint) []uint { + m := make(map[uint]bool) + for _, v := range a { + m[v] = true + } + for _, v := range b { + m[v] = true + } + var res []uint + for k := range m { + res = append(res, k) + } + return res +} + +// Find different elements in two slices +// Returns elements in 'a' that are not in 'b' and elements in 'b' that are not in 'a' +func DiffUintArray(a []uint, b []uint) []uint { + different := []uint{} + set := make(map[uint]bool) + + for _, item := range a { + set[item] = true + } + for _, item := range b { + if !set[item] { + different = append(different, item) + } + } + + set = make(map[uint]bool) + for _, item := range b { + set[item] = true + } + for _, item := range a { + if !set[item] { + different = append(different, item) + } + } + + return different +}