Skip to content

Commit 821b207

Browse files
committed
http2/h2c: Respect the req.Context()
When using h2c.NewHandler, the *http.Request object for h2c requests has a .Context() that doesn't inherit from the *http.Server's BaseContext. This is surprising for users of vanilla net/http, and is surprising to users of http2.ConfigureServer; HTTP/1 requests inherit from that BaseContext, and TLS h2 requests inherit from that BaseContext, but cleartext h2c requests don't. So, modify h2c.NewHander to respect that base Context, by way of the hijacked Request's .Context().
1 parent 5f4716e commit 821b207

File tree

2 files changed

+56
-2
lines changed

2 files changed

+56
-2
lines changed

http2/h2c/h2c.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,14 +84,20 @@ func (s h2cHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
8484
}
8585
defer conn.Close()
8686

87-
s.s.ServeConn(conn, &http2.ServeConnOpts{Handler: s.Handler})
87+
s.s.ServeConn(conn, &http2.ServeConnOpts{
88+
Context: r.Context(),
89+
Handler: s.Handler,
90+
})
8891
return
8992
}
9093
// Handle Upgrade to h2c (RFC 7540 Section 3.2)
9194
if conn, err := h2cUpgrade(w, r); err == nil {
9295
defer conn.Close()
9396

94-
s.s.ServeConn(conn, &http2.ServeConnOpts{Handler: s.Handler})
97+
s.s.ServeConn(conn, &http2.ServeConnOpts{
98+
Context: r.Context(),
99+
Handler: s.Handler,
100+
})
95101
return
96102
}
97103

http2/h2c/h2c_test.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,14 @@ package h2c
77
import (
88
"bufio"
99
"bytes"
10+
"context"
11+
"crypto/tls"
1012
"fmt"
13+
"io/ioutil"
1114
"log"
15+
"net"
1216
"net/http"
17+
"net/http/httptest"
1318
"testing"
1419

1520
"golang.org/x/net/http2"
@@ -56,3 +61,46 @@ func ExampleNewHandler() {
5661
}
5762
log.Fatal(h1s.ListenAndServe())
5863
}
64+
65+
func TestContext(t *testing.T) {
66+
baseCtx := context.WithValue(context.Background(), "testkey", "testvalue")
67+
68+
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
69+
if r.ProtoMajor != 2 {
70+
t.Errorf("Request wasn't handled by h2c. Got ProtoMajor=%v", r.ProtoMajor)
71+
}
72+
if r.Context().Value("testkey") != "testvalue" {
73+
t.Errorf("Request doesn't have expected base context: %v", r.Context())
74+
}
75+
fmt.Fprint(w, "Hello world")
76+
})
77+
78+
h2s := &http2.Server{}
79+
h1s := httptest.NewUnstartedServer(NewHandler(handler, h2s))
80+
h1s.Config.BaseContext = func(_ net.Listener) context.Context {
81+
return baseCtx
82+
}
83+
h1s.Start()
84+
defer h1s.Close()
85+
86+
client := &http.Client{
87+
Transport: &http2.Transport{
88+
AllowHTTP: true,
89+
DialTLS: func(network, addr string, _ *tls.Config) (net.Conn, error) {
90+
return net.Dial(network, addr)
91+
},
92+
},
93+
}
94+
95+
resp, err := client.Get(h1s.URL)
96+
if err != nil {
97+
t.Fatal(err)
98+
}
99+
_, err = ioutil.ReadAll(resp.Body)
100+
if err != nil {
101+
t.Fatal(err)
102+
}
103+
if err := resp.Body.Close(); err != nil {
104+
t.Fatal(err)
105+
}
106+
}

0 commit comments

Comments
 (0)