diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index a9c7ce2..9b3f660 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -44,6 +44,7 @@ jobs: env: ANACONDA_ANON_USAGE_DEBUG: 1 ANACONDA_ANON_USAGE_RAISE: 1 + PYTHONUNBUFFERED: 1 defaults: run: # https://github.com/conda-incubator/setup-miniconda#use-a-default-shell diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4517326..45080fd 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -20,7 +20,7 @@ repos: rev: v3.19.1 hooks: - id: pyupgrade - args: [--py36-plus] + args: [--py36-plus, --keep-percent-format] - repo: https://github.com/PyCQA/isort rev: 6.0.0 hooks: diff --git a/anaconda_anon_usage/heartbeat.py b/anaconda_anon_usage/heartbeat.py index ad583b3..8bcefd4 100644 --- a/anaconda_anon_usage/heartbeat.py +++ b/anaconda_anon_usage/heartbeat.py @@ -9,10 +9,11 @@ import argparse import os import sys +import time from threading import Thread from urllib.parse import urljoin -from conda.base.context import context +from conda.base.context import Context, context, locate_prefix_by_name from conda.gateways.connection.session import get_session from conda.models.channel import Channel @@ -20,7 +21,6 @@ VERBOSE = False STANDALONE = False -DRY_RUN = os.environ.get("ANACONDA_HEARTBEAT_DRY_RUN") CLD_REPO = "https://repo.anaconda.cloud/" ORG_REPO = "https://conda.anaconda.org/" @@ -28,14 +28,19 @@ REPOS = (CLD_REPO, COM_REPO, ORG_REPO) HEARTBEAT_PATH = "noarch/activate-0.0.0-0.conda" +# How long to attempt the connection. When a connection to our +# repository is blocked or slow, a long timeout would lead to +# a slow activation and a poor user experience. This is a total +# timeout value, inclusive of all retries. +TIMEOUT = 0.75 # seconds +ATTEMPTS = 3 -def _print(msg, *args, standalone=False, error=False): + +def _print(msg, *args, error=False): global VERBOSE global STANDALONE if not (VERBOSE or utils.DEBUG or error): return - if standalone and not STANDALONE: - return # It is very important that these messages are printed to stderr # when called from within the activate script. Otherwise they # will insert themselves into the activation command set @@ -43,22 +48,25 @@ def _print(msg, *args, standalone=False, error=False): print(msg % args, file=ofile) -def _ping(session, url, wait): +def _ping(session, url, timeout): try: - response = session.head(url, proxies=session.proxies) - _print("Status code (expect 404): %s", response.status_code) + # A short timeout is necessary here so that the activation + # is not unduly delayed by a blocked internet connection + start_time = time.perf_counter() + response = session.head(url, proxies=session.proxies, timeout=timeout) + delta = time.perf_counter() - start_time + _print( + "Success after %.3fs; code (expect 404): %d", delta, response.status_code + ) except Exception as exc: if type(exc).__name__ != "ConnectionError": - _print("Heartbeat error: %s", exc, error=True) + _print("Unexpected heartbeat error: %s", exc, error=True) + elif "timeout=" in str(exc): + delta = time.perf_counter() - start_time + _print("NO heartbeat sent after %.3fs.", delta) -def attempt_heartbeat(channel=None, path=None, wait=False): - global DRY_RUN - line = "------------------------" - _print(line, standalone=True) - _print("anaconda-anon-usage heartbeat", standalone=True) - _print(line, standalone=True) - +def attempt_heartbeat(prefix=None, dry_run=False, channel=None, path=None): if not hasattr(context, "_aau_initialized"): from . import patch @@ -77,39 +85,99 @@ def attempt_heartbeat(channel=None, path=None, wait=False): break else: _print("No valid heartbeat channel") - _print(line, standalone=True) return url = urljoin(base, channel or "main") + "/" url = urljoin(url, path or HEARTBEAT_PATH) _print("Heartbeat url: %s", url) + if prefix: + Context.checked_prefix = prefix + _print("Prefix: %s", prefix) _print("User agent: %s", context.user_agent) - if DRY_RUN: + + if dry_run: _print("Dry run selected, not sending heartbeat.") - else: - session = get_session(url) - t = Thread(target=_ping, args=(session, url, wait), daemon=True) - t.start() - _print("%saiting for response", "W" if wait else "Not w") - t.join(timeout=None if wait else 0.1) - _print(line, standalone=True) + return + + # Build and configure the session object + timeout = TIMEOUT / ATTEMPTS + context.remote_max_retries = ATTEMPTS - 1 + # No backoff between attempts + context.remote_backoff_factor = 0 + session = get_session(url) + + # Run in the background so we can proceed with the rest of the + # activation tasks while the request fires. The process will wait + # to terminate until the thread is complete. + t = Thread(target=_ping, args=(session, url, timeout), daemon=False) + t.start() + if STANDALONE: + t.join() def main(): global VERBOSE - global DRY_RUN global STANDALONE - p = argparse.ArgumentParser() - p.add_argument("-c", "--channel", default=None) - p.add_argument("-p", "--path", default=None) - p.add_argument("-d", "--dry-run", action="store_true") - p.add_argument("-q", "--quiet", action="store_true") - p.add_argument("-w", "--wait", action="store_true") - args = p.parse_args() STANDALONE = True - VERBOSE = not args.quiet - DRY_RUN = args.dry_run - attempt_heartbeat(args.channel, args.path, args.wait) + VERBOSE = "--quiet" not in sys.argv and "-q" not in sys.argv + + line = "-----------------------------" + _print(line) + _print("anaconda-anon-usage heartbeat") + _print(line) + + def environment_path(s): + assert os.path.isdir(s) + return s + + def environment_name(s): + return locate_prefix_by_name(s) + + p = argparse.ArgumentParser() + g = p.add_mutually_exclusive_group() + g.add_argument( + "-n", + "--name", + type=environment_name, + default=None, + help="Environment name; defaults to the current environment.", + ) + g.add_argument( + "-p", + "--prefix", + type=environment_path, + default=None, + help="Environment prefix; defaults to the current environment.", + ) + p.add_argument( + "-d", + "--dry-run", + action="store_true", + help="Do not send the heartbeat; just show the steps.", + ) + p.add_argument("-q", "--quiet", action="store_true", help="Suppress console logs.") + p.add_argument( + "--channel", + default=None, + help="(advanced) The full URL to a custom repository channel. By default, an " + "Anaconda-hosted channel listed in the user's channel configuration is used.", + ) + p.add_argument( + "--path", + default=None, + help="(advanced) A custom path to append to the channel URL.", + ) + + try: + args = p.parse_args() + attempt_heartbeat( + prefix=args.prefix or args.name, + dry_run=args.dry_run, + channel=args.channel, + path=args.path, + ) + finally: + _print(line) if __name__ == "__main__": diff --git a/anaconda_anon_usage/patch.py b/anaconda_anon_usage/patch.py index 5bbe63c..21efc11 100644 --- a/anaconda_anon_usage/patch.py +++ b/anaconda_anon_usage/patch.py @@ -55,8 +55,7 @@ def _new_activate(self): env = self.env_name_or_prefix if env and os.sep not in env: env = locate_prefix_by_name(env) - Context.checked_prefix = env or sys.prefix - attempt_heartbeat() + attempt_heartbeat(env or sys.prefix) except Exception as exc: _debug("Failed to attempt heartbeat: %s", exc, error=True) finally: diff --git a/tests/integration/proxy_tester.py b/tests/integration/proxy_tester.py new file mode 100755 index 0000000..64ef9ee --- /dev/null +++ b/tests/integration/proxy_tester.py @@ -0,0 +1,535 @@ +#!/usr/bin/env python3 + +"""HTTPS debugging proxy that logs or intercepts HTTPS requests. + +Launches a proxy server that either forwards HTTPS requests while logging +headers and content, or intercepts requests and returns specified responses. +Manages certificates automatically and supports concurrent connections. + +Arguments: + --logfile, -l FILE Write logs to FILE instead of stdout + --port, -p PORT Listen on PORT (default: 8080) + --keep-certs Keep certificates in current directory + --delay TIME Emulate a connection delay of TIME seconds + --return-code, -r N Return status code N for all requests + --return-header H Add header H to responses (can repeat) + --return-data DATA Return DATA as response body + +Examples: + # Log all HTTPS requests to test.log: + ./proxy_tester.py --logfile test.log -- curl https://httpbin.org/ip + + # Return 404 for all requests, but with a half-second delay: + ./proxy_tester.py --return-code 404 --delay 0.5 -- python my_script.py + + # Return custom response with headers and body: + ./proxy_tester.py --return-code 200 \\ + --return-header "Content-Type: application/json" \\ + --return-data '{"status": "ok"}' \\ + -- ./my_script.py +""" + +import argparse +import atexit +import logging +import os +import re +import select +import shutil +import socket +import ssl +import subprocess +import sys +import tempfile +import time +from datetime import datetime, timedelta +from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer +from os.path import isfile, join +from threading import Lock, Thread + +from cryptography import x509 +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.x509.oid import NameOID + +# _forward_data buffer size +BUFFER_SIZE = 65536 +# regex to find newlines in binary data +BINARY_NEWLINE = re.compile(rb"\r?\n") +LOG_FORMAT = "%(asctime)s - %(levelname)s - %(message)s" +CONNECTION_FORMAT = "[%s/%.3f/%.3f] %s" # cid, split, elapsed, message + +logger = logging.getLogger(__name__) + +# +# Certificate operations +# + + +CERT_DIR = None +CA_CERT = None +CA_KEY = None +# Track which host certificates we've logged about to prevent duplicate messages +CERT_READ = set() + + +def read_or_create_cert(host=None): + """Reads and/or creates the SSL certificates for the proxy, including + both the CA certificate and the host certificates signed with it. If + --keep-certs is set, then certificates will be saved between runs.""" + + global CA_CERT + global CA_KEY + + is_CA = host is None + + assert CERT_DIR + cert_path = join(CERT_DIR, "cert.pem" if is_CA else "%s-cert.pem" % host) + key_path = join(CERT_DIR, "key.pem" if is_CA else "%s-key.pem" % host) + + # return quickly if the files already exist + if isfile(cert_path) and isfile(key_path): + if is_CA: + logger.info("Using existing CA certificate") + with open(cert_path, "rb") as f: + CA_CERT = x509.load_pem_x509_certificate(f.read()) + with open(key_path, "rb") as f: + CA_KEY = serialization.load_pem_private_key(f.read(), password=None) + elif host not in CERT_READ: + logger.info("Using existing host certificate for %s", host) + CERT_READ.add(host) + return cert_path, key_path + + if is_CA: + logger.info("Generating CA certificate") + else: + assert CA_CERT and CA_KEY + logger.info("Generating host certificate for %s", host) + + # Generate CSR-like data + hostname = "Debug Proxy CA" if is_CA else host + host_info = [x509.NameAttribute(NameOID.COMMON_NAME, hostname)] + if is_CA: + host_info.append(x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Debug Proxy")) + name = x509.Name(host_info) + + key = rsa.generate_private_key( + public_exponent=65537, + key_size=2048, + ) + pub = key.public_key() + if not host: + CA_KEY = key + cert = ( + x509.CertificateBuilder() + .subject_name(name) + .issuer_name(name if is_CA else CA_CERT.subject) + .public_key(pub) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.now()) + .not_valid_after(datetime.now() + timedelta(days=365)) + .add_extension(x509.BasicConstraints(ca=is_CA, path_length=None), critical=True) + ) + if is_CA: + # Enable certificate signing + cert = cert.add_extension( + x509.KeyUsage( + digital_signature=True, + content_commitment=False, + key_encipherment=True, + data_encipherment=False, + key_agreement=False, + key_cert_sign=True, + crl_sign=True, + encipher_only=False, + decipher_only=False, + ), + critical=True, + ).add_extension( + x509.SubjectKeyIdentifier.from_public_key(pub), + critical=False, + ) + else: + cert = cert.add_extension( + x509.SubjectAlternativeName([x509.DNSName(host)]), critical=False + ) + + # Sign with CA key + cert = cert.sign(CA_KEY, hashes.SHA256()) + if is_CA: + CA_CERT = cert + + # Save and return the certificate and private key in PEM format + cert_pem = cert.public_bytes(serialization.Encoding.PEM) + key_pem = key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + + # Write to files + with open(key_path, "wb") as f: + f.write(key_pem) + with open(cert_path, "wb") as f: + f.write(cert_pem) + + return cert_path, key_path + + +# +# Server implementation +# + + +class MyHTTPServer(ThreadingHTTPServer): + """HTTPS proxy server with thread-per-connection handling""" + + daemon_threads = True + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Connection counter + self.counter = 0 + # Lock for single-threaded operations + self.lock = Lock() + # Interception settings + self.intercept_mode = False + self.return_code = 200 # Default if in intercept mode + self.return_headers = [] # List of (name, value) tuples + self.return_data = "" # Response body + + +class ProxyHandler(BaseHTTPRequestHandler): + + def setup(self): + self.start_time = time.perf_counter() + self.last_time = self.start_time + with self.server.lock: + self.server.counter += 1 + self.cid = "%04d" % self.server.counter + super().setup() + + def log_message(self, format, *args): + """Override to prevent access log messages from appearing on stderr""" + pass + + def _log(self, *args, **kwargs): + """Log message with elapsed time since first message for this connection ID""" + level = kwargs.pop("level", "info") + n_time = time.perf_counter() + d1 = n_time - self.last_time + d2 = n_time - self.start_time + fmt = CONNECTION_FORMAT % (self.cid, d1, d2, args[0]) + getattr(logger, level)(fmt, *args[1:], **kwargs) + self.last_time = n_time + + def _multiline_log( + self, blob, firstline=None, direction=None, include_binary=False + ): + """Split binary/text data into lines for logging, logging text and remaining byte count""" + global BINARY_NEWLINE + lines = [] + if firstline is not None: + lines.append(firstline) + if isinstance(blob, bytes): + while blob: + m = BINARY_NEWLINE.search(blob) + line = blob if m is None else blob[: m.start()] + try: + line = line.decode("iso-8859-1") + blob = blob[m.end() :] # noqa + if not line: + break + lines.append(line) + except UnicodeDecodeError: + break + else: + lines.extend(str(blob).strip().splitlines()) + blob = "" + if include_binary and blob: + lines.append("<+ %d bytes>" % len(blob)) + blob = "" + if direction: + lines[0] = "[%s] %s" % (direction, lines[0]) + self._log("\n | ".join(lines)) + return len(blob) + + def do_CONNECT(self): + self._multiline_log( + self.headers, + firstline=self.requestline, + direction="C->P", + include_binary=True, + ) + host, port = self.path.split(":") + + remote = None + client = None + error_code = 0 + error_msg = None + + try: + # Obtain MITM certificates for this host + with self.server.lock: + cert_file, key_file = read_or_create_cert(host) + + if self.server.delay: + self._log("Enforcing %gs delay", self.server.delay) + current = self.last_time + finish = self.start_time + self.server.delay + while finish - current > 0.001: + time.sleep(finish - current) + current = time.perf_counter() + self._log("End of connection delay") + + # Establish tunnel + self.send_response(200, "Connection Established") + self._multiline_log( + b"".join(self._headers_buffer), direction="P->C", include_binary=True + ) + self.end_headers() + + # Create SSL context for the client connection (MITM certificate) + client_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + client_context.load_cert_chain(cert_file, key_file) + client = client_context.wrap_socket(self.connection, server_side=True) + self._log("[C<>P] SSL handshake completed") + + if self.server.intercept_mode: + # Read the decrypted request + request = t_request = client.recv(BUFFER_SIZE) + while len(t_request) == BUFFER_SIZE: + t_request = client.recv(BUFFER_SIZE) + request += t_request + self._multiline_log(request, direction="C->P", include_binary=True) + + # Build and send custom response + response = ["HTTP/1.1 %d Intercepted" % self.server.return_code] + response.extend(": ".join(h) for h in self.server.return_headers) + if self.server.return_data: + c_len = len(self.server.return_data) + response.append("Content-Length: %d" % c_len) + response.extend(["", self.server.return_data or ""]) + response = "\r\n".join(response).encode("iso-8859-1") + + self._multiline_log(response, direction="P->C", include_binary=True) + client.send(response) + else: + # Create SSL context for the server connection (verify remote) + remote = socket.create_connection((host, int(port))) + server_context = ssl.create_default_context() + remote = server_context.wrap_socket(remote, server_hostname=host) + self._log("[P<>S] SSL handshake completed") + # Forward all requests to the real server + self._forward_data(client, remote) + + except ssl.SSLError as ssl_err: + self._log("SSL error: %s", ssl_err, level="error") + error_code, error_msg = 502, "SSL Handshake Failed" + except OSError as sock_err: + self._log("Socket error: %s", sock_err, level="error") + error_code, error_msg = 504, "Gateway Timeout" + except Exception as exc: + self._log("CONNECT error: %s", exc, level="error") + error_code, error_msg = 502, "Proxy Error" + finally: + if error_code: + try: + self.send_error(error_code, error_msg) + except Exception: + # If connection is already dead, sending an + # error would raise socket.error + pass + self.close_connection = True + if remote: + remote.close() + if client: + client.close() + self._log("Connection closed") + + def _forward_data(self, client, remote): + """Forward data between client and remote, logging headers and tracking binary data size""" + + def forward(source, destination, direction, bcount): + try: + data = source.recv(BUFFER_SIZE) + if not data: + return False, bcount + except (OSError, ssl.SSLError) as exc: + self._log("%s: Receive error: %s", direction, exc, level="error") + return False, bcount + + if bcount == 0: + # First chunk contains headers; subsequent chunks may be binary + ncount = self._multiline_log(data, direction=direction) + bcount += ncount + else: + bcount += len(data) + + try: + destination.sendall(data) + return True, bcount + except Exception as exc: + self._log("%s: Send error: %s", direction, exc, level="error") + return False, bcount + + # Track binary data for each direction separately + c_total = r_total = 0 + while True: + # 1 second timeout to check for connection closure + r, w, e = select.select([client, remote], [], [], 1.0) + if not r: + break + if client in r: + success, c_total = forward(client, remote, "C->S", c_total) + if not success: + break + if remote in r: + success, r_total = forward(remote, client, "S->C", r_total) + if not success: + break + + # Deliver final binary totals + if c_total: + self._log("[C->S] %d additional bytes sent", c_total) + if r_total: + self._log("[S->C] %d additional bytes received", r_total) + + +# +# Command-line interface +# + + +def main(): + global CERT_DIR + + # Parse arguments + parser = argparse.ArgumentParser( + description="HTTPS debugging proxy that logs or intercepts HTTPS requests" + ) + parser.add_argument( + "--logfile", "-l", help="File to write logs to (defaults to stdout)" + ) + parser.add_argument( + "--port", + "-p", + type=int, + default=8080, + help="Port for the proxy server (default: 8080)", + ) + parser.add_argument( + "--delay", + type=float, + action="store", + default=0, + help="Add a delay, in seconds, to each connection request, to test connection issues.", + ) + parser.add_argument( + "--keep-certs", + action="store_true", + help="Keep certificates in current directory instead of using a temporary directory", + ) + parser.add_argument( + "--return-code", + "-r", + type=int, + help="HTTP status code to return for all requests", + ) + parser.add_argument( + "--return-header", + action="append", + help='Response header in format "Name: Value" (can be repeated)', + ) + parser.add_argument("--return-data", help="Response body to return") + parser.add_argument("command", nargs="+", help="Command to run and its arguments") + args = parser.parse_args() + + # Configure logging + logging_config = { + "level": logging.INFO, + "format": LOG_FORMAT, + "handlers": [], + } + if args.logfile: + logging_config["handlers"].append(logging.FileHandler(args.logfile)) + else: + logging_config["handlers"].append(logging.StreamHandler(sys.stdout)) + logging.basicConfig(**logging_config) + + # Set up certificate generation + if args.keep_certs: + CERT_DIR = os.getcwd() + else: + CERT_DIR = tempfile.mkdtemp() + + def cleanup(): + logger.info("Removing temporary certificate directory") + shutil.rmtree(CERT_DIR, ignore_errors=True) + + atexit.register(cleanup) + logger.info("Certificate directory: %s", CERT_DIR) + cert_path, key_path = read_or_create_cert() + + # Start and configure server + server = MyHTTPServer(("0.0.0.0", args.port), ProxyHandler) + server.delay = max(0, args.delay) + + # Enable interception if any response-related args are provided + if ( + any(x is not None for x in [args.return_code, args.return_data]) + or args.return_header + ): + server.intercept_mode = True + server.return_code = args.return_code or 200 + server.return_data = args.return_data or "" + + # Parse headers + server.return_headers = [] + if args.return_header: + for header in args.return_header: + try: + name, value = header.split(":", 1) + server.return_headers.append((name.strip(), value.strip())) + except ValueError: + logger.error("Invalid header format: %s", header) + return 1 + server_thread = Thread(target=server.serve_forever) + server_thread.daemon = True + server_thread.start() + logger.info("Proxy server started on port %d", args.port) + + # Proxy configuration + env = os.environ.copy() + proxy_host = "http://localhost:%d" % args.port + env["HTTPS_PROXY"] = proxy_host + env["https_proxy"] = proxy_host + env["HTTP_PROXY"] = proxy_host + env["http_proxy"] = proxy_host + env["NO_PROXY"] = "" + env["no_proxy"] = "" + + # Certificate configuration + env["CURL_CA_BUNDLE"] = cert_path + env["SSL_CERT_FILE"] = cert_path + env["REQUESTS_CA_BUNDLE"] = cert_path + env["CONDA_SSL_VERIFY"] = cert_path + + # Run child process + returncode = 0 + try: + process = subprocess.Popen(args.command, env=env) + returncode = process.wait() + logger.info("Child process exited with code %d", returncode) + except Exception as exc: + logger.error("Error running child process: %s", exc) + returncode = 255 + finally: + server.shutdown() + server.server_close() + + return returncode + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/integration/test_config.py b/tests/integration/test_config.py index 3d2e034..ca7e11c 100644 --- a/tests/integration/test_config.py +++ b/tests/integration/test_config.py @@ -3,7 +3,7 @@ import re import subprocess import sys -from os.path import basename, expanduser, isfile, join +from os.path import basename, dirname, expanduser, isfile, join from anaconda_anon_usage import tokens as m_tokens @@ -85,10 +85,15 @@ def _config(value, ctype): if mode == "default" and ctype == "env": continue enabled = _config(mode, ctype) - # Make sure to leave override-channels and the full channel URL in here. - # This allows this command to run fully no matter what we do to channel_alias - # and default_channels - cmd = ["conda", "install", "-vvv", "fakepackage"] + # Using the proxy tester allows us to test this without the requests actually + # making it to repo.anaconda.com. The tester returns 404 for all requests. It + # also has the advantage of making sure our code respects proxies properly + pscript = join(dirname(__file__), "proxy_tester.py") + # fmt: off + cmd = ["python", pscript, "--return-code", "404", "--", + "python", "-m", "conda", "install", "--override-channels", + "-c", "defaults", "fakepackage"] + # fmt: on if envname: cmd.extend(["-n", envname]) skip = False @@ -98,18 +103,8 @@ def _config(value, ctype): capture_output=True, text=True, ) - user_agent = "" - for v in proc.stderr.splitlines(): - # Unfortunately conda has evolved how it logs request headers - # So this regular expression attempts to match multiple forms - # > User-Agent: conda/... - # .... {'User-Agent': 'conda/...', ...} - match = re.match(r'.*User-Agent(["\']?): *(["\']?)(.+)', v) - if match: - _, delim, user_agent = match.groups() - if delim and delim in user_agent: - user_agent = user_agent.split(delim, 1)[0] - break + match = re.search(r"^.*User-Agent: (.+)$", proc.stdout, re.MULTILINE) + user_agent = match.groups()[0] if match else "" if first: if user_agent: print(user_agent) diff --git a/tests/integration/test_heartbeats.py b/tests/integration/test_heartbeats.py index 3a54107..3c91fec 100644 --- a/tests/integration/test_heartbeats.py +++ b/tests/integration/test_heartbeats.py @@ -3,6 +3,7 @@ import re import subprocess import sys +from os.path import dirname, join from conda.base.context import context from conda.models.channel import Channel @@ -19,8 +20,6 @@ if os.path.isfile("/etc/conda/machine_token"): expected += ("m",) -os.environ["ANACONDA_ANON_USAGE_DEBUG"] = "1" -os.environ["ANACONDA_HEARTBEAT_DRY_RUN"] = "1" ALL_FIELDS = {"aau", "aid", "c", "s", "e", "u", "h", "n", "m", "o", "U", "H", "N"} @@ -48,24 +47,9 @@ def get_test_envs(): all_environments = set() -def verify_user_agent(output, expected, envname=None, marker=None): - # Unfortunately conda has evolved how it logs request headers - # So this regular expression attempts to match multiple forms - # > User-Agent: conda/... - # .... {'User-Agent': 'conda/...', ...} +def verify_user_agent(user_agent, expected, envname=None, marker=None): other_tokens["n"] = envname if envname else "base" - user_agent = "" - marker = marker or "[uU]ser.[aA]gent" # codespell:ignore - MATCH_RE = r".*" + marker + r'(["\']?): *(["\']?)(.+)' - for v in output.splitlines(): - match = re.match(MATCH_RE, v) - if match: - _, delim, user_agent = match.groups() - if delim and delim in user_agent: - user_agent = user_agent.split(delim, 1)[0] - break - new_values = [t.split("/", 1) for t in user_agent.split(" ") if "/" in t] new_values = {k: v for k, v in new_values if k in ALL_FIELDS} header = " ".join(f"{k}/{v}" for k, v in new_values.items()) @@ -110,25 +94,34 @@ def verify_user_agent(output, expected, envname=None, marker=None): urls = [u for c in context.channels for u in Channel(c).urls()] urls.extend(u.rstrip("/") for u in context.channel_alias.urls()) if any(".anaconda.cloud" in u for u in urls): - hb_url = "https://repo.anaconda.cloud/" + exp_host = "repo.anaconda.cloud:443" elif any(".anaconda.com" in u for u in urls): - hb_url = "https://repo.anaconda.com/" + exp_host = "repo.anaconda.com:443" elif any(".anaconda.org" in u for u in urls): - hb_url = "https://conda.anaconda.org/" + exp_host = "conda.anaconda.org:443" else: - hb_url = None -if hb_url: - hb_url += "pkgs/main/noarch/activate-0.0.0-0.conda" -print("Expected heartbeat url:", hb_url) -print("Expected user agent tokens:", ",".join(expected)) + raise RuntimeError("No heartbeat URL available.") +exp_path = "/pkgs/main/noarch/activate-0.0.0-0.conda" +print("Expected host:", exp_host) +print("Expected path:", exp_path) +print("Expected tokens:", ",".join(expected)) need_header = True -for hval in ("true", "false"): - os.environ["CONDA_ANACONDA_HEARTBEAT"] = hval +for hval in ("true", "false", "delay"): + os.environ["CONDA_ANACONDA_HEARTBEAT"] = str(hval != "false").lower() for envname in envs: # Do each one twice to make sure the user agent string # remains correct on repeated attempts for stype in shells: - cmd = ["conda", "shell." + stype, "activate", envname] + # Using the proxy tester allows us to test this without the requests actually + # making it to repo.anaconda.com. The tester returns 404 for all requests. It + # also has the advantage of making sure our code respects proxies properly + pscript = join(dirname(__file__), "proxy_tester.py") + # fmt: off + cmd = ["python", pscript, "--return-code", "404"] + if hval == "delay": + cmd.extend(["--delay", "2.0"]) + cmd.extend(["--", "python", "-m", "conda", "shell." + stype, "activate", envname]) + # fmt: on proc = subprocess.run( cmd, check=False, @@ -136,26 +129,22 @@ def verify_user_agent(output, expected, envname=None, marker=None): text=True, ) header = status = "" - no_hb_url = "No valid heartbeat channel" in proc.stderr - hb_urls = { - line.rsplit(" ", 1)[-1] - for line in proc.stderr.splitlines() - if "Heartbeat url:" in line - } - status = "" - if hval == "true": - if not (no_hb_url or hb_urls): - status = "NOT ENABLED" - elif hb_url and not hb_urls: - status = "NO HEARTBEAT URL" - elif not hb_url and hb_urls: - status = "UNEXPECTED URLS: " + ",".join(hb_urls) - elif hb_url and any(hb_url not in u for u in hb_urls): - status = "INCORRECT URLS: " + ",".join(hb_urls) - elif hval == "false" and (no_hb_url or hb_urls): + t_host = re.search(r"^.* CONNECT (.*) HTTP/1.1$", proc.stdout, re.MULTILINE) + t_host = t_host.groups()[0] if t_host else "" + t_path = re.search(r"^.* HEAD (.*) HTTP/1.1$", proc.stdout, re.MULTILINE) + t_path = t_path.groups()[0] if t_path else "" + t_uagent = re.search(r"^ . User-Agent: (.*)", proc.stdout, re.MULTILINE) + t_uagent = t_uagent.groups()[0] if t_uagent else "" + if hval != "false" and not t_host: + status = "NOT ENABLED" + elif hval == "false" and t_host: status = "NOT DISABLED" - if hb_urls and not status: - status, header = verify_user_agent(proc.stderr, expected, envname) + elif hval == "delay" and t_path: + status = "TIMEOUT FAILED" + elif t_host and t_path and (t_host != exp_host or t_path != exp_path): + status = f"INCORRECT URL: {t_host}{t_path}" + if not status and hval == "true": + status, header = verify_user_agent(t_uagent, expected, envname) if need_header: if header: print("|", header) diff --git a/tests/requirements.txt b/tests/requirements.txt index 9955dec..f294ac6 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -1,2 +1,3 @@ pytest pytest-cov +cryptography