package image import ( "context" "crypto/sha256" "encoding/base64" "encoding/json" "errors" "fmt" "io" "net/http" "os" "path/filepath" "strconv" "strings" "github.com/h2non/filetype" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" "golang.org/x/sys/unix" "go.sour.is/paste/src/pkg/readutil" "go.sour.is/pkg/lg" ) type image struct { store string maxSize int64 } const DefaultMaxSize = 500 * 1024 * 1024 func New(store string, maxSize int64) (a *image, err error) { a = &image{ store: store, maxSize: DefaultMaxSize, } if maxSize > 0 { a.maxSize = maxSize } if !chkStore(a.store) { return nil, fmt.Errorf("image Store location [%s] does not exist or is not writable", a.store) } return a, nil } func (a *image) RegisterHTTP(mux *http.ServeMux) { mux.Handle("/i", http.StripPrefix("/i", a)) mux.Handle("/i/", http.StripPrefix("/i/", a)) mux.Handle("/3/upload", a) } func (a *image) ServeHTTP(w http.ResponseWriter, r *http.Request) { ctx := r.Context() ctx, span := lg.Span(ctx) defer span.End() switch r.Method { case http.MethodGet: name := strings.TrimPrefix(r.URL.Path, "/") a.get(ctx, w, name) case http.MethodPost: var err error var fd io.ReadCloser = r.Body if r.URL.Path == "/3/upload" { fd, _, err = r.FormFile("image") if err != nil { w.WriteHeader(http.StatusBadRequest) return } dec := base64.NewDecoder(base64.StdEncoding, fd) fd = io.NopCloser(dec) } length := 0 if h := r.Header.Get("Content-Length"); h != "" { if i, err := strconv.Atoi(h); err != nil { length = i } } id, err := a.put(ctx, w, fd, length) switch { case errors.Is(err, ErrGone): w.WriteHeader(http.StatusGone) case errors.Is(err, ErrNotFound): w.WriteHeader(http.StatusNotFound) case errors.Is(err, ErrReadingContent): w.WriteHeader(http.StatusInternalServerError) case errors.Is(err, ErrUnsupportedType): w.WriteHeader(http.StatusUnsupportedMediaType) } type data struct{ Link string `json:"link"` DeleteHash string `json:"deletehash"` } var resp = struct{ Data data `json:"data"` Success bool `json:"success"` Status int `json:"status"` }{ Data: data{ Link: fmt.Sprintf("https://%s/i/%s", r.Host, id), }, Success: true, Status: 200, } json.NewEncoder(w).Encode(resp) default: http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) } } func (a *image) get(ctx context.Context, w http.ResponseWriter, name string) error { _, span := lg.Span(ctx) defer span.End() ext := filepath.Ext(name) id := strings.TrimSuffix(name, ext) fname := filepath.Join(a.store, id) if !chkFile(fname) { return fmt.Errorf("%w: %s", ErrNotFound, fname) } if chkGone(fname) { return fmt.Errorf("%w: %s", ErrGone, fname) } f, err := os.Open(fname) if err != nil { return err } defer f.Close() pr := readutil.NewPreviewReader(f) mime, err := readutil.ReadMIME(pr, name) if err != nil { return err } w.Header().Set("Content-Type", mime) w.Header().Set("X-Content-Type-Options", "nosniff") _, _ = io.Copy(w, pr.Drain()) return nil } func (a *image) put(ctx context.Context, w http.ResponseWriter, r io.ReadCloser, length int) (string, error) { _, span := lg.Span(ctx) defer span.End() defer r.Close() if length > 0 { if int64(length) > a.maxSize { return "", ErrSizeTooLarge } } rdr := io.LimitReader(r, a.maxSize) pr := readutil.NewPreviewReader(rdr) if !isImageOrVideo(pr) { return "", ErrUnsupportedType } rdr = pr.Drain() s256 := sha256.New() tmp, err := os.CreateTemp(a.store, "image-") if err != nil { return "", fmt.Errorf("%w: %w", ErrBadInput, err) } defer os.Remove(tmp.Name()) m := io.MultiWriter(s256, tmp) if _, err := io.Copy(m, rdr); err != nil { return "", fmt.Errorf("%w: %w", ErrBadInput, err) } tmp.Close() id := base64.RawURLEncoding.EncodeToString(s256.Sum(nil)[12:]) fname := filepath.Join(a.store, id) span.AddEvent("image: moving file", trace.WithAttributes(attribute.String("src", tmp.Name()), attribute.String("dst", fname))) _ = os.Rename(tmp.Name(), fname) return id, nil } func isImageOrVideo(in io.Reader) bool { buf := make([]byte, 320) _, err := in.Read(buf) if err != nil { return false } return filetype.IsImage(buf) || filetype.IsVideo(buf) } func chkStore(path string) bool { file, err := os.Stat(path) if err != nil && os.IsNotExist(err) { err = os.MkdirAll(path, 0744) if err != nil { return false } file, err = os.Stat(path) } if err != nil { return false } if !file.IsDir() { return false } if unix.Access(path, unix.W_OK&unix.R_OK) != nil { return false } return true } func chkFile(path string) bool { file, err := os.Stat(path) if err != nil { return false } if file.IsDir() { return false } if unix.Access(path, unix.W_OK&unix.R_OK) != nil { return false } return true } func chkGone(path string) bool { file, err := os.Stat(path) if err != nil { return true } if file.Size() == 0 { return true } return false } var ( ErrNotFound = errors.New("not found") ErrGone = errors.New("gone") ErrReadingContent = errors.New("reading content") ErrSizeTooLarge = errors.New("size too large") ErrBadInput = errors.New("bad input") ErrUnsupportedType = errors.New("unsupported type") )