channeldb: AddInvoice now returns the addIndex of the new invoice

This commit is contained in:
Olaoluwa Osuntokun 2018-06-29 18:05:51 -07:00
parent 2dcc2d63a6
commit 7aeed0b58f
No known key found for this signature in database
GPG Key ID: 964EA263DD637C21
2 changed files with 27 additions and 16 deletions

View File

@ -70,7 +70,7 @@ func TestInvoiceWorkflow(t *testing.T) {
// Add the invoice to the database, this should succeed as there aren't
// any existing invoices within the database with the same payment
// hash.
if err := db.AddInvoice(fakeInvoice); err != nil {
if _, err := db.AddInvoice(fakeInvoice); err != nil {
t.Fatalf("unable to find invoice: %v", err)
}
@ -126,7 +126,7 @@ func TestInvoiceWorkflow(t *testing.T) {
// Attempt to insert generated above again, this should fail as
// duplicates are rejected by the processing logic.
if err := db.AddInvoice(fakeInvoice); err != ErrDuplicateInvoice {
if _, err := db.AddInvoice(fakeInvoice); err != ErrDuplicateInvoice {
t.Fatalf("invoice insertion should fail due to duplication, "+
"instead %v", err)
}
@ -149,7 +149,7 @@ func TestInvoiceWorkflow(t *testing.T) {
t.Fatalf("unable to create invoice: %v", err)
}
if err := db.AddInvoice(invoice); err != nil {
if _, err := db.AddInvoice(invoice); err != nil {
t.Fatalf("unable to add invoice %v", err)
}
@ -198,7 +198,7 @@ func TestInvoiceAddTimeSeries(t *testing.T) {
t.Fatalf("unable to create invoice: %v", err)
}
if err := db.AddInvoice(invoice); err != nil {
if _, err := db.AddInvoice(invoice); err != nil {
t.Fatalf("unable to add invoice %v", err)
}

View File

@ -179,11 +179,12 @@ func validateInvoice(i *Invoice) error {
// has *any* payment hashes which already exists within the database, then the
// insertion will be aborted and rejected due to the strict policy banning any
// duplicate payment hashes.
func (d *DB) AddInvoice(newInvoice *Invoice) error {
func (d *DB) AddInvoice(newInvoice *Invoice) (uint64, error) {
if err := validateInvoice(newInvoice); err != nil {
return err
return 0, err
}
var invoiceAddIndex uint64
err := d.Update(func(tx *bolt.Tx) error {
invoices, err := tx.CreateBucketIfNotExists(invoiceBucket)
if err != nil {
@ -227,15 +228,21 @@ func (d *DB) AddInvoice(newInvoice *Invoice) error {
invoiceNum = byteOrder.Uint32(invoiceCounter)
}
return putInvoice(
newIndex, err := putInvoice(
invoices, invoiceIndex, addIndex, newInvoice, invoiceNum,
)
if err != nil {
return err
}
invoiceAddIndex = newIndex
return nil
})
if err != nil {
return err
return 0, err
}
return err
return invoiceAddIndex, err
}
// InvoicesAddedSince can be used by callers to seek into the event time series
@ -501,7 +508,7 @@ func (d *DB) InvoicesSettledSince(sinceSettleIndex uint64) ([]Invoice, error) {
}
func putInvoice(invoices, invoiceIndex, addIndex *bolt.Bucket,
i *Invoice, invoiceNum uint32) error {
i *Invoice, invoiceNum uint32) (uint64, error) {
// Create the invoice key which is just the big-endian representation
// of the invoice number.
@ -514,7 +521,7 @@ func putInvoice(invoices, invoiceIndex, addIndex *bolt.Bucket,
invoiceCounter := invoiceNum + 1
byteOrder.PutUint32(scratch[:], invoiceCounter)
if err := invoiceIndex.Put(numInvoicesKey, scratch[:]); err != nil {
return err
return 0, err
}
// Add the payment hash to the invoice index. This will let us quickly
@ -523,7 +530,7 @@ func putInvoice(invoices, invoiceIndex, addIndex *bolt.Bucket,
paymentHash := sha256.Sum256(i.Terms.PaymentPreimage[:])
err := invoiceIndex.Put(paymentHash[:], invoiceKey[:])
if err != nil {
return err
return 0, err
}
// Next, we'll obtain the next add invoice index (sequence
@ -531,7 +538,7 @@ func putInvoice(invoices, invoiceIndex, addIndex *bolt.Bucket,
// event stream.
nextAddSeqNo, err := addIndex.NextSequence()
if err != nil {
return err
return 0, err
}
// With the next sequence obtained, we'll updating the event series in
@ -540,7 +547,7 @@ func putInvoice(invoices, invoiceIndex, addIndex *bolt.Bucket,
var seqNoBytes [8]byte
byteOrder.PutUint64(seqNoBytes[:], nextAddSeqNo)
if err := addIndex.Put(seqNoBytes[:], invoiceKey[:]); err != nil {
return err
return 0, err
}
i.AddIndex = nextAddSeqNo
@ -548,10 +555,14 @@ func putInvoice(invoices, invoiceIndex, addIndex *bolt.Bucket,
// Finally, serialize the invoice itself to be written to the disk.
var buf bytes.Buffer
if err := serializeInvoice(&buf, i); err != nil {
return nil
return 0, nil
}
return invoices.Put(invoiceKey[:], buf.Bytes())
if err := invoices.Put(invoiceKey[:], buf.Bytes()); err != nil {
return 0, err
}
return nextAddSeqNo, nil
}
func serializeInvoice(w io.Writer, i *Invoice) error {