btcd/rpcclient/chain_test.go
2024-03-25 09:44:25 -04:00

297 lines
7.4 KiB
Go

package rpcclient
import (
"errors"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"time"
"github.com/gorilla/websocket"
)
var upgrader = websocket.Upgrader{}
// TestUnmarshalGetBlockChainInfoResult ensures that the SoftForks and
// UnifiedSoftForks fields of GetBlockChainInfoResult are properly unmarshaled
// when using the expected backend version.
func TestUnmarshalGetBlockChainInfoResultSoftForks(t *testing.T) {
t.Parallel()
tests := []struct {
name string
version BackendVersion
res []byte
compatible bool
}{
{
name: "bitcoind < 0.19.0 with separate softforks",
version: BitcoindPre19,
res: []byte(`{"softforks": [{"version": 2}]}`),
compatible: true,
},
{
name: "bitcoind >= 0.19.0 with separate softforks",
version: BitcoindPre22,
res: []byte(`{"softforks": [{"version": 2}]}`),
compatible: false,
},
{
name: "bitcoind < 0.19.0 with unified softforks",
version: BitcoindPre19,
res: []byte(`{"softforks": {"segwit": {"type": "bip9"}}}`),
compatible: false,
},
{
name: "bitcoind >= 0.19.0 with unified softforks",
version: BitcoindPre22,
res: []byte(`{"softforks": {"segwit": {"type": "bip9"}}}`),
compatible: true,
},
}
for _, test := range tests {
success := t.Run(test.name, func(t *testing.T) {
// We'll start by unmarshalling the JSON into a struct.
// The SoftForks and UnifiedSoftForks field should not
// be set yet, as they are unmarshaled within a
// different function.
info, err := unmarshalPartialGetBlockChainInfoResult(test.res)
if err != nil {
t.Fatal(err)
}
if info.SoftForks != nil {
t.Fatal("expected SoftForks to be empty")
}
if info.UnifiedSoftForks != nil {
t.Fatal("expected UnifiedSoftForks to be empty")
}
// Proceed to unmarshal the softforks of the response
// with the expected version. If the version is
// incompatible with the response, then this should
// fail.
err = unmarshalGetBlockChainInfoResultSoftForks(
info, test.version, test.res,
)
if test.compatible && err != nil {
t.Fatalf("unable to unmarshal softforks: %v", err)
}
if !test.compatible && err == nil {
t.Fatal("expected to not unmarshal softforks")
}
if !test.compatible {
return
}
// If the version is compatible with the response, we
// should expect to see the proper softforks field set.
if test.version == BitcoindPre22 &&
info.SoftForks != nil {
t.Fatal("expected SoftForks to be empty")
}
if test.version == BitcoindPre19 &&
info.UnifiedSoftForks != nil {
t.Fatal("expected UnifiedSoftForks to be empty")
}
})
if !success {
return
}
}
}
func TestFutureGetBlockCountResultReceiveErrors(t *testing.T) {
responseChan := FutureGetBlockCountResult(make(chan *Response))
response := Response{
result: []byte{},
err: errors.New("blah blah something bad happened"),
}
go func() {
responseChan <- &response
}()
_, err := responseChan.Receive()
if err == nil || err.Error() != "blah blah something bad happened" {
t.Fatalf("unexpected error: %s", err.Error())
}
}
func TestFutureGetBlockCountResultReceiveMarshalsResponseCorrectly(t *testing.T) {
responseChan := FutureGetBlockCountResult(make(chan *Response))
response := Response{
result: []byte{0x36, 0x36},
err: nil,
}
go func() {
responseChan <- &response
}()
res, err := responseChan.Receive()
if err != nil {
t.Fatalf("unexpected error: %s", err.Error())
}
if res != 66 {
t.Fatalf("unexpected response: %d (0x%X)", res, res)
}
}
func TestClientConnectedToWSServerRunner(t *testing.T) {
type TestTableItem struct {
Name string
TestCase func(t *testing.T)
}
testTable := []TestTableItem{
TestTableItem{
Name: "TestGetChainTxStatsAsyncSuccessTx",
TestCase: func(t *testing.T) {
client, serverReceivedChannel, cleanup := makeClient(t)
defer cleanup()
client.GetChainTxStatsAsync()
message := <-serverReceivedChannel
if message != "{\"jsonrpc\":\"1.0\",\"method\":\"getchaintxstats\",\"params\":[],\"id\":1}" {
t.Fatalf("received unexpected message: %s", message)
}
},
},
TestTableItem{
Name: "TestGetChainTxStatsAsyncShutdownError",
TestCase: func(t *testing.T) {
client, _, cleanup := makeClient(t)
defer cleanup()
// a bit of a hack here: since there are multiple places where we read
// from the shutdown channel, and it is not buffered, ensure that a shutdown
// message is sent every time it is read from, this will ensure that
// when client.GetChainTxStatsAsync() gets called, it hits the non-blocking
// read from the shutdown channel
go func() {
type shutdownMessage struct{}
for {
client.shutdown <- shutdownMessage{}
}
}()
var response *Response = nil
for response == nil {
respChan := client.GetChainTxStatsAsync()
select {
case response = <-respChan:
default:
}
}
if response.err == nil || response.err.Error() != "the client has been shutdown" {
t.Fatalf("unexpected error: %s", response.err.Error())
}
},
},
TestTableItem{
Name: "TestGetBestBlockHashAsync",
TestCase: func(t *testing.T) {
client, serverReceivedChannel, cleanup := makeClient(t)
defer cleanup()
ch := client.GetBestBlockHashAsync()
message := <-serverReceivedChannel
if message != "{\"jsonrpc\":\"1.0\",\"method\":\"getbestblockhash\",\"params\":[],\"id\":1}" {
t.Fatalf("received unexpected message: %s", message)
}
expectedResponse := Response{}
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
for {
client.requestLock.Lock()
if client.requestList.Len() > 0 {
r := client.requestList.Back()
r.Value.(*jsonRequest).responseChan <- &expectedResponse
client.requestLock.Unlock()
return
}
client.requestLock.Unlock()
}
}()
response := <-ch
if &expectedResponse != response {
t.Fatalf("received unexpected response")
}
// ensure the goroutine created in this test exists,
// the test is ran with a timeout
wg.Wait()
},
},
}
// since these tests rely on concurrency, ensure there is a reasonable timeout
// that they should run within
for _, testCase := range testTable {
done := make(chan bool)
go func() {
t.Run(testCase.Name, testCase.TestCase)
done <- true
}()
select {
case <-done:
case <-time.After(5 * time.Second):
t.Fatalf("timeout exceeded for: %s", testCase.Name)
}
}
}
func makeClient(t *testing.T) (*Client, chan string, func()) {
serverReceivedChannel := make(chan string)
s := httptest.NewServer(http.HandlerFunc(makeUpgradeOnConnect(serverReceivedChannel)))
url := strings.TrimPrefix(s.URL, "http://")
config := ConnConfig{
DisableTLS: true,
User: "username",
Pass: "password",
Host: url,
}
client, err := New(&config, nil)
if err != nil {
t.Fatalf("error when creating new client %s", err.Error())
}
return client, serverReceivedChannel, func() {
s.Close()
}
}
func makeUpgradeOnConnect(ch chan string) func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
c, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer c.Close()
for {
_, message, err := c.ReadMessage()
if err != nil {
break
}
go func() {
ch <- string(message)
}()
}
}
}