diff --git a/core/box.go b/core/box.go index 6a366a7..68cc043 100644 --- a/core/box.go +++ b/core/box.go @@ -49,6 +49,7 @@ type Box struct { connection *route.ConnectionManager router *route.Router internalService []adapter.LifecycleService + statsTracker *StatsTracker connTracker *ConnTracker done chan struct{} } @@ -324,6 +325,10 @@ func NewBox(options Options) (*Box, error) { return nil, common.NewError("initialize platform interface", err) } } + if statsTracker == nil { + statsTracker = NewStatsTracker() + } + router.AppendTracker(statsTracker) if connTracker == nil { connTracker = NewConnTracker() } @@ -387,6 +392,7 @@ func NewBox(options Options) (*Box, error) { logFactory: logFactory, logger: logFactory.Logger(), internalService: internalServices, + statsTracker: statsTracker, connTracker: connTracker, done: make(chan struct{}), }, nil @@ -530,6 +536,10 @@ func (s *Box) Endpoint() adapter.EndpointManager { return s.endpoint } +func (s *Box) StatsTracker() *StatsTracker { + return s.statsTracker +} + func (s *Box) ConnTracker() *ConnTracker { return s.connTracker } diff --git a/core/conntracker.go b/core/conntracker.go deleted file mode 100644 index f3a98a5..0000000 --- a/core/conntracker.go +++ /dev/null @@ -1,258 +0,0 @@ -package core - -import ( - "context" - "net" - "s-ui/database/model" - "sync" - "time" - - "github.com/gofrs/uuid/v5" - "github.com/sagernet/sing-box/adapter" - "github.com/sagernet/sing/common/atomic" - "github.com/sagernet/sing/common/bufio" - "github.com/sagernet/sing/common/network" -) - -type Counter struct { - read *atomic.Int64 - write *atomic.Int64 -} - -type ConnectionInfo struct { - ID string - Conn net.Conn - PacketConn network.PacketConn - Inbound string - User string - CreatedAt time.Time - Type string // "tcp" or "udp" -} - -type ConnTracker struct { - access sync.Mutex - createdAt time.Time - inbounds map[string]Counter - outbounds map[string]Counter - users map[string]Counter - connections map[string]*ConnectionInfo -} - -func NewConnTracker() *ConnTracker { - return &ConnTracker{ - createdAt: time.Now(), - inbounds: make(map[string]Counter), - outbounds: make(map[string]Counter), - users: make(map[string]Counter), - connections: make(map[string]*ConnectionInfo), - } -} - -func (c *ConnTracker) getReadCounters(inbound string, outbound string, user string) ([]*atomic.Int64, []*atomic.Int64) { - var readCounter []*atomic.Int64 - var writeCounter []*atomic.Int64 - c.access.Lock() - if inbound != "" { - readCounter = append(readCounter, c.loadOrCreateCounter(&c.inbounds, inbound).read) - writeCounter = append(writeCounter, c.inbounds[inbound].write) - } - if outbound != "" { - readCounter = append(readCounter, c.loadOrCreateCounter(&c.outbounds, outbound).read) - writeCounter = append(writeCounter, c.outbounds[outbound].write) - } - if user != "" { - readCounter = append(readCounter, c.loadOrCreateCounter(&c.users, user).read) - writeCounter = append(writeCounter, c.users[user].write) - } - c.access.Unlock() - return readCounter, writeCounter -} - -func (c *ConnTracker) loadOrCreateCounter(obj *map[string]Counter, name string) Counter { - counter, loaded := (*obj)[name] - if loaded { - return counter - } - counter = Counter{read: &atomic.Int64{}, write: &atomic.Int64{}} - (*obj)[name] = counter - return counter -} - -func (c *ConnTracker) generateConnectionID() string { - return uuid.Must(uuid.NewV4()).String() -} - -func (c *ConnTracker) trackConnection(connID string, connInfo *ConnectionInfo) { - c.access.Lock() - defer c.access.Unlock() - c.connections[connID] = connInfo -} - -func (c *ConnTracker) untrackConnection(connID string) { - c.access.Lock() - defer c.access.Unlock() - delete(c.connections, connID) -} - -func (c *ConnTracker) createWrappedConn(conn net.Conn, connID string) net.Conn { - return &wrappedConn{ - Conn: conn, - tracker: c, - connID: connID, - } -} - -func (c *ConnTracker) createWrappedPacketConn(conn network.PacketConn, connID string) network.PacketConn { - return &wrappedPacketConn{ - PacketConn: conn, - tracker: c, - connID: connID, - } -} - -func (c *ConnTracker) RoutedConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, matchedRule adapter.Rule, matchOutbound adapter.Outbound) net.Conn { - readCounter, writeCounter := c.getReadCounters(metadata.Inbound, matchOutbound.Tag(), metadata.User) - - connID := c.generateConnectionID() - connInfo := &ConnectionInfo{ - ID: connID, - Conn: conn, - Inbound: metadata.Inbound, - User: metadata.User, - CreatedAt: time.Now(), - Type: "tcp", - } - - c.trackConnection(connID, connInfo) - - wrappedConn := c.createWrappedConn(conn, connID) - return bufio.NewInt64CounterConn(wrappedConn, readCounter, writeCounter) -} - -func (c *ConnTracker) RoutedPacketConnection(ctx context.Context, conn network.PacketConn, metadata adapter.InboundContext, matchedRule adapter.Rule, matchOutbound adapter.Outbound) network.PacketConn { - readCounter, writeCounter := c.getReadCounters(metadata.Inbound, matchOutbound.Tag(), metadata.User) - - connID := c.generateConnectionID() - connInfo := &ConnectionInfo{ - ID: connID, - PacketConn: conn, - Inbound: metadata.Inbound, - User: metadata.User, - CreatedAt: time.Now(), - Type: "udp", - } - - c.trackConnection(connID, connInfo) - - wrappedConn := c.createWrappedPacketConn(conn, connID) - return bufio.NewInt64CounterPacketConn(wrappedConn, readCounter, writeCounter) -} - -func (c *ConnTracker) ForceCloseConn(inbound, user string) int { - c.access.Lock() - defer c.access.Unlock() - - closedCount := 0 - for connID, connInfo := range c.connections { - if connInfo.Inbound == inbound && connInfo.User == user { - if connInfo.Conn != nil { - connInfo.Conn.Close() - } - if connInfo.PacketConn != nil { - connInfo.PacketConn.Close() - } - delete(c.connections, connID) - closedCount++ - } - } - return closedCount -} - -func (c *ConnTracker) CloseConnByInbound(inbound string) int { - c.access.Lock() - defer c.access.Unlock() - - closedCount := 0 - for connID, connInfo := range c.connections { - if connInfo.Inbound == inbound { - if connInfo.Conn != nil { - connInfo.Conn.Close() - } - if connInfo.PacketConn != nil { - connInfo.PacketConn.Close() - } - delete(c.connections, connID) - closedCount++ - } - } - return closedCount -} - -func (c *ConnTracker) GetStats() *[]model.Stats { - c.access.Lock() - defer c.access.Unlock() - - dt := time.Now().Unix() - - s := []model.Stats{} - for inbound, counter := range c.inbounds { - down := counter.write.Swap(0) - up := counter.read.Swap(0) - if down > 0 || up > 0 { - s = append(s, model.Stats{ - DateTime: dt, - Resource: "inbound", - Tag: inbound, - Direction: false, - Traffic: down, - }, model.Stats{ - DateTime: dt, - Resource: "inbound", - Tag: inbound, - Direction: true, - Traffic: up, - }) - } - } - - for outbound, counter := range c.outbounds { - down := counter.write.Swap(0) - up := counter.read.Swap(0) - if down > 0 || up > 0 { - s = append(s, model.Stats{ - DateTime: dt, - Resource: "outbound", - Tag: outbound, - Direction: false, - Traffic: down, - }, model.Stats{ - DateTime: dt, - Resource: "outbound", - Tag: outbound, - Direction: true, - Traffic: up, - }) - } - } - - for user, counter := range c.users { - down := counter.write.Swap(0) - up := counter.read.Swap(0) - if down > 0 || up > 0 { - s = append(s, model.Stats{ - DateTime: dt, - Resource: "user", - Tag: user, - Direction: false, - Traffic: down, - }, model.Stats{ - DateTime: dt, - Resource: "user", - Tag: user, - Direction: true, - Traffic: up, - }) - } - } - return &s -} diff --git a/core/main.go b/core/main.go index 679ef1e..d373859 100644 --- a/core/main.go +++ b/core/main.go @@ -22,6 +22,7 @@ var ( service_manager adapter.ServiceManager endpoint_manager adapter.EndpointManager router adapter.Router + statsTracker *StatsTracker connTracker *ConnTracker factory log.Factory ) diff --git a/core/tracker_conn.go b/core/tracker_conn.go new file mode 100644 index 0000000..5b38a0f --- /dev/null +++ b/core/tracker_conn.go @@ -0,0 +1,136 @@ +package core + +import ( + "context" + "net" + "sync" + + "github.com/gofrs/uuid/v5" + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing/common/network" +) + +type ConnectionInfo struct { + ID string + Conn net.Conn + PacketConn network.PacketConn + Inbound string + Type string // "tcp" or "udp" +} + +type ConnTracker struct { + access sync.Mutex + connections map[string]*ConnectionInfo +} + +func NewConnTracker() *ConnTracker { + return &ConnTracker{ + connections: make(map[string]*ConnectionInfo), + } +} + +func (c *ConnTracker) generateConnectionID() string { + return uuid.Must(uuid.NewV4()).String() +} + +func (c *ConnTracker) RoutedConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, matchedRule adapter.Rule, matchOutbound adapter.Outbound) net.Conn { + connID := c.generateConnectionID() + connInfo := &ConnectionInfo{ + ID: connID, + Conn: conn, + Inbound: metadata.Inbound, + Type: "tcp", + } + + c.trackConnection(connID, connInfo) + + return c.createWrappedConn(conn, connID) +} + +func (c *ConnTracker) RoutedPacketConnection(ctx context.Context, conn network.PacketConn, metadata adapter.InboundContext, matchedRule adapter.Rule, matchOutbound adapter.Outbound) network.PacketConn { + connID := c.generateConnectionID() + connInfo := &ConnectionInfo{ + ID: connID, + PacketConn: conn, + Inbound: metadata.Inbound, + Type: "udp", + } + + c.trackConnection(connID, connInfo) + + return c.createWrappedPacketConn(conn, connID) +} + +func (c *ConnTracker) CloseConnByInbound(inbound string) int { + c.access.Lock() + defer c.access.Unlock() + + closedCount := 0 + for connID, connInfo := range c.connections { + if connInfo.Inbound == inbound { + if connInfo.Conn != nil { + connInfo.Conn.Close() + } + if connInfo.PacketConn != nil { + connInfo.PacketConn.Close() + } + delete(c.connections, connID) + closedCount++ + } + } + return closedCount +} + +func (c *ConnTracker) trackConnection(connID string, connInfo *ConnectionInfo) { + c.access.Lock() + defer c.access.Unlock() + c.connections[connID] = connInfo +} + +func (c *ConnTracker) untrackConnection(connID string) { + c.access.Lock() + defer c.access.Unlock() + delete(c.connections, connID) +} + +func (c *ConnTracker) createWrappedConn(conn net.Conn, connID string) *wrappedConn { + return &wrappedConn{ + Conn: conn, + connID: connID, + } +} + +func (c *ConnTracker) createWrappedPacketConn(conn network.PacketConn, connID string) *wrappedPacketConn { + return &wrappedPacketConn{ + PacketConn: conn, + connID: connID, + } +} + +type wrappedConn struct { + net.Conn + connID string +} + +func (w *wrappedConn) Close() error { + connTracker.untrackConnection(w.connID) + return w.Conn.Close() +} + +func (w *wrappedConn) Upstream() any { + return w.Conn +} + +type wrappedPacketConn struct { + network.PacketConn + connID string +} + +func (w *wrappedPacketConn) Close() error { + connTracker.untrackConnection(w.connID) + return w.PacketConn.Close() +} + +func (w *wrappedPacketConn) Upstream() any { + return w.PacketConn +} diff --git a/core/tracker_stats.go b/core/tracker_stats.go new file mode 100644 index 0000000..fef3962 --- /dev/null +++ b/core/tracker_stats.go @@ -0,0 +1,144 @@ +package core + +import ( + "context" + "net" + "s-ui/database/model" + "sync" + "time" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing/common/atomic" + "github.com/sagernet/sing/common/bufio" + "github.com/sagernet/sing/common/network" +) + +type Counter struct { + read *atomic.Int64 + write *atomic.Int64 +} + +type StatsTracker struct { + access sync.Mutex + inbounds map[string]Counter + outbounds map[string]Counter + users map[string]Counter +} + +func NewStatsTracker() *StatsTracker { + return &StatsTracker{ + inbounds: make(map[string]Counter), + outbounds: make(map[string]Counter), + users: make(map[string]Counter), + } +} + +func (c *StatsTracker) getReadCounters(inbound string, outbound string, user string) ([]*atomic.Int64, []*atomic.Int64) { + var readCounter []*atomic.Int64 + var writeCounter []*atomic.Int64 + c.access.Lock() + defer c.access.Unlock() + + if inbound != "" { + readCounter = append(readCounter, c.loadOrCreateCounter(&c.inbounds, inbound).read) + writeCounter = append(writeCounter, c.inbounds[inbound].write) + } + if outbound != "" { + readCounter = append(readCounter, c.loadOrCreateCounter(&c.outbounds, outbound).read) + writeCounter = append(writeCounter, c.outbounds[outbound].write) + } + if user != "" { + readCounter = append(readCounter, c.loadOrCreateCounter(&c.users, user).read) + writeCounter = append(writeCounter, c.users[user].write) + } + return readCounter, writeCounter +} + +func (c *StatsTracker) loadOrCreateCounter(obj *map[string]Counter, name string) Counter { + counter, loaded := (*obj)[name] + if loaded { + return counter + } + counter = Counter{read: &atomic.Int64{}, write: &atomic.Int64{}} + (*obj)[name] = counter + return counter +} + +func (c *StatsTracker) RoutedConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, matchedRule adapter.Rule, matchOutbound adapter.Outbound) net.Conn { + readCounter, writeCounter := c.getReadCounters(metadata.Inbound, matchOutbound.Tag(), metadata.User) + return bufio.NewInt64CounterConn(conn, readCounter, writeCounter) +} + +func (c *StatsTracker) RoutedPacketConnection(ctx context.Context, conn network.PacketConn, metadata adapter.InboundContext, matchedRule adapter.Rule, matchOutbound adapter.Outbound) network.PacketConn { + readCounter, writeCounter := c.getReadCounters(metadata.Inbound, matchOutbound.Tag(), metadata.User) + return bufio.NewInt64CounterPacketConn(conn, readCounter, writeCounter) +} + +func (c *StatsTracker) GetStats() *[]model.Stats { + c.access.Lock() + defer c.access.Unlock() + + dt := time.Now().Unix() + + s := []model.Stats{} + for inbound, counter := range c.inbounds { + down := counter.write.Swap(0) + up := counter.read.Swap(0) + if down > 0 || up > 0 { + s = append(s, model.Stats{ + DateTime: dt, + Resource: "inbound", + Tag: inbound, + Direction: false, + Traffic: down, + }, model.Stats{ + DateTime: dt, + Resource: "inbound", + Tag: inbound, + Direction: true, + Traffic: up, + }) + } + } + + for outbound, counter := range c.outbounds { + down := counter.write.Swap(0) + up := counter.read.Swap(0) + if down > 0 || up > 0 { + s = append(s, model.Stats{ + DateTime: dt, + Resource: "outbound", + Tag: outbound, + Direction: false, + Traffic: down, + }, model.Stats{ + DateTime: dt, + Resource: "outbound", + Tag: outbound, + Direction: true, + Traffic: up, + }) + } + } + + for user, counter := range c.users { + down := counter.write.Swap(0) + up := counter.read.Swap(0) + if down > 0 || up > 0 { + s = append(s, model.Stats{ + DateTime: dt, + Resource: "user", + Tag: user, + Direction: false, + Traffic: down, + }, model.Stats{ + DateTime: dt, + Resource: "user", + Tag: user, + Direction: true, + Traffic: up, + }) + } + } + return &s +} diff --git a/core/wrapped_conn.go b/core/wrapped_conn.go deleted file mode 100644 index d10414c..0000000 --- a/core/wrapped_conn.go +++ /dev/null @@ -1,29 +0,0 @@ -package core - -import ( - "net" - - "github.com/sagernet/sing/common/network" -) - -type wrappedConn struct { - net.Conn - tracker *ConnTracker - connID string -} - -func (w *wrappedConn) Close() error { - w.tracker.untrackConnection(w.connID) - return w.Conn.Close() -} - -type wrappedPacketConn struct { - network.PacketConn - tracker *ConnTracker - connID string -} - -func (w *wrappedPacketConn) Close() error { - w.tracker.untrackConnection(w.connID) - return w.PacketConn.Close() -} diff --git a/service/stats.go b/service/stats.go index 876b1e1..693d21f 100644 --- a/service/stats.go +++ b/service/stats.go @@ -23,7 +23,7 @@ func (s *StatsService) SaveStats() error { if !corePtr.IsRunning() { return nil } - stats := corePtr.GetInstance().ConnTracker().GetStats() + stats := corePtr.GetInstance().StatsTracker().GetStats() // Reset onlines onlineResources.Inbound = nil