Merge pull request #3742 from joostjager/expose-custom-tlv

invoices: expose custom tlv records from the payload
This commit is contained in:
Joost Jager 2019-12-10 07:53:59 +01:00 committed by GitHub
commit 699bb193e4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 815 additions and 614 deletions

View file

@ -7,6 +7,7 @@ import (
"time" "time"
"github.com/davecgh/go-spew/spew" "github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/htlcswitch/hop"
"github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
) )
@ -209,13 +210,15 @@ func TestInvoiceCancelSingleHtlc(t *testing.T) {
// Accept an htlc on this invoice. // Accept an htlc on this invoice.
key := CircuitKey{ChanID: lnwire.NewShortChanIDFromInt(1), HtlcID: 4} key := CircuitKey{ChanID: lnwire.NewShortChanIDFromInt(1), HtlcID: 4}
htlc := HtlcAcceptDesc{
Amt: 500,
CustomRecords: make(hop.CustomRecordSet),
}
invoice, err := db.UpdateInvoice(paymentHash, invoice, err := db.UpdateInvoice(paymentHash,
func(invoice *Invoice) (*InvoiceUpdateDesc, error) { func(invoice *Invoice) (*InvoiceUpdateDesc, error) {
return &InvoiceUpdateDesc{ return &InvoiceUpdateDesc{
AddHtlcs: map[CircuitKey]*HtlcAcceptDesc{ AddHtlcs: map[CircuitKey]*HtlcAcceptDesc{
key: { key: &htlc,
Amt: 500,
},
}, },
}, nil }, nil
}) })
@ -432,10 +435,11 @@ func TestDuplicateSettleInvoice(t *testing.T) {
invoice.SettleDate = dbInvoice.SettleDate invoice.SettleDate = dbInvoice.SettleDate
invoice.Htlcs = map[CircuitKey]*InvoiceHTLC{ invoice.Htlcs = map[CircuitKey]*InvoiceHTLC{
{}: { {}: {
Amt: amt, Amt: amt,
AcceptTime: time.Unix(1, 0), AcceptTime: time.Unix(1, 0),
ResolveTime: time.Unix(1, 0), ResolveTime: time.Unix(1, 0),
State: HtlcStateSettled, State: HtlcStateSettled,
CustomRecords: make(hop.CustomRecordSet),
}, },
} }
@ -747,6 +751,8 @@ func getUpdateInvoice(amt lnwire.MilliSatoshi) InvoiceUpdateCallback {
return nil, ErrInvoiceAlreadySettled return nil, ErrInvoiceAlreadySettled
} }
noRecords := make(hop.CustomRecordSet)
update := &InvoiceUpdateDesc{ update := &InvoiceUpdateDesc{
State: &InvoiceStateUpdateDesc{ State: &InvoiceStateUpdateDesc{
Preimage: invoice.Terms.PaymentPreimage, Preimage: invoice.Terms.PaymentPreimage,
@ -754,7 +760,8 @@ func getUpdateInvoice(amt lnwire.MilliSatoshi) InvoiceUpdateCallback {
}, },
AddHtlcs: map[CircuitKey]*HtlcAcceptDesc{ AddHtlcs: map[CircuitKey]*HtlcAcceptDesc{
{}: { {}: {
Amt: amt, Amt: amt,
CustomRecords: noRecords,
}, },
}, },
} }
@ -762,3 +769,64 @@ func getUpdateInvoice(amt lnwire.MilliSatoshi) InvoiceUpdateCallback {
return update, nil return update, nil
} }
} }
// TestCustomRecords tests that custom records are properly recorded in the
// invoice database.
func TestCustomRecords(t *testing.T) {
t.Parallel()
db, cleanUp, err := makeTestDB()
defer cleanUp()
if err != nil {
t.Fatalf("unable to make test db: %v", err)
}
testInvoice := &Invoice{
Htlcs: map[CircuitKey]*InvoiceHTLC{},
}
testInvoice.Terms.Value = lnwire.NewMSatFromSatoshis(10000)
testInvoice.Terms.Features = emptyFeatures
var paymentHash lntypes.Hash
if _, err := db.AddInvoice(testInvoice, paymentHash); err != nil {
t.Fatalf("unable to find invoice: %v", err)
}
// Accept an htlc with custom records on this invoice.
key := CircuitKey{ChanID: lnwire.NewShortChanIDFromInt(1), HtlcID: 4}
records := hop.CustomRecordSet{
100000: []byte{},
100001: []byte{1, 2},
}
_, err = db.UpdateInvoice(paymentHash,
func(invoice *Invoice) (*InvoiceUpdateDesc, error) {
return &InvoiceUpdateDesc{
AddHtlcs: map[CircuitKey]*HtlcAcceptDesc{
key: {
Amt: 500,
CustomRecords: records,
},
},
}, nil
},
)
if err != nil {
t.Fatalf("unable to add invoice htlc: %v", err)
}
// Retrieve the invoice from that database and verify that the custom
// records are present.
dbInvoice, err := db.LookupInvoice(paymentHash)
if err != nil {
t.Fatalf("unable to lookup invoice: %v", err)
}
if len(dbInvoice.Htlcs) != 1 {
t.Fatalf("expected the htlc to be added")
}
if !reflect.DeepEqual(records, dbInvoice.Htlcs[key].CustomRecords) {
t.Fatalf("invalid custom records")
}
}

