lnwallet: add HTLC count validation

This commit is contained in:
Andrey Samokhvalov 2016-11-23 11:36:55 +03:00 committed by Olaoluwa Osuntokun
parent 391d5cd401
commit da3028e10c
3 changed files with 189 additions and 7 deletions

View File

@ -1107,6 +1107,11 @@ func (lc *LightningChannel) SignNextCommitment() ([]byte, uint32, error) {
lc.Lock() lc.Lock()
defer lc.Unlock() defer lc.Unlock()
err := lc.validateCommitmentSanity(lc.theirLogCounter, lc.ourLogCounter, false)
if err != nil {
return nil, 0, err
}
// Ensure that we have enough unused revocation hashes given to us by the // Ensure that we have enough unused revocation hashes given to us by the
// remote party. If the set is empty, then we're unable to create a new // remote party. If the set is empty, then we're unable to create a new
// state unless they first revoke a prior commitment transaction. // state unless they first revoke a prior commitment transaction.
@ -1165,7 +1170,47 @@ func (lc *LightningChannel) SignNextCommitment() ([]byte, uint32, error) {
return sig, lc.theirLogCounter, nil return sig, lc.theirLogCounter, nil
} }
// ReceiveNewCommitment processs a signature for a new commitment state sent by // validateCommitmentSanity is used to validate that on current state the commitment
// transaction is valid in terms of propagating it over Bitcoin network, and
// also that all outputs are meet Bitcoin spec requirements and they are
// spendable.
func (lc *LightningChannel) validateCommitmentSanity(theirLogCounter,
ourLogCounter uint32, prediction bool) error {
htlcCount := 0
if prediction {
htlcCount++
}
// Run through all the HTLC's that will be covered by this transaction
// in order to calculate theirs count.
htlcView := lc.fetchHTLCView(theirLogCounter, ourLogCounter)
for _, entry := range htlcView.ourUpdates {
if entry.EntryType == Add {
htlcCount++
} else {
htlcCount--
}
}
for _, entry := range htlcView.theirUpdates {
if entry.EntryType == Add {
htlcCount++
} else {
htlcCount--
}
}
if htlcCount > MaxHTLCNumber {
return ErrMaxHTLCNumber
}
return nil
}
// ReceiveNewCommitment process a signature for a new commitment state sent by
// the remote party. This method will should be called in response to the // the remote party. This method will should be called in response to the
// remote party initiating a new change, or when the remote party sends a // remote party initiating a new change, or when the remote party sends a
// signature fully accepting a new state we've initiated. If we are able to // signature fully accepting a new state we've initiated. If we are able to
@ -1179,6 +1224,11 @@ func (lc *LightningChannel) ReceiveNewCommitment(rawSig []byte,
lc.Lock() lc.Lock()
defer lc.Unlock() defer lc.Unlock()
err := lc.validateCommitmentSanity(lc.theirLogCounter, ourLogIndex, false)
if err != nil {
return err
}
theirCommitKey := lc.channelState.TheirCommitKey theirCommitKey := lc.channelState.TheirCommitKey
theirMultiSigKey := lc.channelState.TheirMultiSigKey theirMultiSigKey := lc.channelState.TheirMultiSigKey
@ -1517,10 +1567,15 @@ func (lc *LightningChannel) ExtendRevocationWindow() (*lnwire.CommitRevocation,
// should be called when preparing to send an outgoing HTLC. // should be called when preparing to send an outgoing HTLC.
// TODO(roasbeef): check for duplicates below? edge case during restart w/ HTLC // TODO(roasbeef): check for duplicates below? edge case during restart w/ HTLC
// persistence // persistence
func (lc *LightningChannel) AddHTLC(htlc *lnwire.HTLCAddRequest) uint32 { func (lc *LightningChannel) AddHTLC(htlc *lnwire.HTLCAddRequest) (uint32, error) {
lc.Lock() lc.Lock()
defer lc.Unlock() defer lc.Unlock()
err := lc.validateCommitmentSanity(lc.theirLogCounter, lc.ourLogCounter, true)
if err != nil {
return 0, err
}
pd := &PaymentDescriptor{ pd := &PaymentDescriptor{
EntryType: Add, EntryType: Add,
RHash: PaymentHash(htlc.RedemptionHashes[0]), RHash: PaymentHash(htlc.RedemptionHashes[0]),
@ -1532,16 +1587,21 @@ func (lc *LightningChannel) AddHTLC(htlc *lnwire.HTLCAddRequest) uint32 {
lc.ourLogIndex[pd.Index] = lc.ourUpdateLog.PushBack(pd) lc.ourLogIndex[pd.Index] = lc.ourUpdateLog.PushBack(pd)
lc.ourLogCounter++ lc.ourLogCounter++
return pd.Index return pd.Index, nil
} }
// ReceiveHTLC adds an HTLC to the state machine's remote update log. This // ReceiveHTLC adds an HTLC to the state machine's remote update log. This
// method should be called in response to receiving a new HTLC from the remote // method should be called in response to receiving a new HTLC from the remote
// party. // party.
func (lc *LightningChannel) ReceiveHTLC(htlc *lnwire.HTLCAddRequest) uint32 { func (lc *LightningChannel) ReceiveHTLC(htlc *lnwire.HTLCAddRequest) (uint32, error) {
lc.Lock() lc.Lock()
defer lc.Unlock() defer lc.Unlock()
err := lc.validateCommitmentSanity(lc.theirLogCounter, lc.ourLogCounter, true)
if err != nil {
return 0, err
}
pd := &PaymentDescriptor{ pd := &PaymentDescriptor{
EntryType: Add, EntryType: Add,
RHash: PaymentHash(htlc.RedemptionHashes[0]), RHash: PaymentHash(htlc.RedemptionHashes[0]),
@ -1553,7 +1613,7 @@ func (lc *LightningChannel) ReceiveHTLC(htlc *lnwire.HTLCAddRequest) uint32 {
lc.theirLogIndex[pd.Index] = lc.theirUpdateLog.PushBack(pd) lc.theirLogIndex[pd.Index] = lc.theirUpdateLog.PushBack(pd)
lc.theirLogCounter++ lc.theirLogCounter++
return pd.Index return pd.Index, nil
} }
// SettleHTLC attempts to settle an existing outstanding received HTLC. The // SettleHTLC attempts to settle an existing outstanding received HTLC. The

View File

@ -731,6 +731,108 @@ func TestCooperativeChannelClosure(t *testing.T) {
} }
} }
// TestCheckHTLCNumberConstraint checks that we can't add HTLC or receive
// HTLC if number of HTLCs exceed maximum available number, also this test
// checks that if for some reason max number of HTLCs was exceeded and not
// caught before, the creation of new commitment will not be possible because
// of validation error.
func TestCheckHTLCNumberConstraint(t *testing.T) {
createHTLC := func(i int) *lnwire.HTLCAddRequest {
preimage := bytes.Repeat([]byte{byte(i)}, 32)
paymentHash := fastsha256.Sum256(preimage)
return &lnwire.HTLCAddRequest{
RedemptionHashes: [][32]byte{paymentHash},
Amount: lnwire.CreditsAmount(1e7),
Expiry: uint32(5),
}
}
checkError := func(err error) error {
if err == nil {
return errors.New("Exceed max htlc count error was " +
"not received")
} else if err != ErrMaxHTLCNumber {
return errors.Errorf("Unexpected error occured: %v", err)
}
return nil
}
// Create a test channel which will be used for the duration of this
// unittest. The channel will be funded evenly with Alice having 5 BTC,
// and Bob having 5 BTC.
aliceChannel, bobChannel, cleanUp, err := createTestChannels(3)
if err != nil {
t.Fatalf("unable to create test channels: %v", err)
}
defer cleanUp()
// Add max available number of HTLCs.
for i := 0; i < MaxHTLCNumber; i++ {
htlc := createHTLC(i)
if _, err := aliceChannel.AddHTLC(htlc); err != nil {
t.Fatalf("alice unable to add htlc: %v", err)
}
if _, err := bobChannel.ReceiveHTLC(htlc); err != nil {
t.Fatalf("bob unable to receive htlc: %v", err)
}
}
// Next addition should cause HTLC max number validation error.
htlc := createHTLC(0)
if _, err := aliceChannel.AddHTLC(htlc); err != nil {
if err := checkError(err); err != nil {
t.Fatal(err)
}
} else {
t.Fatal("Error was not received")
}
if _, err := bobChannel.AddHTLC(htlc); err != nil {
if err := checkError(err); err != nil {
t.Fatal(err)
}
} else {
t.Fatal("Error was not received")
}
if _, err := aliceChannel.ReceiveHTLC(htlc); err != nil {
if err := checkError(err); err != nil {
t.Fatal(err)
}
} else {
t.Fatal("Error was not received")
}
if _, err := bobChannel.ReceiveHTLC(htlc); err != nil {
if err := checkError(err); err != nil {
t.Fatal(err)
}
} else {
t.Fatal("Error was not received")
}
// Manually add HTLC to check SignNextCommitment validation error.
pd := &PaymentDescriptor{Index: aliceChannel.theirLogCounter}
aliceChannel.theirLogIndex[pd.Index] = aliceChannel.theirUpdateLog.PushBack(pd)
aliceChannel.theirLogCounter++
_, _, err = aliceChannel.SignNextCommitment()
if err := checkError(err); err != nil {
t.Fatal(err)
}
// Manually add HTLC to check ReceiveNewCommitment validation error.
pd = &PaymentDescriptor{Index: bobChannel.theirLogCounter}
bobChannel.theirLogIndex[pd.Index] = bobChannel.theirUpdateLog.PushBack(pd)
bobChannel.theirLogCounter++
// And on this stage we should receive the weight error.
someSig := []byte("somesig")
err = bobChannel.ReceiveNewCommitment(someSig, aliceChannel.theirLogCounter)
if err := checkError(err); err != nil {
t.Fatal(err)
}
}
func TestStateUpdatePersistence(t *testing.T) { func TestStateUpdatePersistence(t *testing.T) {
// Create a test channel which will be used for the duration of this // Create a test channel which will be used for the duration of this
// unittest. The channel will be funded evenly with Alice having 5 BTC, // unittest. The channel will be funded evenly with Alice having 5 BTC,

24
peer.go
View File

@ -1177,7 +1177,23 @@ func (p *peer) handleDownStreamPkt(state *commitmentState, pkt *htlcPacket) {
// to our local log, then update the commitment // to our local log, then update the commitment
// chains. // chains.
htlc.ChannelPoint = state.chanPoint htlc.ChannelPoint = state.chanPoint
index := state.channel.AddHTLC(htlc) index, err := state.channel.AddHTLC(htlc)
if err != nil {
// TODO: possibly perform fallback/retry logic
// depending on type of error
// TODO: send a cancel message back to the htlcSwitch.
peerLog.Errorf("Adding HTLC rejected: %v", err)
pkt.err <- err
// Increase the available bandwidth of the link,
// previously it was decremented and because
// HTLC adding failed we should do the reverse
// operation.
htlcSwitch := p.server.htlcSwitch
htlcSwitch.UpdateLink(htlc.ChannelPoint, pkt.amt)
return
}
p.queueMsg(htlc, nil) p.queueMsg(htlc, nil)
state.pendingBatch = append(state.pendingBatch, &pendingPayment{ state.pendingBatch = append(state.pendingBatch, &pendingPayment{
@ -1258,7 +1274,11 @@ func (p *peer) handleUpstreamMsg(state *commitmentState, msg lnwire.Message) {
// We just received an add request from an upstream peer, so we // We just received an add request from an upstream peer, so we
// add it to our state machine, then add the HTLC to our // add it to our state machine, then add the HTLC to our
// "settle" list in the event that we know the pre-image // "settle" list in the event that we know the pre-image
index := state.channel.ReceiveHTLC(htlcPkt) index, err := state.channel.ReceiveHTLC(htlcPkt)
if err != nil {
peerLog.Errorf("Receiving HTLC rejected: %v", err)
return
}
switch sphinxPacket.Action { switch sphinxPacket.Action {
// We're the designated payment destination. Therefore we // We're the designated payment destination. Therefore we