transocks/server.go

153 lines
3.1 KiB
Go
Raw Normal View History

2016-03-04 07:54:59 +06:00
package transocks
import (
"context"
2016-03-04 07:54:59 +06:00
"io"
"net"
"sync"
2016-03-04 07:54:59 +06:00
"time"
"github.com/cybozu-go/cmd"
2016-03-04 07:54:59 +06:00
"github.com/cybozu-go/log"
"github.com/cybozu-go/netutil"
2016-03-04 07:54:59 +06:00
"golang.org/x/net/proxy"
)
const (
keepAliveTimeout = 3 * time.Minute
copyBufferSize = 64 << 10
2016-03-04 07:54:59 +06:00
)
// Listeners returns a list of net.Listener.
func Listeners(c *Config) ([]net.Listener, error) {
ln, err := net.Listen("tcp", c.Addr)
if err != nil {
return nil, err
}
return []net.Listener{ln}, nil
}
2016-03-04 07:54:59 +06:00
// Server provides transparent proxy server functions.
type Server struct {
cmd.Server
mode Mode
logger *log.Logger
dialer proxy.Dialer
pool sync.Pool
2016-03-04 07:54:59 +06:00
}
// NewServer creates Server.
// If c is not valid, this returns non-nil error.
func NewServer(c *Config) (*Server, error) {
if err := c.validate(); err != nil {
return nil, err
}
dialer := c.Dialer
if dialer == nil {
dialer = &net.Dialer{
KeepAlive: keepAliveTimeout,
DualStack: true,
}
2016-03-04 07:54:59 +06:00
}
pdialer, err := proxy.FromURL(c.ProxyURL, dialer)
2016-03-04 07:54:59 +06:00
if err != nil {
return nil, err
}
logger := c.Logger
if logger == nil {
logger = log.DefaultLogger()
2016-03-04 07:54:59 +06:00
}
s := &Server{
Server: cmd.Server{
ShutdownTimeout: c.ShutdownTimeout,
Env: c.Env,
},
mode: c.Mode,
logger: logger,
dialer: pdialer,
pool: sync.Pool{
New: func() interface{} {
return make([]byte, copyBufferSize)
},
},
}
s.Server.Handler = s.handleConnection
return s, nil
2016-03-04 07:54:59 +06:00
}
func (s *Server) handleConnection(ctx context.Context, conn net.Conn) {
tc, ok := conn.(*net.TCPConn)
if !ok {
s.logger.Error("non-TCP connection", map[string]interface{}{
"conn": conn,
})
return
2016-03-04 07:54:59 +06:00
}
fields := cmd.FieldsFromContext(ctx)
fields[log.FnType] = "access"
fields["client_addr"] = conn.RemoteAddr().String()
2016-03-04 07:54:59 +06:00
var addr string
switch s.mode {
2016-03-04 07:54:59 +06:00
case ModeNAT:
origAddr, err := GetOriginalDST(tc)
2016-03-04 07:54:59 +06:00
if err != nil {
fields[log.FnError] = err.Error()
s.logger.Error("GetOriginalDST failed", fields)
2016-03-04 07:54:59 +06:00
return
}
addr = origAddr.String()
2016-03-04 07:54:59 +06:00
default:
addr = tc.LocalAddr().String()
2016-03-04 07:54:59 +06:00
}
fields["dest_addr"] = addr
2016-03-04 07:54:59 +06:00
destConn, err := s.dialer.Dial("tcp", addr)
2016-03-04 07:54:59 +06:00
if err != nil {
fields[log.FnError] = err.Error()
s.logger.Error("failed to connect to proxy server", fields)
2016-03-04 07:54:59 +06:00
return
}
defer destConn.Close()
s.logger.Info("proxy starts", fields)
// do proxy
st := time.Now()
env := cmd.NewEnvironment(ctx)
env.Go(func(ctx context.Context) error {
buf := s.pool.Get().([]byte)
_, err := io.CopyBuffer(destConn, tc, buf)
s.pool.Put(buf)
if hc, ok := destConn.(netutil.HalfCloser); ok {
hc.CloseWrite()
2016-03-04 07:54:59 +06:00
}
tc.CloseRead()
return err
})
env.Go(func(ctx context.Context) error {
buf := s.pool.Get().([]byte)
_, err := io.CopyBuffer(tc, destConn, buf)
s.pool.Put(buf)
tc.CloseWrite()
if hc, ok := destConn.(netutil.HalfCloser); ok {
hc.CloseRead()
}
return err
})
env.Stop()
err = env.Wait()
2016-03-04 07:54:59 +06:00
fields = cmd.FieldsFromContext(ctx)
fields["elapsed"] = time.Since(st).Seconds()
if err != nil {
fields[log.FnError] = err.Error()
s.logger.Error("proxy ends with an error", fields)
return
2016-03-04 07:54:59 +06:00
}
s.logger.Info("proxy ends", fields)
2016-03-04 07:54:59 +06:00
}