diff --git a/activitypub.go b/activitypub.go index d34e70c..0e69075 100644 --- a/activitypub.go +++ b/activitypub.go @@ -1,825 +1,825 @@ /* * Copyright © 2018-2020 A Bunch Tell LLC. * * This file is part of WriteFreely. * * WriteFreely is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License, included * in the LICENSE file in this source code package. */ package writefreely import ( "bytes" "crypto/sha256" "database/sql" "encoding/base64" "encoding/json" "fmt" "io/ioutil" "net/http" "net/http/httputil" "net/url" "strconv" "time" "github.com/gorilla/mux" "github.com/writeas/activity/streams" "github.com/writeas/httpsig" "github.com/writeas/impart" "github.com/writeas/nerds/store" "github.com/writeas/web-core/activitypub" "github.com/writeas/web-core/activitystreams" "github.com/writeas/web-core/log" ) const ( // TODO: delete. don't use this! apCustomHandleDefault = "blog" apCacheTime = time.Minute ) type RemoteUser struct { ID int64 ActorID string Inbox string SharedInbox string Handle string } func (ru *RemoteUser) AsPerson() *activitystreams.Person { return &activitystreams.Person{ BaseObject: activitystreams.BaseObject{ Type: "Person", Context: []interface{}{ activitystreams.Namespace, }, ID: ru.ActorID, }, Inbox: ru.Inbox, Endpoints: activitystreams.Endpoints{ SharedInbox: ru.SharedInbox, }, } } func activityPubClient() *http.Client { return &http.Client{ Timeout: 15 * time.Second, } } func handleFetchCollectionActivities(app *App, w http.ResponseWriter, r *http.Request) error { w.Header().Set("Server", serverSoftware) vars := mux.Vars(r) alias := vars["alias"] // TODO: enforce visibility // Get base Collection data var c *Collection var err error if app.cfg.App.SingleUser { c, err = app.db.GetCollectionByID(1) } else { c, err = app.db.GetCollection(alias) } if err != nil { return err } silenced, err := app.db.IsUserSilenced(c.OwnerID) if err != nil { log.Error("fetch collection activities: %v", err) return ErrInternalGeneral } if silenced { return ErrCollectionNotFound } c.hostName = app.cfg.App.Host p := c.PersonObject() setCacheControl(w, apCacheTime) return impart.RenderActivityJSON(w, p, http.StatusOK) } func handleFetchCollectionOutbox(app *App, w http.ResponseWriter, r *http.Request) error { w.Header().Set("Server", serverSoftware) vars := mux.Vars(r) alias := vars["alias"] // TODO: enforce visibility // Get base Collection data var c *Collection var err error if app.cfg.App.SingleUser { c, err = app.db.GetCollectionByID(1) } else { c, err = app.db.GetCollection(alias) } if err != nil { return err } silenced, err := app.db.IsUserSilenced(c.OwnerID) if err != nil { log.Error("fetch collection outbox: %v", err) return ErrInternalGeneral } if silenced { return ErrCollectionNotFound } c.hostName = app.cfg.App.Host if app.cfg.App.SingleUser { if alias != c.Alias { return ErrCollectionNotFound } } res := &CollectionObj{Collection: *c} app.db.GetPostsCount(res, false) accountRoot := c.FederatedAccount() page := r.FormValue("page") p, err := strconv.Atoi(page) if err != nil || p < 1 { // Return outbox oc := activitystreams.NewOrderedCollection(accountRoot, "outbox", res.TotalPosts) return impart.RenderActivityJSON(w, oc, http.StatusOK) } // Return outbox page ocp := activitystreams.NewOrderedCollectionPage(accountRoot, "outbox", res.TotalPosts, p) ocp.OrderedItems = []interface{}{} posts, err := app.db.GetPosts(app.cfg, c, p, false, true, false) for _, pp := range *posts { pp.Collection = res o := pp.ActivityObject(app) a := activitystreams.NewCreateActivity(o) a.Context = nil ocp.OrderedItems = append(ocp.OrderedItems, *a) } setCacheControl(w, apCacheTime) return impart.RenderActivityJSON(w, ocp, http.StatusOK) } func handleFetchCollectionFollowers(app *App, w http.ResponseWriter, r *http.Request) error { w.Header().Set("Server", serverSoftware) vars := mux.Vars(r) alias := vars["alias"] // TODO: enforce visibility // Get base Collection data var c *Collection var err error if app.cfg.App.SingleUser { c, err = app.db.GetCollectionByID(1) } else { c, err = app.db.GetCollection(alias) } if err != nil { return err } silenced, err := app.db.IsUserSilenced(c.OwnerID) if err != nil { log.Error("fetch collection followers: %v", err) return ErrInternalGeneral } if silenced { return ErrCollectionNotFound } c.hostName = app.cfg.App.Host accountRoot := c.FederatedAccount() folls, err := app.db.GetAPFollowers(c) if err != nil { return err } page := r.FormValue("page") p, err := strconv.Atoi(page) if err != nil || p < 1 { // Return outbox oc := activitystreams.NewOrderedCollection(accountRoot, "followers", len(*folls)) return impart.RenderActivityJSON(w, oc, http.StatusOK) } // Return outbox page ocp := activitystreams.NewOrderedCollectionPage(accountRoot, "followers", len(*folls), p) ocp.OrderedItems = []interface{}{} /* for _, f := range *folls { ocp.OrderedItems = append(ocp.OrderedItems, f.ActorID) } */ setCacheControl(w, apCacheTime) return impart.RenderActivityJSON(w, ocp, http.StatusOK) } func handleFetchCollectionFollowing(app *App, w http.ResponseWriter, r *http.Request) error { w.Header().Set("Server", serverSoftware) vars := mux.Vars(r) alias := vars["alias"] // TODO: enforce visibility // Get base Collection data var c *Collection var err error if app.cfg.App.SingleUser { c, err = app.db.GetCollectionByID(1) } else { c, err = app.db.GetCollection(alias) } if err != nil { return err } silenced, err := app.db.IsUserSilenced(c.OwnerID) if err != nil { log.Error("fetch collection following: %v", err) return ErrInternalGeneral } if silenced { return ErrCollectionNotFound } c.hostName = app.cfg.App.Host accountRoot := c.FederatedAccount() page := r.FormValue("page") p, err := strconv.Atoi(page) if err != nil || p < 1 { // Return outbox oc := activitystreams.NewOrderedCollection(accountRoot, "following", 0) return impart.RenderActivityJSON(w, oc, http.StatusOK) } // Return outbox page ocp := activitystreams.NewOrderedCollectionPage(accountRoot, "following", 0, p) ocp.OrderedItems = []interface{}{} setCacheControl(w, apCacheTime) return impart.RenderActivityJSON(w, ocp, http.StatusOK) } func handleFetchCollectionInbox(app *App, w http.ResponseWriter, r *http.Request) error { w.Header().Set("Server", serverSoftware) vars := mux.Vars(r) alias := vars["alias"] var c *Collection var err error if app.cfg.App.SingleUser { c, err = app.db.GetCollectionByID(1) } else { c, err = app.db.GetCollection(alias) } if err != nil { // TODO: return Reject? return err } silenced, err := app.db.IsUserSilenced(c.OwnerID) if err != nil { log.Error("fetch collection inbox: %v", err) return ErrInternalGeneral } if silenced { return ErrCollectionNotFound } c.hostName = app.cfg.App.Host if debugging { dump, err := httputil.DumpRequest(r, true) if err != nil { log.Error("Can't dump: %v", err) } else { log.Info("Rec'd! %q", dump) } } var m map[string]interface{} if err := json.NewDecoder(r.Body).Decode(&m); err != nil { return err } a := streams.NewAccept() p := c.PersonObject() var to *url.URL var isFollow, isUnfollow bool fullActor := &activitystreams.Person{} var remoteUser *RemoteUser res := &streams.Resolver{ FollowCallback: func(f *streams.Follow) error { isFollow = true // 1) Use the Follow concrete type here // 2) Errors are propagated to res.Deserialize call below m["@context"] = []string{activitystreams.Namespace} b, _ := json.Marshal(m) if debugging { log.Info("Follow: %s", b) } _, followID := f.GetId() if followID == nil { log.Error("Didn't resolve follow ID") } else { aID := c.FederatedAccount() + "#accept-" + store.GenerateFriendlyRandomString(20) acceptID, err := url.Parse(aID) if err != nil { log.Error("Couldn't parse generated Accept URL '%s': %v", aID, err) } a.SetId(acceptID) } a.AppendObject(f.Raw()) _, to = f.GetActor(0) obj := f.Raw().GetObjectIRI(0) a.AppendActor(obj) // First get actor information if to == nil { return fmt.Errorf("No valid `to` string") } fullActor, remoteUser, err = getActor(app, to.String()) if err != nil { return err } return impart.RenderActivityJSON(w, m, http.StatusOK) }, UndoCallback: func(u *streams.Undo) error { isUnfollow = true m["@context"] = []string{activitystreams.Namespace} b, _ := json.Marshal(m) if debugging { log.Info("Undo: %s", b) } a.AppendObject(u.Raw()) _, to = u.GetActor(0) // TODO: get actor from object.object, not object obj := u.Raw().GetObjectIRI(0) a.AppendActor(obj) if to != nil { // Populate fullActor from DB? remoteUser, err = getRemoteUser(app, to.String()) if err != nil { if iErr, ok := err.(*impart.HTTPError); ok { if iErr.Status == http.StatusNotFound { log.Error("No remoteuser info for Undo event!") } } return err } else { fullActor = remoteUser.AsPerson() } } else { log.Error("No to on Undo!") } return impart.RenderActivityJSON(w, m, http.StatusOK) }, } if err := res.Deserialize(m); err != nil { // 3) Any errors from #2 can be handled, or the payload is an unknown type. log.Error("Unable to resolve Follow: %v", err) if debugging { log.Error("Map: %s", m) } return err } go func() { if to == nil { if debugging { log.Error("No `to` value!") } return } time.Sleep(2 * time.Second) am, err := a.Serialize() if err != nil { log.Error("Unable to serialize Accept: %v", err) return } am["@context"] = []string{activitystreams.Namespace} err = makeActivityPost(app.cfg.App.Host, p, fullActor.Inbox, am) if err != nil { log.Error("Unable to make activity POST: %v", err) return } if isFollow { t, err := app.db.Begin() if err != nil { log.Error("Unable to start transaction: %v", err) return } var followerID int64 if remoteUser != nil { followerID = remoteUser.ID } else { // Add follower locally, since it wasn't found before res, err := t.Exec("INSERT INTO remoteusers (actor_id, inbox, shared_inbox) VALUES (?, ?, ?)", fullActor.ID, fullActor.Inbox, fullActor.Endpoints.SharedInbox) if err != nil { // if duplicate key, res will be nil and panic on // res.LastInsertId below t.Rollback() log.Error("Couldn't add new remoteuser in DB: %v\n", err) return } followerID, err = res.LastInsertId() if err != nil { t.Rollback() log.Error("no lastinsertid for followers, rolling back: %v", err) return } // Add in key _, err = t.Exec("INSERT INTO remoteuserkeys (id, remote_user_id, public_key) VALUES (?, ?, ?)", fullActor.PublicKey.ID, followerID, fullActor.PublicKey.PublicKeyPEM) if err != nil { if !app.db.isDuplicateKeyErr(err) { t.Rollback() log.Error("Couldn't add follower keys in DB: %v\n", err) return } } } // Add follow _, err = t.Exec("INSERT INTO remotefollows (collection_id, remote_user_id, created) VALUES (?, ?, "+app.db.now()+")", c.ID, followerID) if err != nil { if !app.db.isDuplicateKeyErr(err) { t.Rollback() log.Error("Couldn't add follower in DB: %v\n", err) return } } err = t.Commit() if err != nil { t.Rollback() log.Error("Rolling back after Commit(): %v\n", err) return } } else if isUnfollow { // Remove follower locally _, err = app.db.Exec("DELETE FROM remotefollows WHERE collection_id = ? AND remote_user_id = (SELECT id FROM remoteusers WHERE actor_id = ?)", c.ID, to.String()) if err != nil { log.Error("Couldn't remove follower from DB: %v\n", err) } } }() return nil } func makeActivityPost(hostName string, p *activitystreams.Person, url string, m interface{}) error { log.Info("POST %s", url) b, err := json.Marshal(m) if err != nil { return err } r, _ := http.NewRequest("POST", url, bytes.NewBuffer(b)) r.Header.Add("Content-Type", "application/activity+json") - r.Header.Set("User-Agent", "Go ("+serverSoftware+"/"+softwareVer+"; +"+hostName+")") + r.Header.Set("User-Agent", ServerUserAgent(hostName)) h := sha256.New() h.Write(b) r.Header.Add("Digest", "SHA-256="+base64.StdEncoding.EncodeToString(h.Sum(nil))) // Sign using the 'Signature' header privKey, err := activitypub.DecodePrivateKey(p.GetPrivKey()) if err != nil { return err } signer := httpsig.NewSigner(p.PublicKey.ID, privKey, httpsig.RSASHA256, []string{"(request-target)", "date", "host", "digest"}) err = signer.SignSigHeader(r) if err != nil { log.Error("Can't sign: %v", err) } if debugging { dump, err := httputil.DumpRequestOut(r, true) if err != nil { log.Error("Can't dump: %v", err) } else { log.Info("%s", dump) } } resp, err := activityPubClient().Do(r) if err != nil { return err } if resp != nil && resp.Body != nil { defer resp.Body.Close() } body, err := ioutil.ReadAll(resp.Body) if err != nil { return err } if debugging { log.Info("Status : %s", resp.Status) log.Info("Response: %s", body) } return nil } func resolveIRI(hostName, url string) ([]byte, error) { log.Info("GET %s", url) r, _ := http.NewRequest("GET", url, nil) r.Header.Add("Accept", "application/activity+json") - r.Header.Set("User-Agent", "Go ("+serverSoftware+"/"+softwareVer+"; +"+hostName+")") + r.Header.Set("User-Agent", ServerUserAgent(hostName)) if debugging { dump, err := httputil.DumpRequestOut(r, true) if err != nil { log.Error("Can't dump: %v", err) } else { log.Info("%s", dump) } } resp, err := activityPubClient().Do(r) if err != nil { return nil, err } if resp != nil && resp.Body != nil { defer resp.Body.Close() } body, err := ioutil.ReadAll(resp.Body) if err != nil { return nil, err } if debugging { log.Info("Status : %s", resp.Status) log.Info("Response: %s", body) } return body, nil } func deleteFederatedPost(app *App, p *PublicPost, collID int64) error { if debugging { log.Info("Deleting federated post!") } p.Collection.hostName = app.cfg.App.Host actor := p.Collection.PersonObject(collID) na := p.ActivityObject(app) // Add followers p.Collection.ID = collID followers, err := app.db.GetAPFollowers(&p.Collection.Collection) if err != nil { log.Error("Couldn't delete post (get followers)! %v", err) return err } inboxes := map[string][]string{} for _, f := range *followers { inbox := f.SharedInbox if inbox == "" { inbox = f.Inbox } if _, ok := inboxes[inbox]; ok { inboxes[inbox] = append(inboxes[inbox], f.ActorID) } else { inboxes[inbox] = []string{f.ActorID} } } for si, instFolls := range inboxes { na.CC = []string{} for _, f := range instFolls { na.CC = append(na.CC, f) } da := activitystreams.NewDeleteActivity(na) // Make the ID unique to ensure it works in Pleroma // See: https://git.pleroma.social/pleroma/pleroma/issues/1481 da.ID += "#Delete" err = makeActivityPost(app.cfg.App.Host, actor, si, da) if err != nil { log.Error("Couldn't delete post! %v", err) } } return nil } func federatePost(app *App, p *PublicPost, collID int64, isUpdate bool) error { if debugging { if isUpdate { log.Info("Federating updated post!") } else { log.Info("Federating new post!") } } actor := p.Collection.PersonObject(collID) na := p.ActivityObject(app) // Add followers p.Collection.ID = collID followers, err := app.db.GetAPFollowers(&p.Collection.Collection) if err != nil { log.Error("Couldn't post! %v", err) return err } log.Info("Followers for %d: %+v", collID, followers) inboxes := map[string][]string{} for _, f := range *followers { inbox := f.SharedInbox if inbox == "" { inbox = f.Inbox } if _, ok := inboxes[inbox]; ok { // check if we're already sending to this shared inbox inboxes[inbox] = append(inboxes[inbox], f.ActorID) } else { // add the new shared inbox to the list inboxes[inbox] = []string{f.ActorID} } } var activity *activitystreams.Activity // for each one of the shared inboxes for si, instFolls := range inboxes { // add all followers from that instance // to the CC field na.CC = []string{} for _, f := range instFolls { na.CC = append(na.CC, f) } // create a new "Create" activity // with our article as object if isUpdate { activity = activitystreams.NewUpdateActivity(na) } else { activity = activitystreams.NewCreateActivity(na) activity.To = na.To activity.CC = na.CC } // and post it to that sharedInbox err = makeActivityPost(app.cfg.App.Host, actor, si, activity) if err != nil { log.Error("Couldn't post! %v", err) } } // re-create the object so that the CC list gets reset and has // the mentioned users. This might seem wasteful but the code is // cleaner than adding the mentioned users to CC here instead of // in p.ActivityObject() na = p.ActivityObject(app) for _, tag := range na.Tag { if tag.Type == "Mention" { activity = activitystreams.NewCreateActivity(na) activity.To = na.To activity.CC = na.CC // This here might be redundant in some cases as we might have already // sent this to the sharedInbox of this instance above, but we need too // much logic to catch this at the expense of the odd extra request. // I don't believe we'd ever have too many mentions in a single post that this // could become a burden. remoteUser, err := getRemoteUser(app, tag.HRef) if err != nil { log.Error("Unable to find remote user %s. Skipping: %v", tag.HRef, err) continue } err = makeActivityPost(app.cfg.App.Host, actor, remoteUser.Inbox, activity) if err != nil { log.Error("Couldn't post! %v", err) } } } return nil } func getRemoteUser(app *App, actorID string) (*RemoteUser, error) { u := RemoteUser{ActorID: actorID} var handle sql.NullString err := app.db.QueryRow("SELECT id, inbox, shared_inbox, handle FROM remoteusers WHERE actor_id = ?", actorID).Scan(&u.ID, &u.Inbox, &u.SharedInbox, &handle) switch { case err == sql.ErrNoRows: return nil, impart.HTTPError{http.StatusNotFound, "No remote user with that ID."} case err != nil: log.Error("Couldn't get remote user %s: %v", actorID, err) return nil, err } u.Handle = handle.String return &u, nil } // getRemoteUserFromHandle retrieves the profile page of a remote user // from the @user@server.tld handle func getRemoteUserFromHandle(app *App, handle string) (*RemoteUser, error) { u := RemoteUser{Handle: handle} err := app.db.QueryRow("SELECT id, actor_id, inbox, shared_inbox FROM remoteusers WHERE handle = ?", handle).Scan(&u.ID, &u.ActorID, &u.Inbox, &u.SharedInbox) switch { case err == sql.ErrNoRows: return nil, ErrRemoteUserNotFound case err != nil: log.Error("Couldn't get remote user %s: %v", handle, err) return nil, err } return &u, nil } func getActor(app *App, actorIRI string) (*activitystreams.Person, *RemoteUser, error) { log.Info("Fetching actor %s locally", actorIRI) actor := &activitystreams.Person{} remoteUser, err := getRemoteUser(app, actorIRI) if err != nil { if iErr, ok := err.(impart.HTTPError); ok { if iErr.Status == http.StatusNotFound { // Fetch remote actor log.Info("Not found; fetching actor %s remotely", actorIRI) actorResp, err := resolveIRI(app.cfg.App.Host, actorIRI) if err != nil { log.Error("Unable to get actor! %v", err) return nil, nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't fetch actor."} } if err := unmarshalActor(actorResp, actor); err != nil { log.Error("Unable to unmarshal actor! %v", err) return nil, nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't parse actor."} } } else { return nil, nil, err } } else { return nil, nil, err } } else { actor = remoteUser.AsPerson() } return actor, remoteUser, nil } // unmarshal actor normalizes the actor response to conform to // the type Person from github.com/writeas/web-core/activitysteams // // some implementations return different context field types // this converts any non-slice contexts into a slice func unmarshalActor(actorResp []byte, actor *activitystreams.Person) error { // FIXME: Hubzilla has an object for the Actor's url: cannot unmarshal object into Go struct field Person.url of type string // flexActor overrides the Context field to allow // all valid representations during unmarshal flexActor := struct { activitystreams.Person Context json.RawMessage `json:"@context,omitempty"` }{} if err := json.Unmarshal(actorResp, &flexActor); err != nil { return err } actor.Endpoints = flexActor.Endpoints actor.Followers = flexActor.Followers actor.Following = flexActor.Following actor.ID = flexActor.ID actor.Icon = flexActor.Icon actor.Inbox = flexActor.Inbox actor.Name = flexActor.Name actor.Outbox = flexActor.Outbox actor.PreferredUsername = flexActor.PreferredUsername actor.PublicKey = flexActor.PublicKey actor.Summary = flexActor.Summary actor.Type = flexActor.Type actor.URL = flexActor.URL func(val interface{}) { switch val.(type) { case []interface{}: // already a slice, do nothing actor.Context = val.([]interface{}) default: actor.Context = []interface{}{val} } }(flexActor.Context) return nil } func setCacheControl(w http.ResponseWriter, ttl time.Duration) { w.Header().Set("Cache-Control", fmt.Sprintf("public, max-age=%.0f", ttl.Seconds())) } diff --git a/app.go b/app.go index af0a56f..06e677b 100644 --- a/app.go +++ b/app.go @@ -1,905 +1,915 @@ /* * Copyright © 2018-2019 A Bunch Tell LLC. * * This file is part of WriteFreely. * * WriteFreely is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License, included * in the LICENSE file in this source code package. */ package writefreely import ( "crypto/tls" "database/sql" "fmt" "html/template" "io/ioutil" "net/http" "net/url" "os" "os/signal" "path/filepath" "regexp" "strings" "syscall" "time" "github.com/gorilla/mux" "github.com/gorilla/schema" "github.com/gorilla/sessions" "github.com/manifoldco/promptui" stripmd "github.com/writeas/go-strip-markdown" "github.com/writeas/impart" "github.com/writeas/web-core/auth" "github.com/writeas/web-core/converter" "github.com/writeas/web-core/log" "github.com/writeas/writefreely/author" "github.com/writeas/writefreely/config" "github.com/writeas/writefreely/key" "github.com/writeas/writefreely/migrations" "github.com/writeas/writefreely/page" "golang.org/x/crypto/acme/autocert" ) const ( staticDir = "static" assumedTitleLen = 80 postsPerPage = 10 serverSoftware = "WriteFreely" softwareURL = "https://writefreely.org" ) var ( debugging bool // Software version can be set from git env using -ldflags softwareVer = "0.12.0" // DEPRECATED VARS isSingleUser bool ) // App holds data and configuration for an individual WriteFreely instance. type App struct { router *mux.Router shttp *http.ServeMux db *datastore cfg *config.Config cfgFile string keys *key.Keychain sessionStore sessions.Store formDecoder *schema.Decoder updates *updatesCache timeline *localTimeline } // DB returns the App's datastore func (app *App) DB() *datastore { return app.db } // Router returns the App's router func (app *App) Router() *mux.Router { return app.router } // Config returns the App's current configuration. func (app *App) Config() *config.Config { return app.cfg } // SetConfig updates the App's Config to the given value. func (app *App) SetConfig(cfg *config.Config) { app.cfg = cfg } // SetKeys updates the App's Keychain to the given value. func (app *App) SetKeys(k *key.Keychain) { app.keys = k } func (app *App) SessionStore() sessions.Store { return app.sessionStore } func (app *App) SetSessionStore(s sessions.Store) { app.sessionStore = s } // Apper is the interface for getting data into and out of a WriteFreely // instance (or "App"). // // App returns the App for the current instance. // // LoadConfig reads an app configuration into the App, returning any error // encountered. // // SaveConfig persists the current App configuration. // // LoadKeys reads the App's encryption keys and loads them into its // key.Keychain. type Apper interface { App() *App LoadConfig() error SaveConfig(*config.Config) error LoadKeys() error ReqLog(r *http.Request, status int, timeSince time.Duration) string } // App returns the App func (app *App) App() *App { return app } // LoadConfig loads and parses a config file. func (app *App) LoadConfig() error { log.Info("Loading %s configuration...", app.cfgFile) cfg, err := config.Load(app.cfgFile) if err != nil { log.Error("Unable to load configuration: %v", err) os.Exit(1) return err } app.cfg = cfg return nil } // SaveConfig saves the given Config to disk -- namely, to the App's cfgFile. func (app *App) SaveConfig(c *config.Config) error { return config.Save(c, app.cfgFile) } // LoadKeys reads all needed keys from disk into the App. In order to use the // configured `Server.KeysParentDir`, you must call initKeyPaths(App) before // this. func (app *App) LoadKeys() error { var err error app.keys = &key.Keychain{} if debugging { log.Info(" %s", emailKeyPath) } app.keys.EmailKey, err = ioutil.ReadFile(emailKeyPath) if err != nil { return err } if debugging { log.Info(" %s", cookieAuthKeyPath) } app.keys.CookieAuthKey, err = ioutil.ReadFile(cookieAuthKeyPath) if err != nil { return err } if debugging { log.Info(" %s", cookieKeyPath) } app.keys.CookieKey, err = ioutil.ReadFile(cookieKeyPath) if err != nil { return err } return nil } func (app *App) ReqLog(r *http.Request, status int, timeSince time.Duration) string { return fmt.Sprintf("\"%s %s\" %d %s \"%s\"", r.Method, r.RequestURI, status, timeSince, r.UserAgent()) } // handleViewHome shows page at root path. It checks the configuration and // authentication state to show the correct page. func handleViewHome(app *App, w http.ResponseWriter, r *http.Request) error { if app.cfg.App.SingleUser { // Render blog index return handleViewCollection(app, w, r) } // Multi-user instance forceLanding := r.FormValue("landing") == "1" if !forceLanding { // Show correct page based on user auth status and configured landing path u := getUserSession(app, r) if app.cfg.App.Chorus { // This instance is focused on reading, so show Reader on home route if not // private or a private-instance user is logged in. if !app.cfg.App.Private || u != nil { return viewLocalTimeline(app, w, r) } } if u != nil { // User is logged in, so show the Pad return handleViewPad(app, w, r) } if app.cfg.App.Private { return viewLogin(app, w, r) } if land := app.cfg.App.LandingPath(); land != "/" { return impart.HTTPError{http.StatusFound, land} } } return handleViewLanding(app, w, r) } func handleViewLanding(app *App, w http.ResponseWriter, r *http.Request) error { forceLanding := r.FormValue("landing") == "1" p := struct { page.StaticPage Flashes []template.HTML Banner template.HTML Content template.HTML ForcedLanding bool OauthSlack bool OauthWriteAs bool OauthGitlab bool OauthGeneric bool OauthGenericDisplayName string GitlabDisplayName string }{ StaticPage: pageForReq(app, r), ForcedLanding: forceLanding, OauthSlack: app.Config().SlackOauth.ClientID != "", OauthWriteAs: app.Config().WriteAsOauth.ClientID != "", OauthGitlab: app.Config().GitlabOauth.ClientID != "", OauthGeneric: app.Config().GenericOauth.ClientID != "", OauthGenericDisplayName: config.OrDefaultString(app.Config().GenericOauth.DisplayName, genericOauthDisplayName), GitlabDisplayName: config.OrDefaultString(app.Config().GitlabOauth.DisplayName, gitlabDisplayName), } banner, err := getLandingBanner(app) if err != nil { log.Error("unable to get landing banner: %v", err) return impart.HTTPError{http.StatusInternalServerError, fmt.Sprintf("Could not get banner: %v", err)} } p.Banner = template.HTML(applyMarkdown([]byte(banner.Content), "", app.cfg)) content, err := getLandingBody(app) if err != nil { log.Error("unable to get landing content: %v", err) return impart.HTTPError{http.StatusInternalServerError, fmt.Sprintf("Could not get content: %v", err)} } p.Content = template.HTML(applyMarkdown([]byte(content.Content), "", app.cfg)) // Get error messages session, err := app.sessionStore.Get(r, cookieName) if err != nil { // Ignore this log.Error("Unable to get session in handleViewHome; ignoring: %v", err) } flashes, _ := getSessionFlashes(app, w, r, session) for _, flash := range flashes { p.Flashes = append(p.Flashes, template.HTML(flash)) } // Show landing page return renderPage(w, "landing.tmpl", p) } func handleTemplatedPage(app *App, w http.ResponseWriter, r *http.Request, t *template.Template) error { p := struct { page.StaticPage ContentTitle string Content template.HTML PlainContent string Updated string AboutStats *InstanceStats }{ StaticPage: pageForReq(app, r), } if r.URL.Path == "/about" || r.URL.Path == "/privacy" { var c *instanceContent var err error if r.URL.Path == "/about" { c, err = getAboutPage(app) // Fetch stats p.AboutStats = &InstanceStats{} p.AboutStats.NumPosts, _ = app.db.GetTotalPosts() p.AboutStats.NumBlogs, _ = app.db.GetTotalCollections() } else { c, err = getPrivacyPage(app) } if err != nil { return err } p.ContentTitle = c.Title.String p.Content = template.HTML(applyMarkdown([]byte(c.Content), "", app.cfg)) p.PlainContent = shortPostDescription(stripmd.Strip(c.Content)) if !c.Updated.IsZero() { p.Updated = c.Updated.Format("January 2, 2006") } } // Serve templated page err := t.ExecuteTemplate(w, "base", p) if err != nil { log.Error("Unable to render page: %v", err) } return nil } func pageForReq(app *App, r *http.Request) page.StaticPage { p := page.StaticPage{ AppCfg: app.cfg.App, Path: r.URL.Path, Version: "v" + softwareVer, } // Add user information, if given var u *User accessToken := r.FormValue("t") if accessToken != "" { userID := app.db.GetUserID(accessToken) if userID != -1 { var err error u, err = app.db.GetUserByID(userID) if err == nil { p.Username = u.Username } } } else { u = getUserSession(app, r) if u != nil { p.Username = u.Username p.IsAdmin = u != nil && u.IsAdmin() p.CanInvite = canUserInvite(app.cfg, p.IsAdmin) } } p.CanViewReader = !app.cfg.App.Private || u != nil return p } var fileRegex = regexp.MustCompile("/([^/]*\\.[^/]*)$") // Initialize loads the app configuration and initializes templates, keys, // session, route handlers, and the database connection. func Initialize(apper Apper, debug bool) (*App, error) { debugging = debug apper.LoadConfig() // Load templates err := InitTemplates(apper.App().Config()) if err != nil { return nil, fmt.Errorf("load templates: %s", err) } // Load keys and set up session initKeyPaths(apper.App()) // TODO: find a better way to do this, since it's unneeded in all Apper implementations err = InitKeys(apper) if err != nil { return nil, fmt.Errorf("init keys: %s", err) } apper.App().InitUpdates() apper.App().InitSession() apper.App().InitDecoder() err = ConnectToDatabase(apper.App()) if err != nil { return nil, fmt.Errorf("connect to DB: %s", err) } // Handle local timeline, if enabled if apper.App().cfg.App.LocalTimeline { log.Info("Initializing local timeline...") initLocalTimeline(apper.App()) } return apper.App(), nil } func Serve(app *App, r *mux.Router) { log.Info("Going to serve...") isSingleUser = app.cfg.App.SingleUser app.cfg.Server.Dev = debugging // Handle shutdown c := make(chan os.Signal, 2) signal.Notify(c, os.Interrupt, syscall.SIGTERM) go func() { <-c log.Info("Shutting down...") shutdown(app) log.Info("Done.") os.Exit(0) }() // Start gopher server if app.cfg.Server.GopherPort > 0 && !app.cfg.App.Private { go initGopher(app) } // Start web application server var bindAddress = app.cfg.Server.Bind if bindAddress == "" { bindAddress = "localhost" } var err error if app.cfg.IsSecureStandalone() { if app.cfg.Server.Autocert { m := &autocert.Manager{ Prompt: autocert.AcceptTOS, Cache: autocert.DirCache(app.cfg.Server.TLSCertPath), } host, err := url.Parse(app.cfg.App.Host) if err != nil { log.Error("[WARNING] Unable to parse configured host! %s", err) log.Error(`[WARNING] ALL hosts are allowed, which can open you to an attack where clients connect to a server by IP address and pretend to be asking for an incorrect host name, and cause you to reach the CA's rate limit for certificate requests. We recommend supplying a valid host name.`) log.Info("Using autocert on ANY host") } else { log.Info("Using autocert on host %s", host.Host) m.HostPolicy = autocert.HostWhitelist(host.Host) } s := &http.Server{ Addr: ":https", Handler: r, TLSConfig: &tls.Config{ GetCertificate: m.GetCertificate, }, } s.SetKeepAlivesEnabled(false) go func() { log.Info("Serving redirects on http://%s:80", bindAddress) err = http.ListenAndServe(":80", m.HTTPHandler(nil)) log.Error("Unable to start redirect server: %v", err) }() log.Info("Serving on https://%s:443", bindAddress) log.Info("---") err = s.ListenAndServeTLS("", "") } else { go func() { log.Info("Serving redirects on http://%s:80", bindAddress) err = http.ListenAndServe(fmt.Sprintf("%s:80", bindAddress), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, app.cfg.App.Host, http.StatusMovedPermanently) })) log.Error("Unable to start redirect server: %v", err) }() log.Info("Serving on https://%s:443", bindAddress) log.Info("Using manual certificates") log.Info("---") err = http.ListenAndServeTLS(fmt.Sprintf("%s:443", bindAddress), app.cfg.Server.TLSCertPath, app.cfg.Server.TLSKeyPath, r) } } else { log.Info("Serving on http://%s:%d\n", bindAddress, app.cfg.Server.Port) log.Info("---") err = http.ListenAndServe(fmt.Sprintf("%s:%d", bindAddress, app.cfg.Server.Port), r) } if err != nil { log.Error("Unable to start: %v", err) os.Exit(1) } } func (app *App) InitDecoder() { // TODO: do this at the package level, instead of the App level // Initialize modules app.formDecoder = schema.NewDecoder() app.formDecoder.RegisterConverter(converter.NullJSONString{}, converter.ConvertJSONNullString) app.formDecoder.RegisterConverter(converter.NullJSONBool{}, converter.ConvertJSONNullBool) app.formDecoder.RegisterConverter(sql.NullString{}, converter.ConvertSQLNullString) app.formDecoder.RegisterConverter(sql.NullBool{}, converter.ConvertSQLNullBool) app.formDecoder.RegisterConverter(sql.NullInt64{}, converter.ConvertSQLNullInt64) app.formDecoder.RegisterConverter(sql.NullFloat64{}, converter.ConvertSQLNullFloat64) } // ConnectToDatabase validates and connects to the configured database, then // tests the connection. func ConnectToDatabase(app *App) error { // Check database configuration if app.cfg.Database.Type == driverMySQL && (app.cfg.Database.User == "" || app.cfg.Database.Password == "") { return fmt.Errorf("Database user or password not set.") } if app.cfg.Database.Host == "" { app.cfg.Database.Host = "localhost" } if app.cfg.Database.Database == "" { app.cfg.Database.Database = "writefreely" } // TODO: check err connectToDatabase(app) // Test database connection err := app.db.Ping() if err != nil { return fmt.Errorf("Database ping failed: %s", err) } return nil } // FormatVersion constructs the version string for the application func FormatVersion() string { return serverSoftware + " " + softwareVer } // OutputVersion prints out the version of the application. func OutputVersion() { fmt.Println(FormatVersion()) } // NewApp creates a new app instance. func NewApp(cfgFile string) *App { return &App{ cfgFile: cfgFile, } } // CreateConfig creates a default configuration and saves it to the app's cfgFile. func CreateConfig(app *App) error { log.Info("Creating configuration...") c := config.New() log.Info("Saving configuration %s...", app.cfgFile) err := config.Save(c, app.cfgFile) if err != nil { return fmt.Errorf("Unable to save configuration: %v", err) } return nil } // DoConfig runs the interactive configuration process. func DoConfig(app *App, configSections string) { if configSections == "" { configSections = "server db app" } // let's check there aren't any garbage in the list configSectionsArray := strings.Split(configSections, " ") for _, element := range configSectionsArray { if element != "server" && element != "db" && element != "app" { log.Error("Invalid argument to --sections. Valid arguments are only \"server\", \"db\" and \"app\"") os.Exit(1) } } d, err := config.Configure(app.cfgFile, configSections) if err != nil { log.Error("Unable to configure: %v", err) os.Exit(1) } app.cfg = d.Config connectToDatabase(app) defer shutdown(app) if !app.db.DatabaseInitialized() { err = adminInitDatabase(app) if err != nil { log.Error(err.Error()) os.Exit(1) } } else { log.Info("Database already initialized.") } if d.User != nil { u := &User{ Username: d.User.Username, HashedPass: d.User.HashedPass, Created: time.Now().Truncate(time.Second).UTC(), } // Create blog log.Info("Creating user %s...\n", u.Username) err = app.db.CreateUser(app.cfg, u, app.cfg.App.SiteName) if err != nil { log.Error("Unable to create user: %s", err) os.Exit(1) } log.Info("Done!") } os.Exit(0) } // GenerateKeyFiles creates app encryption keys and saves them into the configured KeysParentDir. func GenerateKeyFiles(app *App) error { // Read keys path from config app.LoadConfig() // Create keys dir if it doesn't exist yet fullKeysDir := filepath.Join(app.cfg.Server.KeysParentDir, keysDir) if _, err := os.Stat(fullKeysDir); os.IsNotExist(err) { err = os.Mkdir(fullKeysDir, 0700) if err != nil { return err } } // Generate keys initKeyPaths(app) // TODO: use something like https://github.com/hashicorp/go-multierror to return errors var keyErrs error err := generateKey(emailKeyPath) if err != nil { keyErrs = err } err = generateKey(cookieAuthKeyPath) if err != nil { keyErrs = err } err = generateKey(cookieKeyPath) if err != nil { keyErrs = err } return keyErrs } // CreateSchema creates all database tables needed for the application. func CreateSchema(apper Apper) error { apper.LoadConfig() connectToDatabase(apper.App()) defer shutdown(apper.App()) err := adminInitDatabase(apper.App()) if err != nil { return err } return nil } // Migrate runs all necessary database migrations. func Migrate(apper Apper) error { apper.LoadConfig() connectToDatabase(apper.App()) defer shutdown(apper.App()) err := migrations.Migrate(migrations.NewDatastore(apper.App().db.DB, apper.App().db.driverName)) if err != nil { return fmt.Errorf("migrate: %s", err) } return nil } // ResetPassword runs the interactive password reset process. func ResetPassword(apper Apper, username string) error { // Connect to the database apper.LoadConfig() connectToDatabase(apper.App()) defer shutdown(apper.App()) // Fetch user u, err := apper.App().db.GetUserForAuth(username) if err != nil { log.Error("Get user: %s", err) os.Exit(1) } // Prompt for new password prompt := promptui.Prompt{ Templates: &promptui.PromptTemplates{ Success: "{{ . | bold | faint }}: ", }, Label: "New password", Mask: '*', } newPass, err := prompt.Run() if err != nil { log.Error("%s", err) os.Exit(1) } // Do the update log.Info("Updating...") err = adminResetPassword(apper.App(), u, newPass) if err != nil { log.Error("%s", err) os.Exit(1) } log.Info("Success.") return nil } // DoDeleteAccount runs the confirmation and account delete process. func DoDeleteAccount(apper Apper, username string) error { // Connect to the database apper.LoadConfig() connectToDatabase(apper.App()) defer shutdown(apper.App()) // check user exists u, err := apper.App().db.GetUserForAuth(username) if err != nil { log.Error("%s", err) os.Exit(1) } userID := u.ID // do not delete the admin account // TODO: check for other admins and skip? if u.IsAdmin() { log.Error("Can not delete admin account") os.Exit(1) } // confirm deletion, w/ w/out posts prompt := promptui.Prompt{ Templates: &promptui.PromptTemplates{ Success: "{{ . | bold | faint }}: ", }, Label: fmt.Sprintf("Really delete user : %s", username), IsConfirm: true, } _, err = prompt.Run() if err != nil { log.Info("Aborted...") os.Exit(0) } log.Info("Deleting...") err = apper.App().db.DeleteAccount(userID) if err != nil { log.Error("%s", err) os.Exit(1) } log.Info("Success.") return nil } func connectToDatabase(app *App) { log.Info("Connecting to %s database...", app.cfg.Database.Type) var db *sql.DB var err error if app.cfg.Database.Type == driverMySQL { db, err = sql.Open(app.cfg.Database.Type, fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=true&loc=%s&tls=%t", app.cfg.Database.User, app.cfg.Database.Password, app.cfg.Database.Host, app.cfg.Database.Port, app.cfg.Database.Database, url.QueryEscape(time.Local.String()), app.cfg.Database.TLS)) db.SetMaxOpenConns(50) } else if app.cfg.Database.Type == driverSQLite { if !SQLiteEnabled { log.Error("Invalid database type '%s'. Binary wasn't compiled with SQLite3 support.", app.cfg.Database.Type) os.Exit(1) } if app.cfg.Database.FileName == "" { log.Error("SQLite database filename value in config.ini is empty.") os.Exit(1) } db, err = sql.Open("sqlite3_with_regex", app.cfg.Database.FileName+"?parseTime=true&cached=shared") db.SetMaxOpenConns(1) } else { log.Error("Invalid database type '%s'. Only 'mysql' and 'sqlite3' are supported right now.", app.cfg.Database.Type) os.Exit(1) } if err != nil { log.Error("%s", err) os.Exit(1) } app.db = &datastore{db, app.cfg.Database.Type} } func shutdown(app *App) { log.Info("Closing database connection...") app.db.Close() } // CreateUser creates a new admin or normal user from the given credentials. func CreateUser(apper Apper, username, password string, isAdmin bool) error { // Create an admin user with --create-admin apper.LoadConfig() connectToDatabase(apper.App()) defer shutdown(apper.App()) // Ensure an admin / first user doesn't already exist firstUser, _ := apper.App().db.GetUserByID(1) if isAdmin { // Abort if trying to create admin user, but one already exists if firstUser != nil { return fmt.Errorf("Admin user already exists (%s). Create a regular user with: writefreely --create-user", firstUser.Username) } } else { // Abort if trying to create regular user, but no admin exists yet if firstUser == nil { return fmt.Errorf("No admin user exists yet. Create an admin first with: writefreely --create-admin") } } // Create the user // Normalize and validate username desiredUsername := username username = getSlug(username, "") usernameDesc := username if username != desiredUsername { usernameDesc += " (originally: " + desiredUsername + ")" } if !author.IsValidUsername(apper.App().cfg, username) { return fmt.Errorf("Username %s is invalid, reserved, or shorter than configured minimum length (%d characters).", usernameDesc, apper.App().cfg.App.MinUsernameLen) } // Hash the password hashedPass, err := auth.HashPass([]byte(password)) if err != nil { return fmt.Errorf("Unable to hash password: %v", err) } u := &User{ Username: username, HashedPass: hashedPass, Created: time.Now().Truncate(time.Second).UTC(), } userType := "user" if isAdmin { userType = "admin" } log.Info("Creating %s %s...", userType, usernameDesc) err = apper.App().db.CreateUser(apper.App().Config(), u, desiredUsername) if err != nil { return fmt.Errorf("Unable to create user: %s", err) } log.Info("Done!") return nil } func adminInitDatabase(app *App) error { schemaFileName := "schema.sql" if app.cfg.Database.Type == driverSQLite { schemaFileName = "sqlite.sql" } schema, err := Asset(schemaFileName) if err != nil { return fmt.Errorf("Unable to load schema file: %v", err) } tblReg := regexp.MustCompile("CREATE TABLE (IF NOT EXISTS )?`([a-z_]+)`") queries := strings.Split(string(schema), ";\n") for _, q := range queries { if strings.TrimSpace(q) == "" { continue } parts := tblReg.FindStringSubmatch(q) if len(parts) >= 3 { log.Info("Creating table %s...", parts[2]) } else { log.Info("Creating table ??? (Weird query) No match in: %v", parts) } _, err = app.db.Exec(q) if err != nil { log.Error("%s", err) } else { log.Info("Created.") } } // Set up migrations table log.Info("Initializing appmigrations table...") err = migrations.SetInitialMigrations(migrations.NewDatastore(app.db.DB, app.db.driverName)) if err != nil { return fmt.Errorf("Unable to set initial migrations: %v", err) } log.Info("Running migrations...") err = migrations.Migrate(migrations.NewDatastore(app.db.DB, app.db.driverName)) if err != nil { return fmt.Errorf("migrate: %s", err) } log.Info("Done.") return nil } + +// ServerUserAgent returns a User-Agent string to use in external requests. The +// hostName parameter may be left empty. +func ServerUserAgent(hostName string) string { + hostUAStr := "" + if hostName != "" { + hostUAStr = "; +" + hostName + } + return "Go (" + serverSoftware + "/" + softwareVer + hostUAStr + ")" +} diff --git a/oauth.go b/oauth.go index fe9fe74..dbcf3bf 100644 --- a/oauth.go +++ b/oauth.go @@ -1,448 +1,448 @@ /* * Copyright © 2019-2020 A Bunch Tell LLC. * * This file is part of WriteFreely. * * WriteFreely is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License, included * in the LICENSE file in this source code package. */ package writefreely import ( "context" "encoding/json" "fmt" "io" "io/ioutil" "net/http" "net/url" "strings" "time" "github.com/gorilla/mux" "github.com/gorilla/sessions" "github.com/writeas/impart" "github.com/writeas/web-core/log" "github.com/writeas/writefreely/config" ) // OAuthButtons holds display information for different OAuth providers we support. type OAuthButtons struct { SlackEnabled bool WriteAsEnabled bool GitLabEnabled bool GitLabDisplayName string } // NewOAuthButtons creates a new OAuthButtons struct based on our app configuration. func NewOAuthButtons(cfg *config.Config) *OAuthButtons { return &OAuthButtons{ SlackEnabled: cfg.SlackOauth.ClientID != "", WriteAsEnabled: cfg.WriteAsOauth.ClientID != "", GitLabEnabled: cfg.GitlabOauth.ClientID != "", GitLabDisplayName: config.OrDefaultString(cfg.GitlabOauth.DisplayName, gitlabDisplayName), } } // TokenResponse contains data returned when a token is created either // through a code exchange or using a refresh token. type TokenResponse struct { AccessToken string `json:"access_token"` ExpiresIn int `json:"expires_in"` RefreshToken string `json:"refresh_token"` TokenType string `json:"token_type"` Error string `json:"error"` } // InspectResponse contains data returned when an access token is inspected. type InspectResponse struct { ClientID string `json:"client_id"` UserID string `json:"user_id"` ExpiresAt time.Time `json:"expires_at"` Username string `json:"username"` DisplayName string `json:"-"` Email string `json:"email"` Error string `json:"error"` } // tokenRequestMaxLen is the most bytes that we'll read from the /oauth/token // endpoint. One megabyte is plenty. const tokenRequestMaxLen = 1000000 // infoRequestMaxLen is the most bytes that we'll read from the // /oauth/inspect endpoint. const infoRequestMaxLen = 1000000 // OAuthDatastoreProvider provides a minimal interface of data store, config, // and session store for use with the oauth handlers. type OAuthDatastoreProvider interface { DB() OAuthDatastore Config() *config.Config SessionStore() sessions.Store } // OAuthDatastore provides a minimal interface of data store methods used in // oauth functionality. type OAuthDatastore interface { GetIDForRemoteUser(context.Context, string, string, string) (int64, error) RecordRemoteUserID(context.Context, int64, string, string, string, string) error ValidateOAuthState(context.Context, string) (string, string, int64, string, error) GenerateOAuthState(context.Context, string, string, int64, string) (string, error) CreateUser(*config.Config, *User, string) error GetUserByID(int64) (*User, error) } type HttpClient interface { Do(req *http.Request) (*http.Response, error) } type oauthClient interface { GetProvider() string GetClientID() string GetCallbackLocation() string buildLoginURL(state string) (string, error) exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error) inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error) } type callbackProxyClient struct { server string callbackLocation string httpClient HttpClient } type oauthHandler struct { Config *config.Config DB OAuthDatastore Store sessions.Store EmailKey []byte oauthClient oauthClient callbackProxy *callbackProxyClient } func (h oauthHandler) viewOauthInit(app *App, w http.ResponseWriter, r *http.Request) error { ctx := r.Context() var attachUser int64 if attach := r.URL.Query().Get("attach"); attach == "t" { user, _ := getUserAndSession(app, r) if user == nil { return impart.HTTPError{http.StatusInternalServerError, "cannot attach auth to user: user not found in session"} } attachUser = user.ID } state, err := h.DB.GenerateOAuthState(ctx, h.oauthClient.GetProvider(), h.oauthClient.GetClientID(), attachUser, r.FormValue("invite_code")) if err != nil { log.Error("viewOauthInit error: %s", err) return impart.HTTPError{http.StatusInternalServerError, "could not prepare oauth redirect url"} } if h.callbackProxy != nil { if err := h.callbackProxy.register(ctx, state); err != nil { log.Error("viewOauthInit error: %s", err) return impart.HTTPError{http.StatusInternalServerError, "could not register state server"} } } location, err := h.oauthClient.buildLoginURL(state) if err != nil { log.Error("viewOauthInit error: %s", err) return impart.HTTPError{http.StatusInternalServerError, "could not prepare oauth redirect url"} } return impart.HTTPError{http.StatusTemporaryRedirect, location} } func configureSlackOauth(parentHandler *Handler, r *mux.Router, app *App) { if app.Config().SlackOauth.ClientID != "" { callbackLocation := app.Config().App.Host + "/oauth/callback/slack" var stateRegisterClient *callbackProxyClient = nil if app.Config().SlackOauth.CallbackProxyAPI != "" { stateRegisterClient = &callbackProxyClient{ server: app.Config().SlackOauth.CallbackProxyAPI, callbackLocation: app.Config().App.Host + "/oauth/callback/slack", httpClient: config.DefaultHTTPClient(), } callbackLocation = app.Config().SlackOauth.CallbackProxy } oauthClient := slackOauthClient{ ClientID: app.Config().SlackOauth.ClientID, ClientSecret: app.Config().SlackOauth.ClientSecret, TeamID: app.Config().SlackOauth.TeamID, HttpClient: config.DefaultHTTPClient(), CallbackLocation: callbackLocation, } configureOauthRoutes(parentHandler, r, app, oauthClient, stateRegisterClient) } } func configureWriteAsOauth(parentHandler *Handler, r *mux.Router, app *App) { if app.Config().WriteAsOauth.ClientID != "" { callbackLocation := app.Config().App.Host + "/oauth/callback/write.as" var callbackProxy *callbackProxyClient = nil if app.Config().WriteAsOauth.CallbackProxy != "" { callbackProxy = &callbackProxyClient{ server: app.Config().WriteAsOauth.CallbackProxyAPI, callbackLocation: app.Config().App.Host + "/oauth/callback/write.as", httpClient: config.DefaultHTTPClient(), } callbackLocation = app.Config().WriteAsOauth.CallbackProxy } oauthClient := writeAsOauthClient{ ClientID: app.Config().WriteAsOauth.ClientID, ClientSecret: app.Config().WriteAsOauth.ClientSecret, ExchangeLocation: config.OrDefaultString(app.Config().WriteAsOauth.TokenLocation, writeAsExchangeLocation), InspectLocation: config.OrDefaultString(app.Config().WriteAsOauth.InspectLocation, writeAsIdentityLocation), AuthLocation: config.OrDefaultString(app.Config().WriteAsOauth.AuthLocation, writeAsAuthLocation), HttpClient: config.DefaultHTTPClient(), CallbackLocation: callbackLocation, } configureOauthRoutes(parentHandler, r, app, oauthClient, callbackProxy) } } func configureGitlabOauth(parentHandler *Handler, r *mux.Router, app *App) { if app.Config().GitlabOauth.ClientID != "" { callbackLocation := app.Config().App.Host + "/oauth/callback/gitlab" var callbackProxy *callbackProxyClient = nil if app.Config().GitlabOauth.CallbackProxy != "" { callbackProxy = &callbackProxyClient{ server: app.Config().GitlabOauth.CallbackProxyAPI, callbackLocation: app.Config().App.Host + "/oauth/callback/gitlab", httpClient: config.DefaultHTTPClient(), } callbackLocation = app.Config().GitlabOauth.CallbackProxy } address := config.OrDefaultString(app.Config().GitlabOauth.Host, gitlabHost) oauthClient := gitlabOauthClient{ ClientID: app.Config().GitlabOauth.ClientID, ClientSecret: app.Config().GitlabOauth.ClientSecret, ExchangeLocation: address + "/oauth/token", InspectLocation: address + "/api/v4/user", AuthLocation: address + "/oauth/authorize", HttpClient: config.DefaultHTTPClient(), CallbackLocation: callbackLocation, } configureOauthRoutes(parentHandler, r, app, oauthClient, callbackProxy) } } func configureGenericOauth(parentHandler *Handler, r *mux.Router, app *App) { if app.Config().GenericOauth.ClientID != "" { callbackLocation := app.Config().App.Host + "/oauth/callback/generic" var callbackProxy *callbackProxyClient = nil if app.Config().GenericOauth.CallbackProxy != "" { callbackProxy = &callbackProxyClient{ server: app.Config().GenericOauth.CallbackProxyAPI, callbackLocation: app.Config().App.Host + "/oauth/callback/generic", httpClient: config.DefaultHTTPClient(), } callbackLocation = app.Config().GenericOauth.CallbackProxy } oauthClient := genericOauthClient{ ClientID: app.Config().GenericOauth.ClientID, ClientSecret: app.Config().GenericOauth.ClientSecret, ExchangeLocation: app.Config().GenericOauth.Host + app.Config().GenericOauth.TokenEndpoint, InspectLocation: app.Config().GenericOauth.Host + app.Config().GenericOauth.InspectEndpoint, AuthLocation: app.Config().GenericOauth.Host + app.Config().GenericOauth.AuthEndpoint, HttpClient: config.DefaultHTTPClient(), CallbackLocation: callbackLocation, } configureOauthRoutes(parentHandler, r, app, oauthClient, callbackProxy) } } func configureGiteaOauth(parentHandler *Handler, r *mux.Router, app *App) { if app.Config().GiteaOauth.ClientID != "" { callbackLocation := app.Config().App.Host + "/oauth/callback/gitea" var callbackProxy *callbackProxyClient = nil if app.Config().GiteaOauth.CallbackProxy != "" { callbackProxy = &callbackProxyClient{ server: app.Config().GiteaOauth.CallbackProxyAPI, callbackLocation: app.Config().App.Host + "/oauth/callback/gitea", httpClient: config.DefaultHTTPClient(), } callbackLocation = app.Config().GiteaOauth.CallbackProxy } oauthClient := giteaOauthClient{ ClientID: app.Config().GiteaOauth.ClientID, ClientSecret: app.Config().GiteaOauth.ClientSecret, ExchangeLocation: app.Config().GiteaOauth.Host + "/login/oauth/access_token", InspectLocation: app.Config().GiteaOauth.Host + "/api/v1/user", AuthLocation: app.Config().GiteaOauth.Host + "/login/oauth/authorize", HttpClient: config.DefaultHTTPClient(), CallbackLocation: callbackLocation, } configureOauthRoutes(parentHandler, r, app, oauthClient, callbackProxy) } } func configureOauthRoutes(parentHandler *Handler, r *mux.Router, app *App, oauthClient oauthClient, callbackProxy *callbackProxyClient) { handler := &oauthHandler{ Config: app.Config(), DB: app.DB(), Store: app.SessionStore(), oauthClient: oauthClient, EmailKey: app.keys.EmailKey, callbackProxy: callbackProxy, } r.HandleFunc("/oauth/"+oauthClient.GetProvider(), parentHandler.OAuth(handler.viewOauthInit)).Methods("GET") r.HandleFunc("/oauth/callback/"+oauthClient.GetProvider(), parentHandler.OAuth(handler.viewOauthCallback)).Methods("GET") r.HandleFunc("/oauth/signup", parentHandler.OAuth(handler.viewOauthSignup)).Methods("POST") } func (h oauthHandler) viewOauthCallback(app *App, w http.ResponseWriter, r *http.Request) error { ctx := r.Context() code := r.FormValue("code") state := r.FormValue("state") provider, clientID, attachUserID, inviteCode, err := h.DB.ValidateOAuthState(ctx, state) if err != nil { log.Error("Unable to ValidateOAuthState: %s", err) return impart.HTTPError{http.StatusInternalServerError, err.Error()} } tokenResponse, err := h.oauthClient.exchangeOauthCode(ctx, code) if err != nil { log.Error("Unable to exchangeOauthCode: %s", err) return impart.HTTPError{http.StatusInternalServerError, err.Error()} } // Now that we have the access token, let's use it real quick to make sure // it really really works. tokenInfo, err := h.oauthClient.inspectOauthAccessToken(ctx, tokenResponse.AccessToken) if err != nil { log.Error("Unable to inspectOauthAccessToken: %s", err) return impart.HTTPError{http.StatusInternalServerError, err.Error()} } localUserID, err := h.DB.GetIDForRemoteUser(ctx, tokenInfo.UserID, provider, clientID) if err != nil { log.Error("Unable to GetIDForRemoteUser: %s", err) return impart.HTTPError{http.StatusInternalServerError, err.Error()} } if localUserID != -1 && attachUserID > 0 { if err = addSessionFlash(app, w, r, "This Slack account is already attached to another user.", nil); err != nil { return impart.HTTPError{Status: http.StatusInternalServerError, Message: err.Error()} } return impart.HTTPError{http.StatusFound, "/me/settings"} } if localUserID != -1 { // Existing user, so log in now user, err := h.DB.GetUserByID(localUserID) if err != nil { log.Error("Unable to GetUserByID %d: %s", localUserID, err) return impart.HTTPError{http.StatusInternalServerError, err.Error()} } if err = loginOrFail(h.Store, w, r, user); err != nil { log.Error("Unable to loginOrFail %d: %s", localUserID, err) return impart.HTTPError{http.StatusInternalServerError, err.Error()} } return nil } if attachUserID > 0 { log.Info("attaching to user %d", attachUserID) err = h.DB.RecordRemoteUserID(r.Context(), attachUserID, tokenInfo.UserID, provider, clientID, tokenResponse.AccessToken) if err != nil { return impart.HTTPError{http.StatusInternalServerError, err.Error()} } return impart.HTTPError{http.StatusFound, "/me/settings"} } // New user registration below. // First, verify that user is allowed to register if inviteCode != "" { // Verify invite code is valid i, err := app.db.GetUserInvite(inviteCode) if err != nil { return impart.HTTPError{http.StatusInternalServerError, err.Error()} } if !i.Active(app.db) { return impart.HTTPError{http.StatusNotFound, "Invite link has expired."} } } else if !app.cfg.App.OpenRegistration { addSessionFlash(app, w, r, ErrUserNotFound.Error(), nil) return impart.HTTPError{http.StatusFound, "/login"} } displayName := tokenInfo.DisplayName if len(displayName) == 0 { displayName = tokenInfo.Username } tp := &oauthSignupPageParams{ AccessToken: tokenResponse.AccessToken, TokenUsername: tokenInfo.Username, TokenAlias: tokenInfo.DisplayName, TokenEmail: tokenInfo.Email, TokenRemoteUser: tokenInfo.UserID, Provider: provider, ClientID: clientID, InviteCode: inviteCode, } tp.TokenHash = tp.HashTokenParams(h.Config.Server.HashSeed) return h.showOauthSignupPage(app, w, r, tp, nil) } func (r *callbackProxyClient) register(ctx context.Context, state string) error { form := url.Values{} form.Add("state", state) form.Add("location", r.callbackLocation) req, err := http.NewRequestWithContext(ctx, "POST", r.server, strings.NewReader(form.Encode())) if err != nil { return err } - req.Header.Set("User-Agent", "writefreely") + req.Header.Set("User-Agent", ServerUserAgent("")) req.Header.Set("Accept", "application/json") req.Header.Set("Content-Type", "application/x-www-form-urlencoded") resp, err := r.httpClient.Do(req) if err != nil { return err } if resp.StatusCode != http.StatusCreated { return fmt.Errorf("unable register state location: %d", resp.StatusCode) } return nil } func limitedJsonUnmarshal(body io.ReadCloser, n int, thing interface{}) error { lr := io.LimitReader(body, int64(n+1)) data, err := ioutil.ReadAll(lr) if err != nil { return err } if len(data) == n+1 { return fmt.Errorf("content larger than max read allowance: %d", n) } return json.Unmarshal(data, thing) } func loginOrFail(store sessions.Store, w http.ResponseWriter, r *http.Request, user *User) error { // An error may be returned, but a valid session should always be returned. session, _ := store.Get(r, cookieName) session.Values[cookieUserVal] = user.Cookie() if err := session.Save(r, w); err != nil { fmt.Println("error saving session", err) return err } http.Redirect(w, r, "/", http.StatusTemporaryRedirect) return nil } diff --git a/oauth_generic.go b/oauth_generic.go index 42c84b0..ce65bca 100644 --- a/oauth_generic.go +++ b/oauth_generic.go @@ -1,114 +1,114 @@ package writefreely import ( "context" "errors" "net/http" "net/url" "strings" ) type genericOauthClient struct { ClientID string ClientSecret string AuthLocation string ExchangeLocation string InspectLocation string CallbackLocation string HttpClient HttpClient } var _ oauthClient = genericOauthClient{} const ( genericOauthDisplayName = "OAuth" ) func (c genericOauthClient) GetProvider() string { return "generic" } func (c genericOauthClient) GetClientID() string { return c.ClientID } func (c genericOauthClient) GetCallbackLocation() string { return c.CallbackLocation } func (c genericOauthClient) buildLoginURL(state string) (string, error) { u, err := url.Parse(c.AuthLocation) if err != nil { return "", err } q := u.Query() q.Set("client_id", c.ClientID) q.Set("redirect_uri", c.CallbackLocation) q.Set("response_type", "code") q.Set("state", state) q.Set("scope", "read_user") u.RawQuery = q.Encode() return u.String(), nil } func (c genericOauthClient) exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error) { form := url.Values{} form.Add("grant_type", "authorization_code") form.Add("redirect_uri", c.CallbackLocation) form.Add("scope", "read_user") form.Add("code", code) req, err := http.NewRequest("POST", c.ExchangeLocation, strings.NewReader(form.Encode())) if err != nil { return nil, err } req.WithContext(ctx) - req.Header.Set("User-Agent", "writefreely") + req.Header.Set("User-Agent", ServerUserAgent("")) req.Header.Set("Accept", "application/json") req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.SetBasicAuth(c.ClientID, c.ClientSecret) resp, err := c.HttpClient.Do(req) if err != nil { return nil, err } if resp.StatusCode != http.StatusOK { return nil, errors.New("unable to exchange code for access token") } var tokenResponse TokenResponse if err := limitedJsonUnmarshal(resp.Body, tokenRequestMaxLen, &tokenResponse); err != nil { return nil, err } if tokenResponse.Error != "" { return nil, errors.New(tokenResponse.Error) } return &tokenResponse, nil } func (c genericOauthClient) inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error) { req, err := http.NewRequest("GET", c.InspectLocation, nil) if err != nil { return nil, err } req.WithContext(ctx) - req.Header.Set("User-Agent", "writefreely") + req.Header.Set("User-Agent", ServerUserAgent("")) req.Header.Set("Accept", "application/json") req.Header.Set("Authorization", "Bearer "+accessToken) resp, err := c.HttpClient.Do(req) if err != nil { return nil, err } if resp.StatusCode != http.StatusOK { return nil, errors.New("unable to inspect access token") } var inspectResponse InspectResponse if err := limitedJsonUnmarshal(resp.Body, infoRequestMaxLen, &inspectResponse); err != nil { return nil, err } if inspectResponse.Error != "" { return nil, errors.New(inspectResponse.Error) } return &inspectResponse, nil } diff --git a/oauth_gitea.go b/oauth_gitea.go index e6e1000..a9b7741 100644 --- a/oauth_gitea.go +++ b/oauth_gitea.go @@ -1,114 +1,114 @@ package writefreely import ( "context" "errors" "net/http" "net/url" "strings" ) type giteaOauthClient struct { ClientID string ClientSecret string AuthLocation string ExchangeLocation string InspectLocation string CallbackLocation string HttpClient HttpClient } var _ oauthClient = giteaOauthClient{} const ( giteaDisplayName = "Gitea" ) func (c giteaOauthClient) GetProvider() string { return "gitea" } func (c giteaOauthClient) GetClientID() string { return c.ClientID } func (c giteaOauthClient) GetCallbackLocation() string { return c.CallbackLocation } func (c giteaOauthClient) buildLoginURL(state string) (string, error) { u, err := url.Parse(c.AuthLocation) if err != nil { return "", err } q := u.Query() q.Set("client_id", c.ClientID) q.Set("redirect_uri", c.CallbackLocation) q.Set("response_type", "code") q.Set("state", state) // q.Set("scope", "read_user") u.RawQuery = q.Encode() return u.String(), nil } func (c giteaOauthClient) exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error) { form := url.Values{} form.Add("grant_type", "authorization_code") form.Add("redirect_uri", c.CallbackLocation) // form.Add("scope", "read_user") form.Add("code", code) req, err := http.NewRequest("POST", c.ExchangeLocation, strings.NewReader(form.Encode())) if err != nil { return nil, err } req.WithContext(ctx) - req.Header.Set("User-Agent", "writefreely") + req.Header.Set("User-Agent", ServerUserAgent("")) req.Header.Set("Accept", "application/json") req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.SetBasicAuth(c.ClientID, c.ClientSecret) resp, err := c.HttpClient.Do(req) if err != nil { return nil, err } if resp.StatusCode != http.StatusOK { return nil, errors.New("unable to exchange code for access token") } var tokenResponse TokenResponse if err := limitedJsonUnmarshal(resp.Body, tokenRequestMaxLen, &tokenResponse); err != nil { return nil, err } if tokenResponse.Error != "" { return nil, errors.New(tokenResponse.Error) } return &tokenResponse, nil } func (c giteaOauthClient) inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error) { req, err := http.NewRequest("GET", c.InspectLocation, nil) if err != nil { return nil, err } req.WithContext(ctx) - req.Header.Set("User-Agent", "writefreely") + req.Header.Set("User-Agent", ServerUserAgent("")) req.Header.Set("Accept", "application/json") req.Header.Set("Authorization", "Bearer "+accessToken) resp, err := c.HttpClient.Do(req) if err != nil { return nil, err } if resp.StatusCode != http.StatusOK { return nil, errors.New("unable to inspect access token") } var inspectResponse InspectResponse if err := limitedJsonUnmarshal(resp.Body, infoRequestMaxLen, &inspectResponse); err != nil { return nil, err } if inspectResponse.Error != "" { return nil, errors.New(inspectResponse.Error) } return &inspectResponse, nil } diff --git a/oauth_gitlab.go b/oauth_gitlab.go index c9c74aa..ad919e4 100644 --- a/oauth_gitlab.go +++ b/oauth_gitlab.go @@ -1,115 +1,115 @@ package writefreely import ( "context" "errors" "net/http" "net/url" "strings" ) type gitlabOauthClient struct { ClientID string ClientSecret string AuthLocation string ExchangeLocation string InspectLocation string CallbackLocation string HttpClient HttpClient } var _ oauthClient = gitlabOauthClient{} const ( gitlabHost = "https://gitlab.com" gitlabDisplayName = "GitLab" ) func (c gitlabOauthClient) GetProvider() string { return "gitlab" } func (c gitlabOauthClient) GetClientID() string { return c.ClientID } func (c gitlabOauthClient) GetCallbackLocation() string { return c.CallbackLocation } func (c gitlabOauthClient) buildLoginURL(state string) (string, error) { u, err := url.Parse(c.AuthLocation) if err != nil { return "", err } q := u.Query() q.Set("client_id", c.ClientID) q.Set("redirect_uri", c.CallbackLocation) q.Set("response_type", "code") q.Set("state", state) q.Set("scope", "read_user") u.RawQuery = q.Encode() return u.String(), nil } func (c gitlabOauthClient) exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error) { form := url.Values{} form.Add("grant_type", "authorization_code") form.Add("redirect_uri", c.CallbackLocation) form.Add("scope", "read_user") form.Add("code", code) req, err := http.NewRequest("POST", c.ExchangeLocation, strings.NewReader(form.Encode())) if err != nil { return nil, err } req.WithContext(ctx) - req.Header.Set("User-Agent", "writefreely") + req.Header.Set("User-Agent", ServerUserAgent("")) req.Header.Set("Accept", "application/json") req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.SetBasicAuth(c.ClientID, c.ClientSecret) resp, err := c.HttpClient.Do(req) if err != nil { return nil, err } if resp.StatusCode != http.StatusOK { return nil, errors.New("unable to exchange code for access token") } var tokenResponse TokenResponse if err := limitedJsonUnmarshal(resp.Body, tokenRequestMaxLen, &tokenResponse); err != nil { return nil, err } if tokenResponse.Error != "" { return nil, errors.New(tokenResponse.Error) } return &tokenResponse, nil } func (c gitlabOauthClient) inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error) { req, err := http.NewRequest("GET", c.InspectLocation, nil) if err != nil { return nil, err } req.WithContext(ctx) - req.Header.Set("User-Agent", "writefreely") + req.Header.Set("User-Agent", ServerUserAgent("")) req.Header.Set("Accept", "application/json") req.Header.Set("Authorization", "Bearer "+accessToken) resp, err := c.HttpClient.Do(req) if err != nil { return nil, err } if resp.StatusCode != http.StatusOK { return nil, errors.New("unable to inspect access token") } var inspectResponse InspectResponse if err := limitedJsonUnmarshal(resp.Body, infoRequestMaxLen, &inspectResponse); err != nil { return nil, err } if inspectResponse.Error != "" { return nil, errors.New(inspectResponse.Error) } return &inspectResponse, nil } diff --git a/oauth_slack.go b/oauth_slack.go index c881ab6..bad3775 100644 --- a/oauth_slack.go +++ b/oauth_slack.go @@ -1,178 +1,178 @@ /* * Copyright © 2019-2020 A Bunch Tell LLC. * * This file is part of WriteFreely. * * WriteFreely is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License, included * in the LICENSE file in this source code package. */ package writefreely import ( "context" "errors" "github.com/writeas/slug" "net/http" "net/url" "strings" ) type slackOauthClient struct { ClientID string ClientSecret string TeamID string CallbackLocation string HttpClient HttpClient } type slackExchangeResponse struct { OK bool `json:"ok"` AccessToken string `json:"access_token"` Scope string `json:"scope"` TeamName string `json:"team_name"` TeamID string `json:"team_id"` Error string `json:"error"` } type slackIdentity struct { Name string `json:"name"` ID string `json:"id"` Email string `json:"email"` } type slackTeam struct { Name string `json:"name"` ID string `json:"id"` } type slackUserIdentityResponse struct { OK bool `json:"ok"` User slackIdentity `json:"user"` Team slackTeam `json:"team"` Error string `json:"error"` } const ( slackAuthLocation = "https://slack.com/oauth/authorize" slackExchangeLocation = "https://slack.com/api/oauth.access" slackIdentityLocation = "https://slack.com/api/users.identity" ) var _ oauthClient = slackOauthClient{} func (c slackOauthClient) GetProvider() string { return "slack" } func (c slackOauthClient) GetClientID() string { return c.ClientID } func (c slackOauthClient) GetCallbackLocation() string { return c.CallbackLocation } func (c slackOauthClient) buildLoginURL(state string) (string, error) { u, err := url.Parse(slackAuthLocation) if err != nil { return "", err } q := u.Query() q.Set("client_id", c.ClientID) q.Set("scope", "identity.basic identity.email identity.team") q.Set("redirect_uri", c.CallbackLocation) q.Set("state", state) // If this param is not set, the user can select which team they // authenticate through and then we'd have to match the configured team // against the profile get. That is extra work in the post-auth phase // that we don't want to do. q.Set("team", c.TeamID) // The Slack OAuth docs don't explicitly list this one, but it is part of // the spec, so we include it anyway. q.Set("response_type", "code") u.RawQuery = q.Encode() return u.String(), nil } func (c slackOauthClient) exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error) { form := url.Values{} // The oauth.access documentation doesn't explicitly mention this // parameter, but it is part of the spec, so we include it anyway. // https://api.slack.com/methods/oauth.access form.Add("grant_type", "authorization_code") form.Add("redirect_uri", c.CallbackLocation) form.Add("code", code) req, err := http.NewRequest("POST", slackExchangeLocation, strings.NewReader(form.Encode())) if err != nil { return nil, err } req.WithContext(ctx) - req.Header.Set("User-Agent", "writefreely") + req.Header.Set("User-Agent", ServerUserAgent("")) req.Header.Set("Accept", "application/json") req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.SetBasicAuth(c.ClientID, c.ClientSecret) resp, err := c.HttpClient.Do(req) if err != nil { return nil, err } if resp.StatusCode != http.StatusOK { return nil, errors.New("unable to exchange code for access token") } var tokenResponse slackExchangeResponse if err := limitedJsonUnmarshal(resp.Body, tokenRequestMaxLen, &tokenResponse); err != nil { return nil, err } if !tokenResponse.OK { return nil, errors.New(tokenResponse.Error) } return tokenResponse.TokenResponse(), nil } func (c slackOauthClient) inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error) { req, err := http.NewRequest("GET", slackIdentityLocation, nil) if err != nil { return nil, err } req.WithContext(ctx) - req.Header.Set("User-Agent", "writefreely") + req.Header.Set("User-Agent", ServerUserAgent("")) req.Header.Set("Accept", "application/json") req.Header.Set("Authorization", "Bearer "+accessToken) resp, err := c.HttpClient.Do(req) if err != nil { return nil, err } if resp.StatusCode != http.StatusOK { return nil, errors.New("unable to inspect access token") } var inspectResponse slackUserIdentityResponse if err := limitedJsonUnmarshal(resp.Body, infoRequestMaxLen, &inspectResponse); err != nil { return nil, err } if !inspectResponse.OK { return nil, errors.New(inspectResponse.Error) } return inspectResponse.InspectResponse(), nil } func (resp slackUserIdentityResponse) InspectResponse() *InspectResponse { return &InspectResponse{ UserID: resp.User.ID, Username: slug.Make(resp.User.Name), DisplayName: resp.User.Name, Email: resp.User.Email, } } func (resp slackExchangeResponse) TokenResponse() *TokenResponse { return &TokenResponse{ AccessToken: resp.AccessToken, } } diff --git a/oauth_writeas.go b/oauth_writeas.go index 6251a16..e58f6e9 100644 --- a/oauth_writeas.go +++ b/oauth_writeas.go @@ -1,114 +1,114 @@ package writefreely import ( "context" "errors" "net/http" "net/url" "strings" ) type writeAsOauthClient struct { ClientID string ClientSecret string AuthLocation string ExchangeLocation string InspectLocation string CallbackLocation string HttpClient HttpClient } var _ oauthClient = writeAsOauthClient{} const ( writeAsAuthLocation = "https://write.as/oauth/login" writeAsExchangeLocation = "https://write.as/oauth/token" writeAsIdentityLocation = "https://write.as/oauth/inspect" ) func (c writeAsOauthClient) GetProvider() string { return "write.as" } func (c writeAsOauthClient) GetClientID() string { return c.ClientID } func (c writeAsOauthClient) GetCallbackLocation() string { return c.CallbackLocation } func (c writeAsOauthClient) buildLoginURL(state string) (string, error) { u, err := url.Parse(c.AuthLocation) if err != nil { return "", err } q := u.Query() q.Set("client_id", c.ClientID) q.Set("redirect_uri", c.CallbackLocation) q.Set("response_type", "code") q.Set("state", state) u.RawQuery = q.Encode() return u.String(), nil } func (c writeAsOauthClient) exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error) { form := url.Values{} form.Add("grant_type", "authorization_code") form.Add("redirect_uri", c.CallbackLocation) form.Add("code", code) req, err := http.NewRequest("POST", c.ExchangeLocation, strings.NewReader(form.Encode())) if err != nil { return nil, err } req.WithContext(ctx) - req.Header.Set("User-Agent", "writefreely") + req.Header.Set("User-Agent", ServerUserAgent("")) req.Header.Set("Accept", "application/json") req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.SetBasicAuth(c.ClientID, c.ClientSecret) resp, err := c.HttpClient.Do(req) if err != nil { return nil, err } if resp.StatusCode != http.StatusOK { return nil, errors.New("unable to exchange code for access token") } var tokenResponse TokenResponse if err := limitedJsonUnmarshal(resp.Body, tokenRequestMaxLen, &tokenResponse); err != nil { return nil, err } if tokenResponse.Error != "" { return nil, errors.New(tokenResponse.Error) } return &tokenResponse, nil } func (c writeAsOauthClient) inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error) { req, err := http.NewRequest("GET", c.InspectLocation, nil) if err != nil { return nil, err } req.WithContext(ctx) - req.Header.Set("User-Agent", "writefreely") + req.Header.Set("User-Agent", ServerUserAgent("")) req.Header.Set("Accept", "application/json") req.Header.Set("Authorization", "Bearer "+accessToken) resp, err := c.HttpClient.Do(req) if err != nil { return nil, err } if resp.StatusCode != http.StatusOK { return nil, errors.New("unable to inspect access token") } var inspectResponse InspectResponse if err := limitedJsonUnmarshal(resp.Body, infoRequestMaxLen, &inspectResponse); err != nil { return nil, err } if inspectResponse.Error != "" { return nil, errors.New(inspectResponse.Error) } return &inspectResponse, nil }