common/utils: macros to help get copy/compare across different types right.

Things are often equivalent but different types:
1. u8 arrays in libwally.
2. sha256
3. Secrets derived via sha256
4. txids

Rather than open-coding a BUILD_ASSERT & memcpy, create a macro to do it.

Signed-off-by: Rusty Russell <rusty@rustcorp.com.au>
This commit is contained in:
Rusty Russell 2024-07-18 10:54:55 +09:30
parent bc5c528910
commit af90fdc0bb
12 changed files with 44 additions and 48 deletions

View file

@ -449,10 +449,9 @@ void psbt_elements_normalize_fees(struct wally_psbt *psbt)
}
void wally_psbt_input_get_txid(const struct wally_psbt_input *in,
struct bitcoin_txid *txid)
struct bitcoin_txid *txid)
{
BUILD_ASSERT(sizeof(struct bitcoin_txid) == sizeof(in->txhash));
memcpy(txid, in->txhash, sizeof(struct bitcoin_txid));
CROSS_TYPE_ASSIGNMENT(txid, &in->txhash);
}
bool psbt_has_input(const struct wally_psbt *psbt,
@ -886,25 +885,21 @@ struct amount_sat psbt_compute_fee(const struct wally_psbt *psbt)
}
bool wally_psbt_input_spends(const struct wally_psbt_input *input,
const struct bitcoin_outpoint *outpoint)
const struct bitcoin_outpoint *outpoint)
{
/* Useful, as tx_part can have some NULL inputs */
if (!input)
return false;
BUILD_ASSERT(sizeof(outpoint->txid) == sizeof(input->txhash));
/* Useful, as tx_part can have some NULL inputs */
if (!input)
return false;
if (input->index != outpoint->n)
return false;
if (memcmp(&outpoint->txid, input->txhash, sizeof(outpoint->txid)) != 0)
return false;
return true;
return CROSS_TYPE_EQ(&outpoint->txid, &input->txhash);
}
void wally_psbt_input_get_outpoint(const struct wally_psbt_input *in,
struct bitcoin_outpoint *outpoint)
struct bitcoin_outpoint *outpoint)
{
BUILD_ASSERT(sizeof(struct bitcoin_txid) == sizeof(in->txhash));
memcpy(&outpoint->txid, in->txhash, sizeof(struct bitcoin_txid));
outpoint->n = in->index;
CROSS_TYPE_ASSIGNMENT(&outpoint->txid, &in->txhash);
outpoint->n = in->index;
}
const u8 *wally_psbt_output_get_script(const tal_t *ctx,

View file

@ -403,8 +403,7 @@ void bitcoin_tx_input_set_outpoint(struct bitcoin_tx *tx, int innum,
assert(innum < tx->wtx->num_inputs);
in = &tx->wtx->inputs[innum];
BUILD_ASSERT(sizeof(struct bitcoin_txid) == sizeof(in->txhash));
memcpy(in->txhash, &outpoint->txid, sizeof(struct bitcoin_txid));
CROSS_TYPE_ASSIGNMENT(&in->txhash, &outpoint->txid);
in->index = outpoint->n;
}
@ -412,15 +411,13 @@ void bitcoin_tx_input_set_outpoint(struct bitcoin_tx *tx, int innum,
void wally_tx_input_get_txid(const struct wally_tx_input *in,
struct bitcoin_txid *txid)
{
BUILD_ASSERT(sizeof(struct bitcoin_txid) == sizeof(in->txhash));
memcpy(txid, in->txhash, sizeof(struct bitcoin_txid));
CROSS_TYPE_ASSIGNMENT(txid, &in->txhash);
}
void wally_tx_input_get_outpoint(const struct wally_tx_input *in,
struct bitcoin_outpoint *outpoint)
{
BUILD_ASSERT(sizeof(struct bitcoin_txid) == sizeof(in->txhash));
memcpy(&outpoint->txid, in->txhash, sizeof(struct bitcoin_txid));
wally_tx_input_get_txid(in, &outpoint->txid);
outpoint->n = in->index;
}
@ -824,8 +821,7 @@ bool wally_tx_input_spends(const struct wally_tx_input *input,
/* Useful, as tx_part can have some NULL inputs */
if (!input)
return false;
BUILD_ASSERT(sizeof(outpoint->txid) == sizeof(input->txhash));
if (memcmp(&outpoint->txid, input->txhash, sizeof(outpoint->txid)) != 0)
if (!CROSS_TYPE_EQ(&outpoint->txid, &input->txhash))
return false;
return input->index == outpoint->n;
}

View file

@ -9,8 +9,7 @@
void derive_channel_id(struct channel_id *channel_id,
const struct bitcoin_outpoint *outpoint)
{
BUILD_ASSERT(sizeof(*channel_id) == sizeof(outpoint->txid));
memcpy(channel_id, &outpoint->txid, sizeof(*channel_id));
CROSS_TYPE_ASSIGNMENT(channel_id, &outpoint->txid);
channel_id->id[sizeof(*channel_id)-2] ^= outpoint->n >> 8;
channel_id->id[sizeof(*channel_id)-1] ^= outpoint->n;
}
@ -43,8 +42,7 @@ void derive_channel_id_v2(struct channel_id *channel_id,
pubkey_to_der(der_keys + offset_1, basepoint_1);
pubkey_to_der(der_keys + offset_2, basepoint_2);
sha256(&sha, der_keys, sizeof(der_keys));
BUILD_ASSERT(sizeof(*channel_id) == sizeof(sha));
memcpy(channel_id, &sha, sizeof(*channel_id));
CROSS_TYPE_ASSIGNMENT(channel_id, &sha);
}
void derive_tmp_channel_id(struct channel_id *channel_id,
@ -61,8 +59,7 @@ void derive_tmp_channel_id(struct channel_id *channel_id,
memset(der_keys, 0, PUBKEY_CMPR_LEN);
pubkey_to_der(der_keys + PUBKEY_CMPR_LEN, opener_basepoint);
sha256(&sha, der_keys, sizeof(der_keys));
BUILD_ASSERT(sizeof(*channel_id) == sizeof(sha));
memcpy(channel_id, &sha, sizeof(*channel_id));
CROSS_TYPE_ASSIGNMENT(channel_id, &sha);
}
/* BOLT #2:

View file

@ -100,8 +100,7 @@ bool per_commit_secret(const struct sha256 *shaseed,
shachain_from_seed(shaseed, shachain_index(per_commit_index), &s);
BUILD_ASSERT(sizeof(s) == sizeof(*commit_secret));
memcpy(commit_secret, &s, sizeof(s));
CROSS_TYPE_ASSIGNMENT(commit_secret, &s);
return true;
}
@ -262,7 +261,6 @@ bool shachain_get_secret(const struct shachain *shachain,
if (!shachain_get_hash(shachain, shachain_index(commit_num), &sha))
return false;
BUILD_ASSERT(sizeof(*preimage) == sizeof(sha));
memcpy(preimage, &sha, sizeof(*preimage));
CROSS_TYPE_ASSIGNMENT(preimage, &sha);
return true;
}

View file

@ -3,6 +3,7 @@
#include <ccan/array_size/array_size.h>
#include <ccan/mem/mem.h>
#include <common/hmac.h>
#include <common/utils.h>
#include <wire/wire.h>
void hmac_start(crypto_auth_hmacsha256_state *state,
@ -40,8 +41,7 @@ void subkey_from_hmac(const char *prefix,
{
struct hmac h;
hmac(base->data, sizeof(base->data), prefix, strlen(prefix), &h);
BUILD_ASSERT(sizeof(h.bytes) == sizeof(key->data));
memcpy(key->data, h.bytes, sizeof(key->data));
CROSS_TYPE_ASSIGNMENT(&key->data, &h.bytes);
}
void towire_hmac(u8 **pptr, const struct hmac *hmac)

View file

@ -153,4 +153,20 @@ int tmpdir_mkstemp(const tal_t *ctx, const char *template TAKES, char **created)
*/
char *str_lowering(const void *ctx, const char *string TAKES);
/**
* Assign two different structs which are the same size.
* We use this for assigning secrets <-> sha256 for example.
*/
#define CROSS_TYPE_ASSIGNMENT(dst, src) \
memcpy((dst), (src), \
sizeof(*(dst)) + BUILD_ASSERT_OR_ZERO(sizeof(*(dst)) == sizeof(*(src))))
/**
* Compare two different structs which are the same size.
* We use this for comparing bitcoin_txid <-> sha256 for example.
*/
#define CROSS_TYPE_EQ(a, b) \
(memcmp((a), (b), \
sizeof(*(a)) + BUILD_ASSERT_OR_ZERO(sizeof(*(a)) == sizeof(*(b)))) == 0)
#endif /* LIGHTNING_COMMON_UTILS_H */

View file

@ -773,8 +773,7 @@ bool wireaddr_to_sockname(const struct wireaddr_internal *addr,
if (addr->itype != ADDR_INTERNAL_SOCKNAME)
return false;
sun->sun_family = AF_LOCAL;
BUILD_ASSERT(sizeof(sun->sun_path) == sizeof(addr->u.sockname));
memcpy(sun->sun_path, addr->u.sockname, sizeof(addr->u.sockname));
CROSS_TYPE_ASSIGNMENT(&sun->sun_path, &addr->u.sockname);
return true;
}

View file

@ -422,8 +422,7 @@ static struct handshake *new_handshake(const tal_t *ctx,
*
* 2. `ck = h`
*/
BUILD_ASSERT(sizeof(handshake->h) == sizeof(handshake->ck));
memcpy(&handshake->ck, &handshake->h, sizeof(handshake->ck));
CROSS_TYPE_ASSIGNMENT(&handshake->ck, &handshake->h);
SUPERVERBOSE("# ck=%s",
tal_hexstr(tmpctx, &handshake->ck, sizeof(handshake->ck)));

View file

@ -153,8 +153,7 @@ static void invoice_secret(const struct preimage *payment_preimage,
sha256(&secret, modified.r,
ARRAY_SIZE(modified.r) * sizeof(*modified.r));
BUILD_ASSERT(sizeof(secret.u.u8) == sizeof(payment_secret->data));
memcpy(payment_secret->data, secret.u.u8, sizeof(secret.u.u8));
CROSS_TYPE_ASSIGNMENT(&payment_secret->data, &secret.u.u8);
}
/* FIXME: The spec should require a *real* secret: a signature of the

View file

@ -315,9 +315,8 @@ static size_t pay_mpp_hash(const struct pay_sort_key *key)
static bool pay_mpp_eq(const struct pay_mpp *pm, const struct pay_sort_key *key)
{
return memcmp(pm->sortkey.payment_hash, key->payment_hash,
sizeof(struct sha256)) == 0 &&
pm->sortkey.groupid == key->groupid;
return sha256_eq(pm->sortkey.payment_hash, key->payment_hash)
&& pm->sortkey.groupid == key->groupid;
}
HTABLE_DEFINE_TYPE(struct pay_mpp, pay_mpp_key, pay_mpp_hash, pay_mpp_eq,

View file

@ -142,8 +142,7 @@ static u64 fromwire_tlv_uint(const u8 **cursor, size_t *max, size_t maxlen)
fromwire_fail(cursor, max);
return 0;
}
BUILD_ASSERT(sizeof(val) == sizeof(bytes));
memcpy(&val, bytes, sizeof(bytes));
CROSS_TYPE_ASSIGNMENT(&val, &bytes);
return be64_to_cpu(val);
}

View file

@ -67,8 +67,7 @@ static void towire_tlv_uint(u8 **pptr, u64 v)
be64 val;
val = cpu_to_be64(v);
BUILD_ASSERT(sizeof(val) == sizeof(bytes));
memcpy(bytes, &val, sizeof(bytes));
CROSS_TYPE_ASSIGNMENT(&bytes, &val);
for (num_zeroes = 0; num_zeroes < sizeof(bytes); num_zeroes++)
if (bytes[num_zeroes] != 0)