2021-08-12 21:03:24 +02:00
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"bufio"
"errors"
"io"
"net/http"
"net/url"
"strings"
"time"
)
// HandshakeError describes an error with the handshake from the peer.
type HandshakeError struct {
message string
}
func ( e HandshakeError ) Error ( ) string { return e . message }
// Upgrader specifies parameters for upgrading an HTTP connection to a
// WebSocket connection.
2022-05-02 14:05:18 +01:00
//
// It is safe to call Upgrader's methods concurrently.
2021-08-12 21:03:24 +02:00
type Upgrader struct {
// HandshakeTimeout specifies the duration for the handshake to complete.
HandshakeTimeout time . Duration
// ReadBufferSize and WriteBufferSize specify I/O buffer sizes in bytes. If a buffer
// size is zero, then buffers allocated by the HTTP server are used. The
// I/O buffer sizes do not limit the size of the messages that can be sent
// or received.
2024-06-10 07:43:38 +00:00
// The default value is 4096 bytes, 4kb.
2021-08-12 21:03:24 +02:00
ReadBufferSize , WriteBufferSize int
// WriteBufferPool is a pool of buffers for write operations. If the value
// is not set, then write buffers are allocated to the connection for the
// lifetime of the connection.
//
// A pool is most useful when the application has a modest volume of writes
// across a large number of connections.
//
// Applications should use a single pool for each unique value of
// WriteBufferSize.
WriteBufferPool BufferPool
// Subprotocols specifies the server's supported protocols in order of
// preference. If this field is not nil, then the Upgrade method negotiates a
// subprotocol by selecting the first match in this list with a protocol
// requested by the client. If there's no match, then no protocol is
// negotiated (the Sec-Websocket-Protocol header is not included in the
// handshake response).
Subprotocols [ ] string
// Error specifies the function for generating HTTP error responses. If Error
// is nil, then http.Error is used to generate the HTTP response.
Error func ( w http . ResponseWriter , r * http . Request , status int , reason error )
// CheckOrigin returns true if the request Origin header is acceptable. If
// CheckOrigin is nil, then a safe default is used: return false if the
// Origin request header is present and the origin host is not equal to
// request Host header.
//
// A CheckOrigin function should carefully validate the request origin to
// prevent cross-site request forgery.
CheckOrigin func ( r * http . Request ) bool
// EnableCompression specify if the server should attempt to negotiate per
// message compression (RFC 7692). Setting this value to true does not
// guarantee that compression will be supported. Currently only "no context
// takeover" modes are supported.
EnableCompression bool
}
func ( u * Upgrader ) returnError ( w http . ResponseWriter , r * http . Request , status int , reason string ) ( * Conn , error ) {
err := HandshakeError { reason }
if u . Error != nil {
u . Error ( w , r , status , err )
} else {
w . Header ( ) . Set ( "Sec-Websocket-Version" , "13" )
http . Error ( w , http . StatusText ( status ) , status )
}
return nil , err
}
// checkSameOrigin returns true if the origin is not set or is equal to the request host.
func checkSameOrigin ( r * http . Request ) bool {
origin := r . Header [ "Origin" ]
if len ( origin ) == 0 {
return true
}
u , err := url . Parse ( origin [ 0 ] )
if err != nil {
return false
}
return equalASCIIFold ( u . Host , r . Host )
}
func ( u * Upgrader ) selectSubprotocol ( r * http . Request , responseHeader http . Header ) string {
if u . Subprotocols != nil {
clientProtocols := Subprotocols ( r )
2024-06-10 07:43:38 +00:00
for _ , clientProtocol := range clientProtocols {
for _ , serverProtocol := range u . Subprotocols {
2021-08-12 21:03:24 +02:00
if clientProtocol == serverProtocol {
return clientProtocol
}
}
}
} else if responseHeader != nil {
return responseHeader . Get ( "Sec-Websocket-Protocol" )
}
return ""
}
// Upgrade upgrades the HTTP server connection to the WebSocket protocol.
//
// The responseHeader is included in the response to the client's upgrade
2022-05-02 14:05:18 +01:00
// request. Use the responseHeader to specify cookies (Set-Cookie). To specify
// subprotocols supported by the server, set Upgrader.Subprotocols directly.
2021-08-12 21:03:24 +02:00
//
// If the upgrade fails, then Upgrade replies to the client with an HTTP error
// response.
func ( u * Upgrader ) Upgrade ( w http . ResponseWriter , r * http . Request , responseHeader http . Header ) ( * Conn , error ) {
const badHandshake = "websocket: the client is not using the websocket protocol: "
if ! tokenListContainsValue ( r . Header , "Connection" , "upgrade" ) {
return u . returnError ( w , r , http . StatusBadRequest , badHandshake + "'upgrade' token not found in 'Connection' header" )
}
if ! tokenListContainsValue ( r . Header , "Upgrade" , "websocket" ) {
return u . returnError ( w , r , http . StatusBadRequest , badHandshake + "'websocket' token not found in 'Upgrade' header" )
}
2022-05-02 14:05:18 +01:00
if r . Method != http . MethodGet {
2021-08-12 21:03:24 +02:00
return u . returnError ( w , r , http . StatusMethodNotAllowed , badHandshake + "request method is not GET" )
}
if ! tokenListContainsValue ( r . Header , "Sec-Websocket-Version" , "13" ) {
return u . returnError ( w , r , http . StatusBadRequest , "websocket: unsupported version: 13 not found in 'Sec-Websocket-Version' header" )
}
if _ , ok := responseHeader [ "Sec-Websocket-Extensions" ] ; ok {
return u . returnError ( w , r , http . StatusInternalServerError , "websocket: application specific 'Sec-WebSocket-Extensions' headers are unsupported" )
}
checkOrigin := u . CheckOrigin
if checkOrigin == nil {
checkOrigin = checkSameOrigin
}
if ! checkOrigin ( r ) {
return u . returnError ( w , r , http . StatusForbidden , "websocket: request origin not allowed by Upgrader.CheckOrigin" )
}
challengeKey := r . Header . Get ( "Sec-Websocket-Key" )
2023-11-28 11:05:07 +00:00
if ! isValidChallengeKey ( challengeKey ) {
return u . returnError ( w , r , http . StatusBadRequest , "websocket: not a websocket handshake: 'Sec-WebSocket-Key' header must be Base64 encoded value of 16-byte in length" )
2021-08-12 21:03:24 +02:00
}
subprotocol := u . selectSubprotocol ( r , responseHeader )
// Negotiate PMCE
var compress bool
if u . EnableCompression {
for _ , ext := range parseExtensions ( r . Header ) {
if ext [ "" ] != "permessage-deflate" {
continue
}
compress = true
break
}
}
2024-06-10 07:43:38 +00:00
netConn , brw , err := http . NewResponseController ( w ) . Hijack ( )
2021-08-12 21:03:24 +02:00
if err != nil {
return u . returnError ( w , r , http . StatusInternalServerError , err . Error ( ) )
}
if brw . Reader . Buffered ( ) > 0 {
2024-06-10 07:43:38 +00:00
netConn . Close ( )
2021-08-12 21:03:24 +02:00
return nil , errors . New ( "websocket: client sent data before handshake is complete" )
}
var br * bufio . Reader
if u . ReadBufferSize == 0 && bufioReaderSize ( netConn , brw . Reader ) > 256 {
// Reuse hijacked buffered reader as connection reader.
br = brw . Reader
}
buf := bufioWriterBuffer ( netConn , brw . Writer )
var writeBuf [ ] byte
if u . WriteBufferPool == nil && u . WriteBufferSize == 0 && len ( buf ) >= maxFrameHeaderSize + 256 {
// Reuse hijacked write buffer as connection buffer.
writeBuf = buf
}
c := newConn ( netConn , true , u . ReadBufferSize , u . WriteBufferSize , u . WriteBufferPool , br , writeBuf )
c . subprotocol = subprotocol
if compress {
c . newCompressionWriter = compressNoContextTakeover
c . newDecompressionReader = decompressNoContextTakeover
}
// Use larger of hijacked buffer and connection write buffer for header.
p := buf
if len ( c . writeBuf ) > len ( p ) {
p = c . writeBuf
}
p = p [ : 0 ]
p = append ( p , "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: " ... )
p = append ( p , computeAcceptKey ( challengeKey ) ... )
p = append ( p , "\r\n" ... )
if c . subprotocol != "" {
p = append ( p , "Sec-WebSocket-Protocol: " ... )
p = append ( p , c . subprotocol ... )
p = append ( p , "\r\n" ... )
}
if compress {
p = append ( p , "Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n" ... )
}
for k , vs := range responseHeader {
if k == "Sec-Websocket-Protocol" {
continue
}
for _ , v := range vs {
p = append ( p , k ... )
p = append ( p , ": " ... )
for i := 0 ; i < len ( v ) ; i ++ {
b := v [ i ]
if b <= 31 {
// prevent response splitting.
b = ' '
}
p = append ( p , b )
}
p = append ( p , "\r\n" ... )
}
}
p = append ( p , "\r\n" ... )
// Clear deadlines set by HTTP server.
2024-06-10 07:43:38 +00:00
netConn . SetDeadline ( time . Time { } )
2021-08-12 21:03:24 +02:00
if u . HandshakeTimeout > 0 {
2024-06-10 07:43:38 +00:00
netConn . SetWriteDeadline ( time . Now ( ) . Add ( u . HandshakeTimeout ) )
2021-08-12 21:03:24 +02:00
}
if _ , err = netConn . Write ( p ) ; err != nil {
2024-06-10 07:43:38 +00:00
netConn . Close ( )
2021-08-12 21:03:24 +02:00
return nil , err
}
if u . HandshakeTimeout > 0 {
2024-06-10 07:43:38 +00:00
netConn . SetWriteDeadline ( time . Time { } )
2021-08-12 21:03:24 +02:00
}
return c , nil
}
// Upgrade upgrades the HTTP server connection to the WebSocket protocol.
//
// Deprecated: Use websocket.Upgrader instead.
//
// Upgrade does not perform origin checking. The application is responsible for
// checking the Origin header before calling Upgrade. An example implementation
// of the same origin policy check is:
//
// if req.Header.Get("Origin") != "http://"+req.Host {
// http.Error(w, "Origin not allowed", http.StatusForbidden)
// return
// }
//
// If the endpoint supports subprotocols, then the application is responsible
// for negotiating the protocol used on the connection. Use the Subprotocols()
// function to get the subprotocols requested by the client. Use the
// Sec-Websocket-Protocol response header to specify the subprotocol selected
// by the application.
//
// The responseHeader is included in the response to the client's upgrade
// request. Use the responseHeader to specify cookies (Set-Cookie) and the
// negotiated subprotocol (Sec-Websocket-Protocol).
//
// The connection buffers IO to the underlying network connection. The
// readBufSize and writeBufSize parameters specify the size of the buffers to
// use. Messages can be larger than the buffers.
//
// If the request is not a valid WebSocket handshake, then Upgrade returns an
// error of type HandshakeError. Applications should handle this error by
// replying to the client with an HTTP error response.
func Upgrade ( w http . ResponseWriter , r * http . Request , responseHeader http . Header , readBufSize , writeBufSize int ) ( * Conn , error ) {
u := Upgrader { ReadBufferSize : readBufSize , WriteBufferSize : writeBufSize }
u . Error = func ( w http . ResponseWriter , r * http . Request , status int , reason error ) {
// don't return errors to maintain backwards compatibility
}
u . CheckOrigin = func ( r * http . Request ) bool {
// allow all connections by default
return true
}
return u . Upgrade ( w , r , responseHeader )
}
// Subprotocols returns the subprotocols requested by the client in the
// Sec-Websocket-Protocol header.
func Subprotocols ( r * http . Request ) [ ] string {
h := strings . TrimSpace ( r . Header . Get ( "Sec-Websocket-Protocol" ) )
if h == "" {
return nil
}
protocols := strings . Split ( h , "," )
for i := range protocols {
protocols [ i ] = strings . TrimSpace ( protocols [ i ] )
}
return protocols
}
// IsWebSocketUpgrade returns true if the client requested upgrade to the
// WebSocket protocol.
func IsWebSocketUpgrade ( r * http . Request ) bool {
return tokenListContainsValue ( r . Header , "Connection" , "upgrade" ) &&
tokenListContainsValue ( r . Header , "Upgrade" , "websocket" )
}
// bufioReaderSize size returns the size of a bufio.Reader.
func bufioReaderSize ( originalReader io . Reader , br * bufio . Reader ) int {
// This code assumes that peek on a reset reader returns
// bufio.Reader.buf[:0].
// TODO: Use bufio.Reader.Size() after Go 1.10
br . Reset ( originalReader )
if p , err := br . Peek ( 0 ) ; err == nil {
return cap ( p )
}
return 0
}
// writeHook is an io.Writer that records the last slice passed to it vio
// io.Writer.Write.
type writeHook struct {
p [ ] byte
}
func ( wh * writeHook ) Write ( p [ ] byte ) ( int , error ) {
wh . p = p
return len ( p ) , nil
}
// bufioWriterBuffer grabs the buffer from a bufio.Writer.
func bufioWriterBuffer ( originalWriter io . Writer , bw * bufio . Writer ) [ ] byte {
// This code assumes that bufio.Writer.buf[:1] is passed to the
// bufio.Writer's underlying writer.
var wh writeHook
bw . Reset ( & wh )
2024-06-10 07:43:38 +00:00
bw . WriteByte ( 0 )
bw . Flush ( )
2021-08-12 21:03:24 +02:00
bw . Reset ( originalWriter )
return wh . p [ : cap ( wh . p ) ]
}