lnd/channeldb/forwarding_policy.go

112 lines
3.2 KiB
Go
Raw Normal View History

package channeldb
import (
"github.com/lightningnetwork/lnd/channeldb/models"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnwire"
)
var (
// initialChannelForwardingPolicyBucket is the database bucket used to
// store the forwarding policy for each permanent channel that is
// currently in the process of being opened.
initialChannelForwardingPolicyBucket = []byte(
"initialChannelFwdingPolicy",
)
)
// SaveInitialForwardingPolicy saves the serialized forwarding policy for the
// provided permanent channel id to the initialChannelForwardingPolicyBucket.
func (c *ChannelStateDB) SaveInitialForwardingPolicy(chanID lnwire.ChannelID,
forwardingPolicy *models.ForwardingPolicy) error {
chanIDCopy := make([]byte, 32)
copy(chanIDCopy, chanID[:])
scratch := make([]byte, 36)
byteOrder.PutUint64(scratch[:8], uint64(forwardingPolicy.MinHTLCOut))
byteOrder.PutUint64(scratch[8:16], uint64(forwardingPolicy.MaxHTLC))
byteOrder.PutUint64(scratch[16:24], uint64(forwardingPolicy.BaseFee))
byteOrder.PutUint64(scratch[24:32], uint64(forwardingPolicy.FeeRate))
byteOrder.PutUint32(scratch[32:], forwardingPolicy.TimeLockDelta)
return kvdb.Update(c.backend, func(tx kvdb.RwTx) error {
bucket, err := tx.CreateTopLevelBucket(
initialChannelForwardingPolicyBucket,
)
if err != nil {
return err
}
return bucket.Put(chanIDCopy, scratch)
}, func() {})
}
// GetInitialForwardingPolicy fetches the serialized forwarding policy for the
// provided channel id from the database, or returns ErrChannelNotFound if
// a forwarding policy for this channel id is not found.
func (c *ChannelStateDB) GetInitialForwardingPolicy(
chanID lnwire.ChannelID) (*models.ForwardingPolicy, error) {
chanIDCopy := make([]byte, 32)
copy(chanIDCopy, chanID[:])
var forwardingPolicy *models.ForwardingPolicy
err := kvdb.View(c.backend, func(tx kvdb.RTx) error {
bucket := tx.ReadBucket(initialChannelForwardingPolicyBucket)
if bucket == nil {
// If the bucket does not exist, it means we
// never added a channel fees to the db, so
// return ErrChannelNotFound.
return ErrChannelNotFound
}
stateBytes := bucket.Get(chanIDCopy)
if stateBytes == nil {
return ErrChannelNotFound
}
forwardingPolicy = &models.ForwardingPolicy{
MinHTLCOut: lnwire.MilliSatoshi(
byteOrder.Uint64(stateBytes[:8]),
),
MaxHTLC: lnwire.MilliSatoshi(
byteOrder.Uint64(stateBytes[8:16]),
),
BaseFee: lnwire.MilliSatoshi(
byteOrder.Uint64(stateBytes[16:24]),
),
FeeRate: lnwire.MilliSatoshi(
byteOrder.Uint64(stateBytes[24:32]),
),
TimeLockDelta: byteOrder.Uint32(stateBytes[32:36]),
}
return nil
}, func() {
forwardingPolicy = nil
})
return forwardingPolicy, err
}
// DeleteInitialForwardingPolicy removes the forwarding policy for a given
// channel from the database.
func (c *ChannelStateDB) DeleteInitialForwardingPolicy(
chanID lnwire.ChannelID) error {
chanIDCopy := make([]byte, 32)
copy(chanIDCopy, chanID[:])
return kvdb.Update(c.backend, func(tx kvdb.RwTx) error {
bucket := tx.ReadWriteBucket(
initialChannelForwardingPolicyBucket,
)
if bucket == nil {
return ErrChannelNotFound
}
return bucket.Delete(chanIDCopy)
}, func() {})
}