mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-01-18 13:27:56 +01:00
lnwire: add wire type for stfu
This commit is contained in:
parent
c262b1b5a5
commit
0176fca826
@ -191,6 +191,17 @@ func FuzzWarning(f *testing.F) {
|
||||
})
|
||||
}
|
||||
|
||||
func FuzzStfu(f *testing.F) {
|
||||
f.Fuzz(func(t *testing.T, data []byte) {
|
||||
// Prefix with MsgStfu.
|
||||
data = prefixWithMsgType(data, MsgStfu)
|
||||
|
||||
// Pass the message into our general fuzz harness for wire
|
||||
// messages.
|
||||
harness(t, data)
|
||||
})
|
||||
}
|
||||
|
||||
func FuzzFundingCreated(f *testing.F) {
|
||||
f.Fuzz(func(t *testing.T, data []byte) {
|
||||
// Prefix with MsgFundingCreated.
|
||||
|
@ -438,6 +438,22 @@ func TestLightningWireProtocol(t *testing.T) {
|
||||
// are too complex for the testing/quick package to automatically
|
||||
// generate.
|
||||
customTypeGen := map[MessageType]func([]reflect.Value, *rand.Rand){
|
||||
MsgStfu: func(v []reflect.Value, r *rand.Rand) {
|
||||
req := Stfu{}
|
||||
if _, err := r.Read(req.ChanID[:]); err != nil {
|
||||
t.Fatalf("unable to generate ChanID: %v", err)
|
||||
}
|
||||
|
||||
// 1/2 chance of being initiator
|
||||
req.Initiator = r.Intn(2) == 1
|
||||
|
||||
// 1/2 chance additional TLV data.
|
||||
if r.Intn(2) == 0 {
|
||||
req.ExtraData = []byte{0xfd, 0x00, 0xff, 0x00}
|
||||
}
|
||||
|
||||
v[0] = reflect.ValueOf(req)
|
||||
},
|
||||
MsgInit: func(v []reflect.Value, r *rand.Rand) {
|
||||
req := NewInitMessage(
|
||||
randRawFeatureVector(r),
|
||||
@ -1384,6 +1400,12 @@ func TestLightningWireProtocol(t *testing.T) {
|
||||
msgType MessageType
|
||||
scenario interface{}
|
||||
}{
|
||||
{
|
||||
msgType: MsgStfu,
|
||||
scenario: func(m Stfu) bool {
|
||||
return mainScenario(&m)
|
||||
},
|
||||
},
|
||||
{
|
||||
msgType: MsgInit,
|
||||
scenario: func(m Init) bool {
|
||||
|
@ -23,6 +23,7 @@ type MessageType uint16
|
||||
// Lightning protocol.
|
||||
const (
|
||||
MsgWarning MessageType = 1
|
||||
MsgStfu = 2
|
||||
MsgInit = 16
|
||||
MsgError = 17
|
||||
MsgPing = 18
|
||||
@ -84,6 +85,8 @@ func (t MessageType) String() string {
|
||||
switch t {
|
||||
case MsgWarning:
|
||||
return "Warning"
|
||||
case MsgStfu:
|
||||
return "Stfu"
|
||||
case MsgInit:
|
||||
return "Init"
|
||||
case MsgOpenChannel:
|
||||
@ -211,6 +214,8 @@ func makeEmptyMessage(msgType MessageType) (Message, error) {
|
||||
switch msgType {
|
||||
case MsgWarning:
|
||||
msg = &Warning{}
|
||||
case MsgStfu:
|
||||
msg = &Stfu{}
|
||||
case MsgInit:
|
||||
msg = &Init{}
|
||||
case MsgOpenChannel:
|
||||
|
69
lnwire/stfu.go
Normal file
69
lnwire/stfu.go
Normal file
@ -0,0 +1,69 @@
|
||||
package lnwire
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
)
|
||||
|
||||
// Stfu is a message that is sent to lock the channel state prior to some other
|
||||
// interactive protocol where channel updates need to be paused.
|
||||
type Stfu struct {
|
||||
// ChanID identifies which channel needs to be frozen.
|
||||
ChanID ChannelID
|
||||
|
||||
// Initiator is a byte that identifies whether we are the initiator of
|
||||
// this process.
|
||||
Initiator bool
|
||||
|
||||
// ExtraData is the set of data that was appended to this message to
|
||||
// fill out the full maximum transport message size. These fields can
|
||||
// be used to specify optional data such as custom TLV fields.
|
||||
ExtraData ExtraOpaqueData
|
||||
}
|
||||
|
||||
// A compile time check to ensure Stfu implements the lnwire.Message interface.
|
||||
var _ Message = (*Stfu)(nil)
|
||||
|
||||
// Encode serializes the target Stfu into the passed io.Writer.
|
||||
// Serialization will observe the rules defined by the passed protocol version.
|
||||
//
|
||||
// This is a part of the lnwire.Message interface.
|
||||
func (s *Stfu) Encode(w *bytes.Buffer, _ uint32) error {
|
||||
if err := WriteChannelID(w, s.ChanID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := WriteBool(w, s.Initiator); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return WriteBytes(w, s.ExtraData)
|
||||
}
|
||||
|
||||
// Decode deserializes the serialized Stfu stored in the passed io.Reader
|
||||
// into the target Stfu using the deserialization rules defined by the
|
||||
// passed protocol version.
|
||||
//
|
||||
// This is a part of the lnwire.Message interface.
|
||||
func (s *Stfu) Decode(r io.Reader, _ uint32) error {
|
||||
if err := ReadElements(
|
||||
r, &s.ChanID, &s.Initiator, &s.ExtraData,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// This is required to pass the fuzz test round trip equality check.
|
||||
if len(s.ExtraData) == 0 {
|
||||
s.ExtraData = nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MsgType returns the MessageType code which uniquely identifies this message
|
||||
// as a Stfu on the wire.
|
||||
//
|
||||
// This is part of the lnwire.Message interface.
|
||||
func (s *Stfu) MsgType() MessageType {
|
||||
return MsgStfu
|
||||
}
|
Loading…
Reference in New Issue
Block a user