forked from ms/transocks
peek SNI or http header
This commit is contained in:
parent
2198aaeb4d
commit
0ab2f5ce50
126
server.go
126
server.go
@ -1,9 +1,14 @@
|
||||
package transocks
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@ -103,6 +108,45 @@ func (s *Server) handleConnection(ctx context.Context, conn net.Conn) {
|
||||
default:
|
||||
addr = tc.LocalAddr().String()
|
||||
}
|
||||
|
||||
var reader io.Reader = tc
|
||||
|
||||
// Check if TLS
|
||||
isTLS, reader_n, err := peekSSL(tc)
|
||||
if err != nil {
|
||||
fields[log.FnError] = err.Error()
|
||||
s.logger.Error("peekSSL failed", fields)
|
||||
return
|
||||
}
|
||||
reader = reader_n
|
||||
fields["is_tls"] = isTLS
|
||||
|
||||
if isTLS {
|
||||
// Peek ClientHello message from conn and returns SNI.
|
||||
hello, reader_n2, err := peekClientHello(reader)
|
||||
if err != nil {
|
||||
fields[log.FnError] = err.Error()
|
||||
s.logger.Error("peekClientHello failed", fields)
|
||||
return
|
||||
}
|
||||
if err == nil && hello.ServerName != "" {
|
||||
addr = hello.ServerName + addr[strings.Index(addr, ":"):]
|
||||
}
|
||||
reader = reader_n2
|
||||
} else {
|
||||
// Get HOST Header if http
|
||||
host, reader_n3, err := peekHTTP(reader)
|
||||
if err != nil {
|
||||
fields[log.FnError] = err.Error()
|
||||
s.logger.Error("peekHTTP failed", fields)
|
||||
return
|
||||
}
|
||||
if err == nil && host != "" {
|
||||
addr = host + addr[strings.Index(addr, ":"):]
|
||||
}
|
||||
reader = reader_n3
|
||||
}
|
||||
|
||||
fields["dest_addr"] = addr
|
||||
|
||||
destConn, err := s.dialer.Dial("tcp", addr)
|
||||
@ -120,7 +164,7 @@ func (s *Server) handleConnection(ctx context.Context, conn net.Conn) {
|
||||
env := well.NewEnvironment(ctx)
|
||||
env.Go(func(ctx context.Context) error {
|
||||
buf := s.pool.Get().([]byte)
|
||||
_, err := io.CopyBuffer(destConn, tc, buf)
|
||||
_, err := io.CopyBuffer(destConn, reader, buf)
|
||||
s.pool.Put(buf)
|
||||
if hc, ok := destConn.(netutil.HalfCloser); ok {
|
||||
hc.CloseWrite()
|
||||
@ -150,3 +194,83 @@ func (s *Server) handleConnection(ctx context.Context, conn net.Conn) {
|
||||
}
|
||||
s.logger.Info("proxy ends", fields)
|
||||
}
|
||||
|
||||
// Peek ClientHello message from conn and returns SNI.
|
||||
func peekClientHello(reader io.Reader) (*tls.ClientHelloInfo, io.Reader, error) {
|
||||
peekedBytes := new(bytes.Buffer)
|
||||
hello, err := readClientHello(io.TeeReader(reader, peekedBytes))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return hello, io.MultiReader(peekedBytes, reader), nil
|
||||
}
|
||||
|
||||
func readClientHello(reader io.Reader) (*tls.ClientHelloInfo, error) {
|
||||
var hello *tls.ClientHelloInfo
|
||||
|
||||
err := tls.Server(readOnlyConn{reader: reader}, &tls.Config{
|
||||
GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||
hello = new(tls.ClientHelloInfo)
|
||||
*hello = *argHello
|
||||
return nil, nil
|
||||
},
|
||||
}).Handshake() // Handshake() always returns error, but we can get ClientHelloInfo from GetConfigForClient.
|
||||
|
||||
if hello == nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return hello, nil
|
||||
}
|
||||
|
||||
type readOnlyConn struct {
|
||||
reader io.Reader
|
||||
}
|
||||
|
||||
func (conn readOnlyConn) Read(p []byte) (int, error) { return conn.reader.Read(p) }
|
||||
func (conn readOnlyConn) Write(p []byte) (int, error) { return 0, io.ErrClosedPipe }
|
||||
func (conn readOnlyConn) Close() error { return nil }
|
||||
func (conn readOnlyConn) LocalAddr() net.Addr { return nil }
|
||||
func (conn readOnlyConn) RemoteAddr() net.Addr { return nil }
|
||||
func (conn readOnlyConn) SetDeadline(t time.Time) error { return nil }
|
||||
func (conn readOnlyConn) SetReadDeadline(t time.Time) error { return nil }
|
||||
func (conn readOnlyConn) SetWriteDeadline(t time.Time) error { return nil }
|
||||
|
||||
// Check if tcp connection is SSL/TLS. Leave all bytes untouched by using TeeReader.
|
||||
// Peek ClientHello message from conn and returns SNI.
|
||||
func peekSSL(reader io.Reader) (bool, io.Reader, error) {
|
||||
peekedBytes := new(bytes.Buffer)
|
||||
isTLS, err := isTLS(io.TeeReader(reader, peekedBytes))
|
||||
if err != nil {
|
||||
return false, nil, err
|
||||
}
|
||||
return isTLS, io.MultiReader(peekedBytes, reader), nil
|
||||
}
|
||||
|
||||
func isTLS(reader io.Reader) (bool, error) {
|
||||
buf := make([]byte, 1)
|
||||
_, err := reader.Read(buf)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return buf[0] == 0x16, nil
|
||||
}
|
||||
|
||||
// Get HOST Header if http. Leave all bytes untouched by using TeeReader.
|
||||
func peekHTTP(reader io.Reader) (string, io.Reader, error) {
|
||||
peekedBytes := new(bytes.Buffer)
|
||||
host, err := getHost(io.TeeReader(reader, peekedBytes))
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
return host, io.MultiReader(peekedBytes, reader), nil
|
||||
}
|
||||
|
||||
// Return the HOST from http headers.
|
||||
func getHost(reader io.Reader) (string, error) {
|
||||
req, err := http.ReadRequest(bufio.NewReader(reader))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return req.Host, nil
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user