diff --git a/wire/bench_test.go b/wire/bench_test.go index ffae564c..f6b29c2c 100644 --- a/wire/bench_test.go +++ b/wire/bench_test.go @@ -187,7 +187,7 @@ func BenchmarkReadVarInt9(b *testing.B) { func BenchmarkReadVarStr4(b *testing.B) { buf := []byte{0x04, 't', 'e', 's', 't'} for i := 0; i < b.N; i++ { - readVarString(bytes.NewReader(buf), 0) + ReadVarString(bytes.NewReader(buf), 0) } } @@ -196,7 +196,7 @@ func BenchmarkReadVarStr4(b *testing.B) { func BenchmarkReadVarStr10(b *testing.B) { buf := []byte{0x0a, 't', 'e', 's', 't', '0', '1', '2', '3', '4', '5'} for i := 0; i < b.N; i++ { - readVarString(bytes.NewReader(buf), 0) + ReadVarString(bytes.NewReader(buf), 0) } } @@ -204,7 +204,7 @@ func BenchmarkReadVarStr10(b *testing.B) { // four byte variable length string. func BenchmarkWriteVarStr4(b *testing.B) { for i := 0; i < b.N; i++ { - writeVarString(ioutil.Discard, 0, "test") + WriteVarString(ioutil.Discard, 0, "test") } } @@ -212,7 +212,7 @@ func BenchmarkWriteVarStr4(b *testing.B) { // ten byte variable length string. func BenchmarkWriteVarStr10(b *testing.B) { for i := 0; i < b.N; i++ { - writeVarString(ioutil.Discard, 0, "test012345") + WriteVarString(ioutil.Discard, 0, "test012345") } } diff --git a/wire/common.go b/wire/common.go index 673daa89..f70659c5 100644 --- a/wire/common.go +++ b/wire/common.go @@ -440,14 +440,13 @@ func VarIntSerializeSize(val uint64) int { return 9 } -// readVarString reads a variable length string from r and returns it as a Go -// string. A varString is encoded as a varInt containing the length of the -// string, and the bytes that represent the string itself. An error is returned -// if the length is greater than the maximum block payload size, since it would -// not be possible to put a varString of that size into a block anyways and it -// also helps protect against memory exhaustion attacks and forced panics -// through malformed messages. -func readVarString(r io.Reader, pver uint32) (string, error) { +// ReadVarString reads a variable length string from r and returns it as a Go +// string. A variable length string is encoded as a variable length integer +// containing the length of the string followed by the bytes that represent the +// string itself. An error is returned if the length is greater than the +// maximum block payload size since it helps protect against memory exhaustion +// attacks and forced panics through malformed messages. +func ReadVarString(r io.Reader, pver uint32) (string, error) { count, err := readVarInt(r, pver) if err != nil { return "", err @@ -459,7 +458,7 @@ func readVarString(r io.Reader, pver uint32) (string, error) { if count > MaxMessagePayload { str := fmt.Sprintf("variable length string is too long "+ "[count %d, max %d]", count, MaxMessagePayload) - return "", messageError("readVarString", str) + return "", messageError("ReadVarString", str) } buf := make([]byte, count) @@ -470,9 +469,10 @@ func readVarString(r io.Reader, pver uint32) (string, error) { return string(buf), nil } -// writeVarString serializes str to w as a varInt containing the length of the -// string followed by the bytes that represent the string itself. -func writeVarString(w io.Writer, pver uint32, str string) error { +// WriteVarString serializes str to w as a variable length integer containing +// the length of the string followed by the bytes that represent the string +// itself. +func WriteVarString(w io.Writer, pver uint32, str string) error { err := writeVarInt(w, pver, uint64(len(str))) if err != nil { return err diff --git a/wire/common_test.go b/wire/common_test.go index 09df1c0c..9dc4ebb1 100644 --- a/wire/common_test.go +++ b/wire/common_test.go @@ -471,26 +471,26 @@ func TestVarStringWire(t *testing.T) { for i, test := range tests { // Encode to wire format. var buf bytes.Buffer - err := wire.TstWriteVarString(&buf, test.pver, test.in) + err := wire.WriteVarString(&buf, test.pver, test.in) if err != nil { - t.Errorf("writeVarString #%d error %v", i, err) + t.Errorf("WriteVarString #%d error %v", i, err) continue } if !bytes.Equal(buf.Bytes(), test.buf) { - t.Errorf("writeVarString #%d\n got: %s want: %s", i, + t.Errorf("WriteVarString #%d\n got: %s want: %s", i, spew.Sdump(buf.Bytes()), spew.Sdump(test.buf)) continue } // Decode from wire format. rbuf := bytes.NewReader(test.buf) - val, err := wire.TstReadVarString(rbuf, test.pver) + val, err := wire.ReadVarString(rbuf, test.pver) if err != nil { - t.Errorf("readVarString #%d error %v", i, err) + t.Errorf("ReadVarString #%d error %v", i, err) continue } if val != test.out { - t.Errorf("readVarString #%d\n got: %s want: %s", i, + t.Errorf("ReadVarString #%d\n got: %s want: %s", i, val, test.out) continue } @@ -526,18 +526,18 @@ func TestVarStringWireErrors(t *testing.T) { for i, test := range tests { // Encode to wire format. w := newFixedWriter(test.max) - err := wire.TstWriteVarString(w, test.pver, test.in) + err := wire.WriteVarString(w, test.pver, test.in) if err != test.writeErr { - t.Errorf("writeVarString #%d wrong error got: %v, want: %v", + t.Errorf("WriteVarString #%d wrong error got: %v, want: %v", i, err, test.writeErr) continue } // Decode from wire format. r := newFixedReader(test.max, test.buf) - _, err = wire.TstReadVarString(r, test.pver) + _, err = wire.ReadVarString(r, test.pver) if err != test.readErr { - t.Errorf("readVarString #%d wrong error got: %v, want: %v", + t.Errorf("ReadVarString #%d wrong error got: %v, want: %v", i, err, test.readErr) continue } @@ -566,9 +566,9 @@ func TestVarStringOverflowErrors(t *testing.T) { for i, test := range tests { // Decode from wire format. rbuf := bytes.NewReader(test.buf) - _, err := wire.TstReadVarString(rbuf, test.pver) + _, err := wire.ReadVarString(rbuf, test.pver) if reflect.TypeOf(err) != reflect.TypeOf(test.err) { - t.Errorf("readVarString #%d wrong error got: %v, "+ + t.Errorf("ReadVarString #%d wrong error got: %v, "+ "want: %v", i, err, reflect.TypeOf(test.err)) continue } diff --git a/wire/internal_test.go b/wire/internal_test.go index e3fec695..07539e2e 100644 --- a/wire/internal_test.go +++ b/wire/internal_test.go @@ -63,18 +63,6 @@ func TstWriteVarInt(w io.Writer, pver uint32, val uint64) error { return writeVarInt(w, pver, val) } -// TstReadVarString makes the internal readVarString function available to the -// test package. -func TstReadVarString(r io.Reader, pver uint32) (string, error) { - return readVarString(r, pver) -} - -// TstWriteVarString makes the internal writeVarString function available to the -// test package. -func TstWriteVarString(w io.Writer, pver uint32, str string) error { - return writeVarString(w, pver, str) -} - // TstReadVarBytes makes the internal readVarBytes function available to the // test package. func TstReadVarBytes(r io.Reader, pver uint32, maxAllowed uint32, fieldName string) ([]byte, error) { diff --git a/wire/msgalert.go b/wire/msgalert.go index 2f9d2ed8..f07ad64c 100644 --- a/wire/msgalert.go +++ b/wire/msgalert.go @@ -188,7 +188,7 @@ func (alert *Alert) Serialize(w io.Writer, pver uint32) error { return err } for i := 0; i < int(count); i++ { - err = writeVarString(w, pver, alert.SetSubVer[i]) + err = WriteVarString(w, pver, alert.SetSubVer[i]) if err != nil { return err } @@ -198,15 +198,15 @@ func (alert *Alert) Serialize(w io.Writer, pver uint32) error { if err != nil { return err } - err = writeVarString(w, pver, alert.Comment) + err = WriteVarString(w, pver, alert.Comment) if err != nil { return err } - err = writeVarString(w, pver, alert.StatusBar) + err = WriteVarString(w, pver, alert.StatusBar) if err != nil { return err } - err = writeVarString(w, pver, alert.Reserved) + err = WriteVarString(w, pver, alert.Reserved) if err != nil { return err } @@ -260,7 +260,7 @@ func (alert *Alert) Deserialize(r io.Reader, pver uint32) error { } alert.SetSubVer = make([]string, count) for i := 0; i < int(count); i++ { - alert.SetSubVer[i], err = readVarString(r, pver) + alert.SetSubVer[i], err = ReadVarString(r, pver) if err != nil { return err } @@ -270,15 +270,15 @@ func (alert *Alert) Deserialize(r io.Reader, pver uint32) error { if err != nil { return err } - alert.Comment, err = readVarString(r, pver) + alert.Comment, err = ReadVarString(r, pver) if err != nil { return err } - alert.StatusBar, err = readVarString(r, pver) + alert.StatusBar, err = ReadVarString(r, pver) if err != nil { return err } - alert.Reserved, err = readVarString(r, pver) + alert.Reserved, err = ReadVarString(r, pver) if err != nil { return err } diff --git a/wire/msgreject.go b/wire/msgreject.go index f1ad73a8..bf8659a4 100644 --- a/wire/msgreject.go +++ b/wire/msgreject.go @@ -79,7 +79,7 @@ func (msg *MsgReject) BtcDecode(r io.Reader, pver uint32) error { } // Command that was rejected. - cmd, err := readVarString(r, pver) + cmd, err := ReadVarString(r, pver) if err != nil { return err } @@ -93,7 +93,7 @@ func (msg *MsgReject) BtcDecode(r io.Reader, pver uint32) error { // Human readable string with specific details (over and above the // reject code above) about why the command was rejected. - reason, err := readVarString(r, pver) + reason, err := ReadVarString(r, pver) if err != nil { return err } @@ -121,7 +121,7 @@ func (msg *MsgReject) BtcEncode(w io.Writer, pver uint32) error { } // Command that was rejected. - err := writeVarString(w, pver, msg.Cmd) + err := WriteVarString(w, pver, msg.Cmd) if err != nil { return err } @@ -134,7 +134,7 @@ func (msg *MsgReject) BtcEncode(w io.Writer, pver uint32) error { // Human readable string with specific details (over and above the // reject code above) about why the command was rejected. - err = writeVarString(w, pver, msg.Reason) + err = WriteVarString(w, pver, msg.Reason) if err != nil { return err } diff --git a/wire/msgversion.go b/wire/msgversion.go index 9905b410..6c76e239 100644 --- a/wire/msgversion.go +++ b/wire/msgversion.go @@ -115,7 +115,7 @@ func (msg *MsgVersion) BtcDecode(r io.Reader, pver uint32) error { } } if buf.Len() > 0 { - userAgent, err := readVarString(buf, pver) + userAgent, err := ReadVarString(buf, pver) if err != nil { return err } @@ -181,7 +181,7 @@ func (msg *MsgVersion) BtcEncode(w io.Writer, pver uint32) error { return err } - err = writeVarString(w, pver, msg.UserAgent) + err = WriteVarString(w, pver, msg.UserAgent) if err != nil { return err }