diff --git a/channeldb/payment_control.go b/channeldb/payment_control.go index 1cbc0b81f..6bd094771 100644 --- a/channeldb/payment_control.go +++ b/channeldb/payment_control.go @@ -8,6 +8,7 @@ import ( "github.com/coreos/bbolt" "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/routing/route" ) var ( @@ -33,6 +34,10 @@ var ( // ErrUnknownPaymentStatus is returned when we do not recognize the // existing state of a payment. ErrUnknownPaymentStatus = errors.New("unknown payment status") + + // errNoAttemptInfo is returned when no attempt info is stored yet. + errNoAttemptInfo = errors.New("unable to find attempt info for " + + "inflight payment") ) // PaymentControl implements persistence for payments and payment attempts. @@ -187,9 +192,12 @@ func (p *PaymentControl) RegisterAttempt(paymentHash lntypes.Hash, // duplicate payments to the same payment hash. The provided preimage is // atomically saved to the DB for record keeping. func (p *PaymentControl) Success(paymentHash lntypes.Hash, - preimage lntypes.Preimage) error { + preimage lntypes.Preimage) (*route.Route, error) { - var updateErr error + var ( + updateErr error + route *route.Route + ) err := p.db.Batch(func(tx *bbolt.Tx) error { // Reset the update error, to avoid carrying over an error // from a previous execution of the batched db transaction. @@ -211,14 +219,26 @@ func (p *PaymentControl) Success(paymentHash lntypes.Hash, // Record the successful payment info atomically to the // payments record. - return bucket.Put(paymentSettleInfoKey, preimage[:]) + err = bucket.Put(paymentSettleInfoKey, preimage[:]) + if err != nil { + return err + } + + // Retrieve attempt info for the notification. + attempt, err := fetchPaymentAttempt(bucket) + if err != nil { + return err + } + + route = &attempt.Route + + return nil }) if err != nil { - return err + return nil, err } - return updateErr - + return route, updateErr } // Fail transitions a payment into the Failed state, and records the reason the @@ -259,6 +279,28 @@ func (p *PaymentControl) Fail(paymentHash lntypes.Hash, return updateErr } +// FetchPayment returns information about a payment from the database. +func (p *PaymentControl) FetchPayment(paymentHash lntypes.Hash) ( + *Payment, error) { + + var payment *Payment + err := p.db.View(func(tx *bbolt.Tx) error { + bucket, err := fetchPaymentBucket(tx, paymentHash) + if err != nil { + return err + } + + payment, err = fetchPayment(bucket) + + return err + }) + if err != nil { + return nil, err + } + + return payment, nil +} + // createPaymentBucket creates or fetches the sub-bucket assigned to this // payment hash. func createPaymentBucket(tx *bbolt.Tx, paymentHash lntypes.Hash) ( @@ -357,6 +399,17 @@ func ensureInFlight(bucket *bbolt.Bucket) error { } } +// fetchPaymentAttempt fetches the payment attempt from the bucket. +func fetchPaymentAttempt(bucket *bbolt.Bucket) (*PaymentAttemptInfo, error) { + attemptData := bucket.Get(paymentAttemptInfoKey) + if attemptData == nil { + return nil, errNoAttemptInfo + } + + r := bytes.NewReader(attemptData) + return deserializePaymentAttemptInfo(r) +} + // InFlightPayment is a wrapper around a payment that has status InFlight. type InFlightPayment struct { // Info is the PaymentCreationInfo of the in-flight payment. @@ -408,15 +461,11 @@ func (p *PaymentControl) FetchInFlightPayments() ([]*InFlightPayment, error) { return err } - // Now get the attempt info, which may or may not be - // available. - attempt := bucket.Get(paymentAttemptInfoKey) - if attempt != nil { - r = bytes.NewReader(attempt) - inFlight.Attempt, err = deserializePaymentAttemptInfo(r) - if err != nil { - return err - } + // Now get the attempt info. It could be that there is + // no attempt info yet. + inFlight.Attempt, err = fetchPaymentAttempt(bucket) + if err != nil && err != errNoAttemptInfo { + return err } inFlights = append(inFlights, inFlight) diff --git a/channeldb/payment_control_test.go b/channeldb/payment_control_test.go index 370300d73..ea69645f5 100644 --- a/channeldb/payment_control_test.go +++ b/channeldb/payment_control_test.go @@ -14,6 +14,7 @@ import ( "github.com/coreos/bbolt" "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/routing/route" ) func initDB() (*DB, error) { @@ -131,9 +132,14 @@ func TestPaymentControlSwitchFail(t *testing.T) { ) // Verifies that status was changed to StatusSucceeded. - if err := pControl.Success(info.PaymentHash, preimg); err != nil { + var route *route.Route + route, err = pControl.Success(info.PaymentHash, preimg) + if err != nil { t.Fatalf("error shouldn't have been received, got: %v", err) } + if !reflect.DeepEqual(*route, attempt.Route) { + t.Fatalf("unexpected route returned") + } assertPaymentStatus(t, db, info.PaymentHash, StatusSucceeded) assertPaymentInfo(t, db, info.PaymentHash, info, attempt, preimg, nil) @@ -204,7 +210,7 @@ func TestPaymentControlSwitchDoubleSend(t *testing.T) { } // After settling, the error should be ErrAlreadyPaid. - if err := pControl.Success(info.PaymentHash, preimg); err != nil { + if _, err := pControl.Success(info.PaymentHash, preimg); err != nil { t.Fatalf("error shouldn't have been received, got: %v", err) } assertPaymentStatus(t, db, info.PaymentHash, StatusSucceeded) @@ -234,7 +240,7 @@ func TestPaymentControlSuccessesWithoutInFlight(t *testing.T) { } // Attempt to complete the payment should fail. - err = pControl.Success(info.PaymentHash, preimg) + _, err = pControl.Success(info.PaymentHash, preimg) if err != ErrPaymentNotInitiated { t.Fatalf("expected ErrPaymentNotInitiated, got %v", err) } @@ -337,7 +343,7 @@ func TestPaymentControlDeleteNonInFligt(t *testing.T) { ) } else if p.success { // Verifies that status was changed to StatusSucceeded. - err := pControl.Success(info.PaymentHash, preimg) + _, err := pControl.Success(info.PaymentHash, preimg) if err != nil { t.Fatalf("error shouldn't have been received, got: %v", err) } diff --git a/routing/control_tower.go b/routing/control_tower.go index c7a2216ae..d8a1e0971 100644 --- a/routing/control_tower.go +++ b/routing/control_tower.go @@ -1,9 +1,12 @@ package routing import ( - "github.com/lightningnetwork/lnd/channeldb" + "errors" + "sync" + "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/routing/route" ) // ControlTower tracks all outgoing payments made, whose primary purpose is to @@ -35,18 +38,47 @@ type ControlTower interface { // FetchInFlightPayments returns all payments with status InFlight. FetchInFlightPayments() ([]*channeldb.InFlightPayment, error) + + // SubscribePayment subscribes to updates for the payment with the given + // hash. It returns a boolean indicating whether the payment is still in + // flight and a channel that provides the final outcome of the payment. + SubscribePayment(paymentHash lntypes.Hash) (bool, chan PaymentResult, + error) +} + +// PaymentResult is the struct describing the events received by payment +// subscribers. +type PaymentResult struct { + // Success indicates whether the payment was successful. + Success bool + + // Route is the (last) route attempted to send the HTLC. It is only set + // for successful payments. + Route *route.Route + + // PaymentPreimage is the preimage of a successful payment. This serves + // as a proof of payment. It is only set for successful payments. + Preimage lntypes.Preimage + + // Failure is a failure reason code indicating the reason the payment + // failed. It is only set for failed payments. + FailureReason channeldb.FailureReason } // controlTower is persistent implementation of ControlTower to restrict // double payment sending. type controlTower struct { db *channeldb.PaymentControl + + subscribers map[lntypes.Hash][]chan PaymentResult + subscribersMtx sync.Mutex } // NewControlTower creates a new instance of the controlTower. func NewControlTower(db *channeldb.PaymentControl) ControlTower { return &controlTower{ - db: db, + db: db, + subscribers: make(map[lntypes.Hash][]chan PaymentResult), } } @@ -75,7 +107,21 @@ func (p *controlTower) RegisterAttempt(paymentHash lntypes.Hash, func (p *controlTower) Success(paymentHash lntypes.Hash, preimage lntypes.Preimage) error { - return p.db.Success(paymentHash, preimage) + route, err := p.db.Success(paymentHash, preimage) + if err != nil { + return err + } + + // Notify subscribers of success event. + p.notifyFinalEvent( + paymentHash, PaymentResult{ + Success: true, + Preimage: preimage, + Route: route, + }, + ) + + return nil } // Fail transitions a payment into the Failed state, and records the reason the @@ -85,10 +131,108 @@ func (p *controlTower) Success(paymentHash lntypes.Hash, func (p *controlTower) Fail(paymentHash lntypes.Hash, reason channeldb.FailureReason) error { - return p.db.Fail(paymentHash, reason) + err := p.db.Fail(paymentHash, reason) + if err != nil { + return err + } + + // Notify subscribers of fail event. + p.notifyFinalEvent( + paymentHash, PaymentResult{ + Success: false, + FailureReason: reason, + }, + ) + + return nil } // FetchInFlightPayments returns all payments with status InFlight. func (p *controlTower) FetchInFlightPayments() ([]*channeldb.InFlightPayment, error) { return p.db.FetchInFlightPayments() } + +// SubscribePayment subscribes to updates for the payment with the given hash. +// It returns a boolean indicating whether the payment is still in flight and a +// channel that provides the final outcome of the payment. +func (p *controlTower) SubscribePayment(paymentHash lntypes.Hash) ( + bool, chan PaymentResult, error) { + + // Create a channel with buffer size 1. For every payment there will be + // exactly one event sent. + c := make(chan PaymentResult, 1) + + // Take lock before querying the db to prevent this scenario: + // FetchPayment returns us an in-flight state -> payment succeeds, but + // there is no subscriber to notify yet -> we add ourselves as a + // subscriber -> ... we will never receive a notification. + p.subscribersMtx.Lock() + defer p.subscribersMtx.Unlock() + + payment, err := p.db.FetchPayment(paymentHash) + if err != nil { + return false, nil, err + } + + var event PaymentResult + + switch payment.Status { + + // Payment is currently in flight. Register this subscriber and + // return without writing a result to the channel yet. + case channeldb.StatusInFlight: + p.subscribers[paymentHash] = append( + p.subscribers[paymentHash], c, + ) + + return true, c, nil + + // Payment already succeeded. It is not necessary to register as + // a subscriber, because we can send the result on the channel + // immediately. + case channeldb.StatusSucceeded: + event.Success = true + event.Preimage = *payment.PaymentPreimage + event.Route = &payment.Attempt.Route + + // Payment already failed. It is not necessary to register as a + // subscriber, because we can send the result on the channel + // immediately. + case channeldb.StatusFailed: + event.Success = false + event.FailureReason = *payment.Failure + + default: + return false, nil, errors.New("unknown payment status") + } + + // Write immediate result to the channel. + c <- event + close(c) + + return false, c, nil +} + +// notifyFinalEvent sends a final payment event to all subscribers of this +// payment. The channel will be closed after this. +func (p *controlTower) notifyFinalEvent(paymentHash lntypes.Hash, + event PaymentResult) { + + // Get all subscribers for this hash. As there is only a single outcome, + // the subscriber list can be cleared. + p.subscribersMtx.Lock() + list, ok := p.subscribers[paymentHash] + if !ok { + p.subscribersMtx.Unlock() + return + } + delete(p.subscribers, paymentHash) + p.subscribersMtx.Unlock() + + // Notify all subscribers of the event. The subscriber channel is + // buffered, so it cannot block here. + for _, subscriber := range list { + subscriber <- event + close(subscriber) + } +} diff --git a/routing/control_tower_test.go b/routing/control_tower_test.go new file mode 100644 index 000000000..0a765fdbc --- /dev/null +++ b/routing/control_tower_test.go @@ -0,0 +1,283 @@ +package routing + +import ( + "crypto/rand" + "crypto/sha256" + "fmt" + "io" + "io/ioutil" + "reflect" + "testing" + "time" + + "github.com/btcsuite/btcd/btcec" + "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/routing/route" + + "github.com/lightningnetwork/lnd/lntypes" +) + +var ( + priv, _ = btcec.NewPrivateKey(btcec.S256()) + pub = priv.PubKey() + + testHop = &route.Hop{ + PubKeyBytes: route.NewVertex(pub), + ChannelID: 12345, + OutgoingTimeLock: 111, + AmtToForward: 555, + } + + testRoute = route.Route{ + TotalTimeLock: 123, + TotalAmount: 1234567, + SourcePubKey: route.NewVertex(pub), + Hops: []*route.Hop{ + testHop, + testHop, + }, + } + + testTimeout = 5 * time.Second +) + +// TestControlTowerSubscribeUnknown tests that subscribing to an unknown +// payment fails. +func TestControlTowerSubscribeUnknown(t *testing.T) { + t.Parallel() + + db, err := initDB() + if err != nil { + t.Fatalf("unable to init db: %v", err) + } + + pControl := NewControlTower(channeldb.NewPaymentControl(db)) + + // Subscription should fail when the payment is not known. + _, _, err = pControl.SubscribePayment(lntypes.Hash{1}) + if err != channeldb.ErrPaymentNotInitiated { + t.Fatal("expected subscribe to fail for unknown payment") + } +} + +// TestControlTowerSubscribeSuccess tests that payment updates for a +// successful payment are properly sent to subscribers. +func TestControlTowerSubscribeSuccess(t *testing.T) { + t.Parallel() + + db, err := initDB() + if err != nil { + t.Fatalf("unable to init db: %v", err) + } + + pControl := NewControlTower(channeldb.NewPaymentControl(db)) + + // Initiate a payment. + info, attempt, preimg, err := genInfo() + if err != nil { + t.Fatal(err) + } + + err = pControl.InitPayment(info.PaymentHash, info) + if err != nil { + t.Fatal(err) + } + + // Subscription should succeed and immediately report the InFlight + // status. + inFlight, subscriber1, err := pControl.SubscribePayment(info.PaymentHash) + if err != nil { + t.Fatalf("expected subscribe to succeed, but got: %v", err) + } + if !inFlight { + t.Fatalf("unexpected payment to be in flight") + } + + // Register an attempt. + err = pControl.RegisterAttempt(info.PaymentHash, attempt) + if err != nil { + t.Fatal(err) + } + + // Register a second subscriber after the first attempt has started. + inFlight, subscriber2, err := pControl.SubscribePayment(info.PaymentHash) + if err != nil { + t.Fatalf("expected subscribe to succeed, but got: %v", err) + } + if !inFlight { + t.Fatalf("unexpected payment to be in flight") + } + + // Mark the payment as successful. + if err := pControl.Success(info.PaymentHash, preimg); err != nil { + t.Fatal(err) + } + + // Register a third subscriber after the payment succeeded. + inFlight, subscriber3, err := pControl.SubscribePayment(info.PaymentHash) + if err != nil { + t.Fatalf("expected subscribe to succeed, but got: %v", err) + } + if inFlight { + t.Fatalf("expected payment to be finished") + } + + // We expect all subscribers to now report the final outcome followed by + // no other events. + subscribers := []chan PaymentResult{ + subscriber1, subscriber2, subscriber3, + } + + for _, s := range subscribers { + var result PaymentResult + select { + case result = <-s: + case <-time.After(testTimeout): + t.Fatal("timeout waiting for payment result") + } + + if !result.Success { + t.Fatal("unexpected payment state") + } + if result.Preimage != preimg { + t.Fatal("unexpected preimage") + } + if !reflect.DeepEqual(result.Route, &attempt.Route) { + t.Fatal("unexpected route") + } + + // After the final event, we expect the channel to be closed. + select { + case _, ok := <-s: + if ok { + t.Fatal("expected channel to be closed") + } + case <-time.After(testTimeout): + t.Fatal("timeout waiting for result channel close") + } + } +} + +// TestPaymentControlSubscribeFail tests that payment updates for a +// failed payment are properly sent to subscribers. +func TestPaymentControlSubscribeFail(t *testing.T) { + t.Parallel() + + db, err := initDB() + if err != nil { + t.Fatalf("unable to init db: %v", err) + } + + pControl := NewControlTower(channeldb.NewPaymentControl(db)) + + // Initiate a payment. + info, _, _, err := genInfo() + if err != nil { + t.Fatal(err) + } + + err = pControl.InitPayment(info.PaymentHash, info) + if err != nil { + t.Fatal(err) + } + + // Subscription should succeed. + _, subscriber1, err := pControl.SubscribePayment(info.PaymentHash) + if err != nil { + t.Fatalf("expected subscribe to succeed, but got: %v", err) + } + + // Mark the payment as failed. + if err := pControl.Fail(info.PaymentHash, channeldb.FailureReasonTimeout); err != nil { + t.Fatal(err) + } + + // Register a second subscriber after the payment failed. + inFlight, subscriber2, err := pControl.SubscribePayment(info.PaymentHash) + if err != nil { + t.Fatalf("expected subscribe to succeed, but got: %v", err) + } + if inFlight { + t.Fatalf("expected payment to be finished") + } + + // We expect all subscribers to now report the final outcome followed by + // no other events. + subscribers := []chan PaymentResult{ + subscriber1, subscriber2, + } + + for _, s := range subscribers { + var result PaymentResult + select { + case result = <-s: + case <-time.After(testTimeout): + t.Fatal("timeout waiting for payment result") + } + + if result.Success { + t.Fatal("unexpected payment state") + } + if result.Route != nil { + t.Fatal("expected no route") + } + if result.FailureReason != channeldb.FailureReasonTimeout { + t.Fatal("unexpected failure reason") + } + + // After the final event, we expect the channel to be closed. + select { + case _, ok := <-s: + if ok { + t.Fatal("expected channel to be closed") + } + case <-time.After(testTimeout): + t.Fatal("timeout waiting for result channel close") + } + } +} + +func initDB() (*channeldb.DB, error) { + tempPath, err := ioutil.TempDir("", "routingdb") + if err != nil { + return nil, err + } + + db, err := channeldb.Open(tempPath) + if err != nil { + return nil, err + } + + return db, err +} + +func genInfo() (*channeldb.PaymentCreationInfo, *channeldb.PaymentAttemptInfo, + lntypes.Preimage, error) { + + preimage, err := genPreimage() + if err != nil { + return nil, nil, preimage, fmt.Errorf("unable to "+ + "generate preimage: %v", err) + } + + rhash := sha256.Sum256(preimage[:]) + return &channeldb.PaymentCreationInfo{ + PaymentHash: rhash, + Value: 1, + CreationDate: time.Unix(time.Now().Unix(), 0), + PaymentRequest: []byte("hola"), + }, + &channeldb.PaymentAttemptInfo{ + PaymentID: 1, + SessionKey: priv, + Route: testRoute, + }, preimage, nil +} + +func genPreimage() ([32]byte, error) { + var preimage [32]byte + if _, err := io.ReadFull(rand.Reader, preimage[:]); err != nil { + return preimage, err + } + return preimage, nil +} diff --git a/routing/mock_test.go b/routing/mock_test.go index e61b6677f..600c55874 100644 --- a/routing/mock_test.go +++ b/routing/mock_test.go @@ -4,6 +4,7 @@ import ( "fmt" "sync" + "github.com/go-errors/errors" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/lntypes" @@ -284,3 +285,9 @@ func (m *mockControlTower) FetchInFlightPayments() ( return fl, nil } + +func (m *mockControlTower) SubscribePayment(paymentHash lntypes.Hash) ( + bool, chan PaymentResult, error) { + + return false, nil, errors.New("not implemented") +}