// Copyright (c) 2018 The btcsuite developers
// Use of this source code is governed by an ISC
// license that can be found in the LICENSE file.

package psbt

// The Updater requires provision of a single PSBT and is able to add data to
// both input and output sections.  It can be called repeatedly to add more
// data.  It also allows addition of signatures via the addPartialSignature
// function; this is called internally to the package in the Sign() function of
// Updater, located in signer.go

import (
	"bytes"
	"crypto/sha256"

	"github.com/btcsuite/btcd/btcutil"
	"github.com/btcsuite/btcd/txscript"
	"github.com/btcsuite/btcd/wire"
)

// Updater encapsulates the role 'Updater' as specified in BIP174; it accepts
// Psbt structs and has methods to add fields to the inputs and outputs.
type Updater struct {
	Upsbt *Packet
}

// NewUpdater returns a new instance of Updater, if the passed Psbt struct is
// in a valid form, else an error.
func NewUpdater(p *Packet) (*Updater, error) {
	if err := p.SanityCheck(); err != nil {
		return nil, err
	}

	return &Updater{Upsbt: p}, nil

}

// AddInNonWitnessUtxo adds the utxo information for an input which is
// non-witness. This requires provision of a full transaction (which is the
// source of the corresponding prevOut), and the input index. If addition of
// this key-value pair to the Psbt fails, an error is returned.
func (u *Updater) AddInNonWitnessUtxo(tx *wire.MsgTx, inIndex int) error {
	if inIndex > len(u.Upsbt.Inputs)-1 {
		return ErrInvalidPrevOutNonWitnessTransaction
	}

	u.Upsbt.Inputs[inIndex].NonWitnessUtxo = tx

	if err := u.Upsbt.SanityCheck(); err != nil {
		return ErrInvalidPsbtFormat
	}

	return nil
}

// AddInWitnessUtxo adds the utxo information for an input which is witness.
// This requires provision of a full transaction *output* (which is the source
// of the corresponding prevOut); not the full transaction because BIP143 means
// the output information is sufficient, and the input index. If addition of
// this key-value pair to the Psbt fails, an error is returned.
func (u *Updater) AddInWitnessUtxo(txout *wire.TxOut, inIndex int) error {
	if inIndex > len(u.Upsbt.Inputs)-1 {
		return ErrInvalidPsbtFormat
	}

	u.Upsbt.Inputs[inIndex].WitnessUtxo = txout

	if err := u.Upsbt.SanityCheck(); err != nil {
		return ErrInvalidPsbtFormat
	}

	return nil
}

