spv header sync works OK

This commit is contained in:
Tadge Dryja 2016-01-14 23:08:37 -08:00 committed by Olaoluwa Osuntokun
parent 6cebc7c952
commit e6970e132e
4 changed files with 86 additions and 54 deletions

View File

@ -23,7 +23,7 @@ const (
)
var (
params = &chaincfg.TestNet3Params
params = &chaincfg.TestNetLParams
)
type SPVCon struct {
@ -147,7 +147,7 @@ func (s *SPVCon) SendFilter(f *bloom.Filter) {
return
}
func (s *SPVCon) GrabHeaders() error {
func (s *SPVCon) AskForHeaders() error {
var hdr wire.BlockHeader
ghdr := wire.NewMsgGetHeaders()
ghdr.ProtocolVersion = s.localVersion
@ -188,66 +188,86 @@ func (s *SPVCon) GrabHeaders() error {
s.outMsgQueue <- ghdr
return nil
// =============================================================
// ask for headers. probably will get 2000.
log.Printf("getheader version %d \n", ghdr.ProtocolVersion)
n, m, _, err := wire.ReadMessageN(s.con, VERSION, NETVERSION)
if err != nil {
return err
}
log.Printf("4got %d byte response\n command: %s\n", n, m.Command())
hdrresponse, ok := m.(*wire.MsgHeaders)
if !ok {
log.Printf("got non-header message.")
return nil
// this can acutally happen and we should deal with / ignore it
// also pings, they don't like it when you don't respond to pings.
// invs and the rest we can ignore for now until filters are up.
}
}
func (s *SPVCon) IngestHeaders(m *wire.MsgHeaders) (bool, error) {
var err error
_, err = s.headerFile.Seek(-80, os.SEEK_END)
if err != nil {
return err
return false, err
}
var last wire.BlockHeader
err = last.Deserialize(s.headerFile)
if err != nil {
return err
return false, err
}
prevHash := last.BlockSha()
gotNum := int64(len(hdrresponse.Headers))
gotNum := int64(len(m.Headers))
if gotNum > 0 {
fmt.Printf("got %d headers. Range:\n%s - %s\n",
gotNum, hdrresponse.Headers[0].BlockSha().String(),
hdrresponse.Headers[len(hdrresponse.Headers)-1].BlockSha().String())
gotNum, m.Headers[0].BlockSha().String(),
m.Headers[len(m.Headers)-1].BlockSha().String())
} else {
log.Printf("got 0 headers, we're probably synced up")
return false, nil
}
_, err = s.headerFile.Seek(0, os.SEEK_END)
endPos, err := s.headerFile.Seek(0, os.SEEK_END)
if err != nil {
return err
return false, err
}
for i, resphdr := range hdrresponse.Headers {
// check first header returned to make sure it fits on the end
// of our header file
if i == 0 && !resphdr.PrevBlock.IsEqual(&prevHash) {
return fmt.Errorf("header doesn't fit. points to %s, expect %s",
resphdr.PrevBlock.String(), prevHash.String())
if !m.Headers[0].PrevBlock.IsEqual(&prevHash) {
// delete 100 headers if this happens! Dumb reorg.
log.Printf("possible reorg; header msg doesn't fit. points to %s, expect %s",
m.Headers[0].PrevBlock.String(), prevHash.String())
if endPos < 8080 {
// jeez I give up, back to genesis
s.headerFile.Truncate(80)
} else {
err = s.headerFile.Truncate(endPos - 8000)
if err != nil {
return false, fmt.Errorf("couldn't truncate header file")
}
}
return false, fmt.Errorf("Truncated header file to try again")
}
tip := endPos / 80
tip-- // move back header length so it can read last header
for _, resphdr := range m.Headers {
// write to end of file
err = resphdr.Serialize(s.headerFile)
if err != nil {
return err
}
return false, err
}
endPos, _ := s.headerFile.Seek(0, os.SEEK_END)
tip := endPos / 80
go CheckRange(s.headerFile, tip-gotNum, tip-1, params)
return nil
// advance chain tip
tip++
// check last header
worked := CheckHeader(s.headerFile, tip, params)
if !worked {
if endPos < 8080 {
// jeez I give up, back to genesis
s.headerFile.Truncate(80)
} else {
err = s.headerFile.Truncate(endPos - 8000)
if err != nil {
return false, fmt.Errorf("couldn't truncate header file")
}
}
// probably should disconnect from spv node at this point,
// since they're giving us invalid headers.
return false, fmt.Errorf(
"Header %d - %s doesn't fit, dropping 100 headers.",
resphdr.BlockSha().String(), tip)
}
}
log.Printf("Headers to height %d OK.", tip)
return true, nil
}
func sendMBReq(cn net.Conn, blkhash wire.ShaHash) error {

View File

@ -110,7 +110,7 @@ func CheckHeader(r io.ReadSeeker, height int64, p *chaincfg.Params) bool {
log.Printf(err.Error())
return false
}
log.Printf("start epoch at height %d ", height-(height%epochLength))
// log.Printf("start epoch at height %d ", height-(height%epochLength))
// seek to n-1 header
_, err = r.Seek(80*(height-1), os.SEEK_SET)

View File

@ -7,27 +7,27 @@ import (
"github.com/btcsuite/btcd/wire"
)
func (e3c *SPVCon) incomingMessageHandler() {
func (s *SPVCon) incomingMessageHandler() {
for {
n, xm, _, err := wire.ReadMessageN(e3c.con, e3c.localVersion, e3c.netType)
n, xm, _, err := wire.ReadMessageN(s.con, s.localVersion, s.netType)
if err != nil {
log.Printf("ReadMessageN error. Disconnecting: %s\n", err.Error())
return
}
e3c.RBytes += uint64(n)
s.RBytes += uint64(n)
// log.Printf("Got %d byte %s message\n", n, xm.Command())
switch m := xm.(type) {
case *wire.MsgVersion:
log.Printf("Got version message. Agent %s, version %d, at height %d\n",
m.UserAgent, m.ProtocolVersion, m.LastBlock)
e3c.remoteVersion = uint32(m.ProtocolVersion) // weird cast! bug?
s.remoteVersion = uint32(m.ProtocolVersion) // weird cast! bug?
case *wire.MsgVerAck:
log.Printf("Got verack. Whatever.\n")
case *wire.MsgAddr:
log.Printf("got %d addresses.\n", len(m.AddrList))
case *wire.MsgPing:
log.Printf("Got a ping message. We should pong back or they will kick us off.")
e3c.PongBack(m.Nonce)
s.PongBack(m.Nonce)
case *wire.MsgPong:
log.Printf("Got a pong response. OK.\n")
case *wire.MsgMerkleBlock:
@ -43,6 +43,16 @@ func (e3c *SPVCon) incomingMessageHandler() {
fmt.Printf(" = got %d txs from block %s\n",
len(txids), m.Header.BlockSha().String())
// nextReq <- true
case *wire.MsgHeaders:
moar, err := s.IngestHeaders(m)
if err != nil {
log.Printf("Header error: %s\n", err.Error())
return
}
if moar {
s.AskForHeaders()
}
case *wire.MsgTx:
log.Printf("Got tx %s\n", m.TxSha().String())
@ -55,14 +65,14 @@ func (e3c *SPVCon) incomingMessageHandler() {
// this one seems kindof pointless? could get ridf of it and let
// functions call WriteMessageN themselves...
func (e3c *SPVCon) outgoingMessageHandler() {
func (s *SPVCon) outgoingMessageHandler() {
for {
msg := <-e3c.outMsgQueue
n, err := wire.WriteMessageN(e3c.con, msg, e3c.localVersion, e3c.netType)
msg := <-s.outMsgQueue
n, err := wire.WriteMessageN(s.con, msg, s.localVersion, s.netType)
if err != nil {
log.Printf("Write message error: %s", err.Error())
}
e3c.WBytes += uint64(n)
s.WBytes += uint64(n)
}
return
}

View File

@ -29,6 +29,7 @@ type MyAdr struct { // an address I have the private key for
KeyIdx uint32 // index for private key needed to sign / spend
}
// add addresses into the TxStore
func (t *TxStore) AddAdr(a btcutil.Address, kidx uint32) {
var ma MyAdr
ma.Address = a
@ -37,6 +38,7 @@ func (t *TxStore) AddAdr(a btcutil.Address, kidx uint32) {
return
}
// ... or I'm gonna fade away
func (t *TxStore) GimmeFilter() (*bloom.Filter, error) {
if len(t.Adrs) == 0 {
return nil, fmt.Errorf("no addresses to filter for")