mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-03-03 17:26:57 +01:00
wtclientrpc: prevent watchtower connections to local addresses
Adds validation to prevent watchtower client connections to local addresses by: - Implementing IsLocalAddress() to detect localhost/local network addresses - Adding check in AddTower RPC to reject local tower connections - Including comprehensive unit tests for address validation This helps prevent security issues from misconfigured watchtower setups that accidentally expose local addresses.
This commit is contained in:
parent
acbb33bb7b
commit
fa75071fee
3 changed files with 242 additions and 0 deletions
102
lncfg/address.go
102
lncfg/address.go
|
@ -391,3 +391,105 @@ func ClientAddressDialer(defaultPort string) func(context.Context,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsLocalAddress determines whether the given network address refers to a local
|
||||||
|
// network interface or loopback address. It performs comprehensive checks for
|
||||||
|
// various forms of local addressing including:
|
||||||
|
// - Unix domain sockets
|
||||||
|
// - Common localhost hostnames (localhost, ::1, 127.0.0.1)
|
||||||
|
// - Loopback IP addresses
|
||||||
|
// - IP addresses belonging to any local network interface
|
||||||
|
//
|
||||||
|
// The function handles both direct IP addresses and hostnames that require
|
||||||
|
// resolution. For hostnames, it attempts DNS resolution and checks the resulting
|
||||||
|
// IP. If the address cannot be parsed or resolved, it returns false.
|
||||||
|
//
|
||||||
|
// NOTE: This function performs network interface enumeration which may be
|
||||||
|
// relatively expensive in terms of system calls. Cache results if calling
|
||||||
|
// frequently.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - addr: The network address to check. Can be nil, in which case false is
|
||||||
|
// returned.
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - true if the address is determined to be local
|
||||||
|
// - false if the address is non-local or if any error occurs during checking
|
||||||
|
func IsLocalAddress(addr net.Addr) bool {
|
||||||
|
// Handle nil input
|
||||||
|
if addr == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the string representation of the address
|
||||||
|
addrStr := addr.String()
|
||||||
|
|
||||||
|
// Check for Unix domain sockets which are always local
|
||||||
|
if addr.Network() == "unix" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the host from the address string
|
||||||
|
if hostPart, _, err := net.SplitHostPort(addrStr); err == nil {
|
||||||
|
// SplitHostPort worked, use the host part
|
||||||
|
addrStr = hostPart
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for common localhost names
|
||||||
|
if addrStr == "localhost" || addrStr == "::1" || addrStr == "127.0.0.1" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to resolve the IP address
|
||||||
|
ip := net.ParseIP(addrStr)
|
||||||
|
if ip == nil {
|
||||||
|
// If we can't parse it as an IP, try to resolve it
|
||||||
|
ips, err := net.LookupIP(addrStr)
|
||||||
|
if err != nil || len(ips) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
ip = ips[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if it's a loopback address
|
||||||
|
if ip.IsLoopback() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get all interface addresses
|
||||||
|
interfaces, err := net.Interfaces()
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the IP matches any local interface
|
||||||
|
for _, iface := range interfaces {
|
||||||
|
addrs, err := iface.Addrs()
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ifaceAddr := range addrs {
|
||||||
|
// Get the network and IP from the interface address
|
||||||
|
var ipNet *net.IPNet
|
||||||
|
switch v := ifaceAddr.(type) {
|
||||||
|
case *net.IPNet:
|
||||||
|
ipNet = v
|
||||||
|
case *net.IPAddr:
|
||||||
|
ipNet = &net.IPNet{
|
||||||
|
IP: v.IP,
|
||||||
|
Mask: v.IP.DefaultMask(),
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the IP is in the network range
|
||||||
|
if ipNet.Contains(ip) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
|
@ -3,6 +3,7 @@ package lncfg
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
@ -316,3 +317,139 @@ func TestIsPrivate(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// mockNetAddr implements the net.Addr interface for testing.
|
||||||
|
type mockNetAddr struct {
|
||||||
|
network string
|
||||||
|
addr string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m mockNetAddr) Network() string { return m.network }
|
||||||
|
func (m mockNetAddr) String() string { return m.addr }
|
||||||
|
|
||||||
|
// TestIsLocalAddress tests the IsLocalAddress function with various
|
||||||
|
// types of addresses to ensure it correctly identifies local addresses.
|
||||||
|
func TestIsLocalAddress(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
addr net.Addr
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nil address",
|
||||||
|
addr: nil,
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unix socket",
|
||||||
|
addr: &mockNetAddr{
|
||||||
|
network: "unix",
|
||||||
|
addr: "/tmp/test.sock",
|
||||||
|
},
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "localhost",
|
||||||
|
addr: &mockNetAddr{
|
||||||
|
network: "tcp",
|
||||||
|
addr: "localhost:1234",
|
||||||
|
},
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ipv4 loopback",
|
||||||
|
addr: &mockNetAddr{
|
||||||
|
network: "tcp",
|
||||||
|
addr: "127.0.0.1:1234",
|
||||||
|
},
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ipv6 loopback",
|
||||||
|
addr: &mockNetAddr{
|
||||||
|
network: "tcp",
|
||||||
|
addr: "[::1]:1234",
|
||||||
|
},
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "public ipv4",
|
||||||
|
addr: &mockNetAddr{
|
||||||
|
network: "tcp",
|
||||||
|
addr: "8.8.8.8:1234",
|
||||||
|
},
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "public ipv6",
|
||||||
|
addr: &mockNetAddr{
|
||||||
|
network: "tcp",
|
||||||
|
addr: "[2001:4860:4860::8888]:1234",
|
||||||
|
},
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid address",
|
||||||
|
addr: &mockNetAddr{
|
||||||
|
network: "tcp",
|
||||||
|
addr: "not-a-real-address:1234",
|
||||||
|
},
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
tt := tt // Capture range variable
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
got := IsLocalAddress(tt.addr)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("IsLocalAddress() got = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIsLocalAddressWithRealInterfaces tests the IsLocalAddress function
|
||||||
|
// against actual network interfaces on the system.
|
||||||
|
func TestIsLocalAddressWithRealInterfaces(t *testing.T) {
|
||||||
|
// Get all network interfaces
|
||||||
|
interfaces, err := net.Interfaces()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to get network interfaces: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, iface := range interfaces {
|
||||||
|
addrs, err := iface.Addrs()
|
||||||
|
if err != nil {
|
||||||
|
t.Logf("Skipping interface %s due to error: %v", iface.Name, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, addr := range addrs {
|
||||||
|
// Extract IP from addr
|
||||||
|
var ip net.IP
|
||||||
|
switch v := addr.(type) {
|
||||||
|
case *net.IPNet:
|
||||||
|
ip = v.IP
|
||||||
|
case *net.IPAddr:
|
||||||
|
ip = v.IP
|
||||||
|
default:
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a mock address with this IP
|
||||||
|
mockAddr := &mockNetAddr{
|
||||||
|
network: "tcp",
|
||||||
|
addr: ip.String(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test the address
|
||||||
|
t.Run(fmt.Sprintf("interface_%s_addr_%s", iface.Name, ip), func(t *testing.T) {
|
||||||
|
isLocal := IsLocalAddress(mockAddr)
|
||||||
|
t.Logf("Address %v on interface %s is local: %v", ip, iface.Name, isLocal)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -212,6 +212,9 @@ func (c *WatchtowerClient) AddTower(ctx context.Context,
|
||||||
err)
|
err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if lncfg.IsLocalAddress(addr) {
|
||||||
|
return nil, fmt.Errorf("cannot connect to local network address %v", addr)
|
||||||
|
}
|
||||||
towerAddr := &lnwire.NetAddress{
|
towerAddr := &lnwire.NetAddress{
|
||||||
IdentityKey: pubKey,
|
IdentityKey: pubKey,
|
||||||
Address: addr,
|
Address: addr,
|
||||||
|
|
Loading…
Add table
Reference in a new issue