From f8ac6b340569b8b9e2102a4525d6983054d983cd Mon Sep 17 00:00:00 2001 From: Sean Gilligan Date: Fri, 3 Mar 2023 21:11:09 -0800 Subject: [PATCH] SegwitAddress: store `witnessVersion` internally as a short Separate witnessVersion from witnessProgram in private members. Note that the witnessProgram is still stored in a 5-bit per byte encoding as it was in the combined byte[]. A future PR will change it to 8-bit per byte encoding. --- .../java/org/bitcoinj/base/SegwitAddress.java | 91 ++++++++++--------- 1 file changed, 46 insertions(+), 45 deletions(-) diff --git a/core/src/main/java/org/bitcoinj/base/SegwitAddress.java b/core/src/main/java/org/bitcoinj/base/SegwitAddress.java index ad65906fa..2eec2510e 100644 --- a/core/src/main/java/org/bitcoinj/base/SegwitAddress.java +++ b/core/src/main/java/org/bitcoinj/base/SegwitAddress.java @@ -122,23 +122,8 @@ public class SegwitAddress implements Address { } protected final Network network; - protected final byte[] bytes; - - /** - * Private constructor. Use {@link #fromBech32(Network, String)}, - * {@link #fromHash(Network, byte[])} or {@link ECKey#toAddress(ScriptType, Network)}. - * - * @param network - * network this address is valid for - * @param witnessVersion - * version number between 0 and 16 - * @param witnessProgram - * hash of pubkey, pubkey or script (depending on version) - */ - private SegwitAddress(Network network, int witnessVersion, byte[] witnessProgram) - throws AddressFormatException { - this(network, encode(witnessVersion, witnessProgram)); - } + protected final short witnessVersion; + protected final byte[] witnessProgram; // Currently this is in 5-bit Bech32 form private static Network normalizeNetwork(Network network) { // SegwitAddress does not distinguish between the SIGNET and TESTNET, normalize to TESTNET @@ -151,15 +136,12 @@ public class SegwitAddress implements Address { return network; } - /** - * Helper for the above constructor. - */ - private static byte[] encode(int witnessVersion, byte[] witnessProgram) throws AddressFormatException { - byte[] convertedProgram = convertBits(witnessProgram, 0, witnessProgram.length, 8, 5, true); - byte[] bytes = new byte[1 + convertedProgram.length]; - bytes[0] = (byte) (witnessVersion & 0xff); - System.arraycopy(convertedProgram, 0, bytes, 1, convertedProgram.length); - return bytes; + private static byte[] encode8to5(byte[] data) { + return convertBits(data, 0, data.length, 8, 5, true); + } + + private static byte[] decode5to8(byte[] data) { + return convertBits(data, 0, data.length, 5, 8, false); } /** @@ -168,20 +150,16 @@ public class SegwitAddress implements Address { * * @param network * network this address is valid for - * @param data - * in segwit address format, before bit re-arranging and bech32 encoding + * @param witnessVersion + * version number between 0 and 16 + * @param witnessProgram + * hash of pubkey, pubkey or script (depending on version) (8-bits per byte) * @throws AddressFormatException * if any of the sanity checks fail */ - private SegwitAddress(Network network, byte[] data) throws AddressFormatException { - this.network = normalizeNetwork(checkNotNull(network)); - this.bytes = checkNotNull(data); - if (data.length < 1) - throw new AddressFormatException.InvalidDataLength("Zero data found"); - final int witnessVersion = getWitnessVersion(); + private SegwitAddress(Network network, int witnessVersion, byte[] witnessProgram) throws AddressFormatException { if (witnessVersion < 0 || witnessVersion > 16) throw new AddressFormatException("Invalid script version: " + witnessVersion); - byte[] witnessProgram = getWitnessProgram(); if (witnessProgram.length < WITNESS_PROGRAM_MIN_LENGTH || witnessProgram.length > WITNESS_PROGRAM_MAX_LENGTH) throw new AddressFormatException.InvalidDataLength("Invalid length: " + witnessProgram.length); // Check script length for version 0: @@ -206,6 +184,9 @@ public class SegwitAddress implements Address { if (witnessVersion == 1 && witnessProgram.length != WITNESS_PROGRAM_LENGTH_TR) throw new AddressFormatException.InvalidDataLength( "Invalid length for address version 1: " + witnessProgram.length); + this.network = normalizeNetwork(checkNotNull(network)); + this.witnessVersion = (short) witnessVersion; + this.witnessProgram = encode8to5(checkNotNull(witnessProgram)); } /** @@ -214,7 +195,7 @@ public class SegwitAddress implements Address { * @return witness version, between 0 and 16 */ public int getWitnessVersion() { - return bytes[0] & 0xff; + return witnessVersion; } /** @@ -223,8 +204,8 @@ public class SegwitAddress implements Address { * @return witness program */ public byte[] getWitnessProgram() { - // skip version byte - return convertBits(bytes, 1, bytes.length - 1, 5, 8, false); + // no version byte + return decode5to8(witnessProgram); } @Override @@ -301,8 +282,12 @@ public class SegwitAddress implements Address { } private static SegwitAddress fromBechData(Network network, Bech32.Bech32Data bechData) { - final SegwitAddress address = new SegwitAddress(network, bechData.data); - final int witnessVersion = address.getWitnessVersion(); + if (bechData.data.length < 1) { + throw new AddressFormatException.InvalidDataLength("invalid address length (0)"); + } + final int witnessVersion = bechData.data[0]; + byte[] witnessProgram = decode5to8(trimVersion(bechData.data)); + final SegwitAddress address = new SegwitAddress(network, witnessVersion, witnessProgram); if ((witnessVersion == 0 && bechData.encoding != Bech32.Encoding.BECH32) || (witnessVersion != 0 && bechData.encoding != Bech32.Encoding.BECH32M)) throw new AddressFormatException.UnexpectedWitnessVersion("Unexpected witness version: " + witnessVersion); @@ -398,7 +383,7 @@ public class SegwitAddress implements Address { @Override public int hashCode() { - return Objects.hash(network, Arrays.hashCode(bytes)); + return Objects.hash(network, witnessVersion, Arrays.hashCode(witnessProgram)); } @Override @@ -406,7 +391,7 @@ public class SegwitAddress implements Address { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; SegwitAddress other = (SegwitAddress) o; - return this.network == other.network && Arrays.equals(this.bytes, other.bytes); + return this.network == other.network && witnessVersion == other.witnessVersion && Arrays.equals(this.witnessProgram, other.witnessProgram); } /** @@ -416,9 +401,24 @@ public class SegwitAddress implements Address { */ public String toBech32() { if (getWitnessVersion() == 0) - return Bech32.encode(Bech32.Encoding.BECH32, network.segwitAddressHrp(), bytes); + return Bech32.encode(Bech32.Encoding.BECH32, network.segwitAddressHrp(), appendVersion(witnessVersion, witnessProgram)); else - return Bech32.encode(Bech32.Encoding.BECH32M, network.segwitAddressHrp(), bytes); + return Bech32.encode(Bech32.Encoding.BECH32M, network.segwitAddressHrp(), appendVersion(witnessVersion, witnessProgram)); + } + + // Trim the version byte and return the witness program only + private static byte[] trimVersion(byte[] data) { + byte[] program = new byte[data.length - 1]; + System.arraycopy(data, 1, program, 0, program.length); + return program; + } + + // concatenate the witnessVersion and witnessProgram + private static byte[] appendVersion(short version, byte[] program) { + byte[] data = new byte[program.length + 1]; + data[0] = (byte) version; + System.arraycopy(program, 0, data, 1, program.length); + return data; } /** @@ -455,7 +455,8 @@ public class SegwitAddress implements Address { // Comparator for SegwitAddress, left argument must be SegwitAddress, right argument can be any Address private static final Comparator
SEGWIT_ADDRESS_COMPARATOR = Address.PARTIAL_ADDRESS_COMPARATOR - .thenComparing(a -> ((SegwitAddress) a).bytes, ByteUtils.arrayUnsignedComparator()); // Then compare Segwit bytes + .thenComparing(a -> ((SegwitAddress) a).witnessVersion) + .thenComparing(a -> ((SegwitAddress) a).witnessProgram, ByteUtils.arrayUnsignedComparator()); // Then compare Segwit bytes /** * {@inheritDoc}