wtclientrpc: prevent watchtower connections to local addresses

Adds validation to prevent watchtower client connections to local addresses by:
- Implementing IsLocalAddress() to detect localhost/local network addresses
- Adding check in AddTower RPC to reject local tower connections
- Including comprehensive unit tests for address validation

This helps prevent security issues from misconfigured watchtower setups that
accidentally expose local addresses.
This commit is contained in:
Animesh Bilthare 2024-10-28 20:10:23 +05:30
parent acbb33bb7b
commit fa75071fee
No known key found for this signature in database
GPG key ID: 09BB99849ADF6212
3 changed files with 242 additions and 0 deletions

View file

@ -391,3 +391,105 @@ func ClientAddressDialer(defaultPort string) func(context.Context,
) )
} }
} }
// IsLocalAddress determines whether the given network address refers to a local
// network interface or loopback address. It performs comprehensive checks for
// various forms of local addressing including:
// - Unix domain sockets
// - Common localhost hostnames (localhost, ::1, 127.0.0.1)
// - Loopback IP addresses
// - IP addresses belonging to any local network interface
//
// The function handles both direct IP addresses and hostnames that require
// resolution. For hostnames, it attempts DNS resolution and checks the resulting
// IP. If the address cannot be parsed or resolved, it returns false.
//
// NOTE: This function performs network interface enumeration which may be
// relatively expensive in terms of system calls. Cache results if calling
// frequently.
//
// Parameters:
// - addr: The network address to check. Can be nil, in which case false is
// returned.
//
// Returns:
// - true if the address is determined to be local
// - false if the address is non-local or if any error occurs during checking
func IsLocalAddress(addr net.Addr) bool {
// Handle nil input
if addr == nil {
return false
}
// Get the string representation of the address
addrStr := addr.String()
// Check for Unix domain sockets which are always local
if addr.Network() == "unix" {
return true
}
// Parse the host from the address string
if hostPart, _, err := net.SplitHostPort(addrStr); err == nil {
// SplitHostPort worked, use the host part
addrStr = hostPart
}
// Check for common localhost names
if addrStr == "localhost" || addrStr == "::1" || addrStr == "127.0.0.1" {
return true
}
// Try to resolve the IP address
ip := net.ParseIP(addrStr)
if ip == nil {
// If we can't parse it as an IP, try to resolve it
ips, err := net.LookupIP(addrStr)
if err != nil || len(ips) == 0 {
return false
}
ip = ips[0]
}
// Check if it's a loopback address
if ip.IsLoopback() {
return true
}
// Get all interface addresses
interfaces, err := net.Interfaces()
if err != nil {
return false
}
// Check if the IP matches any local interface
for _, iface := range interfaces {
addrs, err := iface.Addrs()
if err != nil {
continue
}
for _, ifaceAddr := range addrs {
// Get the network and IP from the interface address
var ipNet *net.IPNet
switch v := ifaceAddr.(type) {
case *net.IPNet:
ipNet = v
case *net.IPAddr:
ipNet = &net.IPNet{
IP: v.IP,
Mask: v.IP.DefaultMask(),
}
default:
continue
}
// Check if the IP is in the network range
if ipNet.Contains(ip) {
return true
}
}
}
return false
}

View file

@ -3,6 +3,7 @@ package lncfg
import ( import (
"bytes" "bytes"
"encoding/hex" "encoding/hex"
"fmt"
"net" "net"
"testing" "testing"
@ -316,3 +317,139 @@ func TestIsPrivate(t *testing.T) {
}) })
} }
} }
// mockNetAddr implements the net.Addr interface for testing.
type mockNetAddr struct {
network string
addr string
}
func (m mockNetAddr) Network() string { return m.network }
func (m mockNetAddr) String() string { return m.addr }
// TestIsLocalAddress tests the IsLocalAddress function with various
// types of addresses to ensure it correctly identifies local addresses.
func TestIsLocalAddress(t *testing.T) {
tests := []struct {
name string
addr net.Addr
want bool
}{
{
name: "nil address",
addr: nil,
want: false,
},
{
name: "unix socket",
addr: &mockNetAddr{
network: "unix",
addr: "/tmp/test.sock",
},
want: true,
},
{
name: "localhost",
addr: &mockNetAddr{
network: "tcp",
addr: "localhost:1234",
},
want: true,
},
{
name: "ipv4 loopback",
addr: &mockNetAddr{
network: "tcp",
addr: "127.0.0.1:1234",
},
want: true,
},
{
name: "ipv6 loopback",
addr: &mockNetAddr{
network: "tcp",
addr: "[::1]:1234",
},
want: true,
},
{
name: "public ipv4",
addr: &mockNetAddr{
network: "tcp",
addr: "8.8.8.8:1234",
},
want: false,
},
{
name: "public ipv6",
addr: &mockNetAddr{
network: "tcp",
addr: "[2001:4860:4860::8888]:1234",
},
want: false,
},
{
name: "invalid address",
addr: &mockNetAddr{
network: "tcp",
addr: "not-a-real-address:1234",
},
want: false,
},
}
for _, tt := range tests {
tt := tt // Capture range variable
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := IsLocalAddress(tt.addr)
if got != tt.want {
t.Errorf("IsLocalAddress() got = %v, want %v", got, tt.want)
}
})
}
}
// TestIsLocalAddressWithRealInterfaces tests the IsLocalAddress function
// against actual network interfaces on the system.
func TestIsLocalAddressWithRealInterfaces(t *testing.T) {
// Get all network interfaces
interfaces, err := net.Interfaces()
if err != nil {
t.Fatalf("Failed to get network interfaces: %v", err)
}
for _, iface := range interfaces {
addrs, err := iface.Addrs()
if err != nil {
t.Logf("Skipping interface %s due to error: %v", iface.Name, err)
continue
}
for _, addr := range addrs {
// Extract IP from addr
var ip net.IP
switch v := addr.(type) {
case *net.IPNet:
ip = v.IP
case *net.IPAddr:
ip = v.IP
default:
continue
}
// Create a mock address with this IP
mockAddr := &mockNetAddr{
network: "tcp",
addr: ip.String(),
}
// Test the address
t.Run(fmt.Sprintf("interface_%s_addr_%s", iface.Name, ip), func(t *testing.T) {
isLocal := IsLocalAddress(mockAddr)
t.Logf("Address %v on interface %s is local: %v", ip, iface.Name, isLocal)
})
}
}
}

View file

@ -212,6 +212,9 @@ func (c *WatchtowerClient) AddTower(ctx context.Context,
err) err)
} }
if lncfg.IsLocalAddress(addr) {
return nil, fmt.Errorf("cannot connect to local network address %v", addr)
}
towerAddr := &lnwire.NetAddress{ towerAddr := &lnwire.NetAddress{
IdentityKey: pubKey, IdentityKey: pubKey,
Address: addr, Address: addr,