From 17569cfb2b9d6e68c6a5a5dbe3558ef83bd03f5f Mon Sep 17 00:00:00 2001 From: Jon Lundy Date: Mon, 9 Jan 2023 13:09:58 -0700 Subject: [PATCH] fix: chainmiddleware --- app/peerfinder/ev-peer.go | 1 - app/webfinger/webfinger.go | 2 +- cmd/ev/svc.gql.go | 17 ++++++++++++--- pkg/gql/resolver/resolver.go | 42 ++++++++++++++++++------------------ pkg/mux/httpmux.go | 7 +++--- 5 files changed, 39 insertions(+), 30 deletions(-) diff --git a/app/peerfinder/ev-peer.go b/app/peerfinder/ev-peer.go index c2bd8aa..ac64d1e 100644 --- a/app/peerfinder/ev-peer.go +++ b/app/peerfinder/ev-peer.go @@ -108,4 +108,3 @@ func (p *PeerResults) ApplyEvent(lis ...event.Event) { } } } - diff --git a/app/webfinger/webfinger.go b/app/webfinger/webfinger.go index 011bc99..a6a1ab2 100644 --- a/app/webfinger/webfinger.go +++ b/app/webfinger/webfinger.go @@ -1 +1 @@ -package webfinger \ No newline at end of file +package webfinger diff --git a/cmd/ev/svc.gql.go b/cmd/ev/svc.gql.go index ad3f025..88d4b9e 100644 --- a/cmd/ev/svc.gql.go +++ b/cmd/ev/svc.gql.go @@ -2,10 +2,12 @@ package main import ( "context" + "net/http" "github.com/sour-is/ev/app/gql" "github.com/sour-is/ev/internal/lg" "github.com/sour-is/ev/pkg/gql/resolver" + "github.com/sour-is/ev/pkg/mux" "github.com/sour-is/ev/pkg/service" "github.com/sour-is/ev/pkg/slice" ) @@ -20,10 +22,19 @@ var _ = apps.Register(90, func(ctx context.Context, svc *service.Harness) error span.RecordError(err) return err } + gql.CheckOrigin = func(r *http.Request) bool { + switch r.Header.Get("Origin") { + case "https://ev.sour.is", "https://www.graphqlbin.com", "http://localhost:8080": + return true + default: + return false + } + } + svc.Add(gql) - // svc.Add(mux.RegisterHTTP(func(mux *http.ServeMux) { - // mux.Handle("/", http.RedirectHandler("/playground", http.StatusTemporaryRedirect)) - // })) + svc.Add(mux.RegisterHTTP(func(mux *http.ServeMux) { + mux.Handle("/", http.RedirectHandler("/playground", http.StatusTemporaryRedirect)) + })) return nil }) diff --git a/pkg/gql/resolver/resolver.go b/pkg/gql/resolver/resolver.go index 68b6e1b..224ad44 100644 --- a/pkg/gql/resolver/resolver.go +++ b/pkg/gql/resolver/resolver.go @@ -30,19 +30,25 @@ type BaseResolver interface { } type Resolver[T BaseResolver] struct { - res T + res T + CheckOrigin func(r *http.Request) bool } type IsResolver interface { IsResolver() } +var defaultCheckOrign = func(r *http.Request) bool { + return true +} + func New[T BaseResolver](ctx context.Context, base T, resolvers ...IsResolver) (*Resolver[T], error) { _, span := lg.Span(ctx) defer span.End() + noop := reflect.ValueOf(base.BaseResolver()) + v := reflect.ValueOf(base) v = reflect.Indirect(v) - noop := reflect.ValueOf(base.BaseResolver()) outer: for _, idx := range reflect.VisibleFields(v.Type()) { @@ -64,7 +70,7 @@ outer: field.Set(noop) } - return &Resolver[T]{base}, nil + return &Resolver[T]{res: base, CheckOrigin: defaultCheckOrign}, nil } func (r *Resolver[T]) Resolver() T { @@ -73,15 +79,16 @@ func (r *Resolver[T]) Resolver() T { // ChainMiddlewares will check all embeded resolvers for a GetMiddleware func and add to handler. func (r *Resolver[T]) ChainMiddlewares(h http.Handler) http.Handler { - v := reflect.ValueOf(r) // Get reflected value of *Resolver - v = reflect.Indirect(v) // Get the pointed value (returns a zero value on nil) - n := v.NumField() // Get number of fields to iterate over. - for i := 0; i < n; i++ { - f := v.Field(i) - if !f.CanInterface() { // Skip non-interface types. + v := reflect.ValueOf(r.Resolver()) // Get reflected value of *Resolver + v = reflect.Indirect(v) // Get the pointed value (returns a zero value on nil) + for _, idx := range reflect.VisibleFields(v.Type()) { + field := v.FieldByIndex(idx.Index) + // log.Print("middleware ", field.Type().Name()) + + if !field.CanInterface() { // Skip non-interface types. continue } - if iface, ok := f.Interface().(interface { + if iface, ok := field.Interface().(interface { GetMiddleware() func(http.Handler) http.Handler }); ok { h = iface.GetMiddleware()(h) // Append only items that fulfill the interface. @@ -92,11 +99,11 @@ func (r *Resolver[T]) ChainMiddlewares(h http.Handler) http.Handler { } func (r *Resolver[T]) RegisterHTTP(mux *http.ServeMux) { - gql := NewServer(r.res.ExecutableSchema()) + gql := NewServer(r.Resolver().ExecutableSchema(), r.CheckOrigin) gql.SetRecoverFunc(NoopRecover) gql.Use(otelgqlgen.Middleware()) - mux.Handle("/graphiql", graphiql.Handler("GraphiQL playground", "/gql")) mux.Handle("/gql", lg.Htrace(r.ChainMiddlewares(gql), "gql")) + mux.Handle("/graphiql", graphiql.Handler("GraphiQL playground", "/gql")) mux.Handle("/playground", playground.Handler("GraphQL playground", "/gql")) } @@ -111,19 +118,12 @@ func NoopRecover(ctx context.Context, err interface{}) error { return gqlerror.Errorf("internal system error") } -func NewServer(es graphql.ExecutableSchema) *handler.Server { +func NewServer(es graphql.ExecutableSchema, checkOrigin func(*http.Request) bool) *handler.Server { srv := handler.New(es) srv.AddTransport(transport.Websocket{ Upgrader: websocket.Upgrader{ - CheckOrigin: func(r *http.Request) bool { - switch r.Header.Get("Origin") { - case "https://ev.sour.is", "https://www.graphqlbin.com", "http://localhost:8080": - return true - default: - return false - } - }, + CheckOrigin: checkOrigin, }, KeepAlivePingInterval: 10 * time.Second, }) diff --git a/pkg/mux/httpmux.go b/pkg/mux/httpmux.go index 4c28934..add065f 100644 --- a/pkg/mux/httpmux.go +++ b/pkg/mux/httpmux.go @@ -1,7 +1,6 @@ package mux import ( - "log" "net/http" ) @@ -13,16 +12,16 @@ type mux struct { func (mux *mux) Add(fns ...interface{ RegisterHTTP(*http.ServeMux) }) { for _, fn := range fns { - log.Printf("HTTP: %T", fn) + // log.Printf("HTTP: %T", fn) fn.RegisterHTTP(mux.ServeMux) if fn, ok := fn.(interface{ RegisterAPIv1(*http.ServeMux) }); ok { - log.Printf("APIv1: %T", fn) + // log.Printf("APIv1: %T", fn) fn.RegisterAPIv1(mux.api) } if fn, ok := fn.(interface{ RegisterWellKnown(*http.ServeMux) }); ok { - log.Printf("WellKnown: %T", fn) + // log.Printf("WellKnown: %T", fn) fn.RegisterWellKnown(mux.wellknown) } }