From 3f944dd3371ddeebecffe07732f27c831afe07f2 Mon Sep 17 00:00:00 2001 From: Graham Krizek Date: Wed, 19 Aug 2020 23:06:28 -0500 Subject: [PATCH] lnd: Add CORS support to the WalletUnlocker proxy This commit adds the same CORS functionality that's currently in the main gRPC proxy to the WalletUnlocker proxy. This ensures the CORS configuration is carried through all API endpoints --- lnd.go | 2 +- rpcserver.go | 15 ++++++++------- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/lnd.go b/lnd.go index 21a6923d1..1722061b5 100644 --- a/lnd.go +++ b/lnd.go @@ -1011,7 +1011,7 @@ func waitForWalletPassword(cfg *Config, restEndpoints []net.Addr, return nil, err } - srv := &http.Server{Handler: mux} + srv := &http.Server{Handler: allowCORS(mux, cfg.RestCORS)} for _, restEndpoint := range restEndpoints { lis, err := lncfg.TLSListenOnAddress(restEndpoint, tlsConf) diff --git a/rpcserver.go b/rpcserver.go index 088f4a869..9348088da 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -810,12 +810,6 @@ func (r *rpcServer) Start() error { // Wrap the default grpc-gateway handler with the WebSocket handler. restHandler := lnrpc.NewWebSocketProxy(restMux, rpcsLog) - // Set the CORS headers if configured. This wraps the HTTP handler with - // another handler. - if len(r.cfg.RestCORS) > 0 { - restHandler = allowCORS(restHandler, r.cfg.RestCORS) - } - // With our custom REST proxy mux created, register our main RPC and // give all subservers a chance to register as well. err := lnrpc.RegisterLightningHandlerFromEndpoint( @@ -871,7 +865,8 @@ func (r *rpcServer) Start() error { // through the following chain: // req ---> CORS handler --> WS proxy ---> // REST proxy --> gRPC endpoint - err := http.Serve(lis, restHandler) + corsHandler := allowCORS(restHandler, r.cfg.RestCORS) + err := http.Serve(lis, corsHandler) if err != nil && !lnrpc.IsClosedConnError(err) { rpcsLog.Error(err) } @@ -944,6 +939,12 @@ func allowCORS(handler http.Handler, origins []string) http.Handler { allowMethods := "Access-Control-Allow-Methods" allowOrigin := "Access-Control-Allow-Origin" + // If the user didn't supply any origins that means CORS is disabled + // and we should return the original handler. + if len(origins) == 0 { + return handler + } + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { origin := r.Header.Get("Origin")