Skip to content

Commit

Permalink
Fix timeout handling, add timeout testing
Browse files Browse the repository at this point in the history
  • Loading branch information
mcg1969 committed Feb 2, 2025
1 parent 87a163a commit 1bea81a
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 85 deletions.
1 change: 1 addition & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ jobs:
env:
ANACONDA_ANON_USAGE_DEBUG: 1
ANACONDA_ANON_USAGE_RAISE: 1
PYTHONUNBUFFERED: 1
defaults:
run:
# https://github.com/conda-incubator/setup-miniconda#use-a-default-shell
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ repos:
rev: v3.19.1
hooks:
- id: pyupgrade
args: [--py36-plus]
args: [--py36-plus, --keep-percent-format]
- repo: https://github.com/PyCQA/isort
rev: 6.0.0
hooks:
Expand Down
31 changes: 18 additions & 13 deletions anaconda_anon_usage/heartbeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
"""

import argparse
import os
import sys
from threading import Thread
from urllib.parse import urljoin
Expand All @@ -20,13 +19,16 @@

VERBOSE = False
STANDALONE = False
DRY_RUN = os.environ.get("ANACONDA_HEARTBEAT_DRY_RUN")

CLD_REPO = "https://repo.anaconda.cloud/"
ORG_REPO = "https://conda.anaconda.org/"
COM_REPO = "https://repo.anaconda.com/pkgs/"
REPOS = (CLD_REPO, COM_REPO, ORG_REPO)
HEARTBEAT_PATH = "noarch/activate-0.0.0-0.conda"
# How long to attempt the connection. When a connection to our
# repository is blocked or slow, a long timeout would lead to
# a slow activation and a poor user experience.
TIMEOUT = 0.5


def _print(msg, *args, standalone=False, error=False):
Expand All @@ -45,15 +47,19 @@ def _print(msg, *args, standalone=False, error=False):

def _ping(session, url, wait):
try:
response = session.head(url, proxies=session.proxies)
# A short timeout is necessary here so that the activation
# is not unduly delayed by a blocked internet connection
timeout = TIMEOUT / context.remote_max_retries
response = session.head(url, proxies=session.proxies, timeout=timeout)
_print("Status code (expect 404): %s", response.status_code)
except Exception as exc:
if type(exc).__name__ != "ConnectionError":
_print("Heartbeat error: %s", exc, error=True)
_print("Unexpected heartbeat error: %s", exc, error=True)
elif "timeout=" in str(exc):
_print("Timeout exceeded; heartbeat likely not sent.")


def attempt_heartbeat(channel=None, path=None, wait=False):
global DRY_RUN
def attempt_heartbeat(channel=None, path=None, wait=False, dry_run=False):
line = "------------------------"
_print(line, standalone=True)
_print("anaconda-anon-usage heartbeat", standalone=True)
Expand Down Expand Up @@ -84,20 +90,20 @@ def attempt_heartbeat(channel=None, path=None, wait=False):

_print("Heartbeat url: %s", url)
_print("User agent: %s", context.user_agent)
if DRY_RUN:
if dry_run:
_print("Dry run selected, not sending heartbeat.")
else:
session = get_session(url)
t = Thread(target=_ping, args=(session, url, wait), daemon=True)
# Run in the background so we can proceed with the rest of the
# activation tasks while the request fires. The process will wait
# to terminate until the thread is complete.
t = Thread(target=_ping, args=(session, url, wait), daemon=False)
t.start()
_print("%saiting for response", "W" if wait else "Not w")
t.join(timeout=None if wait else 0.1)
_print(line, standalone=True)


def main():
global VERBOSE
global DRY_RUN
global STANDALONE
p = argparse.ArgumentParser()
p.add_argument("-c", "--channel", default=None)
Expand All @@ -108,8 +114,7 @@ def main():
args = p.parse_args()
STANDALONE = True
VERBOSE = not args.quiet
DRY_RUN = args.dry_run
attempt_heartbeat(args.channel, args.path, args.wait)
attempt_heartbeat(args.channel, args.path, args.wait, args.dry_run)


if __name__ == "__main__":
Expand Down
96 changes: 60 additions & 36 deletions tests/integration/proxy_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
--logfile, -l FILE Write logs to FILE instead of stdout
--port, -p PORT Listen on PORT (default: 8080)
--keep-certs Keep certificates in current directory
--delay TIME Emulate a connection delay of TIME seconds
--return-code, -r N Return status code N for all requests
--return-header H Add header H to responses (can repeat)
--return-data DATA Return DATA as response body
Expand All @@ -18,8 +19,8 @@
# Log all HTTPS requests to test.log:
./proxy_tester.py --logfile test.log -- curl https://httpbin.org/ip
# Return 404 for all requests:
./proxy_tester.py --return-code 404 -- python my_script.py
# Return 404 for all requests, but with a half-second delay:
./proxy_tester.py --return-code 404 --delay 0.5 -- python my_script.py
# Return custom response with headers and body:
./proxy_tester.py --return-code 200 \\
Expand Down Expand Up @@ -190,8 +191,6 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Connection counter
self.counter = 0
# Dictionary to store host certificates: hostname -> (cert_path, key_path)
self.host_certs = {}
# Lock for single-threaded operations
self.lock = Lock()
# Interception settings
Expand All @@ -203,27 +202,22 @@ def __init__(self, *args, **kwargs):

class ProxyHandler(BaseHTTPRequestHandler):

def __init__(self, *args, **kwargs):
def setup(self):
self.start_time = time.perf_counter()
self.cid = None
super().__init__(*args, **kwargs)
with self.server.lock:
self.server.counter += 1
self.cid = "%04d" % self.server.counter
super().setup()

def log_message(self, format, *args):
"""Override to prevent access log messages from appearing on stderr"""
pass

def _get_connection_id(self):
if self.cid is None:
with self.server.lock:
self.server.counter += 1
self.cid = "%04d" % self.server.counter
return self.cid

def _log(self, *args, **kwargs):
"""Log message with elapsed time since first message for this connection ID"""
level = kwargs.pop("level", "info")
delta = time.perf_counter() - self.start_time
fmt = CONNECTION_FORMAT % (self._get_connection_id(), delta, args[0])
fmt = CONNECTION_FORMAT % (self.cid, delta, args[0])
getattr(logger, level)(fmt, *args[1:], **kwargs)

def _multiline_log(
Expand Down Expand Up @@ -253,7 +247,7 @@ def _multiline_log(
lines.append("<+ %d bytes>" % len(blob))
blob = ""
if direction:
lines[0] = f"[{direction}] {lines[0]}"
lines[0] = "[%s] %s" % (direction, lines[0])
self._log("\n | ".join(lines))
return len(blob)

Expand All @@ -268,23 +262,29 @@ def do_CONNECT(self):

remote = None
client = None
error_code = 0
error_msg = None

try:
# Obtain MITM certificates for this host
with self.server.lock:
cert_file, key_file = read_or_create_cert(host)

if self.server.delay:
self._log("Adding %gs connection delay", self.server.delay)
time.sleep(self.server.delay)

# Establish tunnel
self.send_response(200, "Connection Established")
self._multiline_log(
b"".join(self._headers_buffer), direction="P->C", include_binary=True
)
self.end_headers()

# Create SSL context with the host certificate
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
context.load_cert_chain(cert_file, key_file)
client = context.wrap_socket(self.connection, server_side=True)
# Create SSL context for the client connection (MITM certificate)
client_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
client_context.load_cert_chain(cert_file, key_file)
client = client_context.wrap_socket(self.connection, server_side=True)
self._log("[C<>P] SSL handshake completed")

if self.server.intercept_mode:
Expand All @@ -307,21 +307,31 @@ def do_CONNECT(self):
self._multiline_log(response, direction="P->C", include_binary=True)
client.send(response)
else:
# In MITM mode, forward all requests to the real server
# Create SSL context for the server connection (verify remote)
remote = socket.create_connection((host, int(port)))
context = ssl.create_default_context()
remote = context.wrap_socket(remote, server_hostname=host)
server_context = ssl.create_default_context()
remote = server_context.wrap_socket(remote, server_hostname=host)
self._log("[P<>S] SSL handshake completed")
# Forward all requests to the real server
self._forward_data(client, remote)

except ssl.SSLError as ssl_err:
self._log("SSL error: %s", ssl_err, level="error")
error_code, error_msg = 502, "SSL Handshake Failed"
except OSError as sock_err:
self._log("Socket error: %s", sock_err, level="error")
error_code, error_msg = 504, "Gateway Timeout"
except Exception as exc:
self._log("CONNECT error: %s", exc, level="exception")
try:
self.send_error(502)
except Exception:
# If connection is already dead, sending error would raise socket.error
pass
self._log("CONNECT error: %s", exc, level="error")
error_code, error_msg = 502, "Proxy Error"
finally:
if error_code:
try:
self.send_error(error_code, error_msg)
except Exception:
# If connection is already dead, sending an
# error would raise socket.error
pass
self.close_connection = True
if remote:
remote.close()
Expand All @@ -337,16 +347,22 @@ def forward(source, destination, direction, bcount):
data = source.recv(BUFFER_SIZE)
if not data:
return False, bcount
if bcount == 0:
# First chunk contains headers; subsequent chunks may be binary
ncount = self._multiline_log(data, direction=direction)
bcount += ncount
else:
bcount += len(data)
except (OSError, ssl.SSLError) as exc:
self._log("%s: Receive error: %s", direction, exc, level="error")
return False, bcount

if bcount == 0:
# First chunk contains headers; subsequent chunks may be binary
ncount = self._multiline_log(data, direction=direction)
bcount += ncount
else:
bcount += len(data)

try:
destination.sendall(data)
return True, bcount
except Exception as exc:
self._log("%s: Forwarding error: %s", direction, exc, level="exception")
self._log("%s: Send error: %s", direction, exc, level="error")
return False, bcount

# Track binary data for each direction separately
Expand Down Expand Up @@ -394,6 +410,13 @@ def main():
default=8080,
help="Port for the proxy server (default: 8080)",
)
parser.add_argument(
"--delay",
type=float,
action="store",
default=0,
help="Add a delay, in seconds, to each connection request, to test connection issues.",
)
parser.add_argument(
"--keep-certs",
action="store_true",
Expand Down Expand Up @@ -442,6 +465,7 @@ def cleanup():

# Start and configure server
server = ThreadingHTTPServer(("0.0.0.0", args.port), ProxyHandler)
server.delay = max(0, args.delay)

# Enable interception if any response-related args are provided
if (
Expand Down
Loading

0 comments on commit 1bea81a

Please sign in to comment.