From c7d8791123df2f48c02373fe7918cd06ecc45ea7 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Fri, 7 Feb 2025 18:10:45 -0800 Subject: [PATCH] lnwallet/chancloser: enforce pubkey binding for msg mapper --- lnwallet/chancloser/rbf_coop_msg_mapper.go | 27 ++++++++++++++-------- lnwallet/chancloser/rbf_coop_test.go | 2 +- protofsm/msg_mapper.go | 4 ++-- protofsm/state_machine_test.go | 2 +- 4 files changed, 22 insertions(+), 13 deletions(-) diff --git a/lnwallet/chancloser/rbf_coop_msg_mapper.go b/lnwallet/chancloser/rbf_coop_msg_mapper.go index 96c855686..a66cf78cc 100644 --- a/lnwallet/chancloser/rbf_coop_msg_mapper.go +++ b/lnwallet/chancloser/rbf_coop_msg_mapper.go @@ -1,8 +1,10 @@ package chancloser import ( + "github.com/btcsuite/btcd/btcec/v2" "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/msgmux" ) // RbfMsgMapper is a struct that implements the MsgMapper interface for the @@ -16,16 +18,21 @@ type RbfMsgMapper struct { // chanID is the channel ID of the channel being closed. chanID lnwire.ChannelID + + // peerPub is the public key of the peer that the channel is being + // closed. + peerPub btcec.PublicKey } // NewRbfMsgMapper creates a new RbfMsgMapper instance given the current block // height when the co-op close request was initiated. func NewRbfMsgMapper(blockHeight uint32, - chanID lnwire.ChannelID) *RbfMsgMapper { + chanID lnwire.ChannelID, peerPub btcec.PublicKey) *RbfMsgMapper { return &RbfMsgMapper{ blockHeight: blockHeight, chanID: chanID, + peerPub: peerPub, } } @@ -34,18 +41,20 @@ func someEvent[T ProtocolEvent](m T) fn.Option[ProtocolEvent] { return fn.Some(ProtocolEvent(m)) } -// isExpectedChanID returns true if the channel ID of the message matches the +// isForUs returns true if the channel ID + pubkey of the message matches the // bound instance. -func (r *RbfMsgMapper) isExpectedChanID(chanID lnwire.ChannelID) bool { - return r.chanID == chanID +func (r *RbfMsgMapper) isForUs(chanID lnwire.ChannelID, + fromPub btcec.PublicKey) bool { + + return r.chanID == chanID && r.peerPub.IsEqual(&fromPub) } // MapMsg maps a wire message into a FSM event. If the message is not mappable, // then an error is returned. -func (r *RbfMsgMapper) MapMsg(wireMsg lnwire.Message) fn.Option[ProtocolEvent] { - switch msg := wireMsg.(type) { +func (r *RbfMsgMapper) MapMsg(wireMsg msgmux.PeerMsg) fn.Option[ProtocolEvent] { + switch msg := wireMsg.Message.(type) { case *lnwire.Shutdown: - if !r.isExpectedChanID(msg.ChannelID) { + if !r.isForUs(msg.ChannelID, wireMsg.PeerPub) { return fn.None[ProtocolEvent]() } @@ -55,7 +64,7 @@ func (r *RbfMsgMapper) MapMsg(wireMsg lnwire.Message) fn.Option[ProtocolEvent] { }) case *lnwire.ClosingComplete: - if !r.isExpectedChanID(msg.ChannelID) { + if !r.isForUs(msg.ChannelID, wireMsg.PeerPub) { return fn.None[ProtocolEvent]() } @@ -64,7 +73,7 @@ func (r *RbfMsgMapper) MapMsg(wireMsg lnwire.Message) fn.Option[ProtocolEvent] { }) case *lnwire.ClosingSig: - if !r.isExpectedChanID(msg.ChannelID) { + if !r.isForUs(msg.ChannelID, wireMsg.PeerPub) { return fn.None[ProtocolEvent]() } diff --git a/lnwallet/chancloser/rbf_coop_test.go b/lnwallet/chancloser/rbf_coop_test.go index 17d783eae..d0bd76472 100644 --- a/lnwallet/chancloser/rbf_coop_test.go +++ b/lnwallet/chancloser/rbf_coop_test.go @@ -687,7 +687,7 @@ func newRbfCloserTestHarness(t *testing.T, peerPub := randPubKey(t) - msgMapper := NewRbfMsgMapper(uint32(startingHeight), chanID) + msgMapper := NewRbfMsgMapper(uint32(startingHeight), chanID, *peerPub) initialState := cfg.initialState.UnwrapOr(&ChannelActive{}) diff --git a/protofsm/msg_mapper.go b/protofsm/msg_mapper.go index 5e24255fa..a00d86379 100644 --- a/protofsm/msg_mapper.go +++ b/protofsm/msg_mapper.go @@ -2,7 +2,7 @@ package protofsm import ( "github.com/lightningnetwork/lnd/fn/v2" - "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/msgmux" ) // MsgMapper is used to map incoming wire messages into a FSM event. This is @@ -11,5 +11,5 @@ import ( type MsgMapper[Event any] interface { // MapMsg maps a wire message into a FSM event. If the message is not // mappable, then an None is returned. - MapMsg(msg lnwire.Message) fn.Option[Event] + MapMsg(msg msgmux.PeerMsg) fn.Option[Event] } diff --git a/protofsm/state_machine_test.go b/protofsm/state_machine_test.go index ed18743b5..ca05d336c 100644 --- a/protofsm/state_machine_test.go +++ b/protofsm/state_machine_test.go @@ -406,7 +406,7 @@ type dummyMsgMapper struct { mock.Mock } -func (d *dummyMsgMapper) MapMsg(wireMsg lnwire.Message) fn.Option[dummyEvents] { +func (d *dummyMsgMapper) MapMsg(wireMsg msgmux.PeerMsg) fn.Option[dummyEvents] { args := d.Called(wireMsg) //nolint:forcetypeassert