diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index cb03ed1..b5ef5a6 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -92,7 +92,7 @@ jobs: fi ### Build s-ui - go build -ldflags="-w -s" -tags "with_quic,with_grpc,with_ech,with_utls,with_reality_server,with_acme,with_gvisor" -o sui main.go + go build -ldflags="-w -s" -tags "with_quic,with_grpc,with_utls,with_acme,with_gvisor" -o sui main.go mkdir s-ui cp sui s-ui/ diff --git a/api/apiService.go b/api/apiService.go index 013f5fd..61fafaf 100644 --- a/api/apiService.go +++ b/api/apiService.go @@ -22,6 +22,7 @@ type ApiService struct { service.InboundService service.OutboundService service.EndpointService + service.ServicesService service.PanelService service.StatsService service.ServerService @@ -81,6 +82,10 @@ func (a *ApiService) getData(c *gin.Context) (interface{}, error) { if err != nil { return "", err } + services, err := a.ServicesService.GetAll() + if err != nil { + return "", err + } subURI, err := a.SettingService.GetFinalSubURI(strings.Split(c.Request.Host, ":")[0]) if err != nil { return "", err @@ -91,6 +96,7 @@ func (a *ApiService) getData(c *gin.Context) (interface{}, error) { data["inbounds"] = inbounds data["outbounds"] = outbounds data["endpoints"] = endpoints + data["services"] = services data["subURI"] = subURI data["onlines"] = onlines } else { @@ -124,6 +130,12 @@ func (a *ApiService) LoadPartialData(c *gin.Context, objs []string) error { return err } data[obj] = endpoints + case "services": + services, err := a.ServicesService.GetAll() + if err != nil { + return err + } + data[obj] = services case "tls": tlsConfigs, err := a.TlsService.GetAll() if err != nil { diff --git a/core/box.go b/core/box.go index da269ce..5fb329b 100644 --- a/core/box.go +++ b/core/box.go @@ -12,9 +12,13 @@ import ( "github.com/sagernet/sing-box/adapter/endpoint" "github.com/sagernet/sing-box/adapter/inbound" "github.com/sagernet/sing-box/adapter/outbound" + boxService "github.com/sagernet/sing-box/adapter/service" + "github.com/sagernet/sing-box/common/certificate" "github.com/sagernet/sing-box/common/dialer" "github.com/sagernet/sing-box/common/taskmonitor" C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/dns" + "github.com/sagernet/sing-box/dns/transport/local" "github.com/sagernet/sing-box/experimental/cachefile" "github.com/sagernet/sing-box/experimental/libbox/platform" "github.com/sagernet/sing-box/log" @@ -28,21 +32,24 @@ import ( "github.com/sagernet/sing/service/pause" ) -var _ adapter.Service = (*Box)(nil) +var _ adapter.SimpleLifecycle = (*Box)(nil) type Box struct { - createdAt time.Time - logFactory log.Factory - logger log.ContextLogger - network *route.NetworkManager - endpoint *endpoint.Manager - inbound *inbound.Manager - outbound *outbound.Manager - connection *route.ConnectionManager - router *route.Router - services []adapter.LifecycleService - connTracker *ConnTracker - done chan struct{} + createdAt time.Time + logFactory log.Factory + logger log.ContextLogger + network *route.NetworkManager + endpoint *endpoint.Manager + inbound *inbound.Manager + outbound *outbound.Manager + service *boxService.Manager + dnsTransport *dns.TransportManager + dnsRouter *dns.Router + connection *route.ConnectionManager + router *route.Router + internalService []adapter.LifecycleService + connTracker *ConnTracker + done chan struct{} } type Options struct { @@ -55,6 +62,8 @@ func Context( inboundRegistry adapter.InboundRegistry, outboundRegistry adapter.OutboundRegistry, endpointRegistry adapter.EndpointRegistry, + dnsTransportRegistry adapter.DNSTransportRegistry, + serviceRegistry adapter.ServiceRegistry, ) context.Context { if service.FromContext[option.InboundOptionsRegistry](ctx) == nil || service.FromContext[adapter.InboundRegistry](ctx) == nil { @@ -71,6 +80,14 @@ func Context( ctx = service.ContextWith[option.EndpointOptionsRegistry](ctx, endpointRegistry) ctx = service.ContextWith[adapter.EndpointRegistry](ctx, endpointRegistry) } + if service.FromContext[adapter.DNSTransportRegistry](ctx) == nil { + ctx = service.ContextWith[option.DNSTransportOptionsRegistry](ctx, dnsTransportRegistry) + ctx = service.ContextWith[adapter.DNSTransportRegistry](ctx, dnsTransportRegistry) + } + if service.FromContext[adapter.ServiceRegistry](ctx) == nil { + ctx = service.ContextWith[option.ServiceOptionsRegistry](ctx, serviceRegistry) + ctx = service.ContextWith[adapter.ServiceRegistry](ctx, serviceRegistry) + } return ctx } @@ -86,6 +103,8 @@ func NewBox(options Options) (*Box, error) { endpointRegistry := service.FromContext[adapter.EndpointRegistry](ctx) inboundRegistry := service.FromContext[adapter.InboundRegistry](ctx) outboundRegistry := service.FromContext[adapter.OutboundRegistry](ctx) + dnsTransportRegistry := service.FromContext[adapter.DNSTransportRegistry](ctx) + serviceRegistry := service.FromContext[adapter.ServiceRegistry](ctx) if endpointRegistry == nil { return nil, common.NewError("missing endpoint registry in context") @@ -96,6 +115,12 @@ func NewBox(options Options) (*Box, error) { if outboundRegistry == nil { return nil, common.NewError("missing outbound registry in context") } + if dnsTransportRegistry == nil { + return nil, common.NewError("missing DNS transport registry in context") + } + if serviceRegistry == nil { + return nil, common.NewError("missing service registry in context") + } ctx = pause.WithDefaultManager(ctx) experimentalOptions := sbCommon.PtrValueOrDefault(options.Experimental) @@ -120,13 +145,36 @@ func NewBox(options Options) (*Box, error) { } factory = logFactory + var internalServices []adapter.LifecycleService + certificateOptions := sbCommon.PtrValueOrDefault(options.Certificate) + if C.IsAndroid || certificateOptions.Store != "" && certificateOptions.Store != C.CertificateStoreSystem || + len(certificateOptions.Certificate) > 0 || + len(certificateOptions.CertificatePath) > 0 || + len(certificateOptions.CertificateDirectoryPath) > 0 { + certificateStore, err := certificate.NewStore(ctx, logFactory.NewLogger("certificate"), certificateOptions) + if err != nil { + return nil, err + } + service.MustRegister[adapter.CertificateStore](ctx, certificateStore) + internalServices = append(internalServices, certificateStore) + } + routeOptions := sbCommon.PtrValueOrDefault(options.Route) + dnsOptions := sbCommon.PtrValueOrDefault(options.DNS) endpointManager := endpoint.NewManager(logFactory.NewLogger("endpoint"), endpointRegistry) inboundManager := inbound.NewManager(logFactory.NewLogger("inbound"), inboundRegistry, endpointManager) outboundManager := outbound.NewManager(logFactory.NewLogger("outbound"), outboundRegistry, endpointManager, routeOptions.Final) + dnsTransportManager := dns.NewTransportManager(logFactory.NewLogger("dns/transport"), dnsTransportRegistry, outboundManager, dnsOptions.Final) + serviceManager := boxService.NewManager(logFactory.NewLogger("service"), serviceRegistry) + service.MustRegister[adapter.EndpointManager](ctx, endpointManager) service.MustRegister[adapter.InboundManager](ctx, inboundManager) service.MustRegister[adapter.OutboundManager](ctx, outboundManager) + service.MustRegister[adapter.DNSTransportManager](ctx, dnsTransportManager) + service.MustRegister[adapter.ServiceManager](ctx, serviceManager) + + dnsRouter := dns.NewRouter(ctx, logFactory, dnsOptions) + service.MustRegister[adapter.DNSRouter](ctx, dnsRouter) networkManager, err := route.NewNetworkManager(ctx, logFactory.NewLogger("network"), routeOptions) if err != nil { @@ -135,10 +183,34 @@ func NewBox(options Options) (*Box, error) { service.MustRegister[adapter.NetworkManager](ctx, networkManager) connectionManager := route.NewConnectionManager(logFactory.NewLogger("connection")) service.MustRegister[adapter.ConnectionManager](ctx, connectionManager) - router, err := route.NewRouter(ctx, logFactory, routeOptions, sbCommon.PtrValueOrDefault(options.DNS)) + router := route.NewRouter(ctx, logFactory, routeOptions, dnsOptions) + service.MustRegister[adapter.Router](ctx, router) + err = router.Initialize(routeOptions.Rules, routeOptions.RuleSet) if err != nil { return nil, common.NewError("initialize router", err) } + for i, transportOptions := range dnsOptions.Servers { + var tag string + if transportOptions.Tag != "" { + tag = transportOptions.Tag + } else { + tag = F.ToString(i) + } + err = dnsTransportManager.Create( + ctx, + logFactory.NewLogger(F.ToString("dns/", transportOptions.Type, "[", tag, "]")), + tag, + transportOptions.Type, + transportOptions.Options, + ) + if err != nil { + return nil, common.NewError("initialize DNS server[", i, "]", err) + } + } + err = dnsRouter.Initialize(dnsOptions.Rules) + if err != nil { + return nil, common.NewError("initialize dns router", err) + } for i, endpointOptions := range options.Endpoints { var tag string if endpointOptions.Tag != "" { @@ -146,7 +218,8 @@ func NewBox(options Options) (*Box, error) { } else { tag = F.ToString(i) } - err = endpointManager.Create(ctx, + err = endpointManager.Create( + ctx, router, logFactory.NewLogger(F.ToString("endpoint/", endpointOptions.Type, "[", tag, "]")), tag, @@ -164,7 +237,8 @@ func NewBox(options Options) (*Box, error) { } else { tag = F.ToString(i) } - err = inboundManager.Create(ctx, + err = inboundManager.Create( + ctx, router, logFactory.NewLogger(F.ToString("inbound/", inboundOptions.Type, "[", tag, "]")), tag, @@ -201,6 +275,24 @@ func NewBox(options Options) (*Box, error) { return nil, common.NewError("initialize outbound["+F.ToString(i)+"] "+tag, err) } } + for i, serviceOptions := range options.Services { + var tag string + if serviceOptions.Tag != "" { + tag = serviceOptions.Tag + } else { + tag = F.ToString(i) + } + err = serviceManager.Create( + ctx, + logFactory.NewLogger(F.ToString("service/", serviceOptions.Type, "[", tag, "]")), + tag, + serviceOptions.Type, + serviceOptions.Options, + ) + if err != nil { + return nil, common.NewError("initialize service["+F.ToString(i)+"]"+tag, err) + } + } outboundManager.Initialize(sbCommon.Must1( direct.NewOutbound( ctx, @@ -210,6 +302,13 @@ func NewBox(options Options) (*Box, error) { option.DirectOutboundOptions{}, ), )) + dnsTransportManager.Initialize(sbCommon.Must1( + local.NewTransport( + ctx, + logFactory.NewLogger("dns/local"), + "local", + option.LocalDNSServerOptions{}, + ))) if platformInterface != nil { err = platformInterface.Initialize(networkManager) if err != nil { @@ -219,18 +318,16 @@ func NewBox(options Options) (*Box, error) { if connTracker == nil { connTracker = NewConnTracker() } - router.SetTracker(connTracker) - - var services []adapter.LifecycleService + router.AppendTracker(connTracker) if needCacheFile { cacheFile := cachefile.New(ctx, sbCommon.PtrValueOrDefault(experimentalOptions.CacheFile)) service.MustRegister[adapter.CacheFile](ctx, cacheFile) - services = append(services, cacheFile) + internalServices = append(internalServices, cacheFile) } ntpOptions := sbCommon.PtrValueOrDefault(options.NTP) if ntpOptions.Enabled { - ntpDialer, err := dialer.New(ctx, ntpOptions.DialerOptions) + ntpDialer, err := dialer.New(ctx, ntpOptions.DialerOptions, ntpOptions.ServerIsDomain()) if err != nil { return nil, common.NewError(err, "create NTP service") } @@ -243,21 +340,24 @@ func NewBox(options Options) (*Box, error) { WriteToSystem: ntpOptions.WriteToSystem, }) service.MustRegister[ntp.TimeService](ctx, timeService) - services = append(services, adapter.NewLifecycleService(timeService, "ntp service")) + internalServices = append(internalServices, adapter.NewLifecycleService(timeService, "ntp service")) } return &Box{ - network: networkManager, - endpoint: endpointManager, - inbound: inboundManager, - outbound: outboundManager, - connection: connectionManager, - router: router, - createdAt: createdAt, - logFactory: logFactory, - logger: logFactory.Logger(), - services: services, - connTracker: connTracker, - done: make(chan struct{}), + network: networkManager, + endpoint: endpointManager, + inbound: inboundManager, + outbound: outboundManager, + dnsTransport: dnsTransportManager, + service: serviceManager, + dnsRouter: dnsRouter, + connection: connectionManager, + router: router, + createdAt: createdAt, + logFactory: logFactory, + logger: logFactory.Logger(), + internalService: internalServices, + connTracker: connTracker, + done: make(chan struct{}), }, nil } @@ -305,15 +405,15 @@ func (s *Box) preStart() error { if err != nil { return common.NewError(err, "start logger") } - err = adapter.StartNamed(adapter.StartStateInitialize, s.services) // cache-file + err = adapter.StartNamed(adapter.StartStateInitialize, s.internalService) // cache-file if err != nil { return err } - err = adapter.Start(adapter.StartStateInitialize, s.network, s.connection, s.router, s.outbound, s.inbound, s.endpoint) + err = adapter.Start(adapter.StartStateInitialize, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.outbound, s.inbound, s.endpoint, s.service) if err != nil { return err } - err = adapter.Start(adapter.StartStateStart, s.outbound, s.network, s.connection, s.router) + err = adapter.Start(adapter.StartStateStart, s.outbound, s.dnsTransport, s.dnsRouter, s.network, s.connection, s.router) if err != nil { return err } @@ -325,31 +425,27 @@ func (s *Box) start() error { if err != nil { return err } - err = adapter.StartNamed(adapter.StartStateStart, s.services) + err = adapter.StartNamed(adapter.StartStateStart, s.internalService) if err != nil { return err } - err = s.inbound.Start(adapter.StartStateStart) + err = adapter.Start(adapter.StartStateStart, s.inbound, s.endpoint, s.service) if err != nil { return err } - err = adapter.Start(adapter.StartStateStart, s.endpoint) + err = adapter.Start(adapter.StartStatePostStart, s.outbound, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.inbound, s.endpoint, s.service) if err != nil { return err } - err = adapter.Start(adapter.StartStatePostStart, s.outbound, s.network, s.connection, s.router, s.inbound, s.endpoint) + err = adapter.StartNamed(adapter.StartStatePostStart, s.internalService) if err != nil { return err } - err = adapter.StartNamed(adapter.StartStatePostStart, s.services) + err = adapter.Start(adapter.StartStateStarted, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.outbound, s.inbound, s.endpoint, s.service) if err != nil { return err } - err = adapter.Start(adapter.StartStateStarted, s.network, s.connection, s.router, s.outbound, s.inbound, s.endpoint) - if err != nil { - return err - } - err = adapter.StartNamed(adapter.StartStateStarted, s.services) + err = adapter.StartNamed(adapter.StartStateStarted, s.internalService) if err != nil { return err } @@ -364,9 +460,9 @@ func (s *Box) Close() error { close(s.done) } err := sbCommon.Close( - s.endpoint, s.inbound, s.outbound, s.router, s.connection, s.network, + s.endpoint, s.inbound, s.outbound, s.router, s.connection, s.dnsRouter, s.dnsTransport, s.network, ) - for _, lifecycleService := range s.services { + for _, lifecycleService := range s.internalService { err1 := lifecycleService.Close() if err1 != nil { s.logger.Debug(lifecycleService.Name(), " close error: ", err1) diff --git a/core/endpoint.go b/core/endpoint.go index ea1514e..943a39b 100644 --- a/core/endpoint.go +++ b/core/endpoint.go @@ -112,3 +112,36 @@ func (c *Core) RemoveEndpoint(tag string) error { logger.Info("remove endpoint: ", tag) return endpoint_manager.Remove(tag) } + +func (c *Core) AddService(config []byte) error { + if !c.isRunning { + return common.NewError("sing-box is not running") + } + var err error + var srv_config option.Service + + err = srv_config.UnmarshalJSONContext(c.GetCtx(), config) + if err != nil { + return err + } + + err = service_manager.Create( + c.GetCtx(), + factory.NewLogger("service/"+srv_config.Type+"["+srv_config.Tag+"]"), + srv_config.Tag, + srv_config.Type, + srv_config.Options) + if err != nil { + return err + } + + return nil +} + +func (c *Core) RemoveService(tag string) error { + if !c.isRunning { + return common.NewError("sing-box is not running") + } + logger.Info("remove service: ", tag) + return service_manager.Remove(tag) +} diff --git a/core/main.go b/core/main.go index ef3ae29..679ef1e 100644 --- a/core/main.go +++ b/core/main.go @@ -19,6 +19,7 @@ var ( globalCtx context.Context inbound_manager adapter.InboundManager outbound_manager adapter.OutboundManager + service_manager adapter.ServiceManager endpoint_manager adapter.EndpointManager router adapter.Router connTracker *ConnTracker @@ -32,7 +33,7 @@ type Core struct { func NewCore() *Core { globalCtx = context.Background() - globalCtx = sb.Context(globalCtx, inboundRegistry(), outboundRegistry(), EndpointRegistry()) + globalCtx = sb.Context(globalCtx, InboundRegistry(), OutboundRegistry(), EndpointRegistry(), DNSTransportRegistry(), ServiceRegistry()) return &Core{ isRunning: false, instance: nil, @@ -70,6 +71,7 @@ func (c *Core) Start(sbConfig []byte) error { globalCtx = service.ContextWith(globalCtx, c) inbound_manager = service.FromContext[adapter.InboundManager](globalCtx) outbound_manager = service.FromContext[adapter.OutboundManager](globalCtx) + service_manager = service.FromContext[adapter.ServiceManager](globalCtx) endpoint_manager = service.FromContext[adapter.EndpointManager](globalCtx) router = service.FromContext[adapter.Router](globalCtx) diff --git a/core/register.go b/core/register.go index aaf3e01..e4e8d81 100644 --- a/core/register.go +++ b/core/register.go @@ -4,9 +4,18 @@ import ( "github.com/sagernet/sing-box/adapter/endpoint" "github.com/sagernet/sing-box/adapter/inbound" "github.com/sagernet/sing-box/adapter/outbound" + "github.com/sagernet/sing-box/adapter/service" + "github.com/sagernet/sing-box/dns" + "github.com/sagernet/sing-box/dns/transport" + "github.com/sagernet/sing-box/dns/transport/dhcp" + "github.com/sagernet/sing-box/dns/transport/fakeip" + "github.com/sagernet/sing-box/dns/transport/hosts" + "github.com/sagernet/sing-box/dns/transport/local" + "github.com/sagernet/sing-box/dns/transport/quic" + "github.com/sagernet/sing-box/protocol/anytls" "github.com/sagernet/sing-box/protocol/block" "github.com/sagernet/sing-box/protocol/direct" - "github.com/sagernet/sing-box/protocol/dns" + protocolDNS "github.com/sagernet/sing-box/protocol/dns" "github.com/sagernet/sing-box/protocol/group" "github.com/sagernet/sing-box/protocol/http" "github.com/sagernet/sing-box/protocol/hysteria" @@ -19,6 +28,7 @@ import ( "github.com/sagernet/sing-box/protocol/shadowtls" "github.com/sagernet/sing-box/protocol/socks" "github.com/sagernet/sing-box/protocol/ssh" + "github.com/sagernet/sing-box/protocol/tailscale" "github.com/sagernet/sing-box/protocol/tor" "github.com/sagernet/sing-box/protocol/trojan" "github.com/sagernet/sing-box/protocol/tuic" @@ -26,11 +36,14 @@ import ( "github.com/sagernet/sing-box/protocol/vless" "github.com/sagernet/sing-box/protocol/vmess" "github.com/sagernet/sing-box/protocol/wireguard" + "github.com/sagernet/sing-box/service/derp" + "github.com/sagernet/sing-box/service/resolved" + "github.com/sagernet/sing-box/service/ssmapi" _ "github.com/sagernet/sing-box/transport/v2rayquic" _ "github.com/sagernet/sing-dns/quic" ) -func inboundRegistry() *inbound.Registry { +func InboundRegistry() *inbound.Registry { registry := inbound.NewRegistry() tun.RegisterInbound(registry) @@ -48,6 +61,7 @@ func inboundRegistry() *inbound.Registry { naive.RegisterInbound(registry) shadowtls.RegisterInbound(registry) vless.RegisterInbound(registry) + anytls.RegisterInbound(registry) hysteria.RegisterInbound(registry) tuic.RegisterInbound(registry) @@ -56,13 +70,13 @@ func inboundRegistry() *inbound.Registry { return registry } -func outboundRegistry() *outbound.Registry { +func OutboundRegistry() *outbound.Registry { registry := outbound.NewRegistry() direct.RegisterOutbound(registry) block.RegisterOutbound(registry) - dns.RegisterOutbound(registry) + protocolDNS.RegisterOutbound(registry) group.RegisterSelector(registry) group.RegisterURLTest(registry) @@ -76,6 +90,7 @@ func outboundRegistry() *outbound.Registry { ssh.RegisterOutbound(registry) shadowtls.RegisterOutbound(registry) vless.RegisterOutbound(registry) + anytls.RegisterOutbound(registry) hysteria.RegisterOutbound(registry) tuic.RegisterOutbound(registry) @@ -89,6 +104,57 @@ func EndpointRegistry() *endpoint.Registry { registry := endpoint.NewRegistry() wireguard.RegisterEndpoint(registry) + registerTailscaleEndpoint(registry) + + return registry +} + +func DNSTransportRegistry() *dns.TransportRegistry { + registry := dns.NewTransportRegistry() + + transport.RegisterTCP(registry) + transport.RegisterUDP(registry) + transport.RegisterTLS(registry) + transport.RegisterHTTPS(registry) + hosts.RegisterTransport(registry) + local.RegisterTransport(registry) + fakeip.RegisterTransport(registry) + + registerQUICTransports(registry) + registerDHCPTransport(registry) + registerTailscaleTransport(registry) + + return registry +} + +func registerTailscaleEndpoint(registry *endpoint.Registry) { + tailscale.RegisterEndpoint(registry) +} + +func registerTailscaleTransport(registry *dns.TransportRegistry) { + tailscale.RegistryTransport(registry) +} + +func registerDERPService(registry *service.Registry) { + derp.Register(registry) +} + +func registerQUICTransports(registry *dns.TransportRegistry) { + quic.RegisterTransport(registry) + quic.RegisterHTTP3Transport(registry) +} + +func registerDHCPTransport(registry *dns.TransportRegistry) { + dhcp.RegisterTransport(registry) +} + +func ServiceRegistry() *service.Registry { + registry := service.NewRegistry() + + resolved.RegisterService(registry) + ssmapi.RegisterService(registry) + + registerDERPService(registry) return registry } diff --git a/database/db.go b/database/db.go index fdba0ac..4cc5671 100644 --- a/database/db.go +++ b/database/db.go @@ -76,6 +76,7 @@ func InitDB(dbPath string) error { &model.Tls{}, &model.Inbound{}, &model.Outbound{}, + &model.Service{}, &model.Endpoint{}, &model.User{}, &model.Tokens{}, diff --git a/database/model/services.go b/database/model/services.go new file mode 100644 index 0000000..cf7acc9 --- /dev/null +++ b/database/model/services.go @@ -0,0 +1,90 @@ +package model + +import ( + "encoding/json" +) + +type Service struct { + Id uint `json:"id" form:"id" gorm:"primaryKey;autoIncrement"` + Type string `json:"type" form:"type"` + Tag string `json:"tag" form:"tag" gorm:"unique"` + + // Foreign key to tls table + TlsId uint `json:"tls_id" form:"tls_id"` + Tls *Tls `json:"tls" form:"tls" gorm:"foreignKey:TlsId;references:Id"` + + Options json.RawMessage `json:"-" form:"-"` +} + +func (i *Service) UnmarshalJSON(data []byte) error { + var err error + var raw map[string]interface{} + if err = json.Unmarshal(data, &raw); err != nil { + return err + } + + // Extract fixed fields and store the rest in Options + if val, exists := raw["id"].(float64); exists { + i.Id = uint(val) + } + delete(raw, "id") + i.Type, _ = raw["type"].(string) + delete(raw, "type") + i.Tag, _ = raw["tag"].(string) + delete(raw, "tag") + + // TlsId + if val, exists := raw["tls_id"].(float64); exists { + i.TlsId = uint(val) + } + delete(raw, "tls_id") + delete(raw, "tls") + + // Remaining fields + i.Options, err = json.MarshalIndent(raw, "", " ") + return err +} + +// MarshalJSON customizes marshalling +func (i Service) MarshalJSON() ([]byte, error) { + // Combine fixed fields and dynamic fields into one map + combined := make(map[string]interface{}) + combined["type"] = i.Type + combined["tag"] = i.Tag + if i.Tls != nil { + combined["tls"] = i.Tls.Server + } + + if i.Options != nil { + var restFields map[string]json.RawMessage + if err := json.Unmarshal(i.Options, &restFields); err != nil { + return nil, err + } + + for k, v := range restFields { + combined[k] = v + } + } + + return json.Marshal(combined) +} + +func (i Service) MarshalFull() (*map[string]interface{}, error) { + combined := make(map[string]interface{}) + combined["id"] = i.Id + combined["type"] = i.Type + combined["tag"] = i.Tag + combined["tls_id"] = i.TlsId + + if i.Options != nil { + var restFields map[string]interface{} + if err := json.Unmarshal(i.Options, &restFields); err != nil { + return nil, err + } + + for k, v := range restFields { + combined[k] = v + } + } + return &combined, nil +} diff --git a/service/config.go b/service/config.go index 4551bfa..47f9c40 100644 --- a/service/config.go +++ b/service/config.go @@ -22,6 +22,7 @@ type ConfigService struct { SettingService InboundService OutboundService + ServicesService EndpointService } @@ -31,6 +32,7 @@ type SingBoxConfig struct { Ntp json.RawMessage `json:"ntp"` Inbounds []json.RawMessage `json:"inbounds"` Outbounds []json.RawMessage `json:"outbounds"` + Services []json.RawMessage `json:"services"` Endpoints []json.RawMessage `json:"endpoints"` Route json.RawMessage `json:"route"` Experimental json.RawMessage `json:"experimental"` @@ -63,6 +65,10 @@ func (s *ConfigService) GetConfig(data string) (*SingBoxConfig, error) { if err != nil { return nil, err } + singboxConfig.Services, err = s.ServicesService.GetAllConfig(database.GetDB()) + if err != nil { + return nil, err + } singboxConfig.Endpoints, err = s.EndpointService.GetAllConfig(database.GetDB()) if err != nil { return nil, err @@ -119,6 +125,7 @@ func (s *ConfigService) StopCore() error { func (s *ConfigService) Save(obj string, act string, data json.RawMessage, initUsers string, loginUser string, hostname string) ([]string, error) { var err error var inboundIds []uint + var serviceIds []uint var inboundId uint var objs []string = []string{obj} @@ -133,6 +140,12 @@ func (s *ConfigService) Save(obj string, act string, data json.RawMessage, initU logger.Error("unable to restart inbounds: ", err1) } } + if len(serviceIds) > 0 && corePtr.IsRunning() { + err1 := s.ServicesService.RestartServices(db, serviceIds) + if err1 != nil { + logger.Error("unable to restart services: ", err1) + } + } // Try to start core if it is not running if !corePtr.IsRunning() { s.StartCore("") @@ -147,11 +160,13 @@ func (s *ConfigService) Save(obj string, act string, data json.RawMessage, initU inboundIds, err = s.ClientService.Save(tx, act, data, hostname) objs = append(objs, "inbounds") case "tls": - inboundIds, err = s.TlsService.Save(tx, act, data) + serviceIds, inboundIds, err = s.TlsService.Save(tx, act, data) case "inbounds": inboundId, err = s.InboundService.Save(tx, act, data, initUsers, hostname) case "outbounds": err = s.OutboundService.Save(tx, act, data) + case "services": + err = s.ServicesService.Save(tx, act, data) case "endpoints": err = s.EndpointService.Save(tx, act, data) case "config": diff --git a/service/services.go b/service/services.go new file mode 100644 index 0000000..4ac2aaf --- /dev/null +++ b/service/services.go @@ -0,0 +1,149 @@ +package service + +import ( + "encoding/json" + "os" + "s-ui/database" + "s-ui/database/model" + "s-ui/util/common" + + "gorm.io/gorm" +) + +type ServicesService struct{} + +func (s *ServicesService) GetAll() (*[]map[string]interface{}, error) { + db := database.GetDB() + services := []model.Service{} + err := db.Model(model.Service{}).Scan(&services).Error + if err != nil { + return nil, err + } + var data []map[string]interface{} + for _, srv := range services { + srvData := map[string]interface{}{ + "id": srv.Id, + "type": srv.Type, + "tag": srv.Tag, + "tls_id": srv.TlsId, + } + if srv.Options != nil { + var restFields map[string]json.RawMessage + if err := json.Unmarshal(srv.Options, &restFields); err != nil { + return nil, err + } + for k, v := range restFields { + srvData[k] = v + } + } + + data = append(data, srvData) + } + return &data, nil +} + +func (s *ServicesService) GetAllConfig(db *gorm.DB) ([]json.RawMessage, error) { + var servicesJson []json.RawMessage + var services []*model.Service + err := db.Model(model.Service{}).Preload("Tls").Find(&services).Error + if err != nil { + return nil, err + } + for _, srv := range services { + srvJson, err := srv.MarshalJSON() + if err != nil { + return nil, err + } + servicesJson = append(servicesJson, srvJson) + } + return servicesJson, nil +} + +func (s *ServicesService) Save(tx *gorm.DB, act string, data json.RawMessage) error { + var err error + + switch act { + case "new", "edit": + var srv model.Service + err = srv.UnmarshalJSON(data) + if err != nil { + return err + } + + if srv.TlsId > 0 { + err = tx.Model(model.Tls{}).Where("id = ?", srv.TlsId).Find(&srv.Tls).Error + if err != nil { + return err + } + } + + if corePtr.IsRunning() { + configData, err := srv.MarshalJSON() + if err != nil { + return err + } + if act == "edit" { + var oldTag string + err = tx.Model(model.Service{}).Select("tag").Where("id = ?", srv.Id).Find(&oldTag).Error + if err != nil { + return err + } + err = corePtr.RemoveService(oldTag) + if err != nil && err != os.ErrInvalid { + return err + } + } + err = corePtr.AddService(configData) + if err != nil { + return err + } + } + + err = tx.Save(&srv).Error + if err != nil { + return err + } + case "del": + var tag string + err = json.Unmarshal(data, &tag) + if err != nil { + return err + } + if corePtr.IsRunning() { + err = corePtr.RemoveService(tag) + if err != nil && err != os.ErrInvalid { + return err + } + } + err = tx.Where("tag = ?", tag).Delete(model.Service{}).Error + if err != nil { + return err + } + default: + return common.NewErrorf("unknown action: %s", act) + } + return nil +} + +func (s *ServicesService) RestartServices(tx *gorm.DB, ids []uint) error { + var services []*model.Service + err := tx.Model(model.Service{}).Preload("Tls").Where("id in ?", ids).Find(&services).Error + if err != nil { + return err + } + for _, srv := range services { + err = corePtr.RemoveService(srv.Tag) + if err != nil && err != os.ErrInvalid { + return err + } + srvConfig, err := srv.MarshalJSON() + if err != nil { + return err + } + err = corePtr.AddService(srvConfig) + if err != nil { + return err + } + } + return nil +} diff --git a/service/tls.go b/service/tls.go index 11d23ab..bb24202 100644 --- a/service/tls.go +++ b/service/tls.go @@ -24,45 +24,52 @@ func (s *TlsService) GetAll() ([]model.Tls, error) { return tlsConfig, nil } -func (s *TlsService) Save(tx *gorm.DB, action string, data json.RawMessage) ([]uint, error) { +func (s *TlsService) Save(tx *gorm.DB, action string, data json.RawMessage) ([]uint, []uint, error) { var err error var inboundIds []uint + var serviceIds []uint switch action { case "new", "edit": var tls model.Tls err = json.Unmarshal(data, &tls) if err != nil { - return nil, err + return nil, nil, err } err = tx.Save(&tls).Error if err != nil { - return nil, err + return nil, nil, err } err = tx.Model(model.Inbound{}).Select("id").Where("tls_id = ?", tls.Id).Scan(&inboundIds).Error if err != nil { - return nil, err + return nil, nil, err } - return inboundIds, nil + err = tx.Model(model.Service{}).Where("tls_id = ?", tls.Id).Scan(&serviceIds).Error + return serviceIds, inboundIds, nil case "del": var id uint err = json.Unmarshal(data, &id) if err != nil { - return nil, err + return nil, nil, err } var inboundCount int64 err = tx.Model(model.Inbound{}).Where("tls_id = ?", id).Count(&inboundCount).Error if err != nil { - return nil, err + return nil, nil, err } - if inboundCount > 0 { - return nil, common.NewError("tls in use") + var serviceCount int64 + err = tx.Model(model.Service{}).Where("tls_id = ?", id).Count(&serviceCount).Error + if err != nil { + return nil, nil, err + } + if inboundCount > 0 || serviceCount > 0 { + return nil, nil, common.NewError("tls in use") } err = tx.Where("id = ?", id).Delete(model.Tls{}).Error if err != nil { - return nil, err + return nil, nil, err } } - return nil, nil + return nil, nil, nil }