View file

@ -9,6 +9,7 @@ import (
"time" "time"
"github.com/coreos/bbolt" "github.com/coreos/bbolt"
"github.com/lightningnetwork/lnd/htlcswitch/hop"
"github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/tlv" "github.com/lightningnetwork/lnd/tlv"
@ -308,6 +309,10 @@ type InvoiceHTLC struct {
// canceled htlc isn't just removed from the invoice htlcs map, because // canceled htlc isn't just removed from the invoice htlcs map, because
// we need AcceptHeight to properly cancel the htlc back. // we need AcceptHeight to properly cancel the htlc back.
State HtlcState State HtlcState
// CustomRecords contains the custom key/value pairs that accompanied
// the htlc.
CustomRecords hop.CustomRecordSet
} }
// HtlcAcceptDesc describes the details of a newly accepted htlc. // HtlcAcceptDesc describes the details of a newly accepted htlc.
@ -320,6 +325,10 @@ type HtlcAcceptDesc struct {
// Expiry is the expiry height of this htlc. // Expiry is the expiry height of this htlc.
Expiry uint32 Expiry uint32
// CustomRecords contains the custom key/value pairs that accompanied
// the htlc.
CustomRecords hop.CustomRecordSet
} }
// InvoiceUpdateDesc describes the changes that should be applied to the // InvoiceUpdateDesc describes the changes that should be applied to the
@ -1013,7 +1022,8 @@ func serializeHtlcs(w io.Writer, htlcs map[CircuitKey]*InvoiceHTLC) error {
resolveTime := uint64(htlc.ResolveTime.UnixNano()) resolveTime := uint64(htlc.ResolveTime.UnixNano())
state := uint8(htlc.State) state := uint8(htlc.State)
tlvStream, err := tlv.NewStream( var records []tlv.Record
records = append(records,
tlv.MakePrimitiveRecord(chanIDType, &chanID), tlv.MakePrimitiveRecord(chanIDType, &chanID),
tlv.MakePrimitiveRecord(htlcIDType, &key.HtlcID), tlv.MakePrimitiveRecord(htlcIDType, &key.HtlcID),
tlv.MakePrimitiveRecord(amtType, &amt), tlv.MakePrimitiveRecord(amtType, &amt),
@ -1025,6 +1035,16 @@ func serializeHtlcs(w io.Writer, htlcs map[CircuitKey]*InvoiceHTLC) error {
tlv.MakePrimitiveRecord(expiryHeightType, &htlc.Expiry), tlv.MakePrimitiveRecord(expiryHeightType, &htlc.Expiry),
tlv.MakePrimitiveRecord(htlcStateType, &state), tlv.MakePrimitiveRecord(htlcStateType, &state),
) )
// Convert the custom records to tlv.Record types that are ready
// for serialization.
customRecords := tlv.MapToRecords(htlc.CustomRecords)
// Append the custom records. Their ids are in the experimental
// range and sorted, so there is no need to sort again.
records = append(records, customRecords...)
tlvStream, err := tlv.NewStream(records...)
if err != nil { if err != nil {
return err return err
} }
@ -1191,7 +1211,8 @@ func deserializeHtlcs(r io.Reader) (map[CircuitKey]*InvoiceHTLC, error) {
return nil, err return nil, err
} }
if err := tlvStream.Decode(htlcReader); err != nil { parsedTypes, err := tlvStream.DecodeWithParsedTypes(htlcReader)
if err != nil {
return nil, err return nil, err
} }
@ -1201,6 +1222,10 @@ func deserializeHtlcs(r io.Reader) (map[CircuitKey]*InvoiceHTLC, error) {
htlc.State = HtlcState(state) htlc.State = HtlcState(state)
htlc.Amt = lnwire.MilliSatoshi(amt) htlc.Amt = lnwire.MilliSatoshi(amt)
// Reconstruct the custom records fields from the parsed types
// map return from the tlv parser.
htlc.CustomRecords = hop.NewCustomRecords(parsedTypes)
htlcs[key] = &htlc htlcs[key] = &htlc
} }
@ -1290,12 +1315,20 @@ func (d *DB) updateInvoice(hash lntypes.Hash, invoices, settleIndex *bbolt.Bucke
if _, exists := invoice.Htlcs[key]; exists { if _, exists := invoice.Htlcs[key]; exists {
return nil, fmt.Errorf("duplicate add of htlc %v", key) return nil, fmt.Errorf("duplicate add of htlc %v", key)
} }
// Force caller to supply htlc without custom records in a
// consistent way.
if htlcUpdate.CustomRecords == nil {
return nil, errors.New("nil custom records map")
}
htlc := &InvoiceHTLC{ htlc := &InvoiceHTLC{
Amt: htlcUpdate.Amt, Amt: htlcUpdate.Amt,
Expiry: htlcUpdate.Expiry, Expiry: htlcUpdate.Expiry,
AcceptHeight: uint32(htlcUpdate.AcceptHeight), AcceptHeight: uint32(htlcUpdate.AcceptHeight),
AcceptTime: now, AcceptTime: now,
State: HtlcStateAccepted, State: HtlcStateAccepted,
CustomRecords: htlcUpdate.CustomRecords,
} }
invoice.Htlcs[key] = htlc invoice.Htlcs[key] = htlc

View file

@ -79,6 +79,9 @@ func (e ErrInvalidPayload) Error() string {
hopType, e.Violation, e.Type) hopType, e.Violation, e.Type)
} }
// CustomRecordSet stores a set of custom key/value pairs.
type CustomRecordSet map[uint64][]byte
// Payload encapsulates all information delivered to a hop in an onion payload. // Payload encapsulates all information delivered to a hop in an onion payload.
// A Hop can represent either a TLV or legacy payload. The primary forwarding // A Hop can represent either a TLV or legacy payload. The primary forwarding
// instruction can be accessed via ForwardingInfo, and additional records can be // instruction can be accessed via ForwardingInfo, and additional records can be
@ -91,6 +94,10 @@ type Payload struct {
// MPP holds the info provided in an option_mpp record when parsed from // MPP holds the info provided in an option_mpp record when parsed from
// a TLV onion payload. // a TLV onion payload.
MPP *record.MPP MPP *record.MPP
// customRecords are user-defined records in the custom type range that
// were included in the payload.
customRecords CustomRecordSet
} }
// NewLegacyPayload builds a Payload from the amount, cltv, and next hop // NewLegacyPayload builds a Payload from the amount, cltv, and next hop
@ -105,6 +112,7 @@ func NewLegacyPayload(f *sphinx.HopData) *Payload {
AmountToForward: lnwire.MilliSatoshi(f.ForwardAmount), AmountToForward: lnwire.MilliSatoshi(f.ForwardAmount),
OutgoingCTLV: f.OutgoingCltv, OutgoingCTLV: f.OutgoingCltv,
}, },
customRecords: make(CustomRecordSet),
} }
} }
@ -157,6 +165,9 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) {
mpp = nil mpp = nil
} }
// Filter out the custom records.
customRecords := NewCustomRecords(parsedTypes)
return &Payload{ return &Payload{
FwdInfo: ForwardingInfo{ FwdInfo: ForwardingInfo{
Network: BitcoinNetwork, Network: BitcoinNetwork,
@ -164,7 +175,8 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) {
AmountToForward: lnwire.MilliSatoshi(amt), AmountToForward: lnwire.MilliSatoshi(amt),
OutgoingCTLV: cltv, OutgoingCTLV: cltv,
}, },
MPP: mpp, MPP: mpp,
customRecords: customRecords,
}, nil }, nil
} }
@ -174,11 +186,24 @@ func (h *Payload) ForwardingInfo() ForwardingInfo {
return h.FwdInfo return h.FwdInfo
} }
// NewCustomRecords filters the types parsed from the tlv stream for custom
// records.
func NewCustomRecords(parsedTypes tlv.TypeMap) CustomRecordSet {
customRecords := make(CustomRecordSet)
for t, parseResult := range parsedTypes {
if parseResult == nil || t < CustomTypeStart {
continue
}
customRecords[uint64(t)] = parseResult
}
return customRecords
}
// ValidateParsedPayloadTypes checks the types parsed from a hop payload to // ValidateParsedPayloadTypes checks the types parsed from a hop payload to
// ensure that the proper fields are either included or omitted. The finalHop // ensure that the proper fields are either included or omitted. The finalHop
// boolean should be true if the payload was parsed for an exit hop. The // boolean should be true if the payload was parsed for an exit hop. The
// requirements for this method are described in BOLT 04. // requirements for this method are described in BOLT 04.
func ValidateParsedPayloadTypes(parsedTypes tlv.TypeSet, func ValidateParsedPayloadTypes(parsedTypes tlv.TypeMap,
nextHop lnwire.ShortChannelID) error { nextHop lnwire.ShortChannelID) error {
isFinalHop := nextHop == Exit isFinalHop := nextHop == Exit
@ -234,22 +259,28 @@ func (h *Payload) MultiPath() *record.MPP {
return h.MPP return h.MPP
} }
// CustomRecords returns the custom tlv type records that were parsed from the
// payload.
func (h *Payload) CustomRecords() CustomRecordSet {
return h.customRecords
}
// getMinRequiredViolation checks for unrecognized required (even) fields in the // getMinRequiredViolation checks for unrecognized required (even) fields in the
// standard range and returns the lowest required type. Always returning the // standard range and returns the lowest required type. Always returning the
// lowest required type allows a failure message to be deterministic. // lowest required type allows a failure message to be deterministic.
func getMinRequiredViolation(set tlv.TypeSet) *tlv.Type { func getMinRequiredViolation(set tlv.TypeMap) *tlv.Type {
var ( var (
requiredViolation bool requiredViolation bool
minRequiredViolationType tlv.Type minRequiredViolationType tlv.Type
) )
for t, known := range set { for t, parseResult := range set {
// If a type is even but not known to us, we cannot process the // If a type is even but not known to us, we cannot process the
// payload. We are required to understand a field that we don't // payload. We are required to understand a field that we don't
// support. // support.
// //
// We always accept custom fields, because a higher level // We always accept custom fields, because a higher level
// application may understand them. // application may understand them.
if known || t%2 != 0 || t >= CustomTypeStart { if parseResult == nil || t%2 != 0 || t >= CustomTypeStart {
continue continue
} }

View file

@ -11,10 +11,11 @@ import (
) )
type decodePayloadTest struct { type decodePayloadTest struct {
name string name string
payload []byte payload []byte
expErr error expErr error
shouldHaveMPP bool expCustomRecords map[uint64][]byte
shouldHaveMPP bool
} }
var decodePayloadTests = []decodePayloadTest{ var decodePayloadTests = []decodePayloadTest{
@ -133,7 +134,10 @@ var decodePayloadTests = []decodePayloadTest{
{ {
name: "required type in custom range", name: "required type in custom range",
payload: []byte{0x02, 0x00, 0x04, 0x00, payload: []byte{0x02, 0x00, 0x04, 0x00,
0xfe, 0x00, 0x01, 0x00, 0x00, 0x00, 0xfe, 0x00, 0x01, 0x00, 0x00, 0x02, 0x10, 0x11,
},
expCustomRecords: map[uint64][]byte{
65536: {0x10, 0x11},
}, },
}, },
{ {
@ -237,4 +241,14 @@ func testDecodeHopPayloadValidation(t *testing.T, test decodePayloadTest) {
} else if p.MPP != nil { } else if p.MPP != nil {
t.Fatalf("unexpected MPP payload") t.Fatalf("unexpected MPP payload")
} }
// Convert expected nil map to empty map, because we always expect an
// initiated map from the payload.
expCustomRecords := make(hop.CustomRecordSet)
if test.expCustomRecords != nil {
expCustomRecords = test.expCustomRecords
}
if !reflect.DeepEqual(expCustomRecords, p.CustomRecords()) {
t.Fatalf("invalid custom records")
}
} }

View file

@ -1,6 +1,9 @@
package invoices package invoices
import "github.com/lightningnetwork/lnd/record" import (
"github.com/lightningnetwork/lnd/htlcswitch/hop"
"github.com/lightningnetwork/lnd/record"
)
// Payload abstracts access to any additional fields provided in the final hop's // Payload abstracts access to any additional fields provided in the final hop's
// TLV onion payload. // TLV onion payload.
@ -8,4 +11,8 @@ type Payload interface {
// MultiPath returns the record corresponding the option_mpp parsed from // MultiPath returns the record corresponding the option_mpp parsed from
// the onion payload. // the onion payload.
MultiPath() *record.MPP MultiPath() *record.MPP
// CustomRecords returns the custom tlv type records that were parsed
// from the payload.
CustomRecords() hop.CustomRecordSet
} }

View file

@ -443,6 +443,7 @@ func (i *InvoiceRegistry) NotifyExitHopHtlc(rHash lntypes.Hash,
expiry: expiry, expiry: expiry,
currentHeight: currentHeight, currentHeight: currentHeight,
finalCltvRejectDelta: i.finalCltvRejectDelta, finalCltvRejectDelta: i.finalCltvRejectDelta,
customRecords: payload.CustomRecords(),
} }
// We'll attempt to settle an invoice matching this rHash on disk (if // We'll attempt to settle an invoice matching this rHash on disk (if

View file

@ -7,6 +7,7 @@ import (
"time" "time"
"github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/htlcswitch/hop"
"github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/record" "github.com/lightningnetwork/lnd/record"
@ -669,3 +670,7 @@ type mockPayload struct {
func (p *mockPayload) MultiPath() *record.MPP { func (p *mockPayload) MultiPath() *record.MPP {
return p.mpp return p.mpp
} }
func (p *mockPayload) CustomRecords() hop.CustomRecordSet {
return make(hop.CustomRecordSet)
}

View file

@ -3,6 +3,8 @@ package invoices
import ( import (
"errors" "errors"
"github.com/lightningnetwork/lnd/htlcswitch/hop"
"github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
) )
@ -74,6 +76,7 @@ type invoiceUpdateCtx struct {
expiry uint32 expiry uint32
currentHeight int32 currentHeight int32
finalCltvRejectDelta int32 finalCltvRejectDelta int32
customRecords hop.CustomRecordSet
} }
// updateInvoice is a callback for DB.UpdateInvoice that contains the invoice // updateInvoice is a callback for DB.UpdateInvoice that contains the invoice
@ -125,9 +128,10 @@ func updateInvoice(ctx *invoiceUpdateCtx, inv *channeldb.Invoice) (
// Record HTLC in the invoice database. // Record HTLC in the invoice database.
newHtlcs := map[channeldb.CircuitKey]*channeldb.HtlcAcceptDesc{ newHtlcs := map[channeldb.CircuitKey]*channeldb.HtlcAcceptDesc{
ctx.circuitKey: { ctx.circuitKey: {
Amt: ctx.amtPaid, Amt: ctx.amtPaid,
Expiry: ctx.expiry, Expiry: ctx.expiry,
AcceptHeight: ctx.currentHeight, AcceptHeight: ctx.currentHeight,
CustomRecords: ctx.customRecords,
}, },
} }

View file

@ -75,13 +75,14 @@ func CreateRPCInvoice(invoice *channeldb.Invoice,
} }
rpcHtlc := lnrpc.InvoiceHTLC{ rpcHtlc := lnrpc.InvoiceHTLC{
ChanId: key.ChanID.ToUint64(), ChanId: key.ChanID.ToUint64(),
HtlcIndex: key.HtlcID, HtlcIndex: key.HtlcID,
AcceptHeight: int32(htlc.AcceptHeight), AcceptHeight: int32(htlc.AcceptHeight),
AcceptTime: htlc.AcceptTime.Unix(), AcceptTime: htlc.AcceptTime.Unix(),
ExpiryHeight: int32(htlc.Expiry), ExpiryHeight: int32(htlc.Expiry),
AmtMsat: uint64(htlc.Amt), AmtMsat: uint64(htlc.Amt),
State: state, State: state,
CustomRecords: htlc.CustomRecords,
} }
// Only report resolved times if htlc is resolved. // Only report resolved times if htlc is resolved.

File diff suppressed because it is too large Load diff

View file

@ -2396,6 +2396,9 @@ message InvoiceHTLC {
/// Current state the htlc is in. /// Current state the htlc is in.
InvoiceHTLCState state = 8 [json_name = "state"]; InvoiceHTLCState state = 8 [json_name = "state"];
/// Custom tlv records.
map<uint64, bytes> custom_records = 9 [json_name = "custom_records"];
} }
message AddInvoiceResponse { message AddInvoiceResponse {

View file

@ -2819,6 +2819,14 @@
"state": { "state": {
"$ref": "#/definitions/lnrpcInvoiceHTLCState", "$ref": "#/definitions/lnrpcInvoiceHTLCState",
"description": "/ Current state the htlc is in." "description": "/ Current state the htlc is in."
},
"custom_records": {
"type": "object",
"additionalProperties": {
"type": "string",
"format": "byte"
},
"description": "/ Custom tlv records."
} }
}, },
"title": "/ Details of an HTLC that paid to an invoice" "title": "/ Details of an HTLC that paid to an invoice"

View file

@ -12,9 +12,10 @@ import (
// Type is an 64-bit identifier for a TLV Record. // Type is an 64-bit identifier for a TLV Record.
type Type uint64 type Type uint64
// TypeSet is an unordered set of Types. The map item boolean values indicate // TypeMap is a map of parsed Types. The map values are byte slices. If the byte
// whether the type that we parsed was known. // slice is nil, the type was successfully parsed. Otherwise the value is byte
type TypeSet map[Type]bool // slice containing the encoded data.
type TypeMap map[Type][]byte
// Encoder is a signature for methods that can encode TLV values. An error // Encoder is a signature for methods that can encode TLV values. An error
// should be returned if the Encoder cannot support the underlying type of val. // should be returned if the Encoder cannot support the underlying type of val.

View file

@ -1,6 +1,7 @@
package tlv package tlv
import ( import (
"bytes"
"errors" "errors"
"io" "io"
"io/ioutil" "io/ioutil"
@ -139,16 +140,16 @@ func (s *Stream) Decode(r io.Reader) error {
} }
// DecodeWithParsedTypes is identical to Decode, but if successful, returns a // DecodeWithParsedTypes is identical to Decode, but if successful, returns a
// TypeSet containing the types of all records that were decoded or ignored from // TypeMap containing the types of all records that were decoded or ignored from
// the stream. // the stream.
func (s *Stream) DecodeWithParsedTypes(r io.Reader) (TypeSet, error) { func (s *Stream) DecodeWithParsedTypes(r io.Reader) (TypeMap, error) {
return s.decode(r, make(TypeSet)) return s.decode(r, make(TypeMap))
} }
// decode is a helper function that performs the basis of stream decoding. If // decode is a helper function that performs the basis of stream decoding. If
// the caller needs the set of parsed types, it must provide an initialized // the caller needs the set of parsed types, it must provide an initialized
// parsedTypes, otherwise the returned TypeSet will be nil. // parsedTypes, otherwise the returned TypeMap will be nil.
func (s *Stream) decode(r io.Reader, parsedTypes TypeSet) (TypeSet, error) { func (s *Stream) decode(r io.Reader, parsedTypes TypeMap) (TypeMap, error) {
var ( var (
typ Type typ Type
min Type min Type
@ -230,10 +231,25 @@ func (s *Stream) decode(r io.Reader, parsedTypes TypeSet) (TypeSet, error) {
return nil, err return nil, err
} }
// Record the successfully decoded type if the caller
// provided an initialized TypeMap.
if parsedTypes != nil {
parsedTypes[typ] = nil
}
// Otherwise, the record type is unknown and is odd, discard the // Otherwise, the record type is unknown and is odd, discard the
// number of bytes specified by length. // number of bytes specified by length.
default: default:
_, err := io.CopyN(ioutil.Discard, r, int64(length)) // If the caller provided an initialized TypeMap, record
// the encoded bytes.
var b *bytes.Buffer
writer := ioutil.Discard
if parsedTypes != nil {
b = bytes.NewBuffer(make([]byte, 0, length))
writer = b
}
_, err := io.CopyN(writer, r, int64(length))
switch { switch {
// We'll convert any EOFs to ErrUnexpectedEOF, since this // We'll convert any EOFs to ErrUnexpectedEOF, since this
@ -245,12 +261,10 @@ func (s *Stream) decode(r io.Reader, parsedTypes TypeSet) (TypeSet, error) {
case err != nil: case err != nil:
return nil, err return nil, err
} }
}
// Record the successfully decoded or ignored type if the if parsedTypes != nil {
// caller provided an initialized TypeSet. parsedTypes[typ] = b.Bytes()
if parsedTypes != nil { }
parsedTypes[typ] = ok
} }
// Update our record index so that we can begin our next search // Update our record index so that we can begin our next search

View file

@ -12,7 +12,7 @@ type parsedTypeTest struct {
name string name string
encode []tlv.Type encode []tlv.Type
decode []tlv.Type decode []tlv.Type
expParsedTypes tlv.TypeSet expParsedTypes tlv.TypeMap
} }
// TestParsedTypes asserts that a Stream will properly return the set of types // TestParsedTypes asserts that a Stream will properly return the set of types
@ -29,17 +29,17 @@ func TestParsedTypes(t *testing.T) {
name: "known and unknown", name: "known and unknown",
encode: []tlv.Type{knownType, unknownType}, encode: []tlv.Type{knownType, unknownType},
decode: []tlv.Type{knownType}, decode: []tlv.Type{knownType},
expParsedTypes: tlv.TypeSet{ expParsedTypes: tlv.TypeMap{
unknownType: false, unknownType: []byte{0, 0, 0, 0, 0, 0, 0, 0},
knownType: true, knownType: nil,
}, },
}, },
{ {
name: "known and missing known", name: "known and missing known",
encode: []tlv.Type{knownType}, encode: []tlv.Type{knownType},
decode: []tlv.Type{knownType, secondKnownType}, decode: []tlv.Type{knownType, secondKnownType},
expParsedTypes: tlv.TypeSet{ expParsedTypes: tlv.TypeMap{
knownType: true, knownType: nil,
}, },
}, },
} }