diff --git a/btcutil/address.go b/btcutil/address.go index 95d3e6c3..eca7fa0c 100644 --- a/btcutil/address.go +++ b/btcutil/address.go @@ -52,6 +52,10 @@ var ( // than assuming or defaulting to one or the other, this error is // returned and the caller must decide how to decode the address. ErrAddressCollision = errors.New("address collision") + + // ErrIncorrect describes an error where an address could be decoded, but + // it was not for the expected network. + ErrIncorrectNet = errors.New("address is for incorrect net") ) // encodeAddress returns a human-readable payment address given a ripemd160 hash @@ -223,6 +227,26 @@ func DecodeAddress(addr string, defaultNet *chaincfg.Params) (Address, error) { } } +// DecodeAddressForNet decodes the string encoding of an address and returns +// the Address if addr is a valid encoding for a known address type and the +// given network. +// +// This method differs from DecodeAddress in that DecodeAddress tolerates +// differing networks, in case the network parameter isn't needed to decode +// the address. +func DecodeAddressForNet(addr string, net *chaincfg.Params) (Address, error) { + decoded, err := DecodeAddress(addr, net) + if err != nil { + return nil, err + } + + if !decoded.IsForNet(net) { + return nil, ErrIncorrectNet + } + + return decoded, nil +} + // decodeSegWitAddress parses a bech32 encoded segwit address string and // returns the witness version and witness program byte representation. func decodeSegWitAddress(address string) (byte, []byte, error) { diff --git a/btcutil/address_test.go b/btcutil/address_test.go index f5ae2ac0..89ae8982 100644 --- a/btcutil/address_test.go +++ b/btcutil/address_test.go @@ -879,6 +879,67 @@ func TestAddresses(t *testing.T) { continue } + // Ensure addresses are only valid for their given network + // The cases where test.addr and test.encoded are different are + // pay-to-pubkey, and these are valid for all nets. + if test.valid && test.addr == test.encoded { + _, err := btcutil.DecodeAddressForNet(test.addr, test.net) + if err != nil { + t.Errorf("%s: invalid for expected net: %s", test.name, err) + } + + nets := []chaincfg.Params{ + chaincfg.MainNetParams, + chaincfg.TestNet3Params, + chaincfg.RegressionNetParams, + chaincfg.SimNetParams, + chaincfg.SigNetParams, + customParams, + } + + // verify we can't decode for other nets + for _, net := range nets { + if net.Net == test.net.Net { + continue + } + + decoded, err := btcutil.DecodeAddressForNet(test.addr, &net) + if err != nil { + continue + } + // signet bech32 addresses have the same HRP prefix as + // testnet. Skip those. We verify this with an interface + // check instead of concrete types, as there's multiple + // concrete types that implement Segwit addresses. + type bech32 interface { + Hrp() string + } + if _, ok := decoded.(bech32); ok && + net.Net == chaincfg.SigNetParams.Net { + continue + } + + // testnet3, signet and regtest shares the same pubkey hash + // prefixes. Skip those. + if _, ok := decoded.(*btcutil.AddressPubKeyHash); ok && + (net.Net == chaincfg.SigNetParams.Net || + net.Net == chaincfg.RegressionNetParams.Net) { + continue + } + + // testnet3, signet and regtest shares the same script hash + // prefixes. Skip those. + if _, ok := decoded.(*btcutil.AddressScriptHash); ok && + (net.Net == chaincfg.SigNetParams.Net || + net.Net == chaincfg.RegressionNetParams.Net) { + continue + } + + t.Errorf("%s: was able to decode address %s for incorrect net %s: 0x%02x", + test.name, decoded, net.Name, int(net.Net)) + } + } + // Valid test, compare address created with f against expected result. addr, err := test.f() if err != nil {