Skip to content

Commit d219bda

Browse files
committed
Import udp_conn_pool
1 parent ca24e51 commit d219bda

2 files changed

Lines changed: 412 additions & 0 deletions

File tree

dnscrypt-proxy/udp_conn_pool.go

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
package main
2+
3+
import (
4+
"net"
5+
"sync"
6+
"sync/atomic"
7+
"time"
8+
9+
"github.com/jedisct1/dlog"
10+
)
11+
12+
const (
13+
UDPPoolMaxConnsPerAddr = 4
14+
UDPPoolMaxIdleTime = 30 * time.Second
15+
UDPPoolCleanupInterval = 10 * time.Second
16+
UDPPoolShards = 64
17+
)
18+
19+
type pooledConn struct {
20+
conn *net.UDPConn
21+
lastUsed time.Time
22+
}
23+
24+
type poolShard struct {
25+
sync.Mutex
26+
conns map[string][]*pooledConn
27+
}
28+
29+
type UDPConnPool struct {
30+
shards [UDPPoolShards]poolShard
31+
closed int32 // atomic
32+
stopOnce sync.Once
33+
stopCh chan struct{}
34+
}
35+
36+
func NewUDPConnPool() *UDPConnPool {
37+
pool := &UDPConnPool{
38+
stopCh: make(chan struct{}),
39+
}
40+
for i := range pool.shards {
41+
pool.shards[i].conns = make(map[string][]*pooledConn)
42+
}
43+
go pool.cleanupLoop()
44+
return pool
45+
}
46+
47+
func (p *UDPConnPool) getShard(addr string) *poolShard {
48+
h := uint32(0)
49+
for i := 0; i < len(addr); i++ {
50+
h = h*31 + uint32(addr[i])
51+
}
52+
return &p.shards[h%UDPPoolShards]
53+
}
54+
55+
func (p *UDPConnPool) cleanupLoop() {
56+
ticker := time.NewTicker(UDPPoolCleanupInterval)
57+
defer ticker.Stop()
58+
59+
for {
60+
select {
61+
case <-ticker.C:
62+
p.cleanupStale()
63+
case <-p.stopCh:
64+
return
65+
}
66+
}
67+
}
68+
69+
func (p *UDPConnPool) cleanupStale() {
70+
now := time.Now()
71+
for i := range p.shards {
72+
shard := &p.shards[i]
73+
shard.Lock()
74+
for addr, conns := range shard.conns {
75+
var active []*pooledConn
76+
for _, pc := range conns {
77+
if now.Sub(pc.lastUsed) > UDPPoolMaxIdleTime {
78+
pc.conn.Close()
79+
dlog.Debugf("UDP pool: closed stale connection to %s", addr)
80+
} else {
81+
active = append(active, pc)
82+
}
83+
}
84+
if len(active) == 0 {
85+
delete(shard.conns, addr)
86+
} else {
87+
shard.conns[addr] = active
88+
}
89+
}
90+
shard.Unlock()
91+
}
92+
}
93+
94+
func (p *UDPConnPool) Get(addr *net.UDPAddr) (*net.UDPConn, error) {
95+
addrStr := addr.String()
96+
shard := p.getShard(addrStr)
97+
98+
shard.Lock()
99+
conns := shard.conns[addrStr]
100+
if len(conns) > 0 {
101+
pc := conns[len(conns)-1]
102+
shard.conns[addrStr] = conns[:len(conns)-1]
103+
shard.Unlock()
104+
pc.conn.SetReadDeadline(time.Time{})
105+
pc.conn.SetWriteDeadline(time.Time{})
106+
return pc.conn, nil
107+
}
108+
shard.Unlock()
109+
110+
return net.DialUDP("udp", nil, addr)
111+
}
112+
113+
func (p *UDPConnPool) Put(addr *net.UDPAddr, conn *net.UDPConn) {
114+
if conn == nil {
115+
return
116+
}
117+
if atomic.LoadInt32(&p.closed) != 0 {
118+
conn.Close()
119+
return
120+
}
121+
122+
addrStr := addr.String()
123+
shard := p.getShard(addrStr)
124+
125+
shard.Lock()
126+
conns := shard.conns[addrStr]
127+
if len(conns) >= UDPPoolMaxConnsPerAddr {
128+
shard.Unlock()
129+
conn.Close()
130+
return
131+
}
132+
shard.conns[addrStr] = append(conns, &pooledConn{
133+
conn: conn,
134+
lastUsed: time.Now(),
135+
})
136+
shard.Unlock()
137+
}
138+
139+
func (p *UDPConnPool) Discard(conn *net.UDPConn) {
140+
if conn != nil {
141+
conn.Close()
142+
}
143+
}
144+
145+
func (p *UDPConnPool) Close() {
146+
p.stopOnce.Do(func() {
147+
close(p.stopCh)
148+
})
149+
atomic.StoreInt32(&p.closed, 1)
150+
151+
for i := range p.shards {
152+
shard := &p.shards[i]
153+
shard.Lock()
154+
for addr, conns := range shard.conns {
155+
for _, pc := range conns {
156+
pc.conn.Close()
157+
}
158+
delete(shard.conns, addr)
159+
}
160+
shard.Unlock()
161+
}
162+
dlog.Debug("UDP connection pool closed")
163+
}
164+
165+
func (p *UDPConnPool) Stats() (totalConns int, addrCount int) {
166+
for i := range p.shards {
167+
shard := &p.shards[i]
168+
shard.Lock()
169+
addrCount += len(shard.conns)
170+
for _, conns := range shard.conns {
171+
totalConns += len(conns)
172+
}
173+
shard.Unlock()
174+
}
175+
return
176+
}

0 commit comments

Comments
 (0)