package lnwire

import (
	"bytes"
	"fmt"
	"testing"

	"github.com/btcsuite/btcd/btcec/v2/schnorr/musig2"
	"github.com/lightningnetwork/lnd/tlv"
	"github.com/stretchr/testify/require"
)

// testCaseShutdown is a test case for the Shutdown message.
type testCaseShutdown struct {
	// Msg is the message to be encoded and decoded.
	Msg Shutdown

	// ExpectEncodeError is a flag that indicates whether we expect the
	// encoding of the message to fail.
	ExpectEncodeError bool
}

// generateShutdownTestCases generates a set of Shutdown message test cases.
func generateShutdownTestCases(t *testing.T) []testCaseShutdown {
	// Firstly, we'll set basic values for the message fields.
	//
	// Generate random channel ID.
	chanIDBytes, err := generateRandomBytes(32)
	require.NoError(t, err)

	var chanID ChannelID
	copy(chanID[:], chanIDBytes)

	// Generate random payment preimage.
	paymentPreimageBytes, err := generateRandomBytes(32)
	require.NoError(t, err)

	var paymentPreimage [32]byte
	copy(paymentPreimage[:], paymentPreimageBytes)

	deliveryAddr, err := generateRandomBytes(16)
	require.NoError(t, err)

	// Define custom records.
	recordKey1 := uint64(MinCustomRecordsTlvType + 1)
	recordValue1, err := generateRandomBytes(10)
	require.NoError(t, err)

	recordKey2 := uint64(MinCustomRecordsTlvType + 2)
	recordValue2, err := generateRandomBytes(10)
	require.NoError(t, err)

	customRecords := CustomRecords{
		recordKey1: recordValue1,
		recordKey2: recordValue2,
	}

	dummyPubKey, err := pubkeyFromHex(
		"0228f2af0abe322403480fb3ee172f7f1601e67d1da6cad40b54c4468d4" +
			"8236c39",
	)
	require.NoError(t, err)

	muSig2Nonce, err := musig2.GenNonces(musig2.WithPublicKey(dummyPubKey))
	require.NoError(t, err)

	// Construct an instance of extra data that contains records with TLV
	// types below the minimum custom records threshold and that lack
	// corresponding fields in the message struct. Content should persist in
	// the extra data field after encoding and decoding.
	var (
		recordBytes45 = []byte("recordBytes45")
		tlvRecord45   = tlv.NewPrimitiveRecord[tlv.TlvType45](
			recordBytes45,
		)

		recordBytes55 = []byte("recordBytes55")
		tlvRecord55   = tlv.NewPrimitiveRecord[tlv.TlvType55](
			recordBytes55,
		)
	)

	var extraData ExtraOpaqueData
	err = extraData.PackRecords(
		[]tlv.RecordProducer{&tlvRecord45, &tlvRecord55}...,
	)
	require.NoError(t, err)

	return []testCaseShutdown{
		{
			Msg: Shutdown{
				ChannelID:     chanID,
				CustomRecords: customRecords,
				ExtraData:     extraData,
				Address:       deliveryAddr,
			},
		},
		{
			Msg: Shutdown{
				ChannelID:     chanID,
				CustomRecords: customRecords,
				ExtraData:     extraData,
				Address:       deliveryAddr,
				ShutdownNonce: SomeShutdownNonce(
					muSig2Nonce.PubNonce,
				),
			},
		},
	}
}

// TestShutdownEncodeDecode tests Shutdown message encoding and decoding for all
// supported field values.
func TestShutdownEncodeDecode(t *testing.T) {
	t.Parallel()

	// Generate test cases.
	testCases := generateShutdownTestCases(t)

	// Execute test cases.
	for tcIdx, tc := range testCases {
		t.Run(fmt.Sprintf("testcase-%d", tcIdx), func(t *testing.T) {
			// Encode test case message.
			var buf bytes.Buffer
			err := tc.Msg.Encode(&buf, 0)

			// Check if we expect an encoding error.
			if tc.ExpectEncodeError {
				require.Error(t, err)
				return
			}

			require.NoError(t, err)

			// Decode the encoded message bytes message.
			var actualMsg Shutdown
			decodeReader := bytes.NewReader(buf.Bytes())
			err = actualMsg.Decode(decodeReader, 0)
			require.NoError(t, err)

			// Compare the two messages to ensure equality.
			require.Equal(t, tc.Msg, actualMsg)
		})
	}
}