diff --git a/oauth_generic.go b/oauth_generic.go index ba8b97e..ccf6a0f 100644 --- a/oauth_generic.go +++ b/oauth_generic.go @@ -1,126 +1,142 @@ +/* + * Copyright © 2020-2021 A Bunch Tell LLC and respective authors. + * + * This file is part of WriteFreely. + * + * WriteFreely is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, included + * in the LICENSE file in this source code package. + */ + package writefreely import ( "context" "errors" + "fmt" + "github.com/writeas/web-core/log" "net/http" "net/url" "strings" ) type genericOauthClient struct { ClientID string ClientSecret string AuthLocation string ExchangeLocation string InspectLocation string CallbackLocation string Scope string MapUserID string MapUsername string MapDisplayName string MapEmail string HttpClient HttpClient } var _ oauthClient = genericOauthClient{} const ( genericOauthDisplayName = "OAuth" ) func (c genericOauthClient) GetProvider() string { return "generic" } func (c genericOauthClient) GetClientID() string { return c.ClientID } func (c genericOauthClient) GetCallbackLocation() string { return c.CallbackLocation } func (c genericOauthClient) buildLoginURL(state string) (string, error) { u, err := url.Parse(c.AuthLocation) if err != nil { return "", err } q := u.Query() q.Set("client_id", c.ClientID) q.Set("redirect_uri", c.CallbackLocation) q.Set("response_type", "code") q.Set("state", state) q.Set("scope", c.Scope) u.RawQuery = q.Encode() return u.String(), nil } func (c genericOauthClient) exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error) { form := url.Values{} form.Add("grant_type", "authorization_code") form.Add("redirect_uri", c.CallbackLocation) form.Add("scope", c.Scope) form.Add("code", code) req, err := http.NewRequest("POST", c.ExchangeLocation, strings.NewReader(form.Encode())) if err != nil { return nil, err } req.WithContext(ctx) req.Header.Set("User-Agent", ServerUserAgent("")) req.Header.Set("Accept", "application/json") req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.SetBasicAuth(c.ClientID, c.ClientSecret) resp, err := c.HttpClient.Do(req) if err != nil { return nil, err } if resp.StatusCode != http.StatusOK { return nil, errors.New("unable to exchange code for access token") } var tokenResponse TokenResponse if err := limitedJsonUnmarshal(resp.Body, tokenRequestMaxLen, &tokenResponse); err != nil { return nil, err } if tokenResponse.Error != "" { return nil, errors.New(tokenResponse.Error) } return &tokenResponse, nil } func (c genericOauthClient) inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error) { req, err := http.NewRequest("GET", c.InspectLocation, nil) if err != nil { return nil, err } req.WithContext(ctx) req.Header.Set("User-Agent", ServerUserAgent("")) req.Header.Set("Accept", "application/json") req.Header.Set("Authorization", "Bearer "+accessToken) resp, err := c.HttpClient.Do(req) if err != nil { return nil, err } if resp.StatusCode != http.StatusOK { return nil, errors.New("unable to inspect access token") } // since we don't know what the JSON from the server will look like, we create a // generic interface and then map manually to values set in the config var genericInterface map[string]interface{} if err := limitedJsonUnmarshal(resp.Body, infoRequestMaxLen, &genericInterface); err != nil { return nil, err } // map each relevant field in inspectResponse to the mapped field from the config var inspectResponse InspectResponse inspectResponse.UserID, _ = genericInterface[c.MapUserID].(string) + if inspectResponse.UserID == "" { + log.Error("[CONFIGURATION ERROR] Generic OAuth provider returned empty UserID value (`%s`).\n Do you need to configure a different `map_user_id` value for this provider?", c.MapUserID) + return nil, fmt.Errorf("no UserID (`%s`) value returned", c.MapUserID) + } inspectResponse.Username, _ = genericInterface[c.MapUsername].(string) inspectResponse.DisplayName, _ = genericInterface[c.MapDisplayName].(string) inspectResponse.Email, _ = genericInterface[c.MapEmail].(string) return &inspectResponse, nil }