diff --git a/mempool/estimatefee.go b/mempool/estimatefee.go index 724f3c4e..234c5cf4 100644 --- a/mempool/estimatefee.go +++ b/mempool/estimatefee.go @@ -5,11 +5,15 @@ package mempool import ( + "bytes" + "encoding/binary" "errors" "fmt" + "io" "math" "math/rand" "sort" + "strings" "sync" "github.com/roasbeef/btcd/chaincfg/chainhash" @@ -20,8 +24,6 @@ import ( // TODO incorporate Alex Morcos' modifications to Gavin's initial model // https://lists.linuxfoundation.org/pipermail/bitcoin-dev/2014-October/006824.html -// TODO store and restore the FeeEstimator state in the database. - const ( // estimateFeeDepth is the maximum number of blocks before a transaction // is confirmed that we want to track. @@ -48,6 +50,12 @@ const ( btcPerSatoshi = 1E-8 ) +var ( + // EstimateFeeDatabaseKey is the key that we use to + // store the fee estimator in the database. + EstimateFeeDatabaseKey = []byte("estimatefee") +) + // SatoshiPerByte is number with units of satoshis per byte. type SatoshiPerByte float64 @@ -99,6 +107,29 @@ type observedTransaction struct { mined int32 } +func (o *observedTransaction) Serialize(w io.Writer) { + binary.Write(w, binary.BigEndian, o.hash) + binary.Write(w, binary.BigEndian, o.feeRate) + binary.Write(w, binary.BigEndian, o.observed) + binary.Write(w, binary.BigEndian, o.mined) +} + +func deserializeObservedTransaction(r io.Reader) (*observedTransaction, error) { + ot := observedTransaction{} + + // The first 32 bytes should be a hash. + binary.Read(r, binary.BigEndian, &ot.hash) + + // The next 8 are SatoshiPerByte + binary.Read(r, binary.BigEndian, &ot.feeRate) + + // And next there are two uint32's. + binary.Read(r, binary.BigEndian, &ot.observed) + binary.Read(r, binary.BigEndian, &ot.mined) + + return &ot, nil +} + // registeredBlock has the hash of a block and the list of transactions // it mined which had been previously observed by the FeeEstimator. It // is used if Rollback is called to reverse the effect of registering @@ -108,6 +139,15 @@ type registeredBlock struct { transactions []*observedTransaction } +func (rb *registeredBlock) serialize(w io.Writer, txs map[*observedTransaction]uint32) { + binary.Write(w, binary.BigEndian, rb.hash) + + binary.Write(w, binary.BigEndian, uint32(len(rb.transactions))) + for _, o := range rb.transactions { + binary.Write(w, binary.BigEndian, txs[o]) + } +} + // FeeEstimator manages the data necessary to create // fee estimations. It is safe for concurrent access. type FeeEstimator struct { @@ -533,3 +573,177 @@ func (ef *FeeEstimator) EstimateFee(numBlocks uint32) (BtcPerKilobyte, error) { return ef.cached[int(numBlocks)-1].ToBtcPerKb(), nil } + +// In case the format for the serialized version of the FeeEstimator changes, +// we use a version number. If the version number changes, it does not make +// sense to try to upgrade a previous version to a new version. Instead, just +// start fee estimation over. +const estimateFeeSaveVersion = 1 + +func deserializeRegisteredBlock(r io.Reader, txs map[uint32]*observedTransaction) (*registeredBlock, error) { + var lenTransactions uint32 + + rb := ®isteredBlock{} + binary.Read(r, binary.BigEndian, &rb.hash) + binary.Read(r, binary.BigEndian, &lenTransactions) + + rb.transactions = make([]*observedTransaction, lenTransactions) + + for i := uint32(0); i < lenTransactions; i++ { + var index uint32 + binary.Read(r, binary.BigEndian, &index) + rb.transactions[i] = txs[index] + } + + return rb, nil +} + +// FeeEstimatorState represents a saved FeeEstimator that can be +// restored with data from an earlier session of the program. +type FeeEstimatorState []byte + +// observedTxSet is a set of txs that can that is sorted +// by hash. It exists for serialization purposes so that +// a serialized state always comes out the same. +type observedTxSet []*observedTransaction + +func (q observedTxSet) Len() int { return len(q) } + +func (q observedTxSet) Less(i, j int) bool { + return strings.Compare(q[i].hash.String(), q[j].hash.String()) < 0 +} + +func (q observedTxSet) Swap(i, j int) { + q[i], q[j] = q[j], q[i] +} + +// Save records the current state of the FeeEstimator to a []byte that +// can be restored later. +func (ef *FeeEstimator) Save() FeeEstimatorState { + ef.mtx.Lock() + defer ef.mtx.Unlock() + + // TODO figure out what the capacity should be. + w := bytes.NewBuffer(make([]byte, 0)) + + binary.Write(w, binary.BigEndian, uint32(estimateFeeSaveVersion)) + + // Insert basic parameters. + binary.Write(w, binary.BigEndian, &ef.maxRollback) + binary.Write(w, binary.BigEndian, &ef.binSize) + binary.Write(w, binary.BigEndian, &ef.maxReplacements) + binary.Write(w, binary.BigEndian, &ef.minRegisteredBlocks) + binary.Write(w, binary.BigEndian, &ef.lastKnownHeight) + binary.Write(w, binary.BigEndian, &ef.numBlocksRegistered) + + // Put all the observed transactions in a sorted list. + var txCount uint32 + ots := make([]*observedTransaction, len(ef.observed)) + for hash := range ef.observed { + ots[txCount] = ef.observed[hash] + txCount++ + } + + sort.Sort(observedTxSet(ots)) + + txCount = 0 + observed := make(map[*observedTransaction]uint32) + binary.Write(w, binary.BigEndian, uint32(len(ef.observed))) + for _, ot := range ots { + ot.Serialize(w) + observed[ot] = txCount + txCount++ + } + + // Save all the right bins. + for _, list := range ef.bin { + + binary.Write(w, binary.BigEndian, uint32(len(list))) + + for _, o := range list { + binary.Write(w, binary.BigEndian, observed[o]) + } + } + + // Dropped transactions. + binary.Write(w, binary.BigEndian, uint32(len(ef.dropped))) + for _, registered := range ef.dropped { + registered.serialize(w, observed) + } + + // Commit the tx and return. + return FeeEstimatorState(w.Bytes()) +} + +// RestoreFeeEstimator takes a FeeEstimatorState that was previously +// returned by Save and restores it to a FeeEstimator +func RestoreFeeEstimator(data FeeEstimatorState) (*FeeEstimator, error) { + r := bytes.NewReader([]byte(data)) + + // Check version + var version uint32 + err := binary.Read(r, binary.BigEndian, &version) + if err != nil { + return nil, err + } + if version != estimateFeeSaveVersion { + return nil, fmt.Errorf("Incorrect version: expected %d found %d", estimateFeeSaveVersion, version) + } + + ef := &FeeEstimator{ + observed: make(map[chainhash.Hash]*observedTransaction), + } + + // Read basic parameters. + binary.Read(r, binary.BigEndian, &ef.maxRollback) + binary.Read(r, binary.BigEndian, &ef.binSize) + binary.Read(r, binary.BigEndian, &ef.maxReplacements) + binary.Read(r, binary.BigEndian, &ef.minRegisteredBlocks) + binary.Read(r, binary.BigEndian, &ef.lastKnownHeight) + binary.Read(r, binary.BigEndian, &ef.numBlocksRegistered) + + // Read transactions. + var numObserved uint32 + observed := make(map[uint32]*observedTransaction) + binary.Read(r, binary.BigEndian, &numObserved) + for i := uint32(0); i < numObserved; i++ { + ot, err := deserializeObservedTransaction(r) + if err != nil { + return nil, err + } + observed[i] = ot + ef.observed[ot.hash] = ot + } + + // Read bins. + for i := 0; i < estimateFeeDepth; i++ { + var numTransactions uint32 + binary.Read(r, binary.BigEndian, &numTransactions) + bin := make([]*observedTransaction, numTransactions) + for j := uint32(0); j < numTransactions; j++ { + var index uint32 + binary.Read(r, binary.BigEndian, &index) + + var exists bool + bin[j], exists = observed[index] + if !exists { + return nil, fmt.Errorf("Invalid transaction reference %d", index) + } + } + ef.bin[i] = bin + } + + // Read dropped transactions. + var numDropped uint32 + binary.Read(r, binary.BigEndian, &numDropped) + ef.dropped = make([]*registeredBlock, numDropped) + for i := uint32(0); i < numDropped; i++ { + var err error + ef.dropped[int(i)], err = deserializeRegisteredBlock(r, observed) + if err != nil { + return nil, err + } + } + + return ef, nil +} diff --git a/mempool/estimatefee_test.go b/mempool/estimatefee_test.go index df73e322..07130673 100644 --- a/mempool/estimatefee_test.go +++ b/mempool/estimatefee_test.go @@ -5,6 +5,7 @@ package mempool import ( + "bytes" "math/rand" "testing" @@ -364,3 +365,60 @@ func TestEstimateFeeRollback(t *testing.T) { estimateHistory = estimateHistory[0 : len(estimateHistory)-stepsBack] } } + +func (eft *estimateFeeTester) checkSaveAndRestore( + previousEstimates [estimateFeeDepth]BtcPerKilobyte) { + + // Get the save state. + save := eft.ef.Save() + + // Save and restore database. + var err error + eft.ef, err = RestoreFeeEstimator(save) + if err != nil { + eft.t.Fatalf("Could not restore database: %s", err) + } + + // Save again and check that it matches the previous one. + redo := eft.ef.Save() + if !bytes.Equal(save, redo) { + eft.t.Fatalf("Restored states do not match: %v %v", save, redo) + } + + // Check that the results match. + newEstimates := eft.estimates() + + for i, prev := range previousEstimates { + if prev != newEstimates[i] { + eft.t.Error("Mismatch in estimate ", i, " after restore; got ", newEstimates[i], " but expected ", prev) + } + } +} + +// TestSave tests saving and restoring to a []byte. +func TestDatabase(t *testing.T) { + + txPerRound := uint32(7) + txPerBlock := uint32(5) + binSize := uint32(6) + maxReplacements := uint32(4) + rounds := 8 + + eft := estimateFeeTester{ef: newTestFeeEstimator(binSize, maxReplacements, uint32(rounds)+1), t: t} + var txHistory [][]*TxDesc + estimateHistory := [][estimateFeeDepth]BtcPerKilobyte{eft.estimates()} + + for round := 0; round < rounds; round++ { + eft.checkSaveAndRestore(estimateHistory[len(estimateHistory)-1]) + + // Go forward one step. + txHistory, estimateHistory = + eft.round(txHistory, estimateHistory, txPerRound, txPerBlock) + } + + // Reverse the process and try again. + for round := 1; round <= rounds; round++ { + eft.rollback() + eft.checkSaveAndRestore(estimateHistory[len(estimateHistory)-round-1]) + } +} diff --git a/rpcserver.go b/rpcserver.go index d437775a..0dae52a7 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -873,7 +873,7 @@ func handleEstimateFee(s *rpcServer, cmd interface{}, closeChan <-chan struct{}) } // Convert to satoshis per kb. - return float64(feeRate.ToSatoshiPerKb()), nil + return float64(feeRate), nil } // handleGenerate handles generate commands. diff --git a/server.go b/server.go index fa9370ac..ae01b5e2 100644 --- a/server.go +++ b/server.go @@ -230,6 +230,10 @@ type server struct { txIndex *indexers.TxIndex addrIndex *indexers.AddrIndex cfIndex *indexers.CfIndex + + // The fee estimator keeps track of how long transactions are left in + // the mempool before they are mined into blocks. + feeEstimator *mempool.FeeEstimator } // serverPeer extends the peer to maintain state shared by the server and @@ -2107,6 +2111,14 @@ func (s *server) Stop() error { s.rpcServer.Stop() } + // Save fee estimator state in the database. + s.db.Update(func(tx database.Tx) error { + metadata := tx.Metadata() + metadata.Put(mempool.EstimateFeeDatabaseKey, s.feeEstimator.Save()) + + return nil + }) + // Signal the remaining goroutines to quit. close(s.quit) return nil @@ -2411,9 +2423,35 @@ func newServer(listenAddrs []string, db database.DB, chainParams *chaincfg.Param return nil, err } - feeEstimator := mempool.NewFeeEstimator( - mempool.DefaultEstimateFeeMaxRollback, - mempool.DefaultEstimateFeeMinRegisteredBlocks) + // Search for a FeeEstimator state in the database. If none can be found + // or if it cannot be loaded, create a new one. + db.Update(func(tx database.Tx) error { + metadata := tx.Metadata() + feeEstimationData := metadata.Get(mempool.EstimateFeeDatabaseKey) + if feeEstimationData != nil { + // delete it from the database so that we don't try to restore the + // same thing again somehow. + metadata.Delete(mempool.EstimateFeeDatabaseKey) + + // If there is an error, log it and make a new fee estimator. + var err error + s.feeEstimator, err = mempool.RestoreFeeEstimator(feeEstimationData) + + if err != nil { + peerLog.Errorf("Failed to restore fee estimator %v", err) + } + } + + return nil + }) + + // If no feeEstimator has been found, or if the one that has been found + // is behind somehow, create a new one and start over. + if s.feeEstimator == nil || s.feeEstimator.LastKnownHeight() != s.chain.BestSnapshot().Height { + s.feeEstimator = mempool.NewFeeEstimator( + mempool.DefaultEstimateFeeMaxRollback, + mempool.DefaultEstimateFeeMinRegisteredBlocks) + } txC := mempool.Config{ Policy: mempool.Policy{ @@ -2437,7 +2475,7 @@ func newServer(listenAddrs []string, db database.DB, chainParams *chaincfg.Param SigCache: s.sigCache, HashCache: s.hashCache, AddrIndex: s.addrIndex, - FeeEstimator: feeEstimator, + FeeEstimator: s.feeEstimator, } s.txMemPool = mempool.New(&txC) @@ -2586,7 +2624,7 @@ func newServer(listenAddrs []string, db database.DB, chainParams *chaincfg.Param TxIndex: s.txIndex, AddrIndex: s.addrIndex, CfIndex: s.cfIndex, - FeeEstimator: feeEstimator, + FeeEstimator: s.feeEstimator, }) if err != nil { return nil, err