diff --git a/Dockerfile b/Dockerfile index 08486af19..1fd5e8278 100644 --- a/Dockerfile +++ b/Dockerfile @@ -100,6 +100,7 @@ ENV VPN_SERVICE_PROVIDER=pia \ WIREGUARD_PRESHARED_KEY_SECRETFILE=/run/secrets/wireguard_preshared_key \ WIREGUARD_PUBLIC_KEY= \ WIREGUARD_ALLOWED_IPS= \ + WIREGUARD_PERSISTENT_KEEPALIVE_INTERVAL=0 \ WIREGUARD_ADDRESSES= \ WIREGUARD_ADDRESSES_SECRETFILE=/run/secrets/wireguard_addresses \ WIREGUARD_MTU=1400 \ diff --git a/internal/configuration/settings/errors.go b/internal/configuration/settings/errors.go index 73c61b840..ad60b02cd 100644 --- a/internal/configuration/settings/errors.go +++ b/internal/configuration/settings/errors.go @@ -50,5 +50,6 @@ var ( ErrWireguardPrivateKeyNotSet = errors.New("private key is not set") ErrWireguardPublicKeyNotSet = errors.New("public key is not set") ErrWireguardPublicKeyNotValid = errors.New("public key is not valid") + ErrWireguardKeepAliveNegative = errors.New("persistent keep alive interval is negative") ErrWireguardImplementationNotValid = errors.New("implementation is not valid") ) diff --git a/internal/configuration/settings/wireguard.go b/internal/configuration/settings/wireguard.go index 180086987..d27fc0423 100644 --- a/internal/configuration/settings/wireguard.go +++ b/internal/configuration/settings/wireguard.go @@ -5,6 +5,7 @@ import ( "net/netip" "regexp" "strings" + "time" "github.com/qdm12/gluetun/internal/configuration/settings/helpers" "github.com/qdm12/gluetun/internal/constants/providers" @@ -34,7 +35,8 @@ type Wireguard struct { // Interface is the name of the Wireguard interface // to create. It cannot be the empty string in the // internal state. - Interface string `json:"interface"` + Interface string `json:"interface"` + PersistentKeepaliveInterval *time.Duration `json:"persistent_keep_alive_interval"` // Maximum Transmission Unit (MTU) of the Wireguard interface. // It cannot be zero in the internal state, and defaults to // 1400. Note it is not the wireguard-go MTU default of 1420 @@ -123,6 +125,11 @@ func (w Wireguard) validate(vpnProvider string, ipv6Supported bool) (err error) } } + if *w.PersistentKeepaliveInterval < 0 { + return fmt.Errorf("%w: %s", ErrWireguardKeepAliveNegative, + *w.PersistentKeepaliveInterval) + } + // Validate interface if !regexpInterfaceName.MatchString(w.Interface) { return fmt.Errorf("%w: '%s' does not match regex '%s'", @@ -139,13 +146,14 @@ func (w Wireguard) validate(vpnProvider string, ipv6Supported bool) (err error) func (w *Wireguard) copy() (copied Wireguard) { return Wireguard{ - PrivateKey: gosettings.CopyPointer(w.PrivateKey), - PreSharedKey: gosettings.CopyPointer(w.PreSharedKey), - Addresses: gosettings.CopySlice(w.Addresses), - AllowedIPs: gosettings.CopySlice(w.AllowedIPs), - Interface: w.Interface, - MTU: w.MTU, - Implementation: w.Implementation, + PrivateKey: gosettings.CopyPointer(w.PrivateKey), + PreSharedKey: gosettings.CopyPointer(w.PreSharedKey), + Addresses: gosettings.CopySlice(w.Addresses), + AllowedIPs: gosettings.CopySlice(w.AllowedIPs), + PersistentKeepaliveInterval: gosettings.CopyPointer(w.PersistentKeepaliveInterval), + Interface: w.Interface, + MTU: w.MTU, + Implementation: w.Implementation, } } @@ -154,6 +162,8 @@ func (w *Wireguard) overrideWith(other Wireguard) { w.PreSharedKey = gosettings.OverrideWithPointer(w.PreSharedKey, other.PreSharedKey) w.Addresses = gosettings.OverrideWithSlice(w.Addresses, other.Addresses) w.AllowedIPs = gosettings.OverrideWithSlice(w.AllowedIPs, other.AllowedIPs) + w.PersistentKeepaliveInterval = gosettings.OverrideWithPointer(w.PersistentKeepaliveInterval, + other.PersistentKeepaliveInterval) w.Interface = gosettings.OverrideWithComparable(w.Interface, other.Interface) w.MTU = gosettings.OverrideWithComparable(w.MTU, other.MTU) w.Implementation = gosettings.OverrideWithComparable(w.Implementation, other.Implementation) @@ -172,6 +182,7 @@ func (w *Wireguard) setDefaults(vpnProvider string) { netip.PrefixFrom(netip.IPv6Unspecified(), 0), } w.AllowedIPs = gosettings.DefaultSlice(w.AllowedIPs, defaultAllowedIPs) + w.PersistentKeepaliveInterval = gosettings.DefaultPointer(w.PersistentKeepaliveInterval, 0) w.Interface = gosettings.DefaultComparable(w.Interface, "wg0") const defaultMTU = 1400 w.MTU = gosettings.DefaultComparable(w.MTU, defaultMTU) @@ -205,6 +216,10 @@ func (w Wireguard) toLinesNode() (node *gotree.Node) { allowedIPsNode.Appendf(allowedIP.String()) } + if *w.PersistentKeepaliveInterval > 0 { + node.Appendf("Persistent keepalive interval: %s", w.PersistentKeepaliveInterval) + } + interfaceNode := node.Appendf("Network interface: %s", w.Interface) interfaceNode.Appendf("MTU: %d", w.MTU) @@ -241,6 +256,12 @@ func (w *Wireguard) read(r *reader.Reader) (err error) { if err != nil { return err // already wrapped } + + w.PersistentKeepaliveInterval, err = r.DurationPtr("WIREGUARD_PERSISTENT_KEEPALIVE_INTERVAL") + if err != nil { + return err + } + mtuPtr, err := r.Uint16Ptr("WIREGUARD_MTU") if err != nil { return err diff --git a/internal/provider/utils/wireguard.go b/internal/provider/utils/wireguard.go index 6f1a1f3c6..b016ce1a6 100644 --- a/internal/provider/utils/wireguard.go +++ b/internal/provider/utils/wireguard.go @@ -40,5 +40,7 @@ func BuildWireguardSettings(connection models.Connection, settings.AllowedIPs = append(settings.AllowedIPs, allowedIP) } + settings.PersistentKeepaliveInterval = *userSettings.PersistentKeepaliveInterval + return settings } diff --git a/internal/provider/utils/wireguard_test.go b/internal/provider/utils/wireguard_test.go index 7d0aa71c2..3b79ea619 100644 --- a/internal/provider/utils/wireguard_test.go +++ b/internal/provider/utils/wireguard_test.go @@ -3,6 +3,7 @@ package utils import ( "net/netip" "testing" + "time" "github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/models" @@ -10,7 +11,7 @@ import ( "github.com/stretchr/testify/assert" ) -func stringPtr(s string) *string { return &s } +func ptrTo[T any](x T) *T { return &x } func Test_BuildWireguardSettings(t *testing.T) { t.Parallel() @@ -28,8 +29,8 @@ func Test_BuildWireguardSettings(t *testing.T) { PubKey: "public", }, userSettings: settings.Wireguard{ - PrivateKey: stringPtr("private"), - PreSharedKey: stringPtr("pre-shared"), + PrivateKey: ptrTo("private"), + PreSharedKey: ptrTo("pre-shared"), Addresses: []netip.Prefix{ netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32), netip.PrefixFrom(netip.AddrFrom16([16]byte{}), 32), @@ -38,7 +39,8 @@ func Test_BuildWireguardSettings(t *testing.T) { netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32), netip.PrefixFrom(netip.AddrFrom16([16]byte{}), 32), }, - Interface: "wg1", + PersistentKeepaliveInterval: ptrTo(time.Hour), + Interface: "wg1", }, ipv6Supported: false, settings: wireguard.Settings{ @@ -53,8 +55,9 @@ func Test_BuildWireguardSettings(t *testing.T) { AllowedIPs: []netip.Prefix{ netip.PrefixFrom(netip.AddrFrom4([4]byte{2, 2, 2, 2}), 32), }, - RulePriority: 101, - IPv6: boolPtr(false), + PersistentKeepaliveInterval: time.Hour, + RulePriority: 101, + IPv6: boolPtr(false), }, }, } diff --git a/internal/wireguard/config.go b/internal/wireguard/config.go index f351ed593..189bb134a 100644 --- a/internal/wireguard/config.go +++ b/internal/wireguard/config.go @@ -4,6 +4,7 @@ import ( "fmt" "net" "net/netip" + "time" "golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" @@ -43,6 +44,12 @@ func makeDeviceConfig(settings Settings) (config wgtypes.Config, err error) { preSharedKey = &preSharedKeyValue } + var persistentKeepaliveInterval *time.Duration + if settings.PersistentKeepaliveInterval > 0 { + persistentKeepaliveInterval = new(time.Duration) + *persistentKeepaliveInterval = settings.PersistentKeepaliveInterval + } + firewallMark := settings.FirewallMark config = wgtypes.Config{ @@ -63,7 +70,8 @@ func makeDeviceConfig(settings Settings) (config wgtypes.Config, err error) { Mask: []byte(net.IPv6zero), }, }, - ReplaceAllowedIPs: true, + PersistentKeepaliveInterval: persistentKeepaliveInterval, + ReplaceAllowedIPs: true, Endpoint: &net.UDPAddr{ IP: settings.Endpoint.Addr().AsSlice(), Port: int(settings.Endpoint.Port()), diff --git a/internal/wireguard/settings.go b/internal/wireguard/settings.go index fcf659e98..32d2d328e 100644 --- a/internal/wireguard/settings.go +++ b/internal/wireguard/settings.go @@ -6,6 +6,7 @@ import ( "net/netip" "regexp" "strings" + "time" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" @@ -30,6 +31,8 @@ type Settings struct { // the Wireguard interface. // Note IPv6 addresses are ignored if IPv6 is not supported. AllowedIPs []netip.Prefix + // PersistentKeepaliveInterval defines the keep alive interval, if not zero. + PersistentKeepaliveInterval time.Duration // FirewallMark to be used in routing tables and IP rules. // It defaults to 51820 if left to 0. FirewallMark int @@ -99,6 +102,7 @@ var ( ErrAllowedIPsMissing = errors.New("allowed IPs are missing") ErrAllowedIPNotValid = errors.New("allowed IP is not valid") ErrAllowedIPv6NotSupported = errors.New("allowed IPv6 address not supported") + ErrKeepaliveIsNegative = errors.New("keep alive interval is negative") ErrFirewallMarkMissing = errors.New("firewall mark is missing") ErrMTUMissing = errors.New("MTU is missing") ErrImplementationInvalid = errors.New("invalid implementation") @@ -160,6 +164,11 @@ func (s *Settings) Check() (err error) { } } + if s.PersistentKeepaliveInterval < 0 { + return fmt.Errorf("%w: %s", ErrKeepaliveIsNegative, + s.PersistentKeepaliveInterval) + } + if s.FirewallMark == 0 { return fmt.Errorf("%w", ErrFirewallMarkMissing) } @@ -286,5 +295,10 @@ func (s Settings) ToLines(settings ToLinesSettings) (lines []string) { } } + if s.PersistentKeepaliveInterval > 0 { + lines = append(lines, fieldPrefix+"Persistent keep alive interval: "+ + s.PersistentKeepaliveInterval.String()) + } + return lines }