Skip to content
Open
26 changes: 19 additions & 7 deletions pymongo/asynchronous/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,10 +200,21 @@ async def join(self) -> None:

async def close(self) -> None:
self.gc_safe_close()
await self._rtt_monitor.close()
# Increment the generation and maybe close the socket. If the executor
# thread has the socket checked out, it will be closed when checked in.
await self._reset_connection()
# Run rtt_monitor.close() and self._pool.close() independently so a
# failure in rtt cleanup does not skip the monitor pool's close, which
# would otherwise orphan its conn deque.
rtt_error: Optional[BaseException] = None
try:
await self._rtt_monitor.close()
except BaseException as exc:
rtt_error = exc
# Close the monitor pool. This both closes the conn in the deque
# and marks the pool CLOSED, so any in-flight check_once that checks
# the conn back in will close it via close_conn(POOL_CLOSED) rather
# than returning it to the deque of a pool that is about to be GC'd.
await self._pool.close()
if rtt_error is not None:
raise rtt_error

async def _reset_connection(self) -> None:
# Clear our pooled connection.
Expand Down Expand Up @@ -456,9 +467,10 @@ def __init__(self, topology: Topology, topology_settings: TopologySettings, pool

async def close(self) -> None:
self.gc_safe_close()
# Increment the generation and maybe close the socket. If the executor
# thread has the socket checked out, it will be closed when checked in.
await self._pool.reset()
# Close the RTT monitor pool. If the executor task has the socket
# checked out, checkin will close it via close_conn(POOL_CLOSED)
# because the pool is now CLOSED.
await self._pool.close()

async def add_sample(self, sample: float) -> None:
"""Add a RTT sample."""
Expand Down
191 changes: 140 additions & 51 deletions pymongo/asynchronous/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,16 @@ def __init__(
# For gossiping $clusterTime from the connection handshake to the client.
self._cluster_time = None

def __del__(self) -> None:
# Ensure all async connections are properly cleaned up on GC :
if not _IS_SYNC and not self.closed:
try:
transport = self.conn.get_conn.transport
if transport is not None:
transport.abort()
except Exception: # noqa: S110
pass

def set_conn_timeout(self, timeout: Optional[float]) -> None:
"""Cache last timeout to avoid duplicate calls to conn.settimeout."""
if timeout == self.last_timeout:
Expand Down Expand Up @@ -565,6 +575,19 @@ async def close_conn(self, reason: Optional[str]) -> None:

async def _close_conn(self) -> None:
"""Close this connection."""
# Force-abort the underlying transport first so the socket fd is
# released even if a previous _close_conn already set self.closed
# but didn't reach transport.abort() (e.g. cancelled mid-close), or
# if the graceful close path below raises. transport.abort() is
# idempotent — _SelectorTransport._force_close returns early when
# _conn_lost is already set.
if not _IS_SYNC:
transport = self.conn.get_conn.transport
if transport is not None:
try:
transport.abort()
except Exception: # noqa: S110
pass
if self.closed:
return
self.closed = True
Expand Down Expand Up @@ -847,6 +870,23 @@ async def _reset(
for context in self.active_contexts:
context.cancel()

# Synchronously abort the transports of all snapshotted conns. This
# releases the socket fd and schedules _call_connection_lost before
# any await. If the gather below is cancelled (e.g. by test
# teardown propagating a CancelledError into _reset), the inner
# close_conn() Tasks are cancelled before their body runs, so the
# transport.abort() inside _close_conn() never fires and the
# snapshotted conns leak. transport.abort() is idempotent — the
# close_conn coroutines below remain safe to run.
if not _IS_SYNC:
for conn in sockets:
try:
transport = conn.conn.get_conn.transport
if transport is not None:
transport.abort()
except Exception: # noqa: S110
pass

listeners = self.opts._event_listeners
# CMAP spec says that close() MUST close sockets before publishing the
# PoolClosedEvent but that reset() SHOULD close sockets *after*
Expand Down Expand Up @@ -969,18 +1009,31 @@ async def remove_stale_sockets(self, reference_generation: int) -> None:
self._pending += 1
incremented = True
conn = await self.connect()
close_conn = False
async with self.lock:
# Close connection and return if the pool was reset during
# socket creation or while acquiring the pool lock.
if self.gen.get_overall() != reference_generation:
close_conn = True
if not close_conn:
self.conns.appendleft(conn)
self.active_contexts.discard(conn.cancel_context)
if close_conn:
await conn.close_conn(ConnectionClosedReason.STALE)
return
try:
close_conn = False
async with self.lock:
# Close connection and return if the pool was reset during
# socket creation or while acquiring the pool lock.
if self.gen.get_overall() != reference_generation:
close_conn = True
if not close_conn:
self.conns.appendleft(conn)
self.active_contexts.discard(conn.cancel_context)
if close_conn:
await conn.close_conn(ConnectionClosedReason.STALE)
return
except BaseException:
# If cancellation or any other exception lands between
# connect() returning and the conn being appended to the
# deque (or closed via stale-generation), the conn would
# otherwise be orphaned. Close it explicitly to release
# the underlying socket.
if not conn.closed:
try:
await conn.close_conn(ConnectionClosedReason.ERROR)
except BaseException: # noqa: S110
pass
raise
finally:
if incremented:
# Notify after adding the socket to the pool.
Expand Down Expand Up @@ -1064,35 +1117,58 @@ async def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> A
_raise_connection_failure(self.address, error, timeout_details=details)
raise

conn = AsyncConnection(networking_interface, self, self.address, conn_id, self.is_sdam) # type: ignore[arg-type]
async with self.lock:
self.active_contexts.add(conn.cancel_context)
self.active_contexts.discard(tmp_context)
if tmp_context.cancelled:
conn.cancel_context.cancel()
completed_hello = False
try:
if not self.is_sdam:
await conn.hello()
completed_hello = True
self.is_writable = conn.is_writable
if handler:
handler.contribute_socket(conn, completed_handshake=False)

await conn.authenticate()
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException as e:
conn = AsyncConnection(networking_interface, self, self.address, conn_id, self.is_sdam) # type: ignore[arg-type]
except BaseException:
# Release the networking_interface's transport if AsyncConnection
# construction failed, since no AsyncConnection exists yet to
# close_conn() through the outer cleanup path.
transport = getattr(networking_interface.get_conn, "transport", None)
if transport is not None:
try:
transport.abort()
except Exception: # noqa: S110
pass
async with self.lock:
self.active_contexts.discard(conn.cancel_context)
if not completed_hello:
self._handle_connection_error(e)
await conn.close_conn(ConnectionClosedReason.ERROR)
self.active_contexts.discard(tmp_context)
raise
try:
async with self.lock:
self.active_contexts.add(conn.cancel_context)
self.active_contexts.discard(tmp_context)
if tmp_context.cancelled:
conn.cancel_context.cancel()
completed_hello = False
try:
if not self.is_sdam:
await conn.hello()
completed_hello = True
self.is_writable = conn.is_writable
if handler:
handler.contribute_socket(conn, completed_handshake=False)

await conn.authenticate()
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
except BaseException as e:
async with self.lock:
self.active_contexts.discard(conn.cancel_context)
if not completed_hello:
self._handle_connection_error(e)
await conn.close_conn(ConnectionClosedReason.ERROR)
raise

if handler:
await handler.client._topology.receive_cluster_time(conn._cluster_time)
if handler:
await handler.client._topology.receive_cluster_time(conn._cluster_time)

return conn
return conn
# Catch cancellations that interrupt outside the inner try block above
except BaseException:
if not conn.closed:
try:
await conn.close_conn(ConnectionClosedReason.ERROR)
except BaseException: # noqa: S110
pass
raise
Comment on lines +1164 to +1171
Comment on lines +1164 to +1171

@contextlib.asynccontextmanager
async def checkout(
Expand Down Expand Up @@ -1129,21 +1205,21 @@ async def checkout(

conn = await self._get_conn(checkout_started_time, handler=handler)

duration = time.monotonic() - checkout_started_time
if self.enabled_for_cmap:
assert listeners is not None
listeners.publish_connection_checked_out(self.address, conn.id, duration)
if self.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_CONNECTION_LOGGER,
message=_ConnectionStatusMessage.CHECKOUT_SUCCEEDED,
clientId=self._client_id,
serverHost=self.address[0],
serverPort=self.address[1],
driverConnectionId=conn.id,
durationMS=duration,
)
try:
duration = time.monotonic() - checkout_started_time
if self.enabled_for_cmap:
assert listeners is not None
listeners.publish_connection_checked_out(self.address, conn.id, duration)
if self.enabled_for_logging and _CONNECTION_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_CONNECTION_LOGGER,
message=_ConnectionStatusMessage.CHECKOUT_SUCCEEDED,
clientId=self._client_id,
serverHost=self.address[0],
serverPort=self.address[1],
driverConnectionId=conn.id,
durationMS=duration,
)
async with self.lock:
self.active_contexts.add(conn.cancel_context)
yield conn
Expand All @@ -1160,7 +1236,20 @@ async def checkout(
exc_type, exc_val, _ = sys.exc_info()
await handler.handle(exc_type, exc_val)
if not pinned and conn.active:
await self.checkin(conn)
try:
await self.checkin(conn)
except BaseException:
# If checkin is interrupted (e.g., cancellation during a
# lock acquire), the conn is left neither in the deque
# nor closed. Force-abort the transport so the socket fd
# is released instead of leaking on GC.
transport = getattr(conn.conn.get_conn, "transport", None)
if transport is not None:
try:
transport.abort()
except Exception: # noqa: S110
pass
raise
raise
if conn.pinned_txn:
async with self.lock:
Expand Down
11 changes: 10 additions & 1 deletion pymongo/asynchronous/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,17 @@ async def close(self) -> None:
serverPort=self._description.address[1],
)

await self._monitor.close()
# Run monitor.close() and pool.close() independently so a failure in
# one (e.g. monitor's rtt_monitor.close() raising) does not skip the
# other and orphan the server pool's connections.
monitor_error: Optional[BaseException] = None
try:
await self._monitor.close()
except BaseException as exc:
monitor_error = exc
await self._pool.close()
if monitor_error is not None:
raise monitor_error

def request_check(self) -> None:
"""Check the server's state soon."""
Expand Down
24 changes: 22 additions & 2 deletions pymongo/asynchronous/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,8 +722,16 @@ async def close(self) -> None:
"""
async with self._lock:
old_td = self._description
first_close_error: Optional[BaseException] = None
for server in self._servers.values():
await server.close()
# Each server.close must run independently. A failure on one
# server (e.g. its monitor's cleanup raising) must not skip
# close() on the remaining servers, or their pool conns leak.
try:
await server.close()
except BaseException as exc:
if first_close_error is None:
first_close_error = exc
if not _IS_SYNC:
self._monitor_tasks.append(server._monitor)

Expand Down Expand Up @@ -783,6 +791,11 @@ async def close(self) -> None:
await self.__events_executor.join(1)
process_events_queue(weakref.ref(self._events)) # type: ignore[arg-type]

# Re-raise the first server.close() error (if any) after we've done
# everything we can to clean up the rest of the topology.
if first_close_error is not None:
raise first_close_error

@property
def description(self) -> TopologyDescription:
return self._description
Expand Down Expand Up @@ -981,7 +994,14 @@ async def _update_servers(self) -> None:

for address, server in list(self._servers.items()):
if not self._description.has_server(address):
await server.close()
# Each server.close must run independently. If one server's
# cleanup raises, we still need to remove it from _servers and
# add its monitor to _monitor_tasks, and we must process the
# remaining servers, otherwise their pool conns leak.
try:
await server.close()
except BaseException: # noqa: S110
pass
if not _IS_SYNC:
self._monitor_tasks.append(server._monitor)
self._servers.pop(address)
Expand Down
Loading
Loading