diff --git a/lnwire/fuzz_test.go b/lnwire/fuzz_test.go index 542a2f0c0..9a759604a 100644 --- a/lnwire/fuzz_test.go +++ b/lnwire/fuzz_test.go @@ -191,6 +191,17 @@ func FuzzWarning(f *testing.F) { }) } +func FuzzStfu(f *testing.F) { + f.Fuzz(func(t *testing.T, data []byte) { + // Prefix with MsgStfu. + data = prefixWithMsgType(data, MsgStfu) + + // Pass the message into our general fuzz harness for wire + // messages. + harness(t, data) + }) +} + func FuzzFundingCreated(f *testing.F) { f.Fuzz(func(t *testing.T, data []byte) { // Prefix with MsgFundingCreated. diff --git a/lnwire/lnwire_test.go b/lnwire/lnwire_test.go index e4c5c6baf..122a99660 100644 --- a/lnwire/lnwire_test.go +++ b/lnwire/lnwire_test.go @@ -438,6 +438,22 @@ func TestLightningWireProtocol(t *testing.T) { // are too complex for the testing/quick package to automatically // generate. customTypeGen := map[MessageType]func([]reflect.Value, *rand.Rand){ + MsgStfu: func(v []reflect.Value, r *rand.Rand) { + req := Stfu{} + if _, err := r.Read(req.ChanID[:]); err != nil { + t.Fatalf("unable to generate ChanID: %v", err) + } + + // 1/2 chance of being initiator + req.Initiator = r.Intn(2) == 1 + + // 1/2 chance additional TLV data. + if r.Intn(2) == 0 { + req.ExtraData = []byte{0xfd, 0x00, 0xff, 0x00} + } + + v[0] = reflect.ValueOf(req) + }, MsgInit: func(v []reflect.Value, r *rand.Rand) { req := NewInitMessage( randRawFeatureVector(r), @@ -1384,6 +1400,12 @@ func TestLightningWireProtocol(t *testing.T) { msgType MessageType scenario interface{} }{ + { + msgType: MsgStfu, + scenario: func(m Stfu) bool { + return mainScenario(&m) + }, + }, { msgType: MsgInit, scenario: func(m Init) bool { diff --git a/lnwire/message.go b/lnwire/message.go index bcee9f86d..2bf64a313 100644 --- a/lnwire/message.go +++ b/lnwire/message.go @@ -23,6 +23,7 @@ type MessageType uint16 // Lightning protocol. const ( MsgWarning MessageType = 1 + MsgStfu = 2 MsgInit = 16 MsgError = 17 MsgPing = 18 @@ -84,6 +85,8 @@ func (t MessageType) String() string { switch t { case MsgWarning: return "Warning" + case MsgStfu: + return "Stfu" case MsgInit: return "Init" case MsgOpenChannel: @@ -211,6 +214,8 @@ func makeEmptyMessage(msgType MessageType) (Message, error) { switch msgType { case MsgWarning: msg = &Warning{} + case MsgStfu: + msg = &Stfu{} case MsgInit: msg = &Init{} case MsgOpenChannel: diff --git a/lnwire/stfu.go b/lnwire/stfu.go new file mode 100644 index 000000000..0ba1730f4 --- /dev/null +++ b/lnwire/stfu.go @@ -0,0 +1,69 @@ +package lnwire + +import ( + "bytes" + "io" +) + +// Stfu is a message that is sent to lock the channel state prior to some other +// interactive protocol where channel updates need to be paused. +type Stfu struct { + // ChanID identifies which channel needs to be frozen. + ChanID ChannelID + + // Initiator is a byte that identifies whether we are the initiator of + // this process. + Initiator bool + + // ExtraData is the set of data that was appended to this message to + // fill out the full maximum transport message size. These fields can + // be used to specify optional data such as custom TLV fields. + ExtraData ExtraOpaqueData +} + +// A compile time check to ensure Stfu implements the lnwire.Message interface. +var _ Message = (*Stfu)(nil) + +// Encode serializes the target Stfu into the passed io.Writer. +// Serialization will observe the rules defined by the passed protocol version. +// +// This is a part of the lnwire.Message interface. +func (s *Stfu) Encode(w *bytes.Buffer, _ uint32) error { + if err := WriteChannelID(w, s.ChanID); err != nil { + return err + } + + if err := WriteBool(w, s.Initiator); err != nil { + return err + } + + return WriteBytes(w, s.ExtraData) +} + +// Decode deserializes the serialized Stfu stored in the passed io.Reader +// into the target Stfu using the deserialization rules defined by the +// passed protocol version. +// +// This is a part of the lnwire.Message interface. +func (s *Stfu) Decode(r io.Reader, _ uint32) error { + if err := ReadElements( + r, &s.ChanID, &s.Initiator, &s.ExtraData, + ); err != nil { + return err + } + + // This is required to pass the fuzz test round trip equality check. + if len(s.ExtraData) == 0 { + s.ExtraData = nil + } + + return nil +} + +// MsgType returns the MessageType code which uniquely identifies this message +// as a Stfu on the wire. +// +// This is part of the lnwire.Message interface. +func (s *Stfu) MsgType() MessageType { + return MsgStfu +}