diff --git a/common/onion.c b/common/onion.c index 37afa9ab6..584f67af1 100644 --- a/common/onion.c +++ b/common/onion.c @@ -214,17 +214,22 @@ struct onion_payload *onion_decode(const tal_t *ctx, const u8 *cursor = rs->raw_payload; size_t max = tal_bytelen(cursor), len; struct tlv_tlv_payload *tlv; + size_t badfield; if (!pull_payload_length(&cursor, &max, true, &len)) - return tal_free(p); + goto general_fail; tlv = tlv_tlv_payload_new(p); - if (!fromwire_tlv_payload(&cursor, &max, tlv)) - goto fail; + if (!fromwire_tlv_payload(&cursor, &max, tlv)) { + /* FIXME: Fill in correct thing here! */ + goto general_fail; + } - if (!tlv_fields_valid(tlv->fields, accepted_extra_tlvs, failtlvpos)) { - *failtlvtype = tlv->fields[*failtlvpos].numtype; - goto fail; + /* FIXME: This API makes it really hard to get the actual + * offset of field. */ + if (!tlv_fields_valid(tlv->fields, accepted_extra_tlvs, &badfield)) { + *failtlvtype = tlv->fields[badfield].numtype; + goto field_bad; } /* BOLT #4: @@ -233,8 +238,14 @@ struct onion_payload *onion_decode(const tal_t *ctx, * - MUST return an error if `amt_to_forward` or * `outgoing_cltv_value` are not present. */ - if (!tlv->amt_to_forward || !tlv->outgoing_cltv_value) - goto fail; + if (!tlv->amt_to_forward) { + *failtlvtype = TLV_TLV_PAYLOAD_AMT_TO_FORWARD; + goto field_bad; + } + if (!tlv->outgoing_cltv_value) { + *failtlvtype = TLV_TLV_PAYLOAD_OUTGOING_CLTV_VALUE; + goto field_bad; + } p->amt_to_forward = amount_msat(*tlv->amt_to_forward); p->outgoing_cltv = *tlv->outgoing_cltv_value; @@ -247,8 +258,10 @@ struct onion_payload *onion_decode(const tal_t *ctx, * - MUST include `short_channel_id` */ if (rs->nextcase == ONION_FORWARD) { - if (!tlv->short_channel_id) - goto fail; + if (!tlv->short_channel_id) { + *failtlvtype = TLV_TLV_PAYLOAD_SHORT_CHANNEL_ID; + goto field_bad; + } p->forward_channel = tal_dup(p, struct short_channel_id, tlv->short_channel_id); p->total_msat = NULL; @@ -283,18 +296,30 @@ struct onion_payload *onion_decode(const tal_t *ctx, if (rs->nextcase == ONION_FORWARD) { struct tlv_tlv_payload *ntlv; - if (!tlv->encrypted_recipient_data) - goto fail; + if (!tlv->encrypted_recipient_data) { + *failtlvtype = TLV_TLV_PAYLOAD_ENCRYPTED_RECIPIENT_DATA; + goto field_bad; + } ntlv = decrypt_tlv(tmpctx, &p->blinding_ss, tlv->encrypted_recipient_data); - if (!ntlv) - goto fail; + if (!ntlv) { + *failtlvtype = TLV_TLV_PAYLOAD_ENCRYPTED_RECIPIENT_DATA; + goto field_bad; + } /* Must override short_channel_id */ - if (!ntlv->short_channel_id) + if (!ntlv->short_channel_id) { + *failtlvtype = TLV_TLV_PAYLOAD_ENCRYPTED_RECIPIENT_DATA; + /* Place error at *end* of enctlv, + * indicating missing field. */ + *failtlvpos = tlv_field_offset(rs->raw_payload, + tal_bytelen(rs->raw_payload), + *failtlvtype) + + tal_bytelen(tlv->encrypted_recipient_data); goto fail; + } *p->forward_channel = *ntlv->short_channel_id; @@ -313,6 +338,15 @@ struct onion_payload *onion_decode(const tal_t *ctx, p->tlv = tal_steal(p, tlv); return p; +field_bad: + *failtlvpos = tlv_field_offset(rs->raw_payload, tal_bytelen(rs->raw_payload), + *failtlvtype); + goto fail; + +general_fail: + *failtlvtype = 0; + *failtlvpos = tal_bytelen(rs->raw_payload); + goto fail; fail: tal_free(tlv); tal_free(p); diff --git a/common/test/run-sphinx.c b/common/test/run-sphinx.c index 7cb201a8e..36248e44a 100644 --- a/common/test/run-sphinx.c +++ b/common/test/run-sphinx.c @@ -62,6 +62,9 @@ bool fromwire_tlv(const u8 **cursor UNNEEDED, size_t *max UNNEEDED, /* Generated stub for pubkey_from_node_id */ bool pubkey_from_node_id(struct pubkey *key UNNEEDED, const struct node_id *id UNNEEDED) { fprintf(stderr, "pubkey_from_node_id called!\n"); abort(); } +/* Generated stub for tlv_field_offset */ +size_t tlv_field_offset(const u8 *tlvstream UNNEEDED, size_t tlvlen UNNEEDED, u64 fieldtype UNNEEDED) +{ fprintf(stderr, "tlv_field_offset called!\n"); abort(); } /* Generated stub for tlv_fields_valid */ bool tlv_fields_valid(const struct tlv_field *fields UNNEEDED, u64 *allow_extra UNNEEDED, size_t *err_index UNNEEDED) diff --git a/wire/tlvstream.c b/wire/tlvstream.c index 9bf10a774..aadd04460 100644 --- a/wire/tlvstream.c +++ b/wire/tlvstream.c @@ -82,6 +82,34 @@ void tlvstream_set_tu32(struct tlv_field **stream, u64 type, u32 value) tlvstream_set_raw(stream, type, take(ser), tal_bytelen(ser)); } +/* Get the offset of this field: returns size of msg if not found (or + * tlv malformed) */ +size_t tlv_field_offset(const u8 *tlvstream, size_t tlvlen, u64 fieldtype) +{ + size_t max = tlvlen; + while (max > 0) { + u64 type, length; + size_t field_off = tlvlen - max; + + type = fromwire_bigsize(&tlvstream, &max); + length = fromwire_bigsize(&tlvstream, &max); + + if (!tlvstream) + break; + + /* Found it! */ + if (type == fieldtype) + return field_off; + + if (length > max) + break; + + max -= length; + tlvstream += length; + } + return tlvlen; +} + bool fromwire_tlv(const u8 **cursor, size_t *max, const struct tlv_record_type *types, size_t num_types, void *record, struct tlv_field **fields) diff --git a/wire/tlvstream.h b/wire/tlvstream.h index e289513e6..61addfea8 100644 --- a/wire/tlvstream.h +++ b/wire/tlvstream.h @@ -48,6 +48,10 @@ void towire_tlv(u8 **pptr, bool tlv_fields_valid(const struct tlv_field *fields, u64 *allow_extra, size_t *err_index); +/* Get the offset of this field: returns size of msg if not found (or + * tlv malformed) */ +size_t tlv_field_offset(const u8 *tlvstream, size_t tlvlen, u64 fieldtype); + /* Generic primitive setters for tlvstreams. */ void tlvstream_set_raw(struct tlv_field **stream, u64 type, void *value TAKES, size_t valuelen); void tlvstream_set_short_channel_id(struct tlv_field **stream, u64 type,