From 11505a5c05f3489053ec00c01717b755255f3a2c Mon Sep 17 00:00:00 2001 From: Alireza Ahmadi Date: Sun, 22 Mar 2026 18:39:23 +0100 Subject: [PATCH] simplify conn tracker #1056 --- core/box.go | 8 ++--- core/main.go | 2 -- core/tracker_conn.go | 81 ++++++++++++++++++++++++++++++++++++++++---- 3 files changed, 77 insertions(+), 14 deletions(-) diff --git a/core/box.go b/core/box.go index aa91b7f..f9797f8 100644 --- a/core/box.go +++ b/core/box.go @@ -326,13 +326,9 @@ func NewBox(options Options) (*Box, error) { return nil, common.NewError("initialize platform interface", err) } } - if statsTracker == nil { - statsTracker = NewStatsTracker() - } + statsTracker := NewStatsTracker() + connTracker := NewConnTracker() router.AppendTracker(statsTracker) - if connTracker == nil { - connTracker = NewConnTracker() - } router.AppendTracker(connTracker) if needCacheFile { diff --git a/core/main.go b/core/main.go index e43eb94..61aad08 100644 --- a/core/main.go +++ b/core/main.go @@ -22,8 +22,6 @@ var ( service_manager adapter.ServiceManager endpoint_manager adapter.EndpointManager router adapter.Router - statsTracker *StatsTracker - connTracker *ConnTracker factory log.Factory ) diff --git a/core/tracker_conn.go b/core/tracker_conn.go index 2bd3384..626715f 100644 --- a/core/tracker_conn.go +++ b/core/tracker_conn.go @@ -2,11 +2,15 @@ package core import ( "context" + "errors" + "io" "net" "sync" "github.com/gofrs/uuid/v5" "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing/common/buf" + M "github.com/sagernet/sing/common/metadata" "github.com/sagernet/sing/common/network" ) @@ -107,27 +111,68 @@ func (c *ConnTracker) untrackConnection(connID string) { delete(c.connections, connID) } +// shouldUntrackIOErr reports whether err indicates the connection is done (peer closed, reset, etc.). +func shouldUntrackIOErr(err error) bool { + if err == nil { + return false + } + if errors.Is(err, io.EOF) { + return true + } + var ne net.Error + if errors.As(err, &ne) { + return !ne.Temporary() + } + return true +} + func (c *ConnTracker) createWrappedConn(conn net.Conn, connID string) *wrappedConn { return &wrappedConn{ - Conn: conn, - connID: connID, + Conn: conn, + tracker: c, + connID: connID, } } func (c *ConnTracker) createWrappedPacketConn(conn network.PacketConn, connID string) *wrappedPacketConn { return &wrappedPacketConn{ PacketConn: conn, + tracker: c, connID: connID, } } type wrappedConn struct { net.Conn - connID string + tracker *ConnTracker + connID string + untrackOnce sync.Once +} + +func (w *wrappedConn) doUntrack() { + w.untrackOnce.Do(func() { + w.tracker.untrackConnection(w.connID) + }) +} + +func (w *wrappedConn) Read(b []byte) (int, error) { + n, err := w.Conn.Read(b) + if shouldUntrackIOErr(err) { + w.doUntrack() + } + return n, err +} + +func (w *wrappedConn) Write(b []byte) (int, error) { + n, err := w.Conn.Write(b) + if err != nil && shouldUntrackIOErr(err) { + w.doUntrack() + } + return n, err } func (w *wrappedConn) Close() error { - connTracker.untrackConnection(w.connID) + w.doUntrack() return w.Conn.Close() } @@ -137,11 +182,35 @@ func (w *wrappedConn) Upstream() any { type wrappedPacketConn struct { network.PacketConn - connID string + tracker *ConnTracker + connID string + untrackOnce sync.Once +} + +func (w *wrappedPacketConn) doUntrack() { + w.untrackOnce.Do(func() { + w.tracker.untrackConnection(w.connID) + }) +} + +func (w *wrappedPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { + dest, err := w.PacketConn.ReadPacket(buffer) + if shouldUntrackIOErr(err) { + w.doUntrack() + } + return dest, err +} + +func (w *wrappedPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + err := w.PacketConn.WritePacket(buffer, destination) + if err != nil && shouldUntrackIOErr(err) { + w.doUntrack() + } + return err } func (w *wrappedPacketConn) Close() error { - connTracker.untrackConnection(w.connID) + w.doUntrack() return w.PacketConn.Close() }