close connection on restart inbound #684

Using new tracker
This commit is contained in:
Alireza Ahmadi
2025-07-30 12:20:25 +02:00
parent 58fd5f17cf
commit dd7e81c557
4 changed files with 157 additions and 12 deletions
+124 -11
View File
@@ -7,6 +7,7 @@ import (
"sync"
"time"
"github.com/gofrs/uuid/v5"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing/common/atomic"
"github.com/sagernet/sing/common/bufio"
@@ -18,20 +19,32 @@ type Counter struct {
write *atomic.Int64
}
type ConnectionInfo struct {
ID string
Conn net.Conn
PacketConn network.PacketConn
Inbound string
User string
CreatedAt time.Time
Type string // "tcp" or "udp"
}
type ConnTracker struct {
access sync.Mutex
createdAt time.Time
inbounds map[string]Counter
outbounds map[string]Counter
users map[string]Counter
access sync.Mutex
createdAt time.Time
inbounds map[string]Counter
outbounds map[string]Counter
users map[string]Counter
connections map[string]*ConnectionInfo
}
func NewConnTracker() *ConnTracker {
return &ConnTracker{
createdAt: time.Now(),
inbounds: make(map[string]Counter),
outbounds: make(map[string]Counter),
users: make(map[string]Counter),
createdAt: time.Now(),
inbounds: make(map[string]Counter),
outbounds: make(map[string]Counter),
users: make(map[string]Counter),
connections: make(map[string]*ConnectionInfo),
}
}
@@ -65,14 +78,114 @@ func (c *ConnTracker) loadOrCreateCounter(obj *map[string]Counter, name string)
return counter
}
func (c *ConnTracker) generateConnectionID() string {
return uuid.Must(uuid.NewV4()).String()
}
func (c *ConnTracker) trackConnection(connID string, connInfo *ConnectionInfo) {
c.access.Lock()
defer c.access.Unlock()
c.connections[connID] = connInfo
}
func (c *ConnTracker) untrackConnection(connID string) {
c.access.Lock()
defer c.access.Unlock()
delete(c.connections, connID)
}
func (c *ConnTracker) createWrappedConn(conn net.Conn, connID string) net.Conn {
return &wrappedConn{
Conn: conn,
tracker: c,
connID: connID,
}
}
func (c *ConnTracker) createWrappedPacketConn(conn network.PacketConn, connID string) network.PacketConn {
return &wrappedPacketConn{
PacketConn: conn,
tracker: c,
connID: connID,
}
}
func (c *ConnTracker) RoutedConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, matchedRule adapter.Rule, matchOutbound adapter.Outbound) net.Conn {
readCounter, writeCounter := c.getReadCounters(metadata.Inbound, matchOutbound.Tag(), metadata.User)
return bufio.NewInt64CounterConn(conn, readCounter, writeCounter)
connID := c.generateConnectionID()
connInfo := &ConnectionInfo{
ID: connID,
Conn: conn,
Inbound: metadata.Inbound,
User: metadata.User,
CreatedAt: time.Now(),
Type: "tcp",
}
c.trackConnection(connID, connInfo)
wrappedConn := c.createWrappedConn(conn, connID)
return bufio.NewInt64CounterConn(wrappedConn, readCounter, writeCounter)
}
func (c *ConnTracker) RoutedPacketConnection(ctx context.Context, conn network.PacketConn, metadata adapter.InboundContext, matchedRule adapter.Rule, matchOutbound adapter.Outbound) network.PacketConn {
readCounter, writeCounter := c.getReadCounters(metadata.Inbound, matchOutbound.Tag(), metadata.User)
return bufio.NewInt64CounterPacketConn(conn, readCounter, writeCounter)
connID := c.generateConnectionID()
connInfo := &ConnectionInfo{
ID: connID,
PacketConn: conn,
Inbound: metadata.Inbound,
User: metadata.User,
CreatedAt: time.Now(),
Type: "udp",
}
c.trackConnection(connID, connInfo)
wrappedConn := c.createWrappedPacketConn(conn, connID)
return bufio.NewInt64CounterPacketConn(wrappedConn, readCounter, writeCounter)
}
func (c *ConnTracker) ForceCloseConn(inbound, user string) int {
c.access.Lock()
defer c.access.Unlock()
closedCount := 0
for connID, connInfo := range c.connections {
if connInfo.Inbound == inbound && connInfo.User == user {
if connInfo.Conn != nil {
connInfo.Conn.Close()
}
if connInfo.PacketConn != nil {
connInfo.PacketConn.Close()
}
delete(c.connections, connID)
closedCount++
}
}
return closedCount
}
func (c *ConnTracker) CloseConnByInbound(inbound string) int {
c.access.Lock()
defer c.access.Unlock()
closedCount := 0
for connID, connInfo := range c.connections {
if connInfo.Inbound == inbound {
if connInfo.Conn != nil {
connInfo.Conn.Close()
}
if connInfo.PacketConn != nil {
connInfo.PacketConn.Close()
}
delete(c.connections, connID)
closedCount++
}
}
return closedCount
}
func (c *ConnTracker) GetStats() *[]model.Stats {