diff --git a/README.md b/README.md index a3f8d75..1636efe 100644 --- a/README.md +++ b/README.md @@ -57,8 +57,6 @@ Rust-powered WebSocket server with Python API for remote command execution and i └───────┘ └───────┘ └───────┘ ``` -**Key Design**: Daemons connect **TO** the agent (not the other way around), so no ports need to be exposed on the execution plane. - ## Installation ### Python Package (Controller) diff --git a/python/sandd/__init__.py b/python/sandd/__init__.py index aeca31c..557030c 100644 --- a/python/sandd/__init__.py +++ b/python/sandd/__init__.py @@ -7,7 +7,7 @@ - Interactive session (PTY) - File transfer -Example: +Example (Sync API): >>> from sandd import Server >>> server = Server(host="0.0.0.0", port=8765) >>> @@ -23,15 +23,28 @@ >>> # File transfer >>> server.upload_file("daemon-1", "/remote/path", data) >>> data = server.download_file("daemon-1", "/remote/file") + +Example (Async API - Not Yet Implemented): + >>> from sandd import AsyncServer + >>> server = AsyncServer(host="0.0.0.0", port=8765) + >>> + >>> # Execute command + >>> result = await server.exec("daemon-1", "ls -la") + >>> print(result.stdout) + >>> + >>> # Concurrent execution + >>> results = await asyncio.gather( + ... server.exec("daemon-1", "hostname"), + ... server.exec("daemon-2", "uptime") + ... ) """ -from typing import Optional, Dict, List -import time -import sys -import select +from .models import CommandResult, ServerStats +from .server import Server +from .async_server import AsyncServer try: - from ._core import Server as _RustServer, Session, PyCommandResult, PyStats + from ._core import Session except ImportError as e: raise ImportError( "Failed to import Rust extension. " @@ -40,395 +53,8 @@ __all__ = [ "Server", + "AsyncServer", "Session", "CommandResult", "ServerStats", ] - - -class CommandResult: - """Result from command execution - - Attributes: - stdout: Standard output from the command - stderr: Standard error from the command - exit_code: Exit code (0 = success) - duration_ms: Execution duration in milliseconds - """ - - def __init__(self, result: PyCommandResult): - self._result = result - - @property - def stdout(self) -> str: - """Standard output""" - return self._result.stdout - - @property - def stderr(self) -> str: - """Standard error""" - return self._result.stderr - - @property - def exit_code(self) -> int: - """Exit code (0 = success)""" - return self._result.exit_code - - @property - def duration_ms(self) -> int: - """Execution duration in milliseconds""" - return self._result.duration_ms - - @property - def success(self) -> bool: - """Whether the command succeeded (exit_code == 0)""" - return self.exit_code == 0 - - def __repr__(self) -> str: - return ( - f"CommandResult(exit_code={self.exit_code}, " - f"duration_ms={self.duration_ms}, " - f"stdout={len(self.stdout)} bytes, " - f"stderr={len(self.stderr)} bytes)" - ) - - -class ServerStats: - """Server statistics - - Attributes: - total_daemons: Total number of connected daemons - by_platform: Daemon count by platform (e.g., {"linux": 150, "darwin": 50}) - oldest_connection_secs: Age of oldest connection in seconds - """ - - def __init__(self, stats: PyStats): - self._stats = stats - - @property - def total_daemons(self) -> int: - """Total connected daemons""" - return self._stats.total_daemons - - @property - def by_platform(self) -> Dict[str, int]: - """Daemon count by platform""" - return self._stats.by_platform - - @property - def oldest_connection_secs(self) -> int: - """Age of oldest connection in seconds""" - return self._stats.oldest_connection_secs - - def __repr__(self) -> str: - return ( - f"ServerStats(total={self.total_daemons}, " - f"platforms={self.by_platform})" - ) - - -class Server: - """Sandbox execution server - - High-performance WebSocket server for managing remote daemon connections. - Built with Rust for efficient handling of high-concurrency workloads. - - Args: - host: Bind address (default: "0.0.0.0") - port: Bind port (default: 8765) - verbose: Enable logging at INFO level (default: True) - Set to False to disable logs (useful for interactive sessions) - - Example: - >>> server = Server("0.0.0.0", 8765) - >>> server.wait_for_daemon("daemon-1", timeout=30) - >>> result = server.exec("daemon-1", "hostname") - >>> print(result.stdout) - - >>> # Disable logs for clean output, useful for interactive sessions - >>> server = Server("0.0.0.0", 8765, verbose=False) - """ - - def __init__(self, host: str = "0.0.0.0", port: int = 8765, verbose: bool = True): - self._server = _RustServer(host, port, verbose) - self._host = host - self._port = port - - def exec( - self, - daemon_id: str, - command: str, - timeout: int = 300, - env: Optional[Dict[str, str]] = None, - cwd: Optional[str] = None, - ) -> CommandResult: - """Execute a command on a daemon - - Args: - daemon_id: Target daemon ID - command: Command to execute (session string) - timeout: Execution timeout in seconds (default: 300) - env: Environment variables to set - cwd: Working directory - - Returns: - CommandResult with stdout, stderr, exit_code, duration - - Raises: - ValueError: If daemon not found - TimeoutError: If command times out - RuntimeError: If command fails to execute - - Example: - >>> result = server.exec("daemon-1", "ls -la /tmp") - >>> if result.success: - ... print(result.stdout) - """ - result = self._server.exec( - daemon_id, command, timeout, env, cwd - ) - return CommandResult(result) - - def new_session( - self, - daemon_id: str, - rows: int = 24, - cols: int = 80, - term: str = "xterm-256color", - interactive: bool = False, - ) -> Session: - """Create a new interactive session - - Args: - daemon_id: Target daemon ID - rows: Terminal rows (default: 24) - cols: Terminal columns (default: 80) - term: TERM environment variable (default: "xterm-256color") - interactive: If True, enters interactive mode with live terminal (default: False) - - Returns: - Session for interactive I/O (or None if interactive=True, runs in foreground) - - Raises: - ValueError: If daemon not found - RuntimeError: If session fails to start - - Example (Programmatic): - >>> session = server.new_session("daemon-1") - >>> session.write(b"ls -la\\n") - >>> output = session.read(timeout=1.0) - >>> if output: - ... print(output.decode()) - - Example (Interactive): - >>> server.new_session("daemon-1", interactive=True) - # Enters interactive terminal session - type commands directly - """ - session = self._server.new_session(daemon_id, rows, cols, term) - - if interactive: - self._run_interactive(session) - return None - - return session - - def upload_file( - self, - daemon_id: str, - remote_path: str, - data: bytes, - ) -> None: - """Upload a file to a daemon - - Args: - daemon_id: Target daemon ID - remote_path: Destination path on daemon - data: File data to upload - - Raises: - ValueError: If daemon not found - RuntimeError: If upload fails - - Example: - >>> with open("config.yaml", "rb") as f: - ... data = f.read() - >>> server.upload_file("daemon-1", "/etc/app/config.yaml", data) - """ - self._server.upload_file(daemon_id, remote_path, data) - - def download_file( - self, - daemon_id: str, - remote_path: str, - ) -> bytes: - """Download a file from a daemon - - Args: - daemon_id: Target daemon ID - remote_path: Source path on daemon - - Returns: - File data as bytes - - Raises: - ValueError: If daemon not found - RuntimeError: If download fails - - Example: - >>> data = server.download_file("daemon-1", "/var/log/app.log") - >>> with open("app.log", "wb") as f: - ... f.write(data) - """ - return self._server.download_file(daemon_id, remote_path) - - def list_daemons( - self, - labels: Optional[Dict[str, str]] = None, - ) -> List[str]: - """List all connected daemon IDs, optionally filtered by labels - - Args: - labels: Dictionary of label key-value pairs to filter by (AND logic) - All specified labels must match for a daemon to be included - - Returns: - List of daemon IDs - - Example: - >>> # List all daemons - >>> daemons = server.list_daemons() - >>> print(f"Connected: {len(daemons)} daemons") - >>> - >>> # List daemons with single label - >>> prod_daemons = server.list_daemons(labels={"env": "prod"}) - >>> - >>> # List daemons with multiple labels (AND logic) - >>> west_prod = server.list_daemons(labels={"env": "prod", "region": "us-west"}) - >>> for daemon_id in west_prod: - ... print(f" - {daemon_id}") - """ - return self._server.list_daemons(labels) - - def daemon_count(self) -> int: - """Get number of connected daemons - - Returns: - Count of connected daemons - """ - return self._server.daemon_count() - - def get_stats(self) -> ServerStats: - """Get server statistics - - Returns: - ServerStats with connection metrics - - Example: - >>> stats = server.get_stats() - >>> print(f"Total: {stats.total_daemons}") - >>> print(f"Platforms: {stats.by_platform}") - """ - return ServerStats(self._server.get_stats()) - - def _run_interactive(self, session: Session) -> None: - """Run session in interactive mode with live terminal - - Args: - session: Session to make interactive - """ - print("Entering interactive session. Press Ctrl+D to exit.") - print("-" * 60) - - # Set terminal to raw mode on Unix systems - if sys.platform != "win32": - import tty - import termios - old_settings = termios.tcgetattr(sys.stdin) - try: - tty.setraw(sys.stdin.fileno()) - self._interactive_loop(session) - finally: - termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_settings) - else: - # Windows - just run without raw mode - self._interactive_loop(session) - - print("\n" + "-" * 60) - print("Interactive session ended.") - session.close() - - def _interactive_loop(self, session: Session) -> None: - """Main interactive I/O loop - - Args: - session: Session for I/O - """ - try: - while True: - # Check for input from stdin - if sys.platform != "win32": - rlist, _, _ = select.select([sys.stdin], [], [], 0.01) - if rlist: - data = sys.stdin.read(1) - if not data or data == '\x04': # Ctrl+D - break - session.write(data.encode()) - else: - # Windows - simple blocking read - import msvcrt - if msvcrt.kbhit(): - data = msvcrt.getch() - if data == b'\x04': # Ctrl+D - break - session.write(data) - - # Read output from session - output = session.read(timeout=0.01) - if output: - sys.stdout.buffer.write(output) - sys.stdout.buffer.flush() - - except KeyboardInterrupt: - # Ctrl+C - exit gracefully - pass - - def wait_for_daemon( - self, - daemon_id: str, - timeout: float = 30.0, - poll_interval: float = 0.5, - ) -> bool: - """Wait for a daemon to connect - - Args: - daemon_id: Daemon ID to wait for - timeout: Maximum wait time in seconds - poll_interval: How often to check (seconds) - - Returns: - True if daemon connected, False if timed out - - Example: - >>> if server.wait_for_daemon("daemon-1", timeout=60): - ... print("Daemon connected!") - ... result = server.exec("daemon-1", "hostname") - ... else: - ... print("Timeout waiting for daemon") - """ - start = time.time() - while time.time() - start < timeout: - if daemon_id in self.list_daemons(): - return True - time.sleep(poll_interval) - return False - - @property - def address(self) -> str: - """Server address (host:port)""" - return f"{self._host}:{self._port}" - - def __repr__(self) -> str: - return ( - f"Server(address={self.address}, " - f"daemons={self.daemon_count()})" - ) diff --git a/python/sandd/async_server.py b/python/sandd/async_server.py new file mode 100644 index 0000000..bb01fc9 --- /dev/null +++ b/python/sandd/async_server.py @@ -0,0 +1,197 @@ +"""Async API for SandD server""" + +from typing import Optional, Dict, List + +from .models import CommandResult, ServerStats + + +class AsyncServer: + """Async API for SandD server + + Provides async/await interface for managing remote daemons and executing commands. + All I/O operations are async and can be used with asyncio.gather() for concurrency. + + Example: + >>> server = AsyncServer(host="0.0.0.0", port=8765) + >>> + >>> # Execute command + >>> result = await server.exec("daemon-1", "ls -la") + >>> print(result.stdout) + >>> + >>> # Broadcast to multiple daemons + >>> results = await server.broadcast( + ... labels={"env": "prod"}, + ... command="git pull" + ... ) + >>> + >>> # Concurrent execution + >>> results = await asyncio.gather( + ... server.exec("daemon-1", "hostname"), + ... server.exec("daemon-2", "uptime"), + ... server.exec("daemon-3", "whoami") + ... ) + """ + + def __init__(self, host: str = "0.0.0.0", port: int = 8765): + """Initialize async server + + Args: + host: Host address to bind to + port: Port number to listen on + """ + raise NotImplementedError( + "AsyncServer is not yet implemented. " + "Track progress at: https://github.com/InftyAI/SandD/issues/TBD" + ) + + async def exec( + self, + daemon_id: str, + command: str, + timeout: int = 300, + env: Optional[Dict[str, str]] = None, + cwd: Optional[str] = None, + ) -> CommandResult: + """Execute command on daemon (async) + + Args: + daemon_id: Target daemon identifier + command: Shell command to execute + timeout: Execution timeout in seconds (default: 300) + env: Environment variables to set + cwd: Working directory + + Returns: + CommandResult with stdout, stderr, exit_code, duration_ms + + Example: + >>> result = await server.exec("daemon-1", "hostname") + >>> if result.success: + ... print(f"Hostname: {result.stdout}") + """ + raise NotImplementedError("AsyncServer.exec() not yet implemented") + + async def broadcast( + self, + labels: Dict[str, str], + command: str, + timeout: int = 300, + env: Optional[Dict[str, str]] = None, + cwd: Optional[str] = None, + ) -> Dict[str, CommandResult]: + """Broadcast command to all daemons matching labels (async) + + Executes the same command on all daemons that match the label filters, + running them concurrently using asyncio.gather(). + + Args: + labels: Label filters (all must match, AND logic) + command: Command to execute on all matching daemons + timeout: Execution timeout in seconds (default: 300) + env: Environment variables to set + cwd: Working directory + + Returns: + Dict mapping daemon_id -> CommandResult + + Example: + >>> results = await server.broadcast( + ... labels={"env": "prod", "role": "worker"}, + ... command="git pull && systemctl restart app" + ... ) + >>> for daemon_id, result in results.items(): + ... print(f"{daemon_id}: {'OK' if result.success else 'FAILED'}") + """ + raise NotImplementedError("AsyncServer.broadcast() not yet implemented") + + async def new_session(self, daemon_id: str): + """Create new interactive session (async) + + Args: + daemon_id: Target daemon identifier + + Returns: + AsyncSession object for interactive command execution + + Note: + AsyncSession is not yet defined. Will support async read/write. + """ + raise NotImplementedError("AsyncServer.new_session() not yet implemented") + + def list_daemons(self, labels: Optional[Dict[str, str]] = None) -> List[str]: + """List connected daemon IDs + + Args: + labels: Optional label filters (AND logic) + + Returns: + List of daemon IDs matching the filters + """ + raise NotImplementedError("AsyncServer.list_daemons() not yet implemented") + + def daemon_count(self) -> int: + """Get total number of connected daemons + + Returns: + Number of connected daemons + """ + raise NotImplementedError("AsyncServer.daemon_count() not yet implemented") + + async def wait_for_daemon(self, daemon_id: str, timeout: float = 30.0) -> bool: + """Wait for daemon to connect (async) + + Args: + daemon_id: Daemon identifier to wait for + timeout: Maximum wait time in seconds + + Returns: + True if daemon connected, False if timeout + """ + raise NotImplementedError("AsyncServer.wait_for_daemon() not yet implemented") + + def get_stats(self) -> ServerStats: + """Get server statistics + + Returns: + ServerStats object with daemon counts and platform info + """ + raise NotImplementedError("AsyncServer.get_stats() not yet implemented") + + async def upload_file( + self, + daemon_id: str, + remote_path: str, + data: bytes + ) -> None: + """Upload file to daemon (async) + + Args: + daemon_id: Target daemon identifier + remote_path: Destination path on daemon + data: File content as bytes + """ + raise NotImplementedError("AsyncServer.upload_file() not yet implemented") + + async def download_file( + self, + daemon_id: str, + remote_path: str + ) -> bytes: + """Download file from daemon (async) + + Args: + daemon_id: Target daemon identifier + remote_path: Source path on daemon + + Returns: + File content as bytes + """ + raise NotImplementedError("AsyncServer.download_file() not yet implemented") + + @property + def address(self) -> str: + """Server address (host:port)""" + raise NotImplementedError("AsyncServer.address not yet implemented") + + def __repr__(self) -> str: + return "AsyncServer(not yet implemented)" diff --git a/python/sandd/models.py b/python/sandd/models.py new file mode 100644 index 0000000..0f3642a --- /dev/null +++ b/python/sandd/models.py @@ -0,0 +1,92 @@ +"""Data models for SandD""" + +from typing import Dict + +try: + from ._core import PyCommandResult, PyStats +except ImportError as e: + raise ImportError( + "Failed to import Rust extension. " + "Please build the package with: make install" + ) from e + + +class CommandResult: + """Result from command execution + + Attributes: + stdout: Standard output from the command + stderr: Standard error from the command + exit_code: Exit code (0 = success) + duration_ms: Execution duration in milliseconds + """ + + def __init__(self, result: PyCommandResult): + self._result = result + + @property + def stdout(self) -> str: + """Standard output""" + return self._result.stdout + + @property + def stderr(self) -> str: + """Standard error""" + return self._result.stderr + + @property + def exit_code(self) -> int: + """Exit code (0 = success)""" + return self._result.exit_code + + @property + def duration_ms(self) -> int: + """Execution duration in milliseconds""" + return self._result.duration_ms + + @property + def success(self) -> bool: + """Whether the command succeeded (exit_code == 0)""" + return self.exit_code == 0 + + def __repr__(self) -> str: + return ( + f"CommandResult(exit_code={self.exit_code}, " + f"duration_ms={self.duration_ms}, " + f"stdout={len(self.stdout)} bytes, " + f"stderr={len(self.stderr)} bytes)" + ) + + +class ServerStats: + """Server statistics + + Attributes: + total_daemons: Total number of connected daemons + by_platform: Daemon count by platform (e.g., {"linux": 150, "darwin": 50}) + oldest_connection_secs: Age of oldest connection in seconds + """ + + def __init__(self, stats: PyStats): + self._stats = stats + + @property + def total_daemons(self) -> int: + """Total connected daemons""" + return self._stats.total_daemons + + @property + def by_platform(self) -> Dict[str, int]: + """Daemon count by platform""" + return self._stats.by_platform + + @property + def oldest_connection_secs(self) -> int: + """Age of oldest connection in seconds""" + return self._stats.oldest_connection_secs + + def __repr__(self) -> str: + return ( + f"ServerStats(total={self.total_daemons}, " + f"platforms={self.by_platform})" + ) diff --git a/python/sandd/server.py b/python/sandd/server.py new file mode 100644 index 0000000..4688347 --- /dev/null +++ b/python/sandd/server.py @@ -0,0 +1,407 @@ +"""Sync API for SandD server""" + +from typing import Optional, Dict, List +import time +import sys +import select + +from .models import CommandResult, ServerStats + +try: + from ._core import Server as _RustServer, Session +except ImportError as e: + raise ImportError( + "Failed to import Rust extension. " + "Please build the package with: make install" + ) from e + + +class Server: + """Sandbox execution server + + High-performance WebSocket server for managing remote daemon connections. + Built with Rust for efficient handling of high-concurrency workloads. + + Args: + host: Bind address (default: "0.0.0.0") + port: Bind port (default: 8765) + verbose: Enable logging at INFO level (default: True) + Set to False to disable logs (useful for interactive sessions) + + Example: + >>> server = Server("0.0.0.0", 8765) + >>> server.wait_for_daemon("daemon-1", timeout=30) + >>> result = server.exec("daemon-1", "hostname") + >>> print(result.stdout) + + >>> # Disable logs for clean output, useful for interactive sessions + >>> server = Server("0.0.0.0", 8765, verbose=False) + """ + + def __init__(self, host: str = "0.0.0.0", port: int = 8765, verbose: bool = True): + self._server = _RustServer(host, port, verbose) + self._host = host + self._port = port + + def exec( + self, + daemon_id: str, + command: str, + timeout: int = 300, + env: Optional[Dict[str, str]] = None, + cwd: Optional[str] = None, + ) -> CommandResult: + """Execute a command on a daemon + + Commands are processed sequentially by each daemon. If multiple commands + are sent to the same daemon, they will queue and execute one at a time. + + Args: + daemon_id: Target daemon ID + command: Command to execute (shell string) + timeout: Execution timeout in seconds (default: 300) + env: Environment variables to set + cwd: Working directory + + Returns: + CommandResult with stdout, stderr, exit_code, duration + + Raises: + ValueError: If daemon not found + TimeoutError: If command times out + RuntimeError: If command fails to execute + + Example: + >>> # Single command + >>> result = server.exec("daemon-1", "ls -la /tmp") + >>> if result.success: + ... print(result.stdout) + >>> + >>> # Multiple commands to same daemon execute sequentially + >>> result1 = server.exec("daemon-1", "sleep 5") # Takes 5s + >>> result2 = server.exec("daemon-1", "echo hi") # Waits for first to finish + + Note: + Each daemon processes commands sequentially to ensure predictable + execution order and avoid resource conflicts. + """ + result = self._server.exec( + daemon_id, command, timeout, env, cwd + ) + return CommandResult(result) + + def new_session( + self, + daemon_id: str, + rows: int = 24, + cols: int = 80, + term: str = "xterm-256color", + interactive: bool = False, + ) -> Session: + """Create a new interactive session + + Args: + daemon_id: Target daemon ID + rows: Terminal rows (default: 24) + cols: Terminal columns (default: 80) + term: TERM environment variable (default: "xterm-256color") + interactive: If True, enters interactive mode with live terminal (default: False) + + Returns: + Session for interactive I/O (or None if interactive=True, runs in foreground) + + Raises: + ValueError: If daemon not found + RuntimeError: If session fails to start + + Example (Programmatic): + >>> session = server.new_session("daemon-1") + >>> session.write(b"ls -la\\n") + >>> output = session.read(timeout=1.0) + >>> if output: + ... print(output.decode()) + + Example (Interactive): + >>> server.new_session("daemon-1", interactive=True) + # Enters interactive terminal session - type commands directly + """ + session = self._server.new_session(daemon_id, rows, cols, term) + + if interactive: + self._run_interactive(session) + return None + + return session + + def upload_file( + self, + daemon_id: str, + remote_path: str, + data: bytes, + ) -> None: + """Upload a file to a daemon + + Args: + daemon_id: Target daemon ID + remote_path: Destination path on daemon + data: File data to upload + + Raises: + ValueError: If daemon not found + RuntimeError: If upload fails + + Example: + >>> with open("config.yaml", "rb") as f: + ... data = f.read() + >>> server.upload_file("daemon-1", "/etc/app/config.yaml", data) + """ + self._server.upload_file(daemon_id, remote_path, data) + + def download_file( + self, + daemon_id: str, + remote_path: str, + ) -> bytes: + """Download a file from a daemon + + Args: + daemon_id: Target daemon ID + remote_path: Source path on daemon + + Returns: + File data as bytes + + Raises: + ValueError: If daemon not found + RuntimeError: If download fails + + Example: + >>> data = server.download_file("daemon-1", "/var/log/app.log") + >>> with open("app.log", "wb") as f: + ... f.write(data) + """ + return self._server.download_file(daemon_id, remote_path) + + def list_daemons( + self, + labels: Optional[Dict[str, str]] = None, + ) -> List[str]: + """List all connected daemon IDs, optionally filtered by labels + + Args: + labels: Dictionary of label key-value pairs to filter by (AND logic) + All specified labels must match for a daemon to be included + + Returns: + List of daemon IDs + + Example: + >>> # List all daemons + >>> daemons = server.list_daemons() + >>> print(f"Connected: {len(daemons)} daemons") + >>> + >>> # List daemons with single label + >>> prod_daemons = server.list_daemons(labels={"env": "prod"}) + >>> + >>> # List daemons with multiple labels (AND logic) + >>> west_prod = server.list_daemons(labels={"env": "prod", "region": "us-west"}) + >>> for daemon_id in west_prod: + ... print(f" - {daemon_id}") + """ + return self._server.list_daemons(labels) + + def daemon_count(self) -> int: + """Get number of connected daemons + + Returns: + Count of connected daemons + """ + return self._server.daemon_count() + + def get_stats(self) -> ServerStats: + """Get server statistics + + Returns: + ServerStats with connection metrics + + Example: + >>> stats = server.get_stats() + >>> print(f"Total: {stats.total_daemons}") + >>> print(f"Platforms: {stats.by_platform}") + """ + return ServerStats(self._server.get_stats()) + + def _run_interactive(self, session: Session) -> None: + """Run session in interactive mode with live terminal + + Args: + session: Session to make interactive + """ + print("Entering interactive session. Press Ctrl+D to exit.") + print("-" * 60) + + # Set terminal to raw mode on Unix systems + if sys.platform != "win32": + import tty + import termios + old_settings = termios.tcgetattr(sys.stdin) + try: + tty.setraw(sys.stdin.fileno()) + self._interactive_loop(session) + finally: + termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_settings) + else: + # Windows - just run without raw mode + self._interactive_loop(session) + + print("\n" + "-" * 60) + print("Interactive session ended.") + session.close() + + def _interactive_loop(self, session: Session) -> None: + """Main interactive I/O loop + + Args: + session: Session for I/O + """ + try: + while True: + # Check for input from stdin + if sys.platform != "win32": + rlist, _, _ = select.select([sys.stdin], [], [], 0.01) + if rlist: + data = sys.stdin.read(1) + if not data or data == '\x04': # Ctrl+D + break + session.write(data.encode()) + else: + # Windows - simple blocking read + import msvcrt + if msvcrt.kbhit(): + data = msvcrt.getch() + if data == b'\x04': # Ctrl+D + break + session.write(data) + + # Read output from session + output = session.read(timeout=0.01) + if output: + sys.stdout.buffer.write(output) + sys.stdout.buffer.flush() + + except KeyboardInterrupt: + # Ctrl+C - exit gracefully + pass + + def broadcast( + self, + labels: Dict[str, str], + command: str, + timeout: int = 300, + env: Optional[Dict[str, str]] = None, + cwd: Optional[str] = None, + ) -> Dict[str, CommandResult]: + """Broadcast a command to all daemons matching labels + + Executes the same command on all matching daemons concurrently using + Python threads. All daemons receive and execute the command in parallel, + making this much faster than calling exec() in a loop. + + Args: + labels: Label filters (all must match, AND logic) + command: Command to execute on all matching daemons + timeout: Execution timeout in seconds (default: 300) + env: Environment variables to set + cwd: Working directory + + Returns: + Dict mapping daemon_id -> CommandResult + + Example: + >>> # Update all production workers concurrently + >>> results = server.broadcast( + ... labels={"env": "prod", "role": "worker"}, + ... command="git pull && systemctl restart app" + ... ) + >>> + >>> # Check results + >>> for daemon_id, result in results.items(): + ... if result.success: + ... print(f"{daemon_id}: OK") + ... else: + ... print(f"{daemon_id}: FAILED - {result.stderr}") + + Performance: + Broadcasting to N daemons takes approximately the same time as + executing on a single daemon (all run in parallel), rather than + N times longer (sequential execution). + """ + import concurrent.futures + + # Get matching daemons + daemon_ids = self.list_daemons(labels=labels) + if not daemon_ids: + return {} + + # Execute command on all daemons concurrently + def run_command(daemon_id): + try: + return self.exec(daemon_id, command, timeout, env, cwd) + except Exception as e: + # Create error result + return CommandResult(type('obj', (object,), { + 'stdout': '', + 'stderr': str(e), + 'exit_code': -1, + 'duration_ms': 0, + })()) + + results = {} + with concurrent.futures.ThreadPoolExecutor(max_workers=len(daemon_ids)) as executor: + futures = {executor.submit(run_command, did): did for did in daemon_ids} + for future in concurrent.futures.as_completed(futures): + daemon_id = futures[future] + results[daemon_id] = future.result() + + return results + + def wait_for_daemon( + self, + daemon_id: str, + timeout: float = 30.0, + poll_interval: float = 0.5, + ) -> bool: + """Wait for a daemon to connect + + Args: + daemon_id: Daemon ID to wait for + timeout: Maximum wait time in seconds + poll_interval: How often to check (seconds) + + Returns: + True if daemon connected, False if timed out + + Example: + >>> if server.wait_for_daemon("daemon-1", timeout=60): + ... print("Daemon connected!") + ... result = server.exec("daemon-1", "hostname") + ... else: + ... print("Timeout waiting for daemon") + """ + start = time.time() + while time.time() - start < timeout: + if daemon_id in self.list_daemons(): + return True + time.sleep(poll_interval) + return False + + @property + def address(self) -> str: + """Server address (host:port)""" + return f"{self._host}:{self._port}" + + def __repr__(self) -> str: + return ( + f"Server(address={self.address}, " + f"daemons={self.daemon_count()})" + ) diff --git a/python/tests/test_e2e.py b/python/tests/test_e2e.py index 61c5ac1..e3311af 100644 --- a/python/tests/test_e2e.py +++ b/python/tests/test_e2e.py @@ -108,40 +108,114 @@ def run_cmd(daemon_id): assert all("Response from" in r.stdout for r in results) def test_concurrent_execution_same_daemon(self, server): - """Execute multiple commands concurrently on the same daemon""" + """Execute multiple commands on the same daemon (processed sequentially)""" import concurrent.futures - import time daemon_id = "daemon-debian-1" def run_sleep(n): - start = time.time() result = server.exec(daemon_id, f"sleep {n} && echo 'slept {n}s'", timeout=10) - duration = time.time() - start - return result, duration + return result def run_fast(): - start = time.time() result = server.exec(daemon_id, "echo 'fast command'", timeout=5) - duration = time.time() - start - return result, duration + return result - # Start slow command (3s) and fast command concurrently + start = time.time() + # Submit both commands - daemon processes them sequentially with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor: slow_future = executor.submit(run_sleep, 3) fast_future = executor.submit(run_fast) - # Fast command should complete quickly, not wait for slow one - fast_result, fast_duration = fast_future.result() + # Both commands succeed + fast_result = fast_future.result() assert fast_result.success assert "fast command" in fast_result.stdout - assert fast_duration < 1.0 # Should finish in <1s, not wait for 3s sleep - # Slow command completes independently - slow_result, slow_duration = slow_future.result() + slow_result = slow_future.result() assert slow_result.success assert "slept 3s" in slow_result.stdout - assert 2.5 < slow_duration < 4.0 + + # Total time is ~3s (sequential: slow command blocks fast one) + duration = time.time() - start + assert 2.5 < duration < 4.0 # Sequential processing + + +class TestE2EBroadcast: + """Test broadcast operations""" + + def test_broadcast_simple_command(self, server): + """Broadcast a simple command to multiple daemons""" + results = server.broadcast( + labels={"env": "test"}, + command="echo 'hello from broadcast'" + ) + + # Should have 4 test daemons + assert len(results) == 4 + + # Check all succeeded + for _, result in results.items(): + assert result.success + assert "hello from broadcast" in result.stdout + + def test_broadcast_with_multiple_labels(self, server): + """Broadcast with multiple label filters (AND logic)""" + results = server.broadcast( + labels={"env": "test", "distro": "debian"}, + command="hostname" + ) + + # Should match only debian test daemons + assert len(results) == 2 + assert "daemon-debian-1" in results + assert "daemon-debian-2" in results + + for result in results.values(): + assert result.success + + def test_broadcast_no_matching_daemons(self, server): + """Broadcast with labels that match no daemons""" + results = server.broadcast( + labels={"env": "nonexistent"}, + command="hostname" + ) + + # Should return empty dict + assert len(results) == 0 + + def test_broadcast_with_failure(self, server): + """Broadcast command that fails on some daemons""" + results = server.broadcast( + labels={"env": "prod"}, + command="exit 1" + ) + + # Should have results for prod daemons + assert len(results) == 2 + + # All should have exit code 1 + for result in results.values(): + assert not result.success + assert result.exit_code == 1 + + def test_broadcast_concurrent_execution(self, server): + """Verify broadcast executes concurrently, not serially""" + + # Broadcast a 2-second sleep to test daemons + start = time.time() + results = server.broadcast( + labels={"env": "test"}, + command="sleep 2" + ) + duration = time.time() - start + + # Should complete in ~2-3 seconds (concurrent), not 8+ seconds (serial) + assert len(results) == 4 + assert 2.0 < duration < 3.0 + + for result in results.values(): + assert result.success class TestE2ELabels: diff --git a/server/src/lib.rs b/server/src/lib.rs index 139aacc..0a573f4 100644 --- a/server/src/lib.rs +++ b/server/src/lib.rs @@ -73,6 +73,7 @@ impl Server { #[pyo3(signature = (daemon_id, command, timeout=300, env=None, cwd=None))] fn exec( &self, + py: Python, daemon_id: String, command: String, timeout: u64, @@ -101,18 +102,22 @@ impl Server { conn.send_message(msg) .map_err(|e| PyRuntimeError::new_err(format!("Failed to send command: {}", e)))?; - self.runtime.block_on(async { - // Wait for result with timeout - match tokio::time::timeout(Duration::from_secs(timeout + 5), rx).await { - Ok(Ok(result)) => Ok(PyCommandResult { - stdout: result.stdout, - stderr: result.stderr, - exit_code: result.exit_code, - duration_ms: result.duration_ms, - }), - Ok(Err(_)) => Err(PyRuntimeError::new_err("Command channel closed")), - Err(_) => Err(PyTimeoutError::new_err("Command execution timed out")), - } + // Release GIL while waiting for result to allow Python thread concurrency + // Re-acquire GIL to return result or raise timeout error + py.allow_threads(|| { + self.runtime.block_on(async { + // Wait for result with timeout + match tokio::time::timeout(Duration::from_secs(timeout), rx).await { + Ok(Ok(result)) => Ok(PyCommandResult { + stdout: result.stdout, + stderr: result.stderr, + exit_code: result.exit_code, + duration_ms: result.duration_ms, + }), + Ok(Err(_)) => Err(PyRuntimeError::new_err("Command channel closed")), + Err(_) => Err(PyTimeoutError::new_err("Command execution timed out")), + } + }) }) } @@ -220,10 +225,7 @@ impl Server { /// List all connected daemons, optionally filtered by labels #[pyo3(signature = (labels=None))] - fn list_daemons( - &self, - labels: Option>, - ) -> PyResult> { + fn list_daemons(&self, labels: Option>) -> PyResult> { Ok(self.registry.list_all(labels.as_ref())) } diff --git a/server/src/registry.rs b/server/src/registry.rs index 7ad71e9..232e1c9 100644 --- a/server/src/registry.rs +++ b/server/src/registry.rs @@ -121,6 +121,10 @@ impl DaemonConnection { } } + pub fn is_busy(&self) -> bool { + !self.pending_commands.is_empty() + } + pub fn register_session(&self, session_id: String, tx: mpsc::UnboundedSender>) { self.sessions.insert(session_id, tx); } @@ -196,7 +200,10 @@ impl DaemonRegistry { } } - pub fn list_all(&self, labels: Option<&std::collections::HashMap>) -> Vec { + pub fn list_all( + &self, + labels: Option<&std::collections::HashMap>, + ) -> Vec { self.connections .iter() .filter(|entry| {