mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-03-03 17:26:57 +01:00
wtclientrpc: prevent watchtower connections to local addresses
Adds validation to prevent watchtower client connections to local addresses by: - Implementing IsLocalAddress() to detect localhost/local network addresses - Adding check in AddTower RPC to reject local tower connections - Including comprehensive unit tests for address validation This helps prevent security issues from misconfigured watchtower setups that accidentally expose local addresses.
This commit is contained in:
parent
acbb33bb7b
commit
fa75071fee
3 changed files with 242 additions and 0 deletions
102
lncfg/address.go
102
lncfg/address.go
|
@ -391,3 +391,105 @@ func ClientAddressDialer(defaultPort string) func(context.Context,
|
|||
)
|
||||
}
|
||||
}
|
||||
|
||||
// IsLocalAddress determines whether the given network address refers to a local
|
||||
// network interface or loopback address. It performs comprehensive checks for
|
||||
// various forms of local addressing including:
|
||||
// - Unix domain sockets
|
||||
// - Common localhost hostnames (localhost, ::1, 127.0.0.1)
|
||||
// - Loopback IP addresses
|
||||
// - IP addresses belonging to any local network interface
|
||||
//
|
||||
// The function handles both direct IP addresses and hostnames that require
|
||||
// resolution. For hostnames, it attempts DNS resolution and checks the resulting
|
||||
// IP. If the address cannot be parsed or resolved, it returns false.
|
||||
//
|
||||
// NOTE: This function performs network interface enumeration which may be
|
||||
// relatively expensive in terms of system calls. Cache results if calling
|
||||
// frequently.
|
||||
//
|
||||
// Parameters:
|
||||
// - addr: The network address to check. Can be nil, in which case false is
|
||||
// returned.
|
||||
//
|
||||
// Returns:
|
||||
// - true if the address is determined to be local
|
||||
// - false if the address is non-local or if any error occurs during checking
|
||||
func IsLocalAddress(addr net.Addr) bool {
|
||||
// Handle nil input
|
||||
if addr == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Get the string representation of the address
|
||||
addrStr := addr.String()
|
||||
|
||||
// Check for Unix domain sockets which are always local
|
||||
if addr.Network() == "unix" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Parse the host from the address string
|
||||
if hostPart, _, err := net.SplitHostPort(addrStr); err == nil {
|
||||
// SplitHostPort worked, use the host part
|
||||
addrStr = hostPart
|
||||
}
|
||||
|
||||
// Check for common localhost names
|
||||
if addrStr == "localhost" || addrStr == "::1" || addrStr == "127.0.0.1" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Try to resolve the IP address
|
||||
ip := net.ParseIP(addrStr)
|
||||
if ip == nil {
|
||||
// If we can't parse it as an IP, try to resolve it
|
||||
ips, err := net.LookupIP(addrStr)
|
||||
if err != nil || len(ips) == 0 {
|
||||
return false
|
||||
}
|
||||
ip = ips[0]
|
||||
}
|
||||
|
||||
// Check if it's a loopback address
|
||||
if ip.IsLoopback() {
|
||||
return true
|
||||
}
|
||||
|
||||
// Get all interface addresses
|
||||
interfaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if the IP matches any local interface
|
||||
for _, iface := range interfaces {
|
||||
addrs, err := iface.Addrs()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, ifaceAddr := range addrs {
|
||||
// Get the network and IP from the interface address
|
||||
var ipNet *net.IPNet
|
||||
switch v := ifaceAddr.(type) {
|
||||
case *net.IPNet:
|
||||
ipNet = v
|
||||
case *net.IPAddr:
|
||||
ipNet = &net.IPNet{
|
||||
IP: v.IP,
|
||||
Mask: v.IP.DefaultMask(),
|
||||
}
|
||||
default:
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if the IP is in the network range
|
||||
if ipNet.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package lncfg
|
|||
import (
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
|
@ -316,3 +317,139 @@ func TestIsPrivate(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
// mockNetAddr implements the net.Addr interface for testing.
|
||||
type mockNetAddr struct {
|
||||
network string
|
||||
addr string
|
||||
}
|
||||
|
||||
func (m mockNetAddr) Network() string { return m.network }
|
||||
func (m mockNetAddr) String() string { return m.addr }
|
||||
|
||||
// TestIsLocalAddress tests the IsLocalAddress function with various
|
||||
// types of addresses to ensure it correctly identifies local addresses.
|
||||
func TestIsLocalAddress(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
addr net.Addr
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "nil address",
|
||||
addr: nil,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "unix socket",
|
||||
addr: &mockNetAddr{
|
||||
network: "unix",
|
||||
addr: "/tmp/test.sock",
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "localhost",
|
||||
addr: &mockNetAddr{
|
||||
network: "tcp",
|
||||
addr: "localhost:1234",
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "ipv4 loopback",
|
||||
addr: &mockNetAddr{
|
||||
network: "tcp",
|
||||
addr: "127.0.0.1:1234",
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "ipv6 loopback",
|
||||
addr: &mockNetAddr{
|
||||
network: "tcp",
|
||||
addr: "[::1]:1234",
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "public ipv4",
|
||||
addr: &mockNetAddr{
|
||||
network: "tcp",
|
||||
addr: "8.8.8.8:1234",
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "public ipv6",
|
||||
addr: &mockNetAddr{
|
||||
network: "tcp",
|
||||
addr: "[2001:4860:4860::8888]:1234",
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "invalid address",
|
||||
addr: &mockNetAddr{
|
||||
network: "tcp",
|
||||
addr: "not-a-real-address:1234",
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt // Capture range variable
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := IsLocalAddress(tt.addr)
|
||||
if got != tt.want {
|
||||
t.Errorf("IsLocalAddress() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsLocalAddressWithRealInterfaces tests the IsLocalAddress function
|
||||
// against actual network interfaces on the system.
|
||||
func TestIsLocalAddressWithRealInterfaces(t *testing.T) {
|
||||
// Get all network interfaces
|
||||
interfaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get network interfaces: %v", err)
|
||||
}
|
||||
|
||||
for _, iface := range interfaces {
|
||||
addrs, err := iface.Addrs()
|
||||
if err != nil {
|
||||
t.Logf("Skipping interface %s due to error: %v", iface.Name, err)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, addr := range addrs {
|
||||
// Extract IP from addr
|
||||
var ip net.IP
|
||||
switch v := addr.(type) {
|
||||
case *net.IPNet:
|
||||
ip = v.IP
|
||||
case *net.IPAddr:
|
||||
ip = v.IP
|
||||
default:
|
||||
continue
|
||||
}
|
||||
|
||||
// Create a mock address with this IP
|
||||
mockAddr := &mockNetAddr{
|
||||
network: "tcp",
|
||||
addr: ip.String(),
|
||||
}
|
||||
|
||||
// Test the address
|
||||
t.Run(fmt.Sprintf("interface_%s_addr_%s", iface.Name, ip), func(t *testing.T) {
|
||||
isLocal := IsLocalAddress(mockAddr)
|
||||
t.Logf("Address %v on interface %s is local: %v", ip, iface.Name, isLocal)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -212,6 +212,9 @@ func (c *WatchtowerClient) AddTower(ctx context.Context,
|
|||
err)
|
||||
}
|
||||
|
||||
if lncfg.IsLocalAddress(addr) {
|
||||
return nil, fmt.Errorf("cannot connect to local network address %v", addr)
|
||||
}
|
||||
towerAddr := &lnwire.NetAddress{
|
||||
IdentityKey: pubKey,
|
||||
Address: addr,
|
||||
|
|
Loading…
Add table
Reference in a new issue