Skip to content

Commit 4d789d2

Browse files
marcosnilshickford
authored andcommitted
oauth2: add device flow support
Signed-off-by: Marcos Lilljedahl <[email protected]>
1 parent e48dfd9 commit 4d789d2

File tree

2 files changed

+140
-3
lines changed

2 files changed

+140
-3
lines changed

deviceauth.go

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
package oauth2
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"fmt"
7+
"io"
8+
"io/ioutil"
9+
"net/http"
10+
"net/url"
11+
"strings"
12+
13+
"golang.org/x/net/context/ctxhttp"
14+
)
15+
16+
const (
17+
errAuthorizationPending = "authorization_pending"
18+
errSlowDown = "slow_down"
19+
errAccessDenied = "access_denied"
20+
errExpiredToken = "expired_token"
21+
)
22+
23+
type DeviceAuth struct {
24+
DeviceCode string `json:"device_code"`
25+
UserCode string `json:"user_code"`
26+
VerificationURI string `json:"verification_uri,verification_url"`
27+
VerificationURIComplete string `json:"verification_uri_complete,omitempty"`
28+
ExpiresIn int `json:"expires_in"`
29+
Interval int `json:"interval,omitempty"`
30+
raw map[string]interface{}
31+
}
32+
33+
func retrieveDeviceAuth(ctx context.Context, c *Config, v url.Values) (*DeviceAuth, error) {
34+
req, err := http.NewRequest("POST", c.Endpoint.DeviceAuthURL, strings.NewReader(v.Encode()))
35+
if err != nil {
36+
return nil, err
37+
}
38+
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
39+
40+
r, err := ctxhttp.Do(ctx, nil, req)
41+
if err != nil {
42+
return nil, err
43+
}
44+
45+
body, err := ioutil.ReadAll(io.LimitReader(r.Body, 1<<20))
46+
if err != nil {
47+
return nil, fmt.Errorf("oauth2: cannot auth device: %v", err)
48+
}
49+
if code := r.StatusCode; code < 200 || code > 299 {
50+
return nil, &RetrieveError{
51+
Response: r,
52+
Body: body,
53+
}
54+
}
55+
56+
da := &DeviceAuth{}
57+
err = json.Unmarshal(body, &da)
58+
if err != nil {
59+
return nil, err
60+
}
61+
62+
_ = json.Unmarshal(body, &da.raw)
63+
64+
// Azure AD supplies verification_url instead of verification_uri
65+
if da.VerificationURI == "" {
66+
da.VerificationURI, _ = da.raw["verification_url"].(string)
67+
}
68+
69+
return da, nil
70+
}
71+
72+
func parseError(err error) string {
73+
e, ok := err.(*RetrieveError)
74+
if ok {
75+
eResp := make(map[string]string)
76+
_ = json.Unmarshal(e.Body, &eResp)
77+
return eResp["error"]
78+
}
79+
return ""
80+
}

oauth2.go

Lines changed: 60 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
"net/url"
1717
"strings"
1818
"sync"
19+
"time"
1920

2021
"golang.org/x/oauth2/internal"
2122
)
@@ -70,8 +71,9 @@ type TokenSource interface {
7071
// Endpoint represents an OAuth 2.0 provider's authorization and token
7172
// endpoint URLs.
7273
type Endpoint struct {
73-
AuthURL string
74-
TokenURL string
74+
AuthURL string
75+
DeviceAuthURL string
76+
TokenURL string
7577

7678
// AuthStyle optionally specifies how the endpoint wants the
7779
// client ID & client secret sent. The zero value means to
@@ -224,6 +226,62 @@ func (c *Config) Exchange(ctx context.Context, code string, opts ...AuthCodeOpti
224226
return retrieveToken(ctx, c, v)
225227
}
226228

229+
// AuthDevice returns a device auth struct which contains a device code
230+
// and authorization information provided for users to enter on another device.
231+
func (c *Config) AuthDevice(ctx context.Context, opts ...AuthCodeOption) (*DeviceAuth, error) {
232+
v := url.Values{
233+
"client_id": {c.ClientID},
234+
}
235+
if len(c.Scopes) > 0 {
236+
v.Set("scope", strings.Join(c.Scopes, " "))
237+
}
238+
for _, opt := range opts {
239+
opt.setValue(v)
240+
}
241+
return retrieveDeviceAuth(ctx, c, v)
242+
}
243+
244+
// Poll does a polling to exchange an device code for a token.
245+
func (c *Config) Poll(ctx context.Context, da *DeviceAuth, opts ...AuthCodeOption) (*Token, error) {
246+
v := url.Values{
247+
"client_id": {c.ClientID},
248+
"grant_type": {"urn:ietf:params:oauth:grant-type:device_code"},
249+
"device_code": {da.DeviceCode},
250+
}
251+
if len(c.Scopes) > 0 {
252+
v.Set("scope", strings.Join(c.Scopes, " "))
253+
}
254+
for _, opt := range opts {
255+
opt.setValue(v)
256+
}
257+
258+
// If no interval was provided, the client MUST use a reasonable default polling interval.
259+
// See https://tools.ietf.org/html/draft-ietf-oauth-device-flow-07#section-3.5
260+
interval := da.Interval
261+
if interval == 0 {
262+
interval = 5
263+
}
264+
265+
for {
266+
time.Sleep(time.Duration(interval) * time.Second)
267+
268+
tok, err := retrieveToken(ctx, c, v)
269+
if err == nil {
270+
return tok, nil
271+
}
272+
273+
errTyp := parseError(err)
274+
switch errTyp {
275+
case errAccessDenied, errExpiredToken:
276+
return tok, errors.New("oauth2: " + errTyp)
277+
case errSlowDown:
278+
interval += 1
279+
fallthrough
280+
case errAuthorizationPending:
281+
}
282+
}
283+
}
284+
227285
// Client returns an HTTP client using the provided token.
228286
// The token will auto-refresh as necessary. The underlying
229287
// HTTP transport will be obtained using the provided context.
@@ -271,7 +329,6 @@ func (tf *tokenRefresher) Token() (*Token, error) {
271329
"grant_type": {"refresh_token"},
272330
"refresh_token": {tf.refreshToken},
273331
})
274-
275332
if err != nil {
276333
return nil, err
277334
}

0 commit comments

Comments
 (0)