diff --git a/htlcswitch/failure.go b/htlcswitch/failure.go index 4352d267e..1f8e8f6ee 100644 --- a/htlcswitch/failure.go +++ b/htlcswitch/failure.go @@ -9,6 +9,22 @@ import ( "github.com/lightningnetwork/lnd/lnwire" ) +// ClearTextError is an interface which is implemented by errors that occur +// when we know the underlying wire failure message. These errors are the +// opposite to opaque errors which are onion-encrypted blobs only understandable +// to the initiating node. ClearTextErrors are used when we fail a htlc at our +// node, or one of our initiated payments failed and we can decrypt the onion +// encrypted error fully. +type ClearTextError interface { + error + + // WireMessage extracts a valid wire failure message from an internal + // error which may contain additional metadata (which should not be + // exposed to the network). This value may be nil in the case where + // an unknown wire error is returned by one of our peers. + WireMessage() lnwire.FailureMessage +} + // ForwardingError wraps an lnwire.FailureMessage in a struct that also // includes the source of the error. type ForwardingError struct { @@ -22,7 +38,20 @@ type ForwardingError struct { // order to provide context specific error details. ExtraMsg string - lnwire.FailureMessage + // msg is the wire message associated with the error. This value may + // be nil in the case where we fail to decode failure message sent by + // a peer. + msg lnwire.FailureMessage +} + +// WireMessage extracts a valid wire failure message from an internal +// error which may contain additional metadata (which should not be +// exposed to the network). This value may be nil in the case where +// an unknown wire error is returned by one of our peers. +// +// Note this is part of the ClearTextError interface. +func (f *ForwardingError) WireMessage() lnwire.FailureMessage { + return f.msg } // Error implements the built-in error interface. We use this method to allow @@ -30,13 +59,11 @@ type ForwardingError struct { // returned. func (f *ForwardingError) Error() string { if f.ExtraMsg == "" { - return fmt.Sprintf( - "%v@%v", f.FailureMessage, f.FailureSourceIdx, - ) + return fmt.Sprintf("%v@%v", f.msg, f.FailureSourceIdx) } return fmt.Sprintf( - "%v@%v: %v", f.FailureMessage, f.FailureSourceIdx, f.ExtraMsg, + "%v@%v: %v", f.msg, f.FailureSourceIdx, f.ExtraMsg, ) } @@ -47,7 +74,7 @@ func NewForwardingError(failure lnwire.FailureMessage, index int, return &ForwardingError{ FailureSourceIdx: index, - FailureMessage: failure, + msg: failure, ExtraMsg: extraMsg, } } diff --git a/htlcswitch/link_test.go b/htlcswitch/link_test.go index 9efe89dca..0bee52715 100644 --- a/htlcswitch/link_test.go +++ b/htlcswitch/link_test.go @@ -579,7 +579,7 @@ func TestExitNodeTimelockPayloadMismatch(t *testing.T) { t.Fatalf("expected a ForwardingError, instead got: %T", err) } - switch ferr.FailureMessage.(type) { + switch ferr.WireMessage().(type) { case *lnwire.FailFinalIncorrectCltvExpiry: default: t.Fatalf("incorrect error, expected incorrect cltv expiry, "+ @@ -679,7 +679,7 @@ func TestLinkForwardTimelockPolicyMismatch(t *testing.T) { t.Fatalf("expected a ForwardingError, instead got: %T", err) } - switch ferr.FailureMessage.(type) { + switch ferr.WireMessage().(type) { case *lnwire.FailIncorrectCltvExpiry: default: t.Fatalf("incorrect error, expected incorrect cltv expiry, "+ @@ -737,7 +737,7 @@ func TestLinkForwardFeePolicyMismatch(t *testing.T) { t.Fatalf("expected a ForwardingError, instead got: %T", err) } - switch ferr.FailureMessage.(type) { + switch ferr.WireMessage().(type) { case *lnwire.FailFeeInsufficient: default: t.Fatalf("incorrect error, expected fee insufficient, "+ @@ -795,7 +795,7 @@ func TestLinkForwardMinHTLCPolicyMismatch(t *testing.T) { t.Fatalf("expected a ForwardingError, instead got: %T", err) } - switch ferr.FailureMessage.(type) { + switch ferr.WireMessage().(type) { case *lnwire.FailAmountBelowMinimum: default: t.Fatalf("incorrect error, expected amount below minimum, "+ @@ -862,7 +862,7 @@ func TestLinkForwardMaxHTLCPolicyMismatch(t *testing.T) { t.Fatalf("expected a ForwardingError, instead got: %T", err) } - switch ferr.FailureMessage.(type) { + switch ferr.WireMessage().(type) { case *lnwire.FailTemporaryChannelFailure: default: t.Fatalf("incorrect error, expected temporary channel failure, "+ @@ -968,7 +968,7 @@ func TestUpdateForwardingPolicy(t *testing.T) { if !ok { t.Fatalf("expected a ForwardingError, instead got (%T): %v", err, err) } - switch ferr.FailureMessage.(type) { + switch ferr.WireMessage().(type) { case *lnwire.FailFeeInsufficient: default: t.Fatalf("expected FailFeeInsufficient instead got: %v", err) @@ -1008,7 +1008,7 @@ func TestUpdateForwardingPolicy(t *testing.T) { t.Fatalf("expected a ForwardingError, instead got (%T): %v", err, err) } - switch ferr.FailureMessage.(type) { + switch ferr.WireMessage().(type) { case *lnwire.FailTemporaryChannelFailure: default: t.Fatalf("expected TemporaryChannelFailure, instead got: %v", @@ -1253,9 +1253,9 @@ func TestChannelLinkMultiHopUnknownNextHop(t *testing.T) { if !ok { t.Fatalf("expected ForwardingError") } - if _, ok = fErr.FailureMessage.(*lnwire.FailUnknownNextPeer); !ok { + if _, ok = fErr.WireMessage().(*lnwire.FailUnknownNextPeer); !ok { t.Fatalf("wrong error has been received: %T", - fErr.FailureMessage) + fErr.WireMessage()) } // Wait for Alice to receive the revocation. @@ -1369,7 +1369,7 @@ func TestChannelLinkMultiHopDecodeError(t *testing.T) { t.Fatalf("expected a ForwardingError, instead got: %T", err) } - switch ferr.FailureMessage.(type) { + switch ferr.WireMessage().(type) { case *lnwire.FailInvalidOnionVersion: default: t.Fatalf("wrong error have been received: %v", err) @@ -1462,7 +1462,7 @@ func TestChannelLinkExpiryTooSoonExitNode(t *testing.T) { err, err) } - switch ferr.FailureMessage.(type) { + switch ferr.WireMessage().(type) { case *lnwire.FailIncorrectDetails: default: t.Fatalf("expected incorrect_or_unknown_payment_details, "+ @@ -1524,7 +1524,7 @@ func TestChannelLinkExpiryTooSoonMidNode(t *testing.T) { t.Fatalf("expected a ForwardingError, instead got: %T: %v", err, err) } - switch ferr.FailureMessage.(type) { + switch ferr.WireMessage().(type) { case *lnwire.FailExpiryTooSoon: default: t.Fatalf("incorrect error, expected final time lock too "+ @@ -5636,7 +5636,7 @@ func TestChannelLinkCanceledInvoice(t *testing.T) { if !ok { t.Fatalf("expected ForwardingError, but got %v", err) } - _, ok = fErr.FailureMessage.(*lnwire.FailIncorrectDetails) + _, ok = fErr.WireMessage().(*lnwire.FailIncorrectDetails) if !ok { t.Fatalf("expected unknown payment hash, but got %v", err) } @@ -6221,8 +6221,8 @@ func assertFailureCode(t *testing.T, err error, code lnwire.FailCode) { t.Fatalf("expected ForwardingError but got %T", err) } - if fErr.FailureMessage.Code() != code { + if fErr.WireMessage().Code() != code { t.Fatalf("expected %v but got %v", - code, fErr.FailureMessage.Code()) + code, fErr.WireMessage().Code()) } } diff --git a/htlcswitch/switch_test.go b/htlcswitch/switch_test.go index 28493c831..7f78bfe4e 100644 --- a/htlcswitch/switch_test.go +++ b/htlcswitch/switch_test.go @@ -2179,10 +2179,10 @@ func TestUpdateFailMalformedHTLCErrorConversion(t *testing.T) { } fwdingErr := err.(*ForwardingError) - failureMsg := fwdingErr.FailureMessage + failureMsg := fwdingErr.WireMessage() if _, ok := failureMsg.(*lnwire.FailInvalidOnionKey); !ok { t.Fatalf("expected onion failure instead got: %v", - fwdingErr.FailureMessage) + fwdingErr.WireMessage()) } } @@ -2448,7 +2448,7 @@ func TestInvalidFailure(t *testing.T) { if fErr.FailureSourceIdx != 2 { t.Fatal("unexpected error source index") } - if fErr.FailureMessage != nil { + if fErr.WireMessage() != nil { t.Fatal("expected empty failure message") } diff --git a/lnrpc/routerrpc/router_server.go b/lnrpc/routerrpc/router_server.go index ba23a0799..f761883c4 100644 --- a/lnrpc/routerrpc/router_server.go +++ b/lnrpc/routerrpc/router_server.go @@ -331,7 +331,7 @@ func marshallError(sendError error) (*Failure, error) { return nil, sendError } - switch onionErr := fErr.FailureMessage.(type) { + switch onionErr := fErr.WireMessage().(type) { case *lnwire.FailIncorrectDetails: response.Code = Failure_INCORRECT_OR_UNKNOWN_PAYMENT_DETAILS diff --git a/routing/router.go b/routing/router.go index d09383e9f..0b2ab2846 100644 --- a/routing/router.go +++ b/routing/router.go @@ -1917,7 +1917,7 @@ func (r *ChannelRouter) processSendError(paymentID uint64, rt *route.Route, return &internalErrorReason } - failureMessage := fErr.FailureMessage + failureMessage := fErr.WireMessage() failureSourceIdx := fErr.FailureSourceIdx // Apply channel update if the error contains one. For unknown diff --git a/routing/router_test.go b/routing/router_test.go index f62473a3b..53eaee8fd 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -3320,7 +3320,7 @@ func TestSendToRouteStructuredError(t *testing.T) { t.Fatalf("expected forwarding error") } - if _, ok := fErr.FailureMessage.(*lnwire.FailFeeInsufficient); !ok { + if _, ok := fErr.WireMessage().(*lnwire.FailFeeInsufficient); !ok { t.Fatalf("expected fee insufficient error") }