fix: chainmiddleware
This commit is contained in:
parent
0810ec73a0
commit
17569cfb2b
|
@ -108,4 +108,3 @@ func (p *PeerResults) ApplyEvent(lis ...event.Event) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -2,10 +2,12 @@ package main
|
|||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/sour-is/ev/app/gql"
|
||||
"github.com/sour-is/ev/internal/lg"
|
||||
"github.com/sour-is/ev/pkg/gql/resolver"
|
||||
"github.com/sour-is/ev/pkg/mux"
|
||||
"github.com/sour-is/ev/pkg/service"
|
||||
"github.com/sour-is/ev/pkg/slice"
|
||||
)
|
||||
|
@ -20,10 +22,19 @@ var _ = apps.Register(90, func(ctx context.Context, svc *service.Harness) error
|
|||
span.RecordError(err)
|
||||
return err
|
||||
}
|
||||
gql.CheckOrigin = func(r *http.Request) bool {
|
||||
switch r.Header.Get("Origin") {
|
||||
case "https://ev.sour.is", "https://www.graphqlbin.com", "http://localhost:8080":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
svc.Add(gql)
|
||||
// svc.Add(mux.RegisterHTTP(func(mux *http.ServeMux) {
|
||||
// mux.Handle("/", http.RedirectHandler("/playground", http.StatusTemporaryRedirect))
|
||||
// }))
|
||||
svc.Add(mux.RegisterHTTP(func(mux *http.ServeMux) {
|
||||
mux.Handle("/", http.RedirectHandler("/playground", http.StatusTemporaryRedirect))
|
||||
}))
|
||||
|
||||
return nil
|
||||
})
|
||||
|
|
|
@ -30,19 +30,25 @@ type BaseResolver interface {
|
|||
}
|
||||
|
||||
type Resolver[T BaseResolver] struct {
|
||||
res T
|
||||
res T
|
||||
CheckOrigin func(r *http.Request) bool
|
||||
}
|
||||
type IsResolver interface {
|
||||
IsResolver()
|
||||
}
|
||||
|
||||
var defaultCheckOrign = func(r *http.Request) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func New[T BaseResolver](ctx context.Context, base T, resolvers ...IsResolver) (*Resolver[T], error) {
|
||||
_, span := lg.Span(ctx)
|
||||
defer span.End()
|
||||
|
||||
noop := reflect.ValueOf(base.BaseResolver())
|
||||
|
||||
v := reflect.ValueOf(base)
|
||||
v = reflect.Indirect(v)
|
||||
noop := reflect.ValueOf(base.BaseResolver())
|
||||
|
||||
outer:
|
||||
for _, idx := range reflect.VisibleFields(v.Type()) {
|
||||
|
@ -64,7 +70,7 @@ outer:
|
|||
field.Set(noop)
|
||||
}
|
||||
|
||||
return &Resolver[T]{base}, nil
|
||||
return &Resolver[T]{res: base, CheckOrigin: defaultCheckOrign}, nil
|
||||
}
|
||||
|
||||
func (r *Resolver[T]) Resolver() T {
|
||||
|
@ -73,15 +79,16 @@ func (r *Resolver[T]) Resolver() T {
|
|||
|
||||
// ChainMiddlewares will check all embeded resolvers for a GetMiddleware func and add to handler.
|
||||
func (r *Resolver[T]) ChainMiddlewares(h http.Handler) http.Handler {
|
||||
v := reflect.ValueOf(r) // Get reflected value of *Resolver
|
||||
v = reflect.Indirect(v) // Get the pointed value (returns a zero value on nil)
|
||||
n := v.NumField() // Get number of fields to iterate over.
|
||||
for i := 0; i < n; i++ {
|
||||
f := v.Field(i)
|
||||
if !f.CanInterface() { // Skip non-interface types.
|
||||
v := reflect.ValueOf(r.Resolver()) // Get reflected value of *Resolver
|
||||
v = reflect.Indirect(v) // Get the pointed value (returns a zero value on nil)
|
||||
for _, idx := range reflect.VisibleFields(v.Type()) {
|
||||
field := v.FieldByIndex(idx.Index)
|
||||
// log.Print("middleware ", field.Type().Name())
|
||||
|
||||
if !field.CanInterface() { // Skip non-interface types.
|
||||
continue
|
||||
}
|
||||
if iface, ok := f.Interface().(interface {
|
||||
if iface, ok := field.Interface().(interface {
|
||||
GetMiddleware() func(http.Handler) http.Handler
|
||||
}); ok {
|
||||
h = iface.GetMiddleware()(h) // Append only items that fulfill the interface.
|
||||
|
@ -92,11 +99,11 @@ func (r *Resolver[T]) ChainMiddlewares(h http.Handler) http.Handler {
|
|||
}
|
||||
|
||||
func (r *Resolver[T]) RegisterHTTP(mux *http.ServeMux) {
|
||||
gql := NewServer(r.res.ExecutableSchema())
|
||||
gql := NewServer(r.Resolver().ExecutableSchema(), r.CheckOrigin)
|
||||
gql.SetRecoverFunc(NoopRecover)
|
||||
gql.Use(otelgqlgen.Middleware())
|
||||
mux.Handle("/graphiql", graphiql.Handler("GraphiQL playground", "/gql"))
|
||||
mux.Handle("/gql", lg.Htrace(r.ChainMiddlewares(gql), "gql"))
|
||||
mux.Handle("/graphiql", graphiql.Handler("GraphiQL playground", "/gql"))
|
||||
mux.Handle("/playground", playground.Handler("GraphQL playground", "/gql"))
|
||||
}
|
||||
|
||||
|
@ -111,19 +118,12 @@ func NoopRecover(ctx context.Context, err interface{}) error {
|
|||
return gqlerror.Errorf("internal system error")
|
||||
}
|
||||
|
||||
func NewServer(es graphql.ExecutableSchema) *handler.Server {
|
||||
func NewServer(es graphql.ExecutableSchema, checkOrigin func(*http.Request) bool) *handler.Server {
|
||||
srv := handler.New(es)
|
||||
|
||||
srv.AddTransport(transport.Websocket{
|
||||
Upgrader: websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
switch r.Header.Get("Origin") {
|
||||
case "https://ev.sour.is", "https://www.graphqlbin.com", "http://localhost:8080":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
},
|
||||
CheckOrigin: checkOrigin,
|
||||
},
|
||||
KeepAlivePingInterval: 10 * time.Second,
|
||||
})
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package mux
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
|
@ -13,16 +12,16 @@ type mux struct {
|
|||
|
||||
func (mux *mux) Add(fns ...interface{ RegisterHTTP(*http.ServeMux) }) {
|
||||
for _, fn := range fns {
|
||||
log.Printf("HTTP: %T", fn)
|
||||
// log.Printf("HTTP: %T", fn)
|
||||
fn.RegisterHTTP(mux.ServeMux)
|
||||
|
||||
if fn, ok := fn.(interface{ RegisterAPIv1(*http.ServeMux) }); ok {
|
||||
log.Printf("APIv1: %T", fn)
|
||||
// log.Printf("APIv1: %T", fn)
|
||||
fn.RegisterAPIv1(mux.api)
|
||||
}
|
||||
|
||||
if fn, ok := fn.(interface{ RegisterWellKnown(*http.ServeMux) }); ok {
|
||||
log.Printf("WellKnown: %T", fn)
|
||||
// log.Printf("WellKnown: %T", fn)
|
||||
fn.RegisterWellKnown(mux.wellknown)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user