fix: chainmiddleware

This commit is contained in:
Jon Lundy
2023-01-09 13:09:58 -07:00
parent 0810ec73a0
commit 17569cfb2b
5 changed files with 39 additions and 30 deletions

View File

@@ -30,19 +30,25 @@ type BaseResolver interface {
}
type Resolver[T BaseResolver] struct {
res T
res T
CheckOrigin func(r *http.Request) bool
}
type IsResolver interface {
IsResolver()
}
var defaultCheckOrign = func(r *http.Request) bool {
return true
}
func New[T BaseResolver](ctx context.Context, base T, resolvers ...IsResolver) (*Resolver[T], error) {
_, span := lg.Span(ctx)
defer span.End()
noop := reflect.ValueOf(base.BaseResolver())
v := reflect.ValueOf(base)
v = reflect.Indirect(v)
noop := reflect.ValueOf(base.BaseResolver())
outer:
for _, idx := range reflect.VisibleFields(v.Type()) {
@@ -64,7 +70,7 @@ outer:
field.Set(noop)
}
return &Resolver[T]{base}, nil
return &Resolver[T]{res: base, CheckOrigin: defaultCheckOrign}, nil
}
func (r *Resolver[T]) Resolver() T {
@@ -73,15 +79,16 @@ func (r *Resolver[T]) Resolver() T {
// ChainMiddlewares will check all embeded resolvers for a GetMiddleware func and add to handler.
func (r *Resolver[T]) ChainMiddlewares(h http.Handler) http.Handler {
v := reflect.ValueOf(r) // Get reflected value of *Resolver
v = reflect.Indirect(v) // Get the pointed value (returns a zero value on nil)
n := v.NumField() // Get number of fields to iterate over.
for i := 0; i < n; i++ {
f := v.Field(i)
if !f.CanInterface() { // Skip non-interface types.
v := reflect.ValueOf(r.Resolver()) // Get reflected value of *Resolver
v = reflect.Indirect(v) // Get the pointed value (returns a zero value on nil)
for _, idx := range reflect.VisibleFields(v.Type()) {
field := v.FieldByIndex(idx.Index)
// log.Print("middleware ", field.Type().Name())
if !field.CanInterface() { // Skip non-interface types.
continue
}
if iface, ok := f.Interface().(interface {
if iface, ok := field.Interface().(interface {
GetMiddleware() func(http.Handler) http.Handler
}); ok {
h = iface.GetMiddleware()(h) // Append only items that fulfill the interface.
@@ -92,11 +99,11 @@ func (r *Resolver[T]) ChainMiddlewares(h http.Handler) http.Handler {
}
func (r *Resolver[T]) RegisterHTTP(mux *http.ServeMux) {
gql := NewServer(r.res.ExecutableSchema())
gql := NewServer(r.Resolver().ExecutableSchema(), r.CheckOrigin)
gql.SetRecoverFunc(NoopRecover)
gql.Use(otelgqlgen.Middleware())
mux.Handle("/graphiql", graphiql.Handler("GraphiQL playground", "/gql"))
mux.Handle("/gql", lg.Htrace(r.ChainMiddlewares(gql), "gql"))
mux.Handle("/graphiql", graphiql.Handler("GraphiQL playground", "/gql"))
mux.Handle("/playground", playground.Handler("GraphQL playground", "/gql"))
}
@@ -111,19 +118,12 @@ func NoopRecover(ctx context.Context, err interface{}) error {
return gqlerror.Errorf("internal system error")
}
func NewServer(es graphql.ExecutableSchema) *handler.Server {
func NewServer(es graphql.ExecutableSchema, checkOrigin func(*http.Request) bool) *handler.Server {
srv := handler.New(es)
srv.AddTransport(transport.Websocket{
Upgrader: websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
switch r.Header.Get("Origin") {
case "https://ev.sour.is", "https://www.graphqlbin.com", "http://localhost:8080":
return true
default:
return false
}
},
CheckOrigin: checkOrigin,
},
KeepAlivePingInterval: 10 * time.Second,
})