diff --git a/server.go b/server.go index cbc13bc6..12925eea 100644 --- a/server.go +++ b/server.go @@ -762,8 +762,21 @@ func (sp *serverPeer) OnGetCFilters(_ *peer.Peer, msg *wire.MsgGetCFilters) { return } - hashes, err := sp.server.chain.HeightToHashRange(int32(msg.StartHeight), - &msg.StopHash, wire.MaxGetCFiltersReqRange) + // We'll also ensure that the remote party is requesting a set of + // filters that we actually currently maintain. + switch msg.FilterType { + case wire.GCSFilterRegular: + break + + default: + peerLog.Debug("Filter request for unknown filter: %v", + msg.FilterType) + return + } + + hashes, err := sp.server.chain.HeightToHashRange( + int32(msg.StartHeight), &msg.StopHash, wire.MaxGetCFiltersReqRange, + ) if err != nil { peerLog.Debugf("Invalid getcfilters request: %v", err) return @@ -776,8 +789,9 @@ func (sp *serverPeer) OnGetCFilters(_ *peer.Peer, msg *wire.MsgGetCFilters) { hashPtrs[i] = &hashes[i] } - filters, err := sp.server.cfIndex.FiltersByBlockHashes(hashPtrs, - msg.FilterType) + filters, err := sp.server.cfIndex.FiltersByBlockHashes( + hashPtrs, msg.FilterType, + ) if err != nil { peerLog.Errorf("Error retrieving cfilters: %v", err) return @@ -785,10 +799,14 @@ func (sp *serverPeer) OnGetCFilters(_ *peer.Peer, msg *wire.MsgGetCFilters) { for i, filterBytes := range filters { if len(filterBytes) == 0 { - peerLog.Warnf("Could not obtain cfilter for %v", hashes[i]) + peerLog.Warnf("Could not obtain cfilter for %v", + hashes[i]) return } - filterMsg := wire.NewMsgCFilter(msg.FilterType, &hashes[i], filterBytes) + + filterMsg := wire.NewMsgCFilter( + msg.FilterType, &hashes[i], filterBytes, + ) sp.QueueMessage(filterMsg, nil) } } @@ -800,19 +818,32 @@ func (sp *serverPeer) OnGetCFHeaders(_ *peer.Peer, msg *wire.MsgGetCFHeaders) { return } + // We'll also ensure that the remote party is requesting a set of + // headers for filters that we actually currently maintain. + switch msg.FilterType { + case wire.GCSFilterRegular: + break + + default: + peerLog.Debug("Filter request for unknown headers for "+ + "filter: %v", msg.FilterType) + return + } + startHeight := int32(msg.StartHeight) maxResults := wire.MaxCFHeadersPerMsg - // If StartHeight is positive, fetch the predecessor block hash so we can - // populate the PrevFilterHeader field. + // If StartHeight is positive, fetch the predecessor block hash so we + // can populate the PrevFilterHeader field. if msg.StartHeight > 0 { startHeight-- maxResults++ } // Fetch the hashes from the block index. - hashList, err := sp.server.chain.HeightToHashRange(startHeight, - &msg.StopHash, maxResults) + hashList, err := sp.server.chain.HeightToHashRange( + startHeight, &msg.StopHash, maxResults, + ) if err != nil { peerLog.Debugf("Invalid getcfheaders request: %v", err) } @@ -833,8 +864,9 @@ func (sp *serverPeer) OnGetCFHeaders(_ *peer.Peer, msg *wire.MsgGetCFHeaders) { } // Fetch the raw filter hash bytes from the database for all blocks. - filterHashes, err := sp.server.cfIndex.FilterHashesByBlockHashes(hashPtrs, - msg.FilterType) + filterHashes, err := sp.server.cfIndex.FilterHashesByBlockHashes( + hashPtrs, msg.FilterType, + ) if err != nil { peerLog.Errorf("Error retrieving cfilter hashes: %v", err) return @@ -892,6 +924,7 @@ func (sp *serverPeer) OnGetCFHeaders(_ *peer.Peer, msg *wire.MsgGetCFHeaders) { headersMsg.FilterType = msg.FilterType headersMsg.StopHash = msg.StopHash + sp.QueueMessage(headersMsg, nil) } @@ -902,21 +935,38 @@ func (sp *serverPeer) OnGetCFCheckpt(_ *peer.Peer, msg *wire.MsgGetCFCheckpt) { return } - blockHashes, err := sp.server.chain.IntervalBlockHashes(&msg.StopHash, - wire.CFCheckptInterval) + // We'll also ensure that the remote party is requesting a set of + // checkpoints for filters that we actually currently maintain. + switch msg.FilterType { + case wire.GCSFilterRegular: + break + + default: + peerLog.Debug("Filter request for unknown checkpoints for "+ + "filter: %v", msg.FilterType) + return + } + + blockHashes, err := sp.server.chain.IntervalBlockHashes( + &msg.StopHash, wire.CFCheckptInterval, + ) if err != nil { peerLog.Debugf("Invalid getcfilters request: %v", err) return } - var updateCache bool - var checkptCache []cfHeaderKV + var ( + updateCache bool + checkptCache []cfHeaderKV + ) + // If the set of check points requested goes back further than what + // we've already generated in our cache, then we'll need to update it. if len(blockHashes) > len(checkptCache) { - // Update the cache if the checkpoint chain is longer than the cached - // one. This ensures that the cache is relatively stable and mostly - // overlaps with the best chain, since it follows the longest chain - // heuristic. + // Update the cache if the checkpoint chain is longer than the + // cached one. This ensures that the cache is relatively stable + // and mostly overlaps with the best chain, since it follows + // the longest chain heuristic. updateCache = true // Take write lock because we are going to update cache. @@ -925,9 +975,13 @@ func (sp *serverPeer) OnGetCFCheckpt(_ *peer.Peer, msg *wire.MsgGetCFCheckpt) { // Grow the checkptCache to be the length of blockHashes. additionalLength := len(blockHashes) - len(checkptCache) - checkptCache = append(sp.server.cfCheckptCaches[msg.FilterType], - make([]cfHeaderKV, additionalLength)...) + checkptCache = append( + sp.server.cfCheckptCaches[msg.FilterType], + make([]cfHeaderKV, additionalLength)..., + ) } else { + // Otherwise, we don't need to update the cache as we already + // have enough headers pre-generated. updateCache = false // Take reader lock because we are not going to update cache. @@ -946,8 +1000,9 @@ func (sp *serverPeer) OnGetCFCheckpt(_ *peer.Peer, msg *wire.MsgGetCFCheckpt) { } // Populate results with cached checkpoints. - checkptMsg := wire.NewMsgCFCheckpt(msg.FilterType, &msg.StopHash, - len(blockHashes)) + checkptMsg := wire.NewMsgCFCheckpt( + msg.FilterType, &msg.StopHash, len(blockHashes), + ) for i := 0; i < forkIdx; i++ { checkptMsg.AddCFHeader(&checkptCache[i].filterHeader) } @@ -958,8 +1013,9 @@ func (sp *serverPeer) OnGetCFCheckpt(_ *peer.Peer, msg *wire.MsgGetCFCheckpt) { blockHashPtrs = append(blockHashPtrs, &blockHashes[i]) } - filterHeaders, err := sp.server.cfIndex.FilterHeadersByBlockHashes(blockHashPtrs, - msg.FilterType) + filterHeaders, err := sp.server.cfIndex.FilterHeadersByBlockHashes( + blockHashPtrs, msg.FilterType, + ) if err != nil { peerLog.Errorf("Error retrieving cfilter headers: %v", err) return @@ -967,7 +1023,8 @@ func (sp *serverPeer) OnGetCFCheckpt(_ *peer.Peer, msg *wire.MsgGetCFCheckpt) { for i, filterHeaderBytes := range filterHeaders { if len(filterHeaderBytes) == 0 { - peerLog.Warnf("Could not obtain CF header for %v", blockHashPtrs[i]) + peerLog.Warnf("Could not obtain CF header for %v", + blockHashPtrs[i]) return }