Skip to content

Commit 0308cf7

Browse files
authored
Merge pull request #577 from AkihiroSuda/fix-576
port/builtin: support source IP propagation for UDP via IP_TRANSPARENT
2 parents 67c0a39 + 0475df5 commit 0308cf7

File tree

6 files changed

+138
-10
lines changed

6 files changed

+138
-10
lines changed

pkg/port/builtin/builtin_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,5 @@ func TestBuiltIn(t *testing.T) {
3030
}
3131
testsuite.Run(t, pf)
3232
testsuite.RunTCPTransparent(t, pf)
33+
testsuite.RunUDPTransparent(t, pf)
3334
}

pkg/port/builtin/child/child.go

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ func (d *childDriver) handleConnectRequest(c *net.UnixConn, req *msg.Request) er
147147

148148
var targetConn net.Conn
149149
var err error
150-
if d.sourceIPTransparent && req.SourceIP != "" && req.SourcePort != 0 && dialProto == "tcp" && !net.ParseIP(req.SourceIP).IsLoopback() {
150+
if d.sourceIPTransparent && req.SourceIP != "" && req.SourcePort != 0 && (dialProto == "tcp" || dialProto == "udp") && !net.ParseIP(req.SourceIP).IsLoopback() {
151151
d.routingSetup.Do(func() { d.routingReady = d.setupTransparentRouting() })
152152
if !d.routingReady {
153153
d.routingWarn.Do(func() {
@@ -251,9 +251,16 @@ func (d *childDriver) setupTransparentRouting() bool {
251251
// transparentDial dials targetAddr using IP_TRANSPARENT, binding to the given
252252
// source IP and port so the backend service sees the real client address.
253253
func transparentDial(dialProto, targetAddr, sourceIP string, sourcePort int) (net.Conn, error) {
254+
var localAddr net.Addr
255+
switch dialProto {
256+
case "tcp":
257+
localAddr = &net.TCPAddr{IP: net.ParseIP(sourceIP), Port: sourcePort}
258+
case "udp":
259+
localAddr = &net.UDPAddr{IP: net.ParseIP(sourceIP), Port: sourcePort}
260+
}
254261
dialer := net.Dialer{
255262
Timeout: time.Second,
256-
LocalAddr: &net.TCPAddr{IP: net.ParseIP(sourceIP), Port: sourcePort},
263+
LocalAddr: localAddr,
257264
Control: func(network, address string, c syscall.RawConn) error {
258265
var sockErr error
259266
if err := c.Control(func(fd uintptr) {

pkg/port/builtin/msg/msg.go

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,17 @@ func ConnectToChild(c *net.UnixConn, spec port.Spec, sourceAddr net.Addr) (int,
8282
ParentIP: spec.ParentIP,
8383
HostGatewayIP: hostGatewayIP(),
8484
}
85-
if tcpAddr, ok := sourceAddr.(*net.TCPAddr); ok && tcpAddr != nil {
86-
req.SourceIP = tcpAddr.IP.String()
87-
req.SourcePort = tcpAddr.Port
85+
switch a := sourceAddr.(type) {
86+
case *net.TCPAddr:
87+
if a != nil {
88+
req.SourceIP = a.IP.String()
89+
req.SourcePort = a.Port
90+
}
91+
case *net.UDPAddr:
92+
if a != nil {
93+
req.SourceIP = a.IP.String()
94+
req.SourcePort = a.Port
95+
}
8896
}
8997
if _, err := lowlevelmsgutil.MarshalToWriter(c, &req); err != nil {
9098
return 0, err

pkg/port/builtin/parent/udp/udp.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@ func Run(socketPath string, spec port.Spec, stopCh <-chan struct{}, stoppedCh ch
2424
udpp := &udpproxy.UDPProxy{
2525
LogWriter: logWriter,
2626
Listener: c,
27-
BackendDial: func() (*net.UDPConn, error) {
27+
BackendDial: func(from *net.UDPAddr) (*net.UDPConn, error) {
2828
// get fd from the child as an SCM_RIGHTS cmsg
29-
fd, err := msg.ConnectToChildWithRetry(socketPath, spec, 10, nil)
29+
fd, err := msg.ConnectToChildWithRetry(socketPath, spec, 10, from)
3030
if err != nil {
3131
return nil, err
3232
}

pkg/port/builtin/parent/udp/udpproxy/udp_proxy.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ type connTrackMap map[connTrackKey]*net.UDPConn
4949
type UDPProxy struct {
5050
LogWriter io.Writer
5151
Listener *net.UDPConn
52-
BackendDial func() (*net.UDPConn, error)
52+
BackendDial func(from *net.UDPAddr) (*net.UDPConn, error)
5353
connTrackTable connTrackMap
5454
connTrackLock sync.Mutex
5555
}
@@ -108,7 +108,7 @@ func (proxy *UDPProxy) Run() {
108108
proxy.connTrackLock.Lock()
109109
proxyConn, hit := proxy.connTrackTable[*fromKey]
110110
if !hit {
111-
proxyConn, err = proxy.BackendDial()
111+
proxyConn, err = proxy.BackendDial(from)
112112
if err != nil {
113113
fmt.Fprintf(proxy.LogWriter, "Can't proxy a datagram to udp: %v\n", err)
114114
proxy.connTrackLock.Unlock()

pkg/port/testsuite/testsuite.go

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ func Main(m *testing.M, cf func() port.ChildDriver) {
3636
case "echoserver":
3737
runEchoServer()
3838
os.Exit(0)
39+
case "udpechoserver":
40+
runUDPEchoServer()
41+
os.Exit(0)
3942
default:
4043
panic(fmt.Errorf("unknown mode: %q", mode))
4144
}
@@ -504,14 +507,32 @@ func transparentTCPDialAndSend(t *testing.T, parentAddr string) string {
504507
return clientAddr
505508
}
506509

510+
func transparentUDPDialAndSend(t *testing.T, parentAddr string) string {
511+
conn, err := net.Dial("udp", parentAddr)
512+
if err != nil {
513+
t.Fatal(err)
514+
}
515+
clientAddr := conn.LocalAddr().String()
516+
if _, err := conn.Write([]byte("hello")); err != nil {
517+
t.Fatal(err)
518+
}
519+
conn.Close()
520+
return clientAddr
521+
}
522+
507523
func testTransparentWithPID(t *testing.T, proto string, d port.ParentDriver, childPID int) {
508524
ensureDeps(t, "nsenter")
509525
const childPort = 80
510526

511527
var dialAndSend transparentDialAndSend
528+
var echoMode string
512529
switch proto {
513530
case "tcp":
514531
dialAndSend = transparentTCPDialAndSend
532+
echoMode = "echoserver"
533+
case "udp":
534+
dialAndSend = transparentUDPDialAndSend
535+
echoMode = "udpechoserver"
515536
default:
516537
t.Fatalf("unsupported proto for transparent test: %s", proto)
517538
}
@@ -560,7 +581,7 @@ func testTransparentWithPID(t *testing.T, proto string, d port.ParentDriver, chi
560581
"-t", strconv.Itoa(childPID),
561582
exe)
562583
echoCmd.Env = append([]string{
563-
reexecKeyMode + "=echoserver",
584+
reexecKeyMode + "=" + echoMode,
564585
reexecKeyEchoPort + "=" + strconv.Itoa(childPort),
565586
}, os.Environ()...)
566587
echoCmd.Stdout = stdoutW
@@ -611,6 +632,11 @@ func testTransparentWithPID(t *testing.T, proto string, d port.ParentDriver, chi
611632
}
612633
t.Logf("opened port: %+v", portStatus)
613634

635+
if proto == "udp" {
636+
// UDP dial does not return an error even if the proxy is not ready yet
637+
time.Sleep(500 * time.Millisecond)
638+
}
639+
614640
// Dial and send data
615641
parentAddr := net.JoinHostPort(parentIP, strconv.Itoa(parentPort))
616642
clientAddr := dialAndSend(t, parentAddr)
@@ -650,3 +676,89 @@ func testTransparentWithPID(t *testing.T, proto string, d port.ParentDriver, chi
650676
t.Fatal(err)
651677
}
652678
}
679+
680+
// runUDPEchoServer is a re-exec mode that runs a minimal UDP server.
681+
// It listens on 127.0.0.1:<port>, signals readiness by closing fd 3,
682+
// receives one datagram, writes the remote address to stdout, and echoes the data back.
683+
func runUDPEchoServer() {
684+
portStr := os.Getenv(reexecKeyEchoPort)
685+
if portStr == "" {
686+
panic("udpechoserver: missing " + reexecKeyEchoPort)
687+
}
688+
addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:"+portStr)
689+
if err != nil {
690+
panic(fmt.Errorf("udpechoserver: resolve: %w", err))
691+
}
692+
conn, err := net.ListenUDP("udp", addr)
693+
if err != nil {
694+
panic(fmt.Errorf("udpechoserver: listen: %w", err))
695+
}
696+
defer conn.Close()
697+
// Signal readiness by closing fd 3
698+
readyW := os.NewFile(3, "ready")
699+
readyW.Close()
700+
701+
buf := make([]byte, 65507)
702+
n, from, err := conn.ReadFromUDP(buf)
703+
if err != nil {
704+
panic(fmt.Errorf("udpechoserver: read: %w", err))
705+
}
706+
fmt.Fprintln(os.Stdout, from.String())
707+
conn.WriteToUDP(buf[:n], from)
708+
}
709+
710+
func RunUDPTransparent(t *testing.T, pf func() port.ParentDriver) {
711+
t.Run("TestUDPTransparent", func(t *testing.T) { TestUDPTransparent(t, pf()) })
712+
}
713+
714+
func TestUDPTransparent(t *testing.T, d port.ParentDriver) {
715+
ensureDeps(t, "nsenter")
716+
t.Logf("creating USER+NET namespace")
717+
opaque := d.OpaqueForChild()
718+
opaqueJSON, err := json.Marshal(opaque)
719+
if err != nil {
720+
t.Fatal(err)
721+
}
722+
pr, pw, err := os.Pipe()
723+
if err != nil {
724+
t.Fatal(err)
725+
}
726+
cmd := exec.Command("/proc/self/exe")
727+
cmd.Stdout = os.Stderr
728+
cmd.Stderr = os.Stderr
729+
cmd.Env = append([]string{
730+
reexecKeyMode + "=child",
731+
reexecKeyOpaque + "=" + string(opaqueJSON),
732+
reexecKeyQuitFD + "=3"}, os.Environ()...)
733+
cmd.SysProcAttr = &syscall.SysProcAttr{
734+
Pdeathsig: syscall.SIGKILL,
735+
Cloneflags: syscall.CLONE_NEWUSER | syscall.CLONE_NEWNET,
736+
UidMappings: []syscall.SysProcIDMap{
737+
{
738+
ContainerID: 0,
739+
HostID: os.Geteuid(),
740+
Size: 1,
741+
},
742+
},
743+
GidMappings: []syscall.SysProcIDMap{
744+
{
745+
ContainerID: 0,
746+
HostID: os.Getegid(),
747+
Size: 1,
748+
},
749+
},
750+
}
751+
cmd.ExtraFiles = []*os.File{pr}
752+
if err := cmd.Start(); err != nil {
753+
t.Fatal(err)
754+
}
755+
defer func() {
756+
pw.Close()
757+
cmd.Wait()
758+
}()
759+
childPID := cmd.Process.Pid
760+
if out, err := nsenterExec(childPID, "ip", "link", "set", "lo", "up"); err != nil {
761+
t.Fatalf("%v, out=%s", err, string(out))
762+
}
763+
testTransparentWithPID(t, "udp", d, childPID)
764+
}

0 commit comments

Comments
 (0)