Skip to content

Commit 7407611

Browse files
added more type hints
1 parent e001c79 commit 7407611

File tree

7 files changed

+191
-70
lines changed

7 files changed

+191
-70
lines changed

ssh_proxy_server/forwarders/tunnel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
from typeguard import typechecked
1414

1515
import ssh_proxy_server
16-
from ssh_proxy_server.interfaces.server import ServerInterface
1716
if TYPE_CHECKING:
17+
from ssh_proxy_server.interfaces.server import ServerInterface
1818
from ssh_proxy_server.session import Session
1919

2020

@@ -164,7 +164,7 @@ class ServerTunnelForwarder(ServerTunnelBaseForwarder):
164164
def __init__(
165165
self,
166166
session: 'ssh_proxy_server.session.Session',
167-
server_interface: ServerInterface,
167+
server_interface: 'ssh_proxy_server.interfaces.server.ServerInterface',
168168
destination: Optional[Tuple[str, int]]
169169
) -> None:
170170
super(ServerTunnelBaseForwarder, self).__init__()

ssh_proxy_server/multisocket.py

Lines changed: 63 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# type: ignore
21
"""
32
Utility functions to create server sockets able to listen on both
43
IPv4 and IPv6.
@@ -39,13 +38,18 @@
3938
Tuple,
4039
Optional,
4140
Text,
42-
List
41+
List,
42+
Union,
43+
overload
4344
)
4445

46+
from typeguard import typechecked
47+
4548
__author__ = "Giampaolo Rodola' <g.rodola [AT] gmail [DOT] com>"
4649
__license__ = "MIT"
4750

4851

52+
@typechecked
4953
def has_dual_stack(sock: Optional[socket.socket] = None) -> bool:
5054
"""Return True if kernel allows creating a socket which is able to
5155
listen for both IPv4 and IPv6 connections.
@@ -64,6 +68,7 @@ def has_dual_stack(sock: Optional[socket.socket] = None) -> bool:
6468
return False
6569

6670

