lnwire: add wire type for stfu

This commit is contained in:
Keagan McClelland 2023-12-06 16:44:59 -08:00
parent c262b1b5a5
commit 0176fca826
No known key found for this signature in database
GPG Key ID: FA7E65C951F12439
4 changed files with 107 additions and 0 deletions

View File

@ -191,6 +191,17 @@ func FuzzWarning(f *testing.F) {
}) })
} }
func FuzzStfu(f *testing.F) {
f.Fuzz(func(t *testing.T, data []byte) {
// Prefix with MsgStfu.
data = prefixWithMsgType(data, MsgStfu)
// Pass the message into our general fuzz harness for wire
// messages.
harness(t, data)
})
}
func FuzzFundingCreated(f *testing.F) { func FuzzFundingCreated(f *testing.F) {
f.Fuzz(func(t *testing.T, data []byte) { f.Fuzz(func(t *testing.T, data []byte) {
// Prefix with MsgFundingCreated. // Prefix with MsgFundingCreated.

View File

@ -438,6 +438,22 @@ func TestLightningWireProtocol(t *testing.T) {
// are too complex for the testing/quick package to automatically // are too complex for the testing/quick package to automatically
// generate. // generate.
customTypeGen := map[MessageType]func([]reflect.Value, *rand.Rand){ customTypeGen := map[MessageType]func([]reflect.Value, *rand.Rand){
MsgStfu: func(v []reflect.Value, r *rand.Rand) {
req := Stfu{}
if _, err := r.Read(req.ChanID[:]); err != nil {
t.Fatalf("unable to generate ChanID: %v", err)
}
// 1/2 chance of being initiator
req.Initiator = r.Intn(2) == 1
// 1/2 chance additional TLV data.
if r.Intn(2) == 0 {
req.ExtraData = []byte{0xfd, 0x00, 0xff, 0x00}
}
v[0] = reflect.ValueOf(req)
},
MsgInit: func(v []reflect.Value, r *rand.Rand) { MsgInit: func(v []reflect.Value, r *rand.Rand) {
req := NewInitMessage( req := NewInitMessage(
randRawFeatureVector(r), randRawFeatureVector(r),
@ -1384,6 +1400,12 @@ func TestLightningWireProtocol(t *testing.T) {
msgType MessageType msgType MessageType
scenario interface{} scenario interface{}
}{ }{
{
msgType: MsgStfu,
scenario: func(m Stfu) bool {
return mainScenario(&m)
},
},
{ {
msgType: MsgInit, msgType: MsgInit,
scenario: func(m Init) bool { scenario: func(m Init) bool {

View File

@ -23,6 +23,7 @@ type MessageType uint16
// Lightning protocol. // Lightning protocol.
const ( const (
MsgWarning MessageType = 1 MsgWarning MessageType = 1
MsgStfu = 2
MsgInit = 16 MsgInit = 16
MsgError = 17 MsgError = 17
MsgPing = 18 MsgPing = 18
@ -84,6 +85,8 @@ func (t MessageType) String() string {
switch t { switch t {
case MsgWarning: case MsgWarning:
return "Warning" return "Warning"
case MsgStfu:
return "Stfu"
case MsgInit: case MsgInit:
return "Init" return "Init"
case MsgOpenChannel: case MsgOpenChannel:
@ -211,6 +214,8 @@ func makeEmptyMessage(msgType MessageType) (Message, error) {
switch msgType { switch msgType {
case MsgWarning: case MsgWarning:
msg = &Warning{} msg = &Warning{}
case MsgStfu:
msg = &Stfu{}
case MsgInit: case MsgInit:
msg = &Init{} msg = &Init{}
case MsgOpenChannel: case MsgOpenChannel:

69
lnwire/stfu.go Normal file
View File

@ -0,0 +1,69 @@
package lnwire
import (
"bytes"
"io"
)
// Stfu is a message that is sent to lock the channel state prior to some other
// interactive protocol where channel updates need to be paused.
type Stfu struct {
// ChanID identifies which channel needs to be frozen.
ChanID ChannelID
// Initiator is a byte that identifies whether we are the initiator of
// this process.
Initiator bool
// ExtraData is the set of data that was appended to this message to
// fill out the full maximum transport message size. These fields can
// be used to specify optional data such as custom TLV fields.
ExtraData ExtraOpaqueData
}
// A compile time check to ensure Stfu implements the lnwire.Message interface.
var _ Message = (*Stfu)(nil)
// Encode serializes the target Stfu into the passed io.Writer.
// Serialization will observe the rules defined by the passed protocol version.
//
// This is a part of the lnwire.Message interface.
func (s *Stfu) Encode(w *bytes.Buffer, _ uint32) error {
if err := WriteChannelID(w, s.ChanID); err != nil {
return err
}
if err := WriteBool(w, s.Initiator); err != nil {
return err
}
return WriteBytes(w, s.ExtraData)
}
// Decode deserializes the serialized Stfu stored in the passed io.Reader
// into the target Stfu using the deserialization rules defined by the
// passed protocol version.
//
// This is a part of the lnwire.Message interface.
func (s *Stfu) Decode(r io.Reader, _ uint32) error {
if err := ReadElements(
r, &s.ChanID, &s.Initiator, &s.ExtraData,
); err != nil {
return err
}
// This is required to pass the fuzz test round trip equality check.
if len(s.ExtraData) == 0 {
s.ExtraData = nil
}
return nil
}
// MsgType returns the MessageType code which uniquely identifies this message
// as a Stfu on the wire.
//
// This is part of the lnwire.Message interface.
func (s *Stfu) MsgType() MessageType {
return MsgStfu
}