lnd/lnrpc/routerrpc/forward_interceptor.go

202 lines
5.7 KiB
Go
Raw Normal View History

package routerrpc
import (
"errors"
"github.com/lightningnetwork/lnd/channeldb/models"
"github.com/lightningnetwork/lnd/htlcswitch"
"github.com/lightningnetwork/lnd/lnrpc"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
var (
// ErrFwdNotExists is an error returned when the caller tries to resolve
// a forward that doesn't exist anymore.
ErrFwdNotExists = errors.New("forward does not exist")
// ErrMissingPreimage is an error returned when the caller tries to settle
// a forward and doesn't provide a preimage.
ErrMissingPreimage = errors.New("missing preimage")
)
// forwardInterceptor is a helper struct that handles the lifecycle of an rpc
// interceptor streaming session.
// It is created when the stream opens and disconnects when the stream closes.
type forwardInterceptor struct {
// stream is the bidirectional RPC stream
stream Router_HtlcInterceptorServer
htlcSwitch htlcswitch.InterceptableHtlcForwarder
}
// newForwardInterceptor creates a new forwardInterceptor.
func newForwardInterceptor(htlcSwitch htlcswitch.InterceptableHtlcForwarder,
stream Router_HtlcInterceptorServer) *forwardInterceptor {
return &forwardInterceptor{
htlcSwitch: htlcSwitch,
stream: stream,
}
}
// run sends the intercepted packets to the client and receives the
2022-01-13 17:29:43 +01:00
// corersponding responses. On one hand it registered itself as an interceptor
// that receives the switch packets and on the other hand launches a go routine
// to read from the client stream.
// To coordinate all this and make sure it is safe for concurrent access all
// packets are sent to the main where they are handled.
func (r *forwardInterceptor) run() error {
// Register our interceptor so we receive all forwarded packets.
r.htlcSwitch.SetInterceptor(r.onIntercept)
defer r.htlcSwitch.SetInterceptor(nil)
for {
resp, err := r.stream.Recv()
if err != nil {
return err
}
if err := r.resolveFromClient(resp); err != nil {
return err
}
}
}
// onIntercept is the function that is called by the switch for every forwarded
// packet. Our interceptor makes sure we hold the packet and then signal to the
// main loop to handle the packet. We only return true if we were able
// to deliver the packet to the main loop.
func (r *forwardInterceptor) onIntercept(
htlc htlcswitch.InterceptedPacket) error {
log.Tracef("Sending intercepted packet to client %v", htlc)
inKey := htlc.IncomingCircuit
// First hold the forward, then send to client.
interceptionRequest := &ForwardHtlcInterceptRequest{
IncomingCircuitKey: &CircuitKey{
ChanId: inKey.ChanID.ToUint64(),
HtlcId: inKey.HtlcID,
},
OutgoingRequestedChanId: htlc.OutgoingChanID.ToUint64(),
PaymentHash: htlc.Hash[:],
OutgoingAmountMsat: uint64(htlc.OutgoingAmount),
OutgoingExpiry: htlc.OutgoingExpiry,
IncomingAmountMsat: uint64(htlc.IncomingAmount),
IncomingExpiry: htlc.IncomingExpiry,
CustomRecords: htlc.CustomRecords,
OnionBlob: htlc.OnionBlob[:],
AutoFailHeight: htlc.AutoFailHeight,
}
return r.stream.Send(interceptionRequest)
}
// resolveFromClient handles a resolution arrived from the client.
func (r *forwardInterceptor) resolveFromClient(
in *ForwardHtlcInterceptResponse) error {
if in.IncomingCircuitKey == nil {
return status.Errorf(codes.InvalidArgument,
"CircuitKey missing from ForwardHtlcInterceptResponse")
}
log.Tracef("Resolving intercepted packet %v", in)
circuitKey := models.CircuitKey{
ChanID: lnwire.NewShortChanIDFromInt(in.IncomingCircuitKey.ChanId),
HtlcID: in.IncomingCircuitKey.HtlcId,
}
switch in.Action {
case ResolveHoldForwardAction_RESUME:
return r.htlcSwitch.Resolve(&htlcswitch.FwdResolution{
Key: circuitKey,
Action: htlcswitch.FwdActionResume,
})
case ResolveHoldForwardAction_FAIL:
// Fail with an encrypted reason.
if in.FailureMessage != nil {
if in.FailureCode != 0 {
return status.Errorf(
codes.InvalidArgument,
"failure message and failure code "+
"are mutually exclusive",
)
}
// Verify that the size is equal to the fixed failure
// message size + hmac + two uint16 lengths. See BOLT
// #4.
if len(in.FailureMessage) !=
lnwire.FailureMessageLength+32+2+2 {
return status.Errorf(
codes.InvalidArgument,
"failure message length invalid",
)
}
return r.htlcSwitch.Resolve(&htlcswitch.FwdResolution{
Key: circuitKey,
Action: htlcswitch.FwdActionFail,
FailureMessage: in.FailureMessage,
})
}
var code lnwire.FailCode
switch in.FailureCode {
case lnrpc.Failure_INVALID_ONION_HMAC:
code = lnwire.CodeInvalidOnionHmac
case lnrpc.Failure_INVALID_ONION_KEY:
code = lnwire.CodeInvalidOnionKey
case lnrpc.Failure_INVALID_ONION_VERSION:
code = lnwire.CodeInvalidOnionVersion
// Default to TemporaryChannelFailure.
case 0, lnrpc.Failure_TEMPORARY_CHANNEL_FAILURE:
code = lnwire.CodeTemporaryChannelFailure
default:
return status.Errorf(
codes.InvalidArgument,
"unsupported failure code: %v", in.FailureCode,
)
}
return r.htlcSwitch.Resolve(&htlcswitch.FwdResolution{
Key: circuitKey,
Action: htlcswitch.FwdActionFail,
FailureCode: code,
})
case ResolveHoldForwardAction_SETTLE:
if in.Preimage == nil {
return ErrMissingPreimage
}
preimage, err := lntypes.MakePreimage(in.Preimage)
if err != nil {
return err
}
return r.htlcSwitch.Resolve(&htlcswitch.FwdResolution{
Key: circuitKey,
Action: htlcswitch.FwdActionSettle,
Preimage: preimage,
})
default:
return status.Errorf(
codes.InvalidArgument,
"unrecognized resolve action %v", in.Action,
)
}
}