|
| 1 | +/* |
| 2 | + * |
| 3 | + * Copyright 2017 gRPC authors. |
| 4 | + * |
| 5 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 6 | + * you may not use this file except in compliance with the License. |
| 7 | + * You may obtain a copy of the License at |
| 8 | + * |
| 9 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | + * |
| 11 | + * Unless required by applicable law or agreed to in writing, software |
| 12 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | + * See the License for the specific language governing permissions and |
| 15 | + * limitations under the License. |
| 16 | + * |
| 17 | + */ |
| 18 | + |
| 19 | +// Package bufconn provides a net.Conn implemented by a buffer and related |
| 20 | +// dialing and listening functionality. |
| 21 | +package bufconn |
| 22 | + |
| 23 | +import ( |
| 24 | + "fmt" |
| 25 | + "io" |
| 26 | + "net" |
| 27 | + "sync" |
| 28 | + "time" |
| 29 | +) |
| 30 | + |
| 31 | +// Listener implements a net.Listener that creates local, buffered net.Conns |
| 32 | +// via its Accept and Dial method. |
| 33 | +type Listener struct { |
| 34 | + mu sync.Mutex |
| 35 | + sz int |
| 36 | + ch chan net.Conn |
| 37 | + closed bool |
| 38 | +} |
| 39 | + |
| 40 | +var errClosed = fmt.Errorf("Closed") |
| 41 | + |
| 42 | +// Listen returns a Listener that can only be contacted by its own Dialers and |
| 43 | +// creates buffered connections between the two. |
| 44 | +func Listen(sz int) *Listener { |
| 45 | + return &Listener{sz: sz, ch: make(chan net.Conn)} |
| 46 | +} |
| 47 | + |
| 48 | +// Accept blocks until Dial is called, then returns a net.Conn for the server |
| 49 | +// half of the connection. |
| 50 | +func (l *Listener) Accept() (net.Conn, error) { |
| 51 | + c := <-l.ch |
| 52 | + if c == nil { |
| 53 | + return nil, errClosed |
| 54 | + } |
| 55 | + return c, nil |
| 56 | +} |
| 57 | + |
| 58 | +// Close stops the listener. |
| 59 | +func (l *Listener) Close() error { |
| 60 | + l.mu.Lock() |
| 61 | + defer l.mu.Unlock() |
| 62 | + if l.closed { |
| 63 | + return nil |
| 64 | + } |
| 65 | + l.closed = true |
| 66 | + close(l.ch) |
| 67 | + return nil |
| 68 | +} |
| 69 | + |
| 70 | +// Addr reports the address of the listener. |
| 71 | +func (l *Listener) Addr() net.Addr { return addr{} } |
| 72 | + |
| 73 | +// Dial creates an in-memory full-duplex network connection, unblocks Accept by |
| 74 | +// providing it the server half of the connection, and returns the client half |
| 75 | +// of the connection. |
| 76 | +func (l *Listener) Dial() (net.Conn, error) { |
| 77 | + l.mu.Lock() |
| 78 | + defer l.mu.Unlock() |
| 79 | + if l.closed { |
| 80 | + return nil, errClosed |
| 81 | + } |
| 82 | + p1, p2 := newPipe(l.sz), newPipe(l.sz) |
| 83 | + l.ch <- &conn{p1, p2} |
| 84 | + return &conn{p2, p1}, nil |
| 85 | +} |
| 86 | + |
| 87 | +type pipe struct { |
| 88 | + mu sync.Mutex |
| 89 | + |
| 90 | + // buf contains the data in the pipe. It is a ring buffer of fixed capacity, |
| 91 | + // with r and w pointing to the offset to read and write, respsectively. |
| 92 | + // |
| 93 | + // Data is read between [r, w) and written to [w, r), wrapping around the end |
| 94 | + // of the slice if necessary. |
| 95 | + // |
| 96 | + // The buffer is empty if r == len(buf), otherwise if r == w, it is full. |
| 97 | + // |
| 98 | + // w and r are always in the range [0, cap(buf)) and [0, len(buf)]. |
| 99 | + buf []byte |
| 100 | + w, r int |
| 101 | + |
| 102 | + wwait sync.Cond |
| 103 | + rwait sync.Cond |
| 104 | + closed bool |
| 105 | +} |
| 106 | + |
| 107 | +func newPipe(sz int) *pipe { |
| 108 | + p := &pipe{buf: make([]byte, 0, sz)} |
| 109 | + p.wwait.L = &p.mu |
| 110 | + p.rwait.L = &p.mu |
| 111 | + return p |
| 112 | +} |
| 113 | + |
| 114 | +func (p *pipe) empty() bool { |
| 115 | + return p.r == len(p.buf) |
| 116 | +} |
| 117 | + |
| 118 | +func (p *pipe) full() bool { |
| 119 | + return p.r < len(p.buf) && p.r == p.w |
| 120 | +} |
| 121 | + |
| 122 | +func (p *pipe) Read(b []byte) (n int, err error) { |
| 123 | + p.mu.Lock() |
| 124 | + defer p.mu.Unlock() |
| 125 | + // Block until p has data. |
| 126 | + for { |
| 127 | + if p.closed { |
| 128 | + return 0, io.ErrClosedPipe |
| 129 | + } |
| 130 | + if !p.empty() { |
| 131 | + break |
| 132 | + } |
| 133 | + p.rwait.Wait() |
| 134 | + } |
| 135 | + wasFull := p.full() |
| 136 | + |
| 137 | + n = copy(b, p.buf[p.r:len(p.buf)]) |
| 138 | + p.r += n |
| 139 | + if p.r == cap(p.buf) { |
| 140 | + p.r = 0 |
| 141 | + p.buf = p.buf[:p.w] |
| 142 | + } |
| 143 | + |
| 144 | + // Signal a blocked writer, if any |
| 145 | + if wasFull { |
| 146 | + p.wwait.Signal() |
| 147 | + } |
| 148 | + |
| 149 | + return n, nil |
| 150 | +} |
| 151 | + |
| 152 | +func (p *pipe) Write(b []byte) (n int, err error) { |
| 153 | + p.mu.Lock() |
| 154 | + defer p.mu.Unlock() |
| 155 | + if p.closed { |
| 156 | + return 0, io.ErrClosedPipe |
| 157 | + } |
| 158 | + for len(b) > 0 { |
| 159 | + // Block until p is not full. |
| 160 | + for { |
| 161 | + if p.closed { |
| 162 | + return 0, io.ErrClosedPipe |
| 163 | + } |
| 164 | + if !p.full() { |
| 165 | + break |
| 166 | + } |
| 167 | + p.wwait.Wait() |
| 168 | + } |
| 169 | + wasEmpty := p.empty() |
| 170 | + |
| 171 | + end := cap(p.buf) |
| 172 | + if p.w < p.r { |
| 173 | + end = p.r |
| 174 | + } |
| 175 | + x := copy(p.buf[p.w:end], b) |
| 176 | + b = b[x:] |
| 177 | + n += x |
| 178 | + p.w += x |
| 179 | + if p.w > len(p.buf) { |
| 180 | + p.buf = p.buf[:p.w] |
| 181 | + } |
| 182 | + if p.w == cap(p.buf) { |
| 183 | + p.w = 0 |
| 184 | + } |
| 185 | + |
| 186 | + // Signal a blocked reader, if any. |
| 187 | + if wasEmpty { |
| 188 | + p.rwait.Signal() |
| 189 | + } |
| 190 | + } |
| 191 | + return n, nil |
| 192 | +} |
| 193 | + |
| 194 | +func (p *pipe) Close() error { |
| 195 | + p.mu.Lock() |
| 196 | + defer p.mu.Unlock() |
| 197 | + p.closed = true |
| 198 | + // Signal all blocked readers and writers to return an error. |
| 199 | + p.rwait.Broadcast() |
| 200 | + p.wwait.Broadcast() |
| 201 | + return nil |
| 202 | +} |
| 203 | + |
| 204 | +type conn struct { |
| 205 | + io.ReadCloser |
| 206 | + io.WriteCloser |
| 207 | +} |
| 208 | + |
| 209 | +func (c *conn) Close() error { |
| 210 | + err1 := c.ReadCloser.Close() |
| 211 | + err2 := c.WriteCloser.Close() |
| 212 | + if err1 != nil { |
| 213 | + return err1 |
| 214 | + } |
| 215 | + return err2 |
| 216 | +} |
| 217 | + |
| 218 | +func (*conn) LocalAddr() net.Addr { return addr{} } |
| 219 | +func (*conn) RemoteAddr() net.Addr { return addr{} } |
| 220 | +func (c *conn) SetDeadline(t time.Time) error { return fmt.Errorf("unsupported") } |
| 221 | +func (c *conn) SetReadDeadline(t time.Time) error { return fmt.Errorf("unsupported") } |
| 222 | +func (c *conn) SetWriteDeadline(t time.Time) error { return fmt.Errorf("unsupported") } |
| 223 | + |
| 224 | +type addr struct{} |
| 225 | + |
| 226 | +func (addr) Network() string { return "bufconn" } |
| 227 | +func (addr) String() string { return "bufconn" } |
0 commit comments