simplify conn tracker #1056
This commit is contained in:
+2
-6
@@ -326,13 +326,9 @@ func NewBox(options Options) (*Box, error) {
|
|||||||
return nil, common.NewError("initialize platform interface", err)
|
return nil, common.NewError("initialize platform interface", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if statsTracker == nil {
|
statsTracker := NewStatsTracker()
|
||||||
statsTracker = NewStatsTracker()
|
connTracker := NewConnTracker()
|
||||||
}
|
|
||||||
router.AppendTracker(statsTracker)
|
router.AppendTracker(statsTracker)
|
||||||
if connTracker == nil {
|
|
||||||
connTracker = NewConnTracker()
|
|
||||||
}
|
|
||||||
router.AppendTracker(connTracker)
|
router.AppendTracker(connTracker)
|
||||||
|
|
||||||
if needCacheFile {
|
if needCacheFile {
|
||||||
|
|||||||
@@ -22,8 +22,6 @@ var (
|
|||||||
service_manager adapter.ServiceManager
|
service_manager adapter.ServiceManager
|
||||||
endpoint_manager adapter.EndpointManager
|
endpoint_manager adapter.EndpointManager
|
||||||
router adapter.Router
|
router adapter.Router
|
||||||
statsTracker *StatsTracker
|
|
||||||
connTracker *ConnTracker
|
|
||||||
factory log.Factory
|
factory log.Factory
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
+71
-2
@@ -2,11 +2,15 @@ package core
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/gofrs/uuid/v5"
|
"github.com/gofrs/uuid/v5"
|
||||||
"github.com/sagernet/sing-box/adapter"
|
"github.com/sagernet/sing-box/adapter"
|
||||||
|
"github.com/sagernet/sing/common/buf"
|
||||||
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
"github.com/sagernet/sing/common/network"
|
"github.com/sagernet/sing/common/network"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -107,9 +111,25 @@ func (c *ConnTracker) untrackConnection(connID string) {
|
|||||||
delete(c.connections, connID)
|
delete(c.connections, connID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// shouldUntrackIOErr reports whether err indicates the connection is done (peer closed, reset, etc.).
|
||||||
|
func shouldUntrackIOErr(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if errors.Is(err, io.EOF) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
var ne net.Error
|
||||||
|
if errors.As(err, &ne) {
|
||||||
|
return !ne.Temporary()
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
func (c *ConnTracker) createWrappedConn(conn net.Conn, connID string) *wrappedConn {
|
func (c *ConnTracker) createWrappedConn(conn net.Conn, connID string) *wrappedConn {
|
||||||
return &wrappedConn{
|
return &wrappedConn{
|
||||||
Conn: conn,
|
Conn: conn,
|
||||||
|
tracker: c,
|
||||||
connID: connID,
|
connID: connID,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -117,17 +137,42 @@ func (c *ConnTracker) createWrappedConn(conn net.Conn, connID string) *wrappedCo
|
|||||||
func (c *ConnTracker) createWrappedPacketConn(conn network.PacketConn, connID string) *wrappedPacketConn {
|
func (c *ConnTracker) createWrappedPacketConn(conn network.PacketConn, connID string) *wrappedPacketConn {
|
||||||
return &wrappedPacketConn{
|
return &wrappedPacketConn{
|
||||||
PacketConn: conn,
|
PacketConn: conn,
|
||||||
|
tracker: c,
|
||||||
connID: connID,
|
connID: connID,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type wrappedConn struct {
|
type wrappedConn struct {
|
||||||
net.Conn
|
net.Conn
|
||||||
|
tracker *ConnTracker
|
||||||
connID string
|
connID string
|
||||||
|
untrackOnce sync.Once
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *wrappedConn) doUntrack() {
|
||||||
|
w.untrackOnce.Do(func() {
|
||||||
|
w.tracker.untrackConnection(w.connID)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *wrappedConn) Read(b []byte) (int, error) {
|
||||||
|
n, err := w.Conn.Read(b)
|
||||||
|
if shouldUntrackIOErr(err) {
|
||||||
|
w.doUntrack()
|
||||||
|
}
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *wrappedConn) Write(b []byte) (int, error) {
|
||||||
|
n, err := w.Conn.Write(b)
|
||||||
|
if err != nil && shouldUntrackIOErr(err) {
|
||||||
|
w.doUntrack()
|
||||||
|
}
|
||||||
|
return n, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *wrappedConn) Close() error {
|
func (w *wrappedConn) Close() error {
|
||||||
connTracker.untrackConnection(w.connID)
|
w.doUntrack()
|
||||||
return w.Conn.Close()
|
return w.Conn.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -137,11 +182,35 @@ func (w *wrappedConn) Upstream() any {
|
|||||||
|
|
||||||
type wrappedPacketConn struct {
|
type wrappedPacketConn struct {
|
||||||
network.PacketConn
|
network.PacketConn
|
||||||
|
tracker *ConnTracker
|
||||||
connID string
|
connID string
|
||||||
|
untrackOnce sync.Once
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *wrappedPacketConn) doUntrack() {
|
||||||
|
w.untrackOnce.Do(func() {
|
||||||
|
w.tracker.untrackConnection(w.connID)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *wrappedPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) {
|
||||||
|
dest, err := w.PacketConn.ReadPacket(buffer)
|
||||||
|
if shouldUntrackIOErr(err) {
|
||||||
|
w.doUntrack()
|
||||||
|
}
|
||||||
|
return dest, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *wrappedPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
|
||||||
|
err := w.PacketConn.WritePacket(buffer, destination)
|
||||||
|
if err != nil && shouldUntrackIOErr(err) {
|
||||||
|
w.doUntrack()
|
||||||
|
}
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *wrappedPacketConn) Close() error {
|
func (w *wrappedPacketConn) Close() error {
|
||||||
connTracker.untrackConnection(w.connID)
|
w.doUntrack()
|
||||||
return w.PacketConn.Close()
|
return w.PacketConn.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user