diff --git a/server.go b/server.go index 35e534e..86f770a 100644 --- a/server.go +++ b/server.go @@ -1,9 +1,14 @@ package transocks import ( + "bufio" + "bytes" "context" + "crypto/tls" "io" "net" + "net/http" + "strings" "sync" "time" @@ -103,6 +108,45 @@ func (s *Server) handleConnection(ctx context.Context, conn net.Conn) { default: addr = tc.LocalAddr().String() } + + var reader io.Reader = tc + + // Check if TLS + isTLS, reader_n, err := peekSSL(tc) + if err != nil { + fields[log.FnError] = err.Error() + s.logger.Error("peekSSL failed", fields) + return + } + reader = reader_n + fields["is_tls"] = isTLS + + if isTLS { + // Peek ClientHello message from conn and returns SNI. + hello, reader_n2, err := peekClientHello(reader) + if err != nil { + fields[log.FnError] = err.Error() + s.logger.Error("peekClientHello failed", fields) + return + } + if err == nil && hello.ServerName != "" { + addr = hello.ServerName + addr[strings.Index(addr, ":"):] + } + reader = reader_n2 + } else { + // Get HOST Header if http + host, reader_n3, err := peekHTTP(reader) + if err != nil { + fields[log.FnError] = err.Error() + s.logger.Error("peekHTTP failed", fields) + return + } + if err == nil && host != "" { + addr = host + addr[strings.Index(addr, ":"):] + } + reader = reader_n3 + } + fields["dest_addr"] = addr destConn, err := s.dialer.Dial("tcp", addr) @@ -120,7 +164,7 @@ func (s *Server) handleConnection(ctx context.Context, conn net.Conn) { env := well.NewEnvironment(ctx) env.Go(func(ctx context.Context) error { buf := s.pool.Get().([]byte) - _, err := io.CopyBuffer(destConn, tc, buf) + _, err := io.CopyBuffer(destConn, reader, buf) s.pool.Put(buf) if hc, ok := destConn.(netutil.HalfCloser); ok { hc.CloseWrite() @@ -150,3 +194,83 @@ func (s *Server) handleConnection(ctx context.Context, conn net.Conn) { } s.logger.Info("proxy ends", fields) } + +// Peek ClientHello message from conn and returns SNI. +func peekClientHello(reader io.Reader) (*tls.ClientHelloInfo, io.Reader, error) { + peekedBytes := new(bytes.Buffer) + hello, err := readClientHello(io.TeeReader(reader, peekedBytes)) + if err != nil { + return nil, nil, err + } + return hello, io.MultiReader(peekedBytes, reader), nil +} + +func readClientHello(reader io.Reader) (*tls.ClientHelloInfo, error) { + var hello *tls.ClientHelloInfo + + err := tls.Server(readOnlyConn{reader: reader}, &tls.Config{ + GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) { + hello = new(tls.ClientHelloInfo) + *hello = *argHello + return nil, nil + }, + }).Handshake() // Handshake() always returns error, but we can get ClientHelloInfo from GetConfigForClient. + + if hello == nil { + return nil, err + } + + return hello, nil +} + +type readOnlyConn struct { + reader io.Reader +} + +func (conn readOnlyConn) Read(p []byte) (int, error) { return conn.reader.Read(p) } +func (conn readOnlyConn) Write(p []byte) (int, error) { return 0, io.ErrClosedPipe } +func (conn readOnlyConn) Close() error { return nil } +func (conn readOnlyConn) LocalAddr() net.Addr { return nil } +func (conn readOnlyConn) RemoteAddr() net.Addr { return nil } +func (conn readOnlyConn) SetDeadline(t time.Time) error { return nil } +func (conn readOnlyConn) SetReadDeadline(t time.Time) error { return nil } +func (conn readOnlyConn) SetWriteDeadline(t time.Time) error { return nil } + +// Check if tcp connection is SSL/TLS. Leave all bytes untouched by using TeeReader. +// Peek ClientHello message from conn and returns SNI. +func peekSSL(reader io.Reader) (bool, io.Reader, error) { + peekedBytes := new(bytes.Buffer) + isTLS, err := isTLS(io.TeeReader(reader, peekedBytes)) + if err != nil { + return false, nil, err + } + return isTLS, io.MultiReader(peekedBytes, reader), nil +} + +func isTLS(reader io.Reader) (bool, error) { + buf := make([]byte, 1) + _, err := reader.Read(buf) + if err != nil { + return false, err + } + return buf[0] == 0x16, nil +} + +// Get HOST Header if http. Leave all bytes untouched by using TeeReader. +func peekHTTP(reader io.Reader) (string, io.Reader, error) { + peekedBytes := new(bytes.Buffer) + host, err := getHost(io.TeeReader(reader, peekedBytes)) + if err != nil { + return "", nil, err + } + return host, io.MultiReader(peekedBytes, reader), nil +} + +// Return the HOST from http headers. +func getHost(reader io.Reader) (string, error) { + req, err := http.ReadRequest(bufio.NewReader(reader)) + if err != nil { + return "", err + } + return req.Host, nil +}