From dd7e81c557013fb2c8f28bfbdf36dedc0ec0476f Mon Sep 17 00:00:00 2001 From: Alireza Ahmadi Date: Wed, 30 Jul 2025 12:20:25 +0200 Subject: [PATCH] close connection on restart inbound #684 Using new tracker --- core/conntracker.go | 135 +++++++++++++++++++++++++++++++++++++++---- core/wrapped_conn.go | 29 ++++++++++ go.mod | 2 +- service/inbounds.go | 3 + 4 files changed, 157 insertions(+), 12 deletions(-) create mode 100644 core/wrapped_conn.go diff --git a/core/conntracker.go b/core/conntracker.go index 5f99adc..f3a98a5 100644 --- a/core/conntracker.go +++ b/core/conntracker.go @@ -7,6 +7,7 @@ import ( "sync" "time" + "github.com/gofrs/uuid/v5" "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing/common/atomic" "github.com/sagernet/sing/common/bufio" @@ -18,20 +19,32 @@ type Counter struct { write *atomic.Int64 } +type ConnectionInfo struct { + ID string + Conn net.Conn + PacketConn network.PacketConn + Inbound string + User string + CreatedAt time.Time + Type string // "tcp" or "udp" +} + type ConnTracker struct { - access sync.Mutex - createdAt time.Time - inbounds map[string]Counter - outbounds map[string]Counter - users map[string]Counter + access sync.Mutex + createdAt time.Time + inbounds map[string]Counter + outbounds map[string]Counter + users map[string]Counter + connections map[string]*ConnectionInfo } func NewConnTracker() *ConnTracker { return &ConnTracker{ - createdAt: time.Now(), - inbounds: make(map[string]Counter), - outbounds: make(map[string]Counter), - users: make(map[string]Counter), + createdAt: time.Now(), + inbounds: make(map[string]Counter), + outbounds: make(map[string]Counter), + users: make(map[string]Counter), + connections: make(map[string]*ConnectionInfo), } } @@ -65,14 +78,114 @@ func (c *ConnTracker) loadOrCreateCounter(obj *map[string]Counter, name string) return counter } +func (c *ConnTracker) generateConnectionID() string { + return uuid.Must(uuid.NewV4()).String() +} + +func (c *ConnTracker) trackConnection(connID string, connInfo *ConnectionInfo) { + c.access.Lock() + defer c.access.Unlock() + c.connections[connID] = connInfo +} + +func (c *ConnTracker) untrackConnection(connID string) { + c.access.Lock() + defer c.access.Unlock() + delete(c.connections, connID) +} + +func (c *ConnTracker) createWrappedConn(conn net.Conn, connID string) net.Conn { + return &wrappedConn{ + Conn: conn, + tracker: c, + connID: connID, + } +} + +func (c *ConnTracker) createWrappedPacketConn(conn network.PacketConn, connID string) network.PacketConn { + return &wrappedPacketConn{ + PacketConn: conn, + tracker: c, + connID: connID, + } +} + func (c *ConnTracker) RoutedConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, matchedRule adapter.Rule, matchOutbound adapter.Outbound) net.Conn { readCounter, writeCounter := c.getReadCounters(metadata.Inbound, matchOutbound.Tag(), metadata.User) - return bufio.NewInt64CounterConn(conn, readCounter, writeCounter) + + connID := c.generateConnectionID() + connInfo := &ConnectionInfo{ + ID: connID, + Conn: conn, + Inbound: metadata.Inbound, + User: metadata.User, + CreatedAt: time.Now(), + Type: "tcp", + } + + c.trackConnection(connID, connInfo) + + wrappedConn := c.createWrappedConn(conn, connID) + return bufio.NewInt64CounterConn(wrappedConn, readCounter, writeCounter) } func (c *ConnTracker) RoutedPacketConnection(ctx context.Context, conn network.PacketConn, metadata adapter.InboundContext, matchedRule adapter.Rule, matchOutbound adapter.Outbound) network.PacketConn { readCounter, writeCounter := c.getReadCounters(metadata.Inbound, matchOutbound.Tag(), metadata.User) - return bufio.NewInt64CounterPacketConn(conn, readCounter, writeCounter) + + connID := c.generateConnectionID() + connInfo := &ConnectionInfo{ + ID: connID, + PacketConn: conn, + Inbound: metadata.Inbound, + User: metadata.User, + CreatedAt: time.Now(), + Type: "udp", + } + + c.trackConnection(connID, connInfo) + + wrappedConn := c.createWrappedPacketConn(conn, connID) + return bufio.NewInt64CounterPacketConn(wrappedConn, readCounter, writeCounter) +} + +func (c *ConnTracker) ForceCloseConn(inbound, user string) int { + c.access.Lock() + defer c.access.Unlock() + + closedCount := 0 + for connID, connInfo := range c.connections { + if connInfo.Inbound == inbound && connInfo.User == user { + if connInfo.Conn != nil { + connInfo.Conn.Close() + } + if connInfo.PacketConn != nil { + connInfo.PacketConn.Close() + } + delete(c.connections, connID) + closedCount++ + } + } + return closedCount +} + +func (c *ConnTracker) CloseConnByInbound(inbound string) int { + c.access.Lock() + defer c.access.Unlock() + + closedCount := 0 + for connID, connInfo := range c.connections { + if connInfo.Inbound == inbound { + if connInfo.Conn != nil { + connInfo.Conn.Close() + } + if connInfo.PacketConn != nil { + connInfo.PacketConn.Close() + } + delete(c.connections, connID) + closedCount++ + } + } + return closedCount } func (c *ConnTracker) GetStats() *[]model.Stats { diff --git a/core/wrapped_conn.go b/core/wrapped_conn.go new file mode 100644 index 0000000..d10414c --- /dev/null +++ b/core/wrapped_conn.go @@ -0,0 +1,29 @@ +package core + +import ( + "net" + + "github.com/sagernet/sing/common/network" +) + +type wrappedConn struct { + net.Conn + tracker *ConnTracker + connID string +} + +func (w *wrappedConn) Close() error { + w.tracker.untrackConnection(w.connID) + return w.Conn.Close() +} + +type wrappedPacketConn struct { + network.PacketConn + tracker *ConnTracker + connID string +} + +func (w *wrappedPacketConn) Close() error { + w.tracker.untrackConnection(w.connID) + return w.PacketConn.Close() +} diff --git a/go.mod b/go.mod index 871a4a3..e9ec8a5 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/gin-contrib/gzip v1.2.3 github.com/gin-contrib/sessions v1.0.4 github.com/gin-gonic/gin v1.10.1 + github.com/google/uuid v1.6.0 github.com/op/go-logging v0.0.0-20160315200505-970db520ece7 github.com/robfig/cron/v3 v3.0.1 github.com/sagernet/sing v0.7.0-beta.1.0.20250722151551-64142925accb @@ -58,7 +59,6 @@ require ( github.com/google/btree v1.1.3 // indirect github.com/google/go-cmp v0.7.0 // indirect github.com/google/nftables v0.2.1-0.20240414091927-5e242ec57806 // indirect - github.com/google/uuid v1.6.0 // indirect github.com/gorilla/context v1.1.2 // indirect github.com/gorilla/csrf v1.7.3-0.20250123201450-9dd6af1f6d30 // indirect github.com/gorilla/securecookie v1.1.2 // indirect diff --git a/service/inbounds.go b/service/inbounds.go index a0963c6..3eb1e3e 100644 --- a/service/inbounds.go +++ b/service/inbounds.go @@ -341,6 +341,9 @@ func (s *InboundService) RestartInbounds(tx *gorm.DB, ids []uint) error { if err != nil && err != os.ErrInvalid { return err } + // Close all existing connections + corePtr.GetInstance().ConnTracker().CloseConnByInbound(inbound.Tag) + inboundConfig, err := inbound.MarshalJSON() if err != nil { return err