transocks/server.go

283 lines
6.9 KiB
Go
Raw Permalink Normal View History

2016-03-04 07:54:59 +06:00
package transocks
import (
2023-02-28 23:31:06 +07:00
"bufio"
"bytes"
"context"
2023-02-28 23:31:06 +07:00
"crypto/tls"
2016-03-04 07:54:59 +06:00
"io"
"net"
2023-02-28 23:31:06 +07:00
"net/http"
"strings"
"sync"
2016-03-04 07:54:59 +06:00
"time"
"github.com/cybozu-go/log"
"github.com/cybozu-go/netutil"
"github.com/cybozu-go/well"
2016-03-04 07:54:59 +06:00
"golang.org/x/net/proxy"
)
const (
keepAliveTimeout = 3 * time.Minute
copyBufferSize = 64 << 10
2016-03-04 07:54:59 +06:00
)
// Listeners returns a list of net.Listener.
func Listeners(c *Config) ([]net.Listener, error) {
ln, err := net.Listen("tcp", c.Addr)
if err != nil {
return nil, err
}
return []net.Listener{ln}, nil
}
2016-03-04 07:54:59 +06:00
// Server provides transparent proxy server functions.
type Server struct {
well.Server
mode Mode
logger *log.Logger
dialer proxy.Dialer
pool sync.Pool
2016-03-04 07:54:59 +06:00
}
// NewServer creates Server.
// If c is not valid, this returns non-nil error.
func NewServer(c *Config) (*Server, error) {
if err := c.validate(); err != nil {
return nil, err
}
dialer := c.Dialer
if dialer == nil {
dialer = &net.Dialer{
KeepAlive: keepAliveTimeout,
DualStack: true,
}
2016-03-04 07:54:59 +06:00
}
pdialer, err := proxy.FromURL(c.ProxyURL, dialer)
2016-03-04 07:54:59 +06:00
if err != nil {
return nil, err
}
logger := c.Logger
if logger == nil {
logger = log.DefaultLogger()
2016-03-04 07:54:59 +06:00
}
s := &Server{
Server: well.Server{
ShutdownTimeout: c.ShutdownTimeout,
Env: c.Env,
},
mode: c.Mode,
logger: logger,
dialer: pdialer,
pool: sync.Pool{
New: func() interface{} {
return make([]byte, copyBufferSize)
},
},
}
s.Server.Handler = s.handleConnection
return s, nil
2016-03-04 07:54:59 +06:00
}
func (s *Server) handleConnection(ctx context.Context, conn net.Conn) {
tc, ok := conn.(*net.TCPConn)
if !ok {
2024-11-01 22:43:08 +07:00
_ = s.logger.Error("non-TCP connection", map[string]interface{}{
"conn": conn,
})
return
2016-03-04 07:54:59 +06:00
}
fields := well.FieldsFromContext(ctx)
fields[log.FnType] = "access"
fields["client_addr"] = conn.RemoteAddr().String()
2016-03-04 07:54:59 +06:00
var addr string
switch s.mode {
2016-03-04 07:54:59 +06:00
case ModeNAT:
origAddr, err := GetOriginalDST(tc)
2016-03-04 07:54:59 +06:00
if err != nil {
fields[log.FnError] = err.Error()
2024-11-01 22:43:08 +07:00
_ = s.logger.Error("GetOriginalDST failed", fields)
2016-03-04 07:54:59 +06:00
return
}
addr = origAddr.String()
2016-03-04 07:54:59 +06:00
default:
addr = tc.LocalAddr().String()
2016-03-04 07:54:59 +06:00
}
2023-02-28 23:31:06 +07:00
2024-11-01 22:43:08 +07:00
var (
reader io.Reader = tc
isTLS bool
reader_n io.Reader
)
2023-02-28 23:31:06 +07:00
// Check if TLS
2024-11-01 22:43:08 +07:00
if isTLSloc, reader_nloc, err := peekSSL(tc); err != nil {
2023-02-28 23:31:06 +07:00
fields[log.FnError] = err.Error()
2024-11-01 22:43:08 +07:00
_ = s.logger.Error("peekSSL failed", fields)
2023-02-28 23:31:06 +07:00
return
2024-11-01 22:43:08 +07:00
} else {
isTLS = isTLSloc
reader_n = reader_nloc
2023-02-28 23:31:06 +07:00
}
reader = reader_n
fields["is_tls"] = isTLS
if isTLS {
// Peek ClientHello message from conn and returns SNI.
hello, reader_n2, err := peekClientHello(reader)
if err != nil {
fields[log.FnError] = err.Error()
2024-11-01 22:43:08 +07:00
_ = s.logger.Warn("peekClientHello failed", fields)
2023-02-28 23:31:06 +07:00
}
if err == nil && hello.ServerName != "" {
addr = hello.ServerName + addr[strings.Index(addr, ":"):]
}
reader = reader_n2
} else {
// Get HOST Header if http
host, reader_n3, err := peekHTTP(reader)
if err != nil {
fields[log.FnError] = err.Error()
2024-11-01 22:43:08 +07:00
_ = s.logger.Warn("peekHTTP failed", fields)
} else if host != "" {
if strings.Contains(host, ":") {
host = host[:strings.Index(host, ":")]
}
addr = host + addr[strings.Index(addr, ":"):]
2023-02-28 23:31:06 +07:00
}
reader = reader_n3
}
fields["dest_addr"] = addr
2016-03-04 07:54:59 +06:00
destConn, err := s.dialer.Dial("tcp", addr)
2016-03-04 07:54:59 +06:00
if err != nil {
fields[log.FnError] = err.Error()
2024-11-01 22:43:08 +07:00
_ = s.logger.Error("failed to connect to proxy server", fields)
2016-03-04 07:54:59 +06:00
return
}
2024-11-01 22:43:08 +07:00
defer func() { _ = destConn.Close() }()
2024-11-01 22:43:08 +07:00
_ = s.logger.Info("proxy starts", fields)
// do proxy
st := time.Now()
env := well.NewEnvironment(ctx)
env.Go(func(ctx context.Context) error {
buf := s.pool.Get().([]byte)
2024-11-01 22:43:08 +07:00
_, err = io.CopyBuffer(destConn, reader, buf)
s.pool.Put(&buf)
if hc, ok := destConn.(netutil.HalfCloser); ok {
2024-11-01 22:43:08 +07:00
_ = hc.CloseWrite()
2016-03-04 07:54:59 +06:00
}
2024-11-01 22:43:08 +07:00
_ = tc.CloseRead()
return err
})
env.Go(func(ctx context.Context) error {
buf := s.pool.Get().([]byte)
2024-11-01 22:43:08 +07:00
_, err = io.CopyBuffer(tc, destConn, buf)
s.pool.Put(&buf)
_ = tc.CloseWrite()
if hc, ok := destConn.(netutil.HalfCloser); ok {
2024-11-01 22:43:08 +07:00
_ = hc.CloseRead()
}
return err
})
env.Stop()
err = env.Wait()
2016-03-04 07:54:59 +06:00
fields = well.FieldsFromContext(ctx)
fields["elapsed"] = time.Since(st).Seconds()
if err != nil {
fields[log.FnError] = err.Error()
2024-11-01 22:43:08 +07:00
_ = s.logger.Error("proxy ends with an error", fields)
return
2016-03-04 07:54:59 +06:00
}
2024-11-01 22:43:08 +07:00
_ = s.logger.Info("proxy ends", fields)
2016-03-04 07:54:59 +06:00
}
2023-02-28 23:31:06 +07:00
// Peek ClientHello message from conn and returns SNI.
func peekClientHello(reader io.Reader) (*tls.ClientHelloInfo, io.Reader, error) {
peekedBytes := new(bytes.Buffer)
hello, err := readClientHello(io.TeeReader(reader, peekedBytes))
if err != nil {
return nil, nil, err
}
return hello, io.MultiReader(peekedBytes, reader), nil
}
func readClientHello(reader io.Reader) (*tls.ClientHelloInfo, error) {
var hello *tls.ClientHelloInfo
err := tls.Server(readOnlyConn{reader: reader}, &tls.Config{
GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) {
hello = new(tls.ClientHelloInfo)
*hello = *argHello
return nil, nil
},
}).Handshake() // Handshake() always returns error, but we can get ClientHelloInfo from GetConfigForClient.
if hello == nil {
return nil, err
}
return hello, nil
}
type readOnlyConn struct {
reader io.Reader
}
func (conn readOnlyConn) Read(p []byte) (int, error) { return conn.reader.Read(p) }
func (conn readOnlyConn) Write(p []byte) (int, error) { return 0, io.ErrClosedPipe }
func (conn readOnlyConn) Close() error { return nil }
func (conn readOnlyConn) LocalAddr() net.Addr { return nil }
func (conn readOnlyConn) RemoteAddr() net.Addr { return nil }
func (conn readOnlyConn) SetDeadline(t time.Time) error { return nil }
func (conn readOnlyConn) SetReadDeadline(t time.Time) error { return nil }
func (conn readOnlyConn) SetWriteDeadline(t time.Time) error { return nil }
// Check if tcp connection is SSL/TLS. Leave all bytes untouched by using TeeReader.
// Peek ClientHello message from conn and returns SNI.
func peekSSL(reader io.Reader) (bool, io.Reader, error) {
peekedBytes := new(bytes.Buffer)
isTLS, err := isTLS(io.TeeReader(reader, peekedBytes))
if err != nil {
return false, io.MultiReader(peekedBytes, reader), err
2023-02-28 23:31:06 +07:00
}
return isTLS, io.MultiReader(peekedBytes, reader), nil
}
func isTLS(reader io.Reader) (bool, error) {
buf := make([]byte, 1)
_, err := reader.Read(buf)
if err != nil {
return false, err
}
return buf[0] == 0x16, nil
}
// Get HOST Header if http. Leave all bytes untouched by using TeeReader.
func peekHTTP(reader io.Reader) (string, io.Reader, error) {
peekedBytes := new(bytes.Buffer)
host, err := getHost(io.TeeReader(reader, peekedBytes))
if err != nil {
return "", io.MultiReader(peekedBytes, reader), err
2023-02-28 23:31:06 +07:00
}
return host, io.MultiReader(peekedBytes, reader), nil
}
// Return the HOST from http headers.
func getHost(reader io.Reader) (string, error) {
req, err := http.ReadRequest(bufio.NewReader(reader))
if err != nil {
return "", err
}
return req.Host, nil
}