// addPartialSignature allows the Updater role to insert fields of type partial
// signature into a Psbt, consisting of both the pubkey (as keydata) and the
// ECDSA signature (as value).  Note that the Signer role is encapsulated in
// this function; signatures are only allowed to be added that follow the
// sanity-check on signing rules explained in the BIP under `Signer`; if the
// rules are not satisfied, an ErrInvalidSignatureForInput is returned.
//
// NOTE: This function does *not* validate the ECDSA signature itself.
func (u *Updater) addPartialSignature(inIndex int, sig []byte,
	pubkey []byte) error {

	partialSig := PartialSig{
		PubKey: pubkey, Signature: sig,
	}

	// First validate the passed (sig, pub).
	if !partialSig.checkValid() {
		return ErrInvalidPsbtFormat
	}

	pInput := u.Upsbt.Inputs[inIndex]

	// First check; don't add duplicates.
	for _, x := range pInput.PartialSigs {
		if bytes.Equal(x.PubKey, partialSig.PubKey) {
			return ErrDuplicateKey
		}
	}

	// Attaching signature without utxo field is not allowed.
	if pInput.WitnessUtxo == nil && pInput.NonWitnessUtxo == nil {
		return ErrInvalidPsbtFormat
	}

	// Next, we perform a series of additional sanity checks.
	if pInput.NonWitnessUtxo != nil {
		if len(u.Upsbt.UnsignedTx.TxIn) < inIndex+1 {
			return ErrInvalidPrevOutNonWitnessTransaction
		}

		if pInput.NonWitnessUtxo.TxHash() !=
			u.Upsbt.UnsignedTx.TxIn[inIndex].PreviousOutPoint.Hash {
			return ErrInvalidSignatureForInput
		}

		// To validate that the redeem script matches, we must pull out
		// the scriptPubKey of the corresponding output and compare
		// that with the P2SH scriptPubKey that is generated by
		// redeemScript.
		if pInput.RedeemScript != nil {
			outIndex := u.Upsbt.UnsignedTx.TxIn[inIndex].PreviousOutPoint.Index
			scriptPubKey := pInput.NonWitnessUtxo.TxOut[outIndex].PkScript
			scriptHash := btcutil.Hash160(pInput.RedeemScript)

			scriptHashScript, err := txscript.NewScriptBuilder().
				AddOp(txscript.OP_HASH160).
				AddData(scriptHash).
				AddOp(txscript.OP_EQUAL).
				Script()
			if err != nil {
				return err
			}

			if !bytes.Equal(scriptHashScript, scriptPubKey) {
				return ErrInvalidSignatureForInput
			}
		}

	}

	// It could be that we set both the non-witness and witness UTXO fields
	// in case it's from a wallet that patched the CVE-2020-14199
	// vulnerability. We detect whether the input being spent is actually a
	// witness input and then copy it over to the witness UTXO field in the
	// signer. Run the witness checks as well, even if we might already have
	// checked the script hash. But that should be a negligible performance
	// penalty.
	if pInput.WitnessUtxo != nil {
		scriptPubKey := pInput.WitnessUtxo.PkScript

		var script []byte
		if pInput.RedeemScript != nil {
			scriptHash := btcutil.Hash160(pInput.RedeemScript)
			scriptHashScript, err := txscript.NewScriptBuilder().
				AddOp(txscript.OP_HASH160).
				AddData(scriptHash).
				AddOp(txscript.OP_EQUAL).
				Script()
			if err != nil {
				return err
			}

			if !bytes.Equal(scriptHashScript, scriptPubKey) {
				return ErrInvalidSignatureForInput
			}

			script = pInput.RedeemScript
		} else {
			script = scriptPubKey
		}

		// If a witnessScript field is present, this is a P2WSH,
		// whether nested or not (that is handled by the assignment to
		// `script` above); in that case, sanity check that `script` is
		// the p2wsh of witnessScript. Contrariwise, if no
		// witnessScript field is present, this will be signed as
		// p2wkh.
		if pInput.WitnessScript != nil {
			witnessScriptHash := sha256.Sum256(pInput.WitnessScript)
			witnessScriptHashScript, err := txscript.NewScriptBuilder().
				AddOp(txscript.OP_0).
				AddData(witnessScriptHash[:]).
				Script()
			if err != nil {
				return err
			}

			if !bytes.Equal(script, witnessScriptHashScript[:]) {
				return ErrInvalidSignatureForInput
			}
		} else {
			// Otherwise, this is a p2wkh input.
			pubkeyHash := btcutil.Hash160(pubkey)
			pubkeyHashScript, err := txscript.NewScriptBuilder().
				AddOp(txscript.OP_0).
				AddData(pubkeyHash).
				Script()
			if err != nil {
				return err
			}

			// Validate that we're able to properly reconstruct the
			// witness program.
			if !bytes.Equal(pubkeyHashScript, script) {
				return ErrInvalidSignatureForInput
			}
		}
	}

	u.Upsbt.Inputs[inIndex].PartialSigs = append(
		u.Upsbt.Inputs[inIndex].PartialSigs, &partialSig,
	)

	if err := u.Upsbt.SanityCheck(); err != nil {
		return err
	}

	// Addition of a non-duplicate-key partial signature cannot violate
	// sanity-check rules.
	return nil
}

// AddInSighashType adds the sighash type information for an input.  The
// sighash type is passed as a 32 bit unsigned integer, along with the index
// for the input. An error is returned if addition of this key-value pair to
// the Psbt fails.
func (u *Updater) AddInSighashType(sighashType txscript.SigHashType,
	inIndex int) error {

	u.Upsbt.Inputs[inIndex].SighashType = sighashType

	if err := u.Upsbt.SanityCheck(); err != nil {
		return err
	}
	return nil
}

