[feature] add 'state' oauth2 param to /oauth/authorize (#730)

This commit is contained in:
tobi 2022-07-28 16:43:27 +02:00 committed by GitHub
parent 7ca5bac7c6
commit 8106b69856
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 35 additions and 20 deletions

View file

@ -55,13 +55,14 @@ const (
callbackStateParam = "state" callbackStateParam = "state"
callbackCodeParam = "code" callbackCodeParam = "code"
sessionUserID = "userid" sessionUserID = "userid"
sessionClientID = "client_id" sessionClientID = "client_id"
sessionRedirectURI = "redirect_uri" sessionRedirectURI = "redirect_uri"
sessionForceLogin = "force_login" sessionForceLogin = "force_login"
sessionResponseType = "response_type" sessionResponseType = "response_type"
sessionScope = "scope" sessionScope = "scope"
sessionState = "state" sessionInternalState = "internal_state"
sessionClientState = "client_state"
) )
// Module implements the ClientAPIModule interface for // Module implements the ClientAPIModule interface for

View file

@ -189,6 +189,11 @@ func (m *Module) AuthorizePOSTHandler(c *gin.Context) {
errs = append(errs, fmt.Sprintf("key %s was not found in session", sessionScope)) errs = append(errs, fmt.Sprintf("key %s was not found in session", sessionScope))
} }
var clientState string
if s, ok := s.Get(sessionClientState).(string); ok {
clientState = s
}
userID, ok := s.Get(sessionUserID).(string) userID, ok := s.Get(sessionUserID).(string)
if !ok { if !ok {
errs = append(errs, fmt.Sprintf("key %s was not found in session", sessionUserID)) errs = append(errs, fmt.Sprintf("key %s was not found in session", sessionUserID))
@ -246,6 +251,10 @@ func (m *Module) AuthorizePOSTHandler(c *gin.Context) {
sessionUserID: {userID}, sessionUserID: {userID},
} }
if clientState != "" {
c.Request.Form.Set("state", clientState)
}
if err := m.processor.OAuthHandleAuthorizeRequest(c.Writer, c.Request); err != nil { if err := m.processor.OAuthHandleAuthorizeRequest(c.Writer, c.Request); err != nil {
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error(), helpfulAdvice), m.processor.InstanceGet) api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error(), helpfulAdvice), m.processor.InstanceGet)
} }
@ -285,7 +294,8 @@ func saveAuthFormToSession(s sessions.Session, form *model.OAuthAuthorize) gtser
s.Set(sessionClientID, form.ClientID) s.Set(sessionClientID, form.ClientID)
s.Set(sessionRedirectURI, form.RedirectURI) s.Set(sessionRedirectURI, form.RedirectURI)
s.Set(sessionScope, form.Scope) s.Set(sessionScope, form.Scope)
s.Set(sessionState, uuid.NewString()) s.Set(sessionInternalState, uuid.NewString())
s.Set(sessionClientState, form.State)
if err := s.Save(); err != nil { if err := s.Save(); err != nil {
err := fmt.Errorf("error saving form values onto session: %s", err) err := fmt.Errorf("error saving form values onto session: %s", err)

View file

@ -45,26 +45,26 @@ func (m *Module) CallbackGETHandler(c *gin.Context) {
// check the query vs session state parameter to mitigate csrf // check the query vs session state parameter to mitigate csrf
// https://auth0.com/docs/secure/attack-protection/state-parameters // https://auth0.com/docs/secure/attack-protection/state-parameters
state := c.Query(callbackStateParam) returnedInternalState := c.Query(callbackStateParam)
if state == "" { if returnedInternalState == "" {
m.clearSession(s) m.clearSession(s)
err := fmt.Errorf("%s parameter not found on callback query", callbackStateParam) err := fmt.Errorf("%s parameter not found on callback query", callbackStateParam)
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
savedStateI := s.Get(sessionState) savedInternalStateI := s.Get(sessionInternalState)
savedState, ok := savedStateI.(string) savedInternalState, ok := savedInternalStateI.(string)
if !ok { if !ok {
m.clearSession(s) m.clearSession(s)
err := fmt.Errorf("key %s was not found in session", sessionState) err := fmt.Errorf("key %s was not found in session", sessionInternalState)
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
if state != savedState { if returnedInternalState != savedInternalState {
m.clearSession(s) m.clearSession(s)
err := errors.New("mismatch between query state and session state") err := errors.New("mismatch between callback state and saved state")
api.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet) api.ErrorHandler(c, gtserror.NewErrorUnauthorized(err, err.Error()), m.processor.InstanceGet)
return return
} }

View file

@ -58,16 +58,16 @@ func (m *Module) SignInGETHandler(c *gin.Context) {
// idp provider is in use, so redirect to it // idp provider is in use, so redirect to it
s := sessions.Default(c) s := sessions.Default(c)
stateI := s.Get(sessionState) internalStateI := s.Get(sessionInternalState)
state, ok := stateI.(string) internalState, ok := internalStateI.(string)
if !ok { if !ok {
m.clearSession(s) m.clearSession(s)
err := fmt.Errorf("key %s was not found in session", sessionState) err := fmt.Errorf("key %s was not found in session", sessionInternalState)
api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet) api.ErrorHandler(c, gtserror.NewErrorBadRequest(err, err.Error()), m.processor.InstanceGet)
return return
} }
c.Redirect(http.StatusSeeOther, m.idp.AuthCodeURL(state)) c.Redirect(http.StatusSeeOther, m.idp.AuthCodeURL(internalState))
} }
// SignInPOSTHandler should be served at https://example.org/auth/sign_in. // SignInPOSTHandler should be served at https://example.org/auth/sign_in.

View file

@ -312,7 +312,7 @@ func (suite *StatusCreateTestSuite) TestAttachNewMediaSuccess() {
ctx.Request = httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://localhost:8080/%s", status.BasePath), nil) // the endpoint we're hitting ctx.Request = httptest.NewRequest(http.MethodPost, fmt.Sprintf("http://localhost:8080/%s", status.BasePath), nil) // the endpoint we're hitting
ctx.Request.Header.Set("accept", "application/json") ctx.Request.Header.Set("accept", "application/json")
ctx.Request.Form = url.Values{ ctx.Request.Form = url.Values{
"status": {"here's an image attachment"}, "status": {"here's an image attachment"},
"media_ids[]": {attachment.ID}, "media_ids[]": {attachment.ID},
} }
suite.statusModule.StatusCreatePOSTHandler(ctx) suite.statusModule.StatusCreatePOSTHandler(ctx)

View file

@ -33,4 +33,8 @@ type OAuthAuthorize struct {
// List of requested OAuth scopes, separated by spaces (or by pluses, if using query parameters). // List of requested OAuth scopes, separated by spaces (or by pluses, if using query parameters).
// Must be a subset of scopes declared during app registration. If not provided, defaults to read. // Must be a subset of scopes declared during app registration. If not provided, defaults to read.
Scope string `form:"scope" json:"scope"` Scope string `form:"scope" json:"scope"`
// State is used by the application to store request-specific data and/or prevent CSRF attacks.
// The authorization server must return the unmodified state value back to the application.
// See https://www.oauth.com/oauth2-servers/authorization/the-authorization-request/
State string `form:"state" json:"state"`
} }