diff --git a/oauth.go b/oauth.go index e28e21a..2958721 100644 --- a/oauth.go +++ b/oauth.go @@ -1,467 +1,467 @@ /* * Copyright © 2019-2021 A Bunch Tell LLC. * * This file is part of WriteFreely. * * WriteFreely is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License, included * in the LICENSE file in this source code package. */ package writefreely import ( "context" "encoding/json" "fmt" "io" "io/ioutil" "net/http" "net/url" "strings" "time" "github.com/gorilla/mux" "github.com/gorilla/sessions" "github.com/writeas/impart" "github.com/writeas/web-core/log" "github.com/writefreely/writefreely/config" ) // OAuthButtons holds display information for different OAuth providers we support. type OAuthButtons struct { SlackEnabled bool WriteAsEnabled bool GitLabEnabled bool GitLabDisplayName string GiteaEnabled bool GiteaDisplayName string GenericEnabled bool GenericDisplayName string } // NewOAuthButtons creates a new OAuthButtons struct based on our app configuration. func NewOAuthButtons(cfg *config.Config) *OAuthButtons { return &OAuthButtons{ SlackEnabled: cfg.SlackOauth.ClientID != "", WriteAsEnabled: cfg.WriteAsOauth.ClientID != "", GitLabEnabled: cfg.GitlabOauth.ClientID != "", GitLabDisplayName: config.OrDefaultString(cfg.GitlabOauth.DisplayName, gitlabDisplayName), GiteaEnabled: cfg.GiteaOauth.ClientID != "", GiteaDisplayName: config.OrDefaultString(cfg.GiteaOauth.DisplayName, giteaDisplayName), GenericEnabled: cfg.GenericOauth.ClientID != "", GenericDisplayName: config.OrDefaultString(cfg.GenericOauth.DisplayName, genericOauthDisplayName), } } // TokenResponse contains data returned when a token is created either // through a code exchange or using a refresh token. type TokenResponse struct { AccessToken string `json:"access_token"` ExpiresIn int `json:"expires_in"` RefreshToken string `json:"refresh_token"` TokenType string `json:"token_type"` Error string `json:"error"` } // InspectResponse contains data returned when an access token is inspected. type InspectResponse struct { ClientID string `json:"client_id"` UserID string `json:"user_id"` ExpiresAt time.Time `json:"expires_at"` Username string `json:"username"` DisplayName string `json:"-"` Email string `json:"email"` Error string `json:"error"` } // tokenRequestMaxLen is the most bytes that we'll read from the /oauth/token // endpoint. One megabyte is plenty. const tokenRequestMaxLen = 1000000 // infoRequestMaxLen is the most bytes that we'll read from the // /oauth/inspect endpoint. const infoRequestMaxLen = 1000000 // OAuthDatastoreProvider provides a minimal interface of data store, config, // and session store for use with the oauth handlers. type OAuthDatastoreProvider interface { DB() OAuthDatastore Config() *config.Config SessionStore() sessions.Store } // OAuthDatastore provides a minimal interface of data store methods used in // oauth functionality. type OAuthDatastore interface { GetIDForRemoteUser(context.Context, string, string, string) (int64, error) RecordRemoteUserID(context.Context, int64, string, string, string, string) error ValidateOAuthState(context.Context, string) (string, string, int64, string, error) GenerateOAuthState(context.Context, string, string, int64, string) (string, error) - CreateUser(*config.Config, *User, string) error + CreateUser(*config.Config, *User, string, string) error GetUserByID(int64) (*User, error) } type HttpClient interface { Do(req *http.Request) (*http.Response, error) } type oauthClient interface { GetProvider() string GetClientID() string GetCallbackLocation() string buildLoginURL(state string) (string, error) exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error) inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error) } type callbackProxyClient struct { server string callbackLocation string httpClient HttpClient } type oauthHandler struct { Config *config.Config DB OAuthDatastore Store sessions.Store EmailKey []byte oauthClient oauthClient callbackProxy *callbackProxyClient } func (h oauthHandler) viewOauthInit(app *App, w http.ResponseWriter, r *http.Request) error { ctx := r.Context() var attachUser int64 if attach := r.URL.Query().Get("attach"); attach == "t" { user, _ := getUserAndSession(app, r) if user == nil { return impart.HTTPError{http.StatusInternalServerError, "cannot attach auth to user: user not found in session"} } attachUser = user.ID } state, err := h.DB.GenerateOAuthState(ctx, h.oauthClient.GetProvider(), h.oauthClient.GetClientID(), attachUser, r.FormValue("invite_code")) if err != nil { log.Error("viewOauthInit error: %s", err) return impart.HTTPError{http.StatusInternalServerError, "could not prepare oauth redirect url"} } if h.callbackProxy != nil { if err := h.callbackProxy.register(ctx, state); err != nil { log.Error("viewOauthInit error: %s", err) return impart.HTTPError{http.StatusInternalServerError, "could not register state server"} } } location, err := h.oauthClient.buildLoginURL(state) if err != nil { log.Error("viewOauthInit error: %s", err) return impart.HTTPError{http.StatusInternalServerError, "could not prepare oauth redirect url"} } return impart.HTTPError{http.StatusTemporaryRedirect, location} } func configureSlackOauth(parentHandler *Handler, r *mux.Router, app *App) { if app.Config().SlackOauth.ClientID != "" { callbackLocation := app.Config().App.Host + "/oauth/callback/slack" var stateRegisterClient *callbackProxyClient = nil if app.Config().SlackOauth.CallbackProxyAPI != "" { stateRegisterClient = &callbackProxyClient{ server: app.Config().SlackOauth.CallbackProxyAPI, callbackLocation: app.Config().App.Host + "/oauth/callback/slack", httpClient: config.DefaultHTTPClient(), } callbackLocation = app.Config().SlackOauth.CallbackProxy } oauthClient := slackOauthClient{ ClientID: app.Config().SlackOauth.ClientID, ClientSecret: app.Config().SlackOauth.ClientSecret, TeamID: app.Config().SlackOauth.TeamID, HttpClient: config.DefaultHTTPClient(), CallbackLocation: callbackLocation, } configureOauthRoutes(parentHandler, r, app, oauthClient, stateRegisterClient) } } func configureWriteAsOauth(parentHandler *Handler, r *mux.Router, app *App) { if app.Config().WriteAsOauth.ClientID != "" { callbackLocation := app.Config().App.Host + "/oauth/callback/write.as" var callbackProxy *callbackProxyClient = nil if app.Config().WriteAsOauth.CallbackProxy != "" { callbackProxy = &callbackProxyClient{ server: app.Config().WriteAsOauth.CallbackProxyAPI, callbackLocation: app.Config().App.Host + "/oauth/callback/write.as", httpClient: config.DefaultHTTPClient(), } callbackLocation = app.Config().WriteAsOauth.CallbackProxy } oauthClient := writeAsOauthClient{ ClientID: app.Config().WriteAsOauth.ClientID, ClientSecret: app.Config().WriteAsOauth.ClientSecret, ExchangeLocation: config.OrDefaultString(app.Config().WriteAsOauth.TokenLocation, writeAsExchangeLocation), InspectLocation: config.OrDefaultString(app.Config().WriteAsOauth.InspectLocation, writeAsIdentityLocation), AuthLocation: config.OrDefaultString(app.Config().WriteAsOauth.AuthLocation, writeAsAuthLocation), HttpClient: config.DefaultHTTPClient(), CallbackLocation: callbackLocation, } configureOauthRoutes(parentHandler, r, app, oauthClient, callbackProxy) } } func configureGitlabOauth(parentHandler *Handler, r *mux.Router, app *App) { if app.Config().GitlabOauth.ClientID != "" { callbackLocation := app.Config().App.Host + "/oauth/callback/gitlab" var callbackProxy *callbackProxyClient = nil if app.Config().GitlabOauth.CallbackProxy != "" { callbackProxy = &callbackProxyClient{ server: app.Config().GitlabOauth.CallbackProxyAPI, callbackLocation: app.Config().App.Host + "/oauth/callback/gitlab", httpClient: config.DefaultHTTPClient(), } callbackLocation = app.Config().GitlabOauth.CallbackProxy } address := config.OrDefaultString(app.Config().GitlabOauth.Host, gitlabHost) oauthClient := gitlabOauthClient{ ClientID: app.Config().GitlabOauth.ClientID, ClientSecret: app.Config().GitlabOauth.ClientSecret, ExchangeLocation: address + "/oauth/token", InspectLocation: address + "/api/v4/user", AuthLocation: address + "/oauth/authorize", HttpClient: config.DefaultHTTPClient(), CallbackLocation: callbackLocation, } configureOauthRoutes(parentHandler, r, app, oauthClient, callbackProxy) } } func configureGenericOauth(parentHandler *Handler, r *mux.Router, app *App) { if app.Config().GenericOauth.ClientID != "" { callbackLocation := app.Config().App.Host + "/oauth/callback/generic" var callbackProxy *callbackProxyClient = nil if app.Config().GenericOauth.CallbackProxy != "" { callbackProxy = &callbackProxyClient{ server: app.Config().GenericOauth.CallbackProxyAPI, callbackLocation: app.Config().App.Host + "/oauth/callback/generic", httpClient: config.DefaultHTTPClient(), } callbackLocation = app.Config().GenericOauth.CallbackProxy } oauthClient := genericOauthClient{ ClientID: app.Config().GenericOauth.ClientID, ClientSecret: app.Config().GenericOauth.ClientSecret, ExchangeLocation: app.Config().GenericOauth.Host + app.Config().GenericOauth.TokenEndpoint, InspectLocation: app.Config().GenericOauth.Host + app.Config().GenericOauth.InspectEndpoint, AuthLocation: app.Config().GenericOauth.Host + app.Config().GenericOauth.AuthEndpoint, HttpClient: config.DefaultHTTPClient(), CallbackLocation: callbackLocation, Scope: config.OrDefaultString(app.Config().GenericOauth.Scope, "read_user"), MapUserID: config.OrDefaultString(app.Config().GenericOauth.MapUserID, "user_id"), MapUsername: config.OrDefaultString(app.Config().GenericOauth.MapUsername, "username"), MapDisplayName: config.OrDefaultString(app.Config().GenericOauth.MapDisplayName, "-"), MapEmail: config.OrDefaultString(app.Config().GenericOauth.MapEmail, "email"), } configureOauthRoutes(parentHandler, r, app, oauthClient, callbackProxy) } } func configureGiteaOauth(parentHandler *Handler, r *mux.Router, app *App) { if app.Config().GiteaOauth.ClientID != "" { callbackLocation := app.Config().App.Host + "/oauth/callback/gitea" var callbackProxy *callbackProxyClient = nil if app.Config().GiteaOauth.CallbackProxy != "" { callbackProxy = &callbackProxyClient{ server: app.Config().GiteaOauth.CallbackProxyAPI, callbackLocation: app.Config().App.Host + "/oauth/callback/gitea", httpClient: config.DefaultHTTPClient(), } callbackLocation = app.Config().GiteaOauth.CallbackProxy } oauthClient := giteaOauthClient{ ClientID: app.Config().GiteaOauth.ClientID, ClientSecret: app.Config().GiteaOauth.ClientSecret, ExchangeLocation: app.Config().GiteaOauth.Host + "/login/oauth/access_token", InspectLocation: app.Config().GiteaOauth.Host + "/api/v1/user", AuthLocation: app.Config().GiteaOauth.Host + "/login/oauth/authorize", HttpClient: config.DefaultHTTPClient(), CallbackLocation: callbackLocation, } configureOauthRoutes(parentHandler, r, app, oauthClient, callbackProxy) } } func configureOauthRoutes(parentHandler *Handler, r *mux.Router, app *App, oauthClient oauthClient, callbackProxy *callbackProxyClient) { handler := &oauthHandler{ Config: app.Config(), DB: app.DB(), Store: app.SessionStore(), oauthClient: oauthClient, EmailKey: app.keys.EmailKey, callbackProxy: callbackProxy, } r.HandleFunc("/oauth/"+oauthClient.GetProvider(), parentHandler.OAuth(handler.viewOauthInit)).Methods("GET") r.HandleFunc("/oauth/callback/"+oauthClient.GetProvider(), parentHandler.OAuth(handler.viewOauthCallback)).Methods("GET") r.HandleFunc("/oauth/signup", parentHandler.OAuth(handler.viewOauthSignup)).Methods("POST") } func (h oauthHandler) viewOauthCallback(app *App, w http.ResponseWriter, r *http.Request) error { ctx := r.Context() code := r.FormValue("code") state := r.FormValue("state") provider, clientID, attachUserID, inviteCode, err := h.DB.ValidateOAuthState(ctx, state) if err != nil { log.Error("Unable to ValidateOAuthState: %s", err) return impart.HTTPError{http.StatusInternalServerError, err.Error()} } tokenResponse, err := h.oauthClient.exchangeOauthCode(ctx, code) if err != nil { log.Error("Unable to exchangeOauthCode: %s", err) // TODO: show user friendly message if needed // TODO: show NO message for cases like user pressing "Cancel" on authorize step addSessionFlash(app, w, r, err.Error(), nil) if attachUserID > 0 { return impart.HTTPError{http.StatusFound, "/me/settings"} } return impart.HTTPError{http.StatusInternalServerError, err.Error()} } // Now that we have the access token, let's use it real quick to make sure // it really really works. tokenInfo, err := h.oauthClient.inspectOauthAccessToken(ctx, tokenResponse.AccessToken) if err != nil { log.Error("Unable to inspectOauthAccessToken: %s", err) return impart.HTTPError{http.StatusInternalServerError, err.Error()} } localUserID, err := h.DB.GetIDForRemoteUser(ctx, tokenInfo.UserID, provider, clientID) if err != nil { log.Error("Unable to GetIDForRemoteUser: %s", err) return impart.HTTPError{http.StatusInternalServerError, err.Error()} } if localUserID != -1 && attachUserID > 0 { if err = addSessionFlash(app, w, r, "This Slack account is already attached to another user.", nil); err != nil { return impart.HTTPError{Status: http.StatusInternalServerError, Message: err.Error()} } return impart.HTTPError{http.StatusFound, "/me/settings"} } if localUserID != -1 { // Existing user, so log in now user, err := h.DB.GetUserByID(localUserID) if err != nil { log.Error("Unable to GetUserByID %d: %s", localUserID, err) return impart.HTTPError{http.StatusInternalServerError, err.Error()} } if err = loginOrFail(h.Store, w, r, user); err != nil { log.Error("Unable to loginOrFail %d: %s", localUserID, err) return impart.HTTPError{http.StatusInternalServerError, err.Error()} } return nil } if attachUserID > 0 { log.Info("attaching to user %d", attachUserID) err = h.DB.RecordRemoteUserID(r.Context(), attachUserID, tokenInfo.UserID, provider, clientID, tokenResponse.AccessToken) if err != nil { return impart.HTTPError{http.StatusInternalServerError, err.Error()} } return impart.HTTPError{http.StatusFound, "/me/settings"} } // New user registration below. // First, verify that user is allowed to register if inviteCode != "" { // Verify invite code is valid i, err := app.db.GetUserInvite(inviteCode) if err != nil { return impart.HTTPError{http.StatusInternalServerError, err.Error()} } if !i.Active(app.db) { return impart.HTTPError{http.StatusNotFound, "Invite link has expired."} } } else if !app.cfg.App.OpenRegistration { addSessionFlash(app, w, r, ErrUserNotFound.Error(), nil) return impart.HTTPError{http.StatusFound, "/login"} } displayName := tokenInfo.DisplayName if len(displayName) == 0 { displayName = tokenInfo.Username } tp := &oauthSignupPageParams{ AccessToken: tokenResponse.AccessToken, TokenUsername: tokenInfo.Username, TokenAlias: tokenInfo.DisplayName, TokenEmail: tokenInfo.Email, TokenRemoteUser: tokenInfo.UserID, Provider: provider, ClientID: clientID, InviteCode: inviteCode, } tp.TokenHash = tp.HashTokenParams(h.Config.Server.HashSeed) return h.showOauthSignupPage(app, w, r, tp, nil) } func (r *callbackProxyClient) register(ctx context.Context, state string) error { form := url.Values{} form.Add("state", state) form.Add("location", r.callbackLocation) req, err := http.NewRequestWithContext(ctx, "POST", r.server, strings.NewReader(form.Encode())) if err != nil { return err } req.Header.Set("User-Agent", ServerUserAgent("")) req.Header.Set("Accept", "application/json") req.Header.Set("Content-Type", "application/x-www-form-urlencoded") resp, err := r.httpClient.Do(req) if err != nil { return err } if resp.StatusCode != http.StatusCreated { return fmt.Errorf("unable register state location: %d", resp.StatusCode) } return nil } func limitedJsonUnmarshal(body io.ReadCloser, n int, thing interface{}) error { lr := io.LimitReader(body, int64(n+1)) data, err := ioutil.ReadAll(lr) if err != nil { return err } if len(data) == n+1 { return fmt.Errorf("content larger than max read allowance: %d", n) } return json.Unmarshal(data, thing) } func loginOrFail(store sessions.Store, w http.ResponseWriter, r *http.Request, user *User) error { // An error may be returned, but a valid session should always be returned. session, _ := store.Get(r, cookieName) session.Values[cookieUserVal] = user.Cookie() if err := session.Save(r, w); err != nil { fmt.Println("error saving session", err) return err } http.Redirect(w, r, "/", http.StatusTemporaryRedirect) return nil } diff --git a/oauth_signup.go b/oauth_signup.go index b1256be..8dff416 100644 --- a/oauth_signup.go +++ b/oauth_signup.go @@ -1,231 +1,231 @@ /* * Copyright © 2020-2021 A Bunch Tell LLC. * * This file is part of WriteFreely. * * WriteFreely is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License, included * in the LICENSE file in this source code package. */ package writefreely import ( "crypto/sha256" "encoding/hex" "fmt" "github.com/writeas/impart" "github.com/writeas/web-core/auth" "github.com/writeas/web-core/log" "github.com/writefreely/writefreely/page" "html/template" "net/http" "strings" "time" ) type viewOauthSignupVars struct { page.StaticPage To string Message template.HTML Flashes []template.HTML AccessToken string TokenUsername string TokenAlias string // TODO: rename this to match the data it represents: the collection title TokenEmail string TokenRemoteUser string Provider string ClientID string TokenHash string InviteCode string LoginUsername string Alias string // TODO: rename this to match the data it represents: the collection title Email string } const ( oauthParamAccessToken = "access_token" oauthParamTokenUsername = "token_username" oauthParamTokenAlias = "token_alias" oauthParamTokenEmail = "token_email" oauthParamTokenRemoteUserID = "token_remote_user" oauthParamClientID = "client_id" oauthParamProvider = "provider" oauthParamHash = "signature" oauthParamUsername = "username" oauthParamAlias = "alias" oauthParamEmail = "email" oauthParamPassword = "password" oauthParamInviteCode = "invite_code" ) type oauthSignupPageParams struct { AccessToken string TokenUsername string TokenAlias string // TODO: rename this to match the data it represents: the collection title TokenEmail string TokenRemoteUser string ClientID string Provider string TokenHash string InviteCode string } func (p oauthSignupPageParams) HashTokenParams(key string) string { hasher := sha256.New() hasher.Write([]byte(key)) hasher.Write([]byte(p.AccessToken)) hasher.Write([]byte(p.TokenUsername)) hasher.Write([]byte(p.TokenAlias)) hasher.Write([]byte(p.TokenEmail)) hasher.Write([]byte(p.TokenRemoteUser)) hasher.Write([]byte(p.ClientID)) hasher.Write([]byte(p.Provider)) return hex.EncodeToString(hasher.Sum(nil)) } func (h oauthHandler) viewOauthSignup(app *App, w http.ResponseWriter, r *http.Request) error { tp := &oauthSignupPageParams{ AccessToken: r.FormValue(oauthParamAccessToken), TokenUsername: r.FormValue(oauthParamTokenUsername), TokenAlias: r.FormValue(oauthParamTokenAlias), TokenEmail: r.FormValue(oauthParamTokenEmail), TokenRemoteUser: r.FormValue(oauthParamTokenRemoteUserID), ClientID: r.FormValue(oauthParamClientID), Provider: r.FormValue(oauthParamProvider), InviteCode: r.FormValue(oauthParamInviteCode), } if tp.HashTokenParams(h.Config.Server.HashSeed) != r.FormValue(oauthParamHash) { return impart.HTTPError{Status: http.StatusBadRequest, Message: "Request has been tampered with."} } tp.TokenHash = tp.HashTokenParams(h.Config.Server.HashSeed) if err := h.validateOauthSignup(r); err != nil { return h.showOauthSignupPage(app, w, r, tp, err) } var err error hashedPass := []byte{} clearPass := r.FormValue(oauthParamPassword) hasPass := clearPass != "" if hasPass { hashedPass, err = auth.HashPass([]byte(clearPass)) if err != nil { return h.showOauthSignupPage(app, w, r, tp, fmt.Errorf("unable to hash password")) } } newUser := &User{ Username: r.FormValue(oauthParamUsername), HashedPass: hashedPass, HasPass: hasPass, Email: prepareUserEmail(r.FormValue(oauthParamEmail), h.EmailKey), Created: time.Now().Truncate(time.Second).UTC(), } displayName := r.FormValue(oauthParamAlias) if len(displayName) == 0 { displayName = r.FormValue(oauthParamUsername) } - err = h.DB.CreateUser(h.Config, newUser, displayName) + err = h.DB.CreateUser(h.Config, newUser, displayName, "") if err != nil { return h.showOauthSignupPage(app, w, r, tp, err) } // Log invite if needed if tp.InviteCode != "" { err = app.db.CreateInvitedUser(tp.InviteCode, newUser.ID) if err != nil { return err } } err = h.DB.RecordRemoteUserID(r.Context(), newUser.ID, r.FormValue(oauthParamTokenRemoteUserID), r.FormValue(oauthParamProvider), r.FormValue(oauthParamClientID), r.FormValue(oauthParamAccessToken)) if err != nil { return h.showOauthSignupPage(app, w, r, tp, err) } if err := loginOrFail(h.Store, w, r, newUser); err != nil { return h.showOauthSignupPage(app, w, r, tp, err) } return nil } func (h oauthHandler) validateOauthSignup(r *http.Request) error { username := r.FormValue(oauthParamUsername) if len(username) < h.Config.App.MinUsernameLen { return impart.HTTPError{Status: http.StatusBadRequest, Message: "Username is too short."} } if len(username) > 100 { return impart.HTTPError{Status: http.StatusBadRequest, Message: "Username is too long."} } collTitle := r.FormValue(oauthParamAlias) if len(collTitle) == 0 { collTitle = username } email := r.FormValue(oauthParamEmail) if len(email) > 0 { parts := strings.Split(email, "@") if len(parts) != 2 || (len(parts[0]) < 1 || len(parts[1]) < 1) { return impart.HTTPError{Status: http.StatusBadRequest, Message: "Invalid email address"} } } return nil } func (h oauthHandler) showOauthSignupPage(app *App, w http.ResponseWriter, r *http.Request, tp *oauthSignupPageParams, errMsg error) error { username := tp.TokenUsername collTitle := tp.TokenAlias email := tp.TokenEmail session, err := app.sessionStore.Get(r, cookieName) if err != nil { // Ignore this log.Error("Unable to get session; ignoring: %v", err) } if tmpValue := r.FormValue(oauthParamUsername); len(tmpValue) > 0 { username = tmpValue } if tmpValue := r.FormValue(oauthParamAlias); len(tmpValue) > 0 { collTitle = tmpValue } if tmpValue := r.FormValue(oauthParamEmail); len(tmpValue) > 0 { email = tmpValue } p := &viewOauthSignupVars{ StaticPage: pageForReq(app, r), To: r.FormValue("to"), Flashes: []template.HTML{}, AccessToken: tp.AccessToken, TokenUsername: tp.TokenUsername, TokenAlias: tp.TokenAlias, TokenEmail: tp.TokenEmail, TokenRemoteUser: tp.TokenRemoteUser, Provider: tp.Provider, ClientID: tp.ClientID, TokenHash: tp.TokenHash, InviteCode: tp.InviteCode, LoginUsername: username, Alias: collTitle, Email: email, } // Display any error messages flashes, _ := getSessionFlashes(app, w, r, session) for _, flash := range flashes { p.Flashes = append(p.Flashes, template.HTML(flash)) } if errMsg != nil { p.Flashes = append(p.Flashes, template.HTML(errMsg.Error())) } err = pages["signup-oauth.tmpl"].ExecuteTemplate(w, "base", p) if err != nil { log.Error("Unable to render signup-oauth: %v", err) return err } return nil } diff --git a/oauth_test.go b/oauth_test.go index 5694416..553c1cb 100644 --- a/oauth_test.go +++ b/oauth_test.go @@ -1,261 +1,261 @@ /* * Copyright © 2019-2021 A Bunch Tell LLC. * * This file is part of WriteFreely. * * WriteFreely is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License, included * in the LICENSE file in this source code package. */ package writefreely import ( "context" "fmt" "github.com/gorilla/sessions" "github.com/stretchr/testify/assert" "github.com/writeas/impart" "github.com/writeas/web-core/id" "github.com/writefreely/writefreely/config" "net/http" "net/http/httptest" "net/url" "strings" "testing" ) type MockOAuthDatastoreProvider struct { DoDB func() OAuthDatastore DoConfig func() *config.Config DoSessionStore func() sessions.Store } type MockOAuthDatastore struct { DoGenerateOAuthState func(context.Context, string, string, int64, string) (string, error) DoValidateOAuthState func(context.Context, string) (string, string, int64, string, error) DoGetIDForRemoteUser func(context.Context, string, string, string) (int64, error) DoCreateUser func(*config.Config, *User, string) error DoRecordRemoteUserID func(context.Context, int64, string, string, string, string) error DoGetUserByID func(int64) (*User, error) } var _ OAuthDatastore = &MockOAuthDatastore{} type StringReadCloser struct { *strings.Reader } func (src *StringReadCloser) Close() error { return nil } type MockHTTPClient struct { DoDo func(req *http.Request) (*http.Response, error) } func (m *MockHTTPClient) Do(req *http.Request) (*http.Response, error) { if m.DoDo != nil { return m.DoDo(req) } return &http.Response{}, nil } func (m *MockOAuthDatastoreProvider) SessionStore() sessions.Store { if m.DoSessionStore != nil { return m.DoSessionStore() } return sessions.NewCookieStore([]byte("secret-key")) } func (m *MockOAuthDatastoreProvider) DB() OAuthDatastore { if m.DoDB != nil { return m.DoDB() } return &MockOAuthDatastore{} } func (m *MockOAuthDatastoreProvider) Config() *config.Config { if m.DoConfig != nil { return m.DoConfig() } cfg := config.New() cfg.UseSQLite(true) cfg.WriteAsOauth = config.WriteAsOauthCfg{ ClientID: "development", ClientSecret: "development", AuthLocation: "https://write.as/oauth/login", TokenLocation: "https://write.as/oauth/token", InspectLocation: "https://write.as/oauth/inspect", } cfg.SlackOauth = config.SlackOauthCfg{ ClientID: "development", ClientSecret: "development", TeamID: "development", } return cfg } func (m *MockOAuthDatastore) ValidateOAuthState(ctx context.Context, state string) (string, string, int64, string, error) { if m.DoValidateOAuthState != nil { return m.DoValidateOAuthState(ctx, state) } return "", "", 0, "", nil } func (m *MockOAuthDatastore) GetIDForRemoteUser(ctx context.Context, remoteUserID, provider, clientID string) (int64, error) { if m.DoGetIDForRemoteUser != nil { return m.DoGetIDForRemoteUser(ctx, remoteUserID, provider, clientID) } return -1, nil } -func (m *MockOAuthDatastore) CreateUser(cfg *config.Config, u *User, username string) error { +func (m *MockOAuthDatastore) CreateUser(cfg *config.Config, u *User, username, description string) error { if m.DoCreateUser != nil { return m.DoCreateUser(cfg, u, username) } u.ID = 1 return nil } func (m *MockOAuthDatastore) RecordRemoteUserID(ctx context.Context, localUserID int64, remoteUserID, provider, clientID, accessToken string) error { if m.DoRecordRemoteUserID != nil { return m.DoRecordRemoteUserID(ctx, localUserID, remoteUserID, provider, clientID, accessToken) } return nil } func (m *MockOAuthDatastore) GetUserByID(userID int64) (*User, error) { if m.DoGetUserByID != nil { return m.DoGetUserByID(userID) } user := &User{} return user, nil } func (m *MockOAuthDatastore) GenerateOAuthState(ctx context.Context, provider string, clientID string, attachUserID int64, inviteCode string) (string, error) { if m.DoGenerateOAuthState != nil { return m.DoGenerateOAuthState(ctx, provider, clientID, attachUserID, inviteCode) } return id.Generate62RandomString(14), nil } func TestViewOauthInit(t *testing.T) { t.Run("success", func(t *testing.T) { app := &MockOAuthDatastoreProvider{} h := oauthHandler{ Config: app.Config(), DB: app.DB(), Store: app.SessionStore(), EmailKey: []byte{0xd, 0xe, 0xc, 0xa, 0xf, 0xf, 0xb, 0xa, 0xd}, oauthClient: writeAsOauthClient{ ClientID: app.Config().WriteAsOauth.ClientID, ClientSecret: app.Config().WriteAsOauth.ClientSecret, ExchangeLocation: app.Config().WriteAsOauth.TokenLocation, InspectLocation: app.Config().WriteAsOauth.InspectLocation, AuthLocation: app.Config().WriteAsOauth.AuthLocation, CallbackLocation: "http://localhost/oauth/callback", HttpClient: nil, }, } req, err := http.NewRequest("GET", "/oauth/client", nil) assert.NoError(t, err) rr := httptest.NewRecorder() err = h.viewOauthInit(nil, rr, req) assert.NotNil(t, err) httpErr, ok := err.(impart.HTTPError) assert.True(t, ok) assert.Equal(t, http.StatusTemporaryRedirect, httpErr.Status) assert.NotEmpty(t, httpErr.Message) locURI, err := url.Parse(httpErr.Message) assert.NoError(t, err) assert.Equal(t, "/oauth/login", locURI.Path) assert.Equal(t, "development", locURI.Query().Get("client_id")) assert.Equal(t, "http://localhost/oauth/callback", locURI.Query().Get("redirect_uri")) assert.Equal(t, "code", locURI.Query().Get("response_type")) assert.NotEmpty(t, locURI.Query().Get("state")) }) t.Run("state failure", func(t *testing.T) { app := &MockOAuthDatastoreProvider{ DoDB: func() OAuthDatastore { return &MockOAuthDatastore{ DoGenerateOAuthState: func(ctx context.Context, provider, clientID string, attachUserID int64, inviteCode string) (string, error) { return "", fmt.Errorf("pretend unable to write state error") }, } }, } h := oauthHandler{ Config: app.Config(), DB: app.DB(), Store: app.SessionStore(), EmailKey: []byte{0xd, 0xe, 0xc, 0xa, 0xf, 0xf, 0xb, 0xa, 0xd}, oauthClient: writeAsOauthClient{ ClientID: app.Config().WriteAsOauth.ClientID, ClientSecret: app.Config().WriteAsOauth.ClientSecret, ExchangeLocation: app.Config().WriteAsOauth.TokenLocation, InspectLocation: app.Config().WriteAsOauth.InspectLocation, AuthLocation: app.Config().WriteAsOauth.AuthLocation, CallbackLocation: "http://localhost/oauth/callback", HttpClient: nil, }, } req, err := http.NewRequest("GET", "/oauth/client", nil) assert.NoError(t, err) rr := httptest.NewRecorder() err = h.viewOauthInit(nil, rr, req) httpErr, ok := err.(impart.HTTPError) assert.True(t, ok) assert.NotEmpty(t, httpErr.Message) assert.Equal(t, http.StatusInternalServerError, httpErr.Status) assert.Equal(t, "could not prepare oauth redirect url", httpErr.Message) }) } func TestViewOauthCallback(t *testing.T) { t.Run("success", func(t *testing.T) { app := &MockOAuthDatastoreProvider{} h := oauthHandler{ Config: app.Config(), DB: app.DB(), Store: app.SessionStore(), EmailKey: []byte{0xd, 0xe, 0xc, 0xa, 0xf, 0xf, 0xb, 0xa, 0xd}, oauthClient: writeAsOauthClient{ ClientID: app.Config().WriteAsOauth.ClientID, ClientSecret: app.Config().WriteAsOauth.ClientSecret, ExchangeLocation: app.Config().WriteAsOauth.TokenLocation, InspectLocation: app.Config().WriteAsOauth.InspectLocation, AuthLocation: app.Config().WriteAsOauth.AuthLocation, CallbackLocation: "http://localhost/oauth/callback", HttpClient: &MockHTTPClient{ DoDo: func(req *http.Request) (*http.Response, error) { switch req.URL.String() { case "https://write.as/oauth/token": return &http.Response{ StatusCode: 200, Body: &StringReadCloser{strings.NewReader(`{"access_token": "access_token", "expires_in": 1000, "refresh_token": "refresh_token", "token_type": "access"}`)}, }, nil case "https://write.as/oauth/inspect": return &http.Response{ StatusCode: 200, Body: &StringReadCloser{strings.NewReader(`{"client_id": "development", "user_id": "1", "expires_at": "2019-12-19T11:42:01Z", "username": "nick", "email": "nick@testing.write.as"}`)}, }, nil } return &http.Response{ StatusCode: http.StatusNotFound, }, nil }, }, }, } req, err := http.NewRequest("GET", "/oauth/callback", nil) assert.NoError(t, err) rr := httptest.NewRecorder() err = h.viewOauthCallback(&App{cfg: app.Config(), sessionStore: app.SessionStore()}, rr, req) assert.NoError(t, err) assert.Equal(t, http.StatusTemporaryRedirect, rr.Code) }) }