// AddInRedeemScript adds the redeem script information for an input.  The
// redeem script is passed serialized, as a byte slice, along with the index of
// the input. An error is returned if addition of this key-value pair to the
// Psbt fails.
func (u *Updater) AddInRedeemScript(redeemScript []byte,
	inIndex int) error {

	u.Upsbt.Inputs[inIndex].RedeemScript = redeemScript

	if err := u.Upsbt.SanityCheck(); err != nil {
		return ErrInvalidPsbtFormat
	}

	return nil
}

// AddInWitnessScript adds the witness script information for an input.  The
// witness script is passed serialized, as a byte slice, along with the index
// of the input. An error is returned if addition of this key-value pair to the
// Psbt fails.
func (u *Updater) AddInWitnessScript(witnessScript []byte,
	inIndex int) error {

	u.Upsbt.Inputs[inIndex].WitnessScript = witnessScript

	if err := u.Upsbt.SanityCheck(); err != nil {
		return err
	}

	return nil
}

// AddInBip32Derivation takes a master key fingerprint as defined in BIP32, a
// BIP32 path as a slice of uint32 values, and a serialized pubkey as a byte
// slice, along with the integer index of the input, and inserts this data into
// that input.
//
// NOTE: This can be called multiple times for the same input.  An error is
// returned if addition of this key-value pair to the Psbt fails.
func (u *Updater) AddInBip32Derivation(masterKeyFingerprint uint32,
	bip32Path []uint32, pubKeyData []byte, inIndex int) error {

	bip32Derivation := Bip32Derivation{
		PubKey:               pubKeyData,
		MasterKeyFingerprint: masterKeyFingerprint,
		Bip32Path:            bip32Path,
	}

	if !bip32Derivation.checkValid() {
		return ErrInvalidPsbtFormat
	}

	// Don't allow duplicate keys
	for _, x := range u.Upsbt.Inputs[inIndex].Bip32Derivation {
		if bytes.Equal(x.PubKey, bip32Derivation.PubKey) {
			return ErrDuplicateKey
		}
	}

	u.Upsbt.Inputs[inIndex].Bip32Derivation = append(
		u.Upsbt.Inputs[inIndex].Bip32Derivation, &bip32Derivation,
	)

	if err := u.Upsbt.SanityCheck(); err != nil {
		return err
	}

	return nil
}

// AddOutBip32Derivation takes a master key fingerprint as defined in BIP32, a
// BIP32 path as a slice of uint32 values, and a serialized pubkey as a byte
// slice, along with the integer index of the output, and inserts this data
// into that output.
//
// NOTE: That this can be called multiple times for the same output.  An error
// is returned if addition of this key-value pair to the Psbt fails.
func (u *Updater) AddOutBip32Derivation(masterKeyFingerprint uint32,
	bip32Path []uint32, pubKeyData []byte, outIndex int) error {

	bip32Derivation := Bip32Derivation{
		PubKey:               pubKeyData,
		MasterKeyFingerprint: masterKeyFingerprint,
		Bip32Path:            bip32Path,
	}

	if !bip32Derivation.checkValid() {
		return ErrInvalidPsbtFormat
	}

	// Don't allow duplicate keys
	for _, x := range u.Upsbt.Outputs[outIndex].Bip32Derivation {
		if bytes.Equal(x.PubKey, bip32Derivation.PubKey) {
			return ErrDuplicateKey
		}
	}

	u.Upsbt.Outputs[outIndex].Bip32Derivation = append(
		u.Upsbt.Outputs[outIndex].Bip32Derivation, &bip32Derivation,
	)

	if err := u.Upsbt.SanityCheck(); err != nil {
		return err
	}

	return nil
}

// AddOutRedeemScript takes a redeem script as a byte slice and appends it to
// the output at index outIndex.
func (u *Updater) AddOutRedeemScript(redeemScript []byte,
	outIndex int) error {

	u.Upsbt.Outputs[outIndex].RedeemScript = redeemScript

	if err := u.Upsbt.SanityCheck(); err != nil {
		return ErrInvalidPsbtFormat
	}

	return nil
}

// AddOutWitnessScript takes a witness script as a byte slice and appends it to
// the output at index outIndex.
func (u *Updater) AddOutWitnessScript(witnessScript []byte,
	outIndex int) error {

	u.Upsbt.Outputs[outIndex].WitnessScript = witnessScript

	if err := u.Upsbt.SanityCheck(); err != nil {
		return err
	}

	return nil
}