71+
@typechecked
6772
def create_server_sock(
6873
address: Tuple[Text, int],
6974
family: Optional[socket.AddressFamily] = None,
@@ -170,6 +175,7 @@ class MultipleSocketsListener:
170175
socket in the list.
171176
"""
172177

178+
@typechecked
173179
def __init__(
174180
self,
175181
addresses: List[Tuple[Text, int]],
@@ -179,8 +185,8 @@ def __init__(
179185
queue_size: int = 5
180186
) -> None:
181187
self._pollster: Optional[select.poll]
182-
self._socks = []
183-
self._sockmap = {}
188+
self._socks: List[socket.socket] = []
189+
self._sockmap: Dict[int, socket.socket] = {}
184190
if hasattr(select, 'poll'):
185191
self._pollster = select.poll()
186192
else:
@@ -206,13 +212,16 @@ def __init__(
206212
if not completed:
207213
self.close()
208214

209-
def __enter__(self):
215+
@typechecked
216+
def __enter__(self) -> 'MultipleSocketsListener':
210217
return self
211218

212-
def __exit__(self, *args):
219+
@typechecked
220+
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
213221
self.close()
214222

215-
def __repr__(self):
223+
@typechecked
224+
def __repr__(self) -> Text:
216225
addrs = []
217226
for sock in self._socks:
218227
try:
@@ -221,72 +230,101 @@ def __repr__(self):
221230
addrs.append(())
222231
return '<%s (%r) at %#x>' % (self.__class__.__name__, addrs, id(self))
223232

224-
def _poll(self):
233+
@typechecked
234+
def _poll(self) -> Optional[Any]:
225235
"""Return the first readable fd."""
236+
fds_select: Optional[Tuple[List[Any], List[Any], List[Any]]] = None
237+
fds_poll: Optional[List[Tuple[int, int]]] = None
226238
timeout = self.gettimeout()
227239
if self._pollster is None:
228-
fds = select.select(self._sockmap.keys(), [], [], timeout)
229-
if timeout and fds == ([], [], []):
230-
raise socket.timeout('timed out')
240+
fds_select = select.select(self._sockmap.keys(), [], [], timeout)
241+
if timeout and fds_select == ([], [], []):
242+
raise TimeoutError('timed out')
231243
else:
232244
if timeout is not None:
233245
timeout *= 1000
234-
fds = self._pollster.poll(timeout)
235-
if timeout and fds == []:
236-
raise socket.timeout('timed out')
246+
fds_poll = self._pollster.poll(timeout)
247+
if timeout and fds_poll == []:
248+
raise TimeoutError('timed out')
237249
try:
238-
return fds[0][0]
250+
if fds_select is not None:
251+
return fds_select[0][0]
252+
if fds_poll is not None:
253+
return fds_poll[0][0]
239254
except IndexError:
240255
pass # non-blocking socket
256+
return None
241257

242-
def _multicall(self, name: Text, *args: Tuple[Any], **kwargs: Dict[Text, Any]) -> None:
258+
@typechecked
259+
def _multicall(self, name: Text, *args: Any, **kwargs: Any) -> None:
243260
for sock in self._socks:
244261
meth = getattr(sock, name)
245262
meth(*args, **kwargs)
246263

247-
def accept(self) -> None:
264+
@typechecked
265+
def accept(self) -> Tuple[socket.socket, Any]:
248266
"""Accept a connection from the first socket which is ready
249267
to do so.
250268
"""
251269
fd = self._poll()
252270
sock = self._sockmap[fd] if fd else self._socks[0]
253271
return sock.accept()
254272

255-
def filenos(self):
273+
@typechecked
274+
def filenos(self) -> List[int]:
256275
"""Return sockets' file descriptors as a list of integers.
257276
This is useful with select().
258277
"""
259278
return list(self._sockmap.keys())
260279

261280

262-
def getsockname(self):
281+
@typechecked
282+
def getsockname(self) -> Any:
263283
"""Return first registered socket's own address."""
264284
return self._socks[0].getsockname()
265285

266-
def getsockopt(self, level, optname, buflen=0):
286+
@overload
287+
def getsockopt(self, level: int, optname: int) -> int: ...
288+
289+
@overload
290+
def getsockopt(self, level: int, optname: int, buflen: int) -> bytes: ...
291+
292+
@typechecked
293+
def getsockopt(self, level: int, optname: int, buflen: int = 0) -> Union[int, bytes]:
267294
"""Return first registered socket's options."""
268295
return self._socks[0].getsockopt(level, optname, buflen)
269296

270-
def gettimeout(self) -> float:
297+
@typechecked
298+
def gettimeout(self) -> Optional[float]:
271299
"""Return first registered socket's timeout."""
272300
return self._socks[0].gettimeout()
273301

302+
@typechecked
274303
def settimeout(self, timeout: float) -> None:
275304
"""Set timeout for all registered sockets."""
276305
self._multicall('settimeout', timeout)
277306

278-
def setblocking(self, flag):
307+
@typechecked
308+
def setblocking(self, flag: bool) -> None:
279309
"""Set non/blocking mode for all registered sockets."""
280310
self._multicall('setblocking', flag)
281311

282-
def setsockopt(self, level, optname, value):
312+
@overload
313+
def setsockopt(self, level: int, optname: int, value: Union[int, bytes], optlen: None) -> None: ...
314+
@overload
315+
def setsockopt(self, level: int, optname: int, value: None, optlen: int) -> None: ...
316+
317+
@typechecked
318+
def setsockopt(self, level: int, optname: int, value: Optional[Union[int, bytes]], optlen: Optional[int]) -> None:
283319
"""Set option for all registered sockets."""
284-
self._multicall('setsockopt', level, optname, value)
320+
self._multicall('setsockopt', level, optname, value, optlen)
285321

286-
def shutdown(self, how) -> None:
322+
@typechecked
323+
def shutdown(self, how: int) -> None:
287324
"""Shut down all registered sockets."""
288325
self._multicall('shutdown', how)
289326

327+
@typechecked
290328
def close(self) -> None:
291329
"""Close all registered sockets."""
292330
self._multicall('close')

ssh_proxy_server/plugins/session/tcpserver.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,25 +2,46 @@
22
import socket
33
import threading
44
import time
5+
from typing import (
6+
Callable,
7+
List,
8+
Text,
9+
Any,
10+
Union,
11+
Tuple,
12+
Optional
13+
)
14+
import paramiko
15+
16+
from typeguard import typechecked
517

618

719
class TCPServerThread(threading.Thread):
820

9-
def __init__(self, request_handler, network='127.0.0.1', port=0, run_status=True, daemon=False):
21+
@typechecked
22+
def __init__(
23+
self,
24+
request_handler: Optional[Callable[[paramiko.Channel, Tuple[Text, int]], None]] = None,
25+
network: Text = '127.0.0.1',
26+
port: int = 0,
27+
run_status: bool = True,
28+
daemon: bool = False
29+
) -> None:
1030
super(TCPServerThread, self).__init__()
1131
self.running = run_status
1232
self.network = network
1333
self.port = port
14-
self.handle_request = request_handler
34+
self.handle_request_callback = request_handler
1535
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
1636
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
1737
if daemon:
1838
self.daemon = True
1939
self.socket.bind((self.network, self.port))
2040
self.network, self.port = self.socket.getsockname()
2141
self.socket.listen(5)
22-
self.threads = []
42+
self.threads: List[threading.Thread] = []
2343

44+
@typechecked
2445
def run(self) -> None:
2546
while self.running:
2647
readable = select.select([self.socket], [], [], 0.5)[0]
@@ -30,10 +51,13 @@ def run(self) -> None:
3051
t.start()
3152
time.sleep(0.1)
3253

33-
def handle_request(self, client, addr):
34-
pass
54+
@typechecked
55+
def handle_request(self, client: paramiko.Channel, addr: Tuple[Text, int]) -> None:
56+
if self.handle_request_callback is not None:
57+
self.handle_request_callback(client, addr)
3558

36-
def close(self):
59+
@typechecked
60+
def close(self) -> None:
3761
for t in self.threads:
3862
t.join()
3963
self.socket.close()

0 commit comments

Comments
 (0)