From f76203727d27b537e4f2ae8d28f3b49beab556b3 Mon Sep 17 00:00:00 2001 From: Rusty Russell Date: Fri, 27 Jul 2018 06:55:37 +0930 Subject: [PATCH] htlc_wire: be stricter in marshaling/unmarshaling struct failed_htlc. There are three cases: 1. failcode is 0, scid is NULL, failreason is the onion to fwd. 2. failcode is non-zero, but UPDATE bit not set. scid is NULL, failreason NULL. 3. failcode has UPDATE bit set. scid is non-NULL, failreason is NULL. Assert these on marshaling, and only send the parts we need so unmarshal is always canonical. Signed-off-by: Rusty Russell --- common/htlc_wire.c | 36 +++++++++++++++++++++++------------- 1 file changed, 23 insertions(+), 13 deletions(-) diff --git a/common/htlc_wire.c b/common/htlc_wire.c index bedd1e1eb..937854256 100644 --- a/common/htlc_wire.c +++ b/common/htlc_wire.c @@ -30,12 +30,18 @@ void towire_failed_htlc(u8 **pptr, const struct failed_htlc *failed) assert(!failed->failcode || !tal_len(failed->failreason)); towire_u64(pptr, failed->id); towire_u16(pptr, failed->failcode); - if (failed->failcode & UPDATE) + if (failed->failcode & UPDATE) { + assert(!failed->failreason); towire_short_channel_id(pptr, failed->scid); - else + } else { assert(!failed->scid); - towire_u16(pptr, tal_count(failed->failreason)); - towire_u8_array(pptr, failed->failreason, tal_count(failed->failreason)); + if (!failed->failcode) { + assert(failed->failreason); + towire_u16(pptr, tal_count(failed->failreason)); + towire_u8_array(pptr, failed->failreason, + tal_count(failed->failreason)); + } + } } void towire_htlc_state(u8 **pptr, const enum htlc_state hstate) @@ -87,22 +93,26 @@ void fromwire_fulfilled_htlc(const u8 **cursor, size_t *max, struct failed_htlc *fromwire_failed_htlc(const tal_t *ctx, const u8 **cursor, size_t *max) { - u16 failreason_len; struct failed_htlc *failed = tal(ctx, struct failed_htlc); failed->id = fromwire_u64(cursor, max); failed->failcode = fromwire_u16(cursor, max); - if (failed->failcode & UPDATE) { - failed->scid = tal(failed, struct short_channel_id); - fromwire_short_channel_id(cursor, max, failed->scid); - } else + if (failed->failcode == 0) { + u16 failreason_len; failed->scid = NULL; - failreason_len = fromwire_u16(cursor, max); - if (failreason_len) + failreason_len = fromwire_u16(cursor, max); failed->failreason = tal_arr(failed, u8, failreason_len); - else + fromwire_u8_array(cursor, max, failed->failreason, + failreason_len); + } else { failed->failreason = NULL; - fromwire_u8_array(cursor, max, failed->failreason, failreason_len); + if (failed->failcode & UPDATE) { + failed->scid = tal(failed, struct short_channel_id); + fromwire_short_channel_id(cursor, max, failed->scid); + } else { + failed->scid = NULL; + } + } return failed; }