Skip to main content

ssh_core/
lib.rs

1#![no_std]
2
3extern crate alloc;
4
5use alloc::{vec, vec::Vec};
6
7pub type Result<T> = core::result::Result<T, SshCoreError>;
8
9#[derive(Debug, Clone, Copy, Eq, PartialEq)]
10pub enum SshCoreError {
11    InvalidPacket,
12    InvalidState,
13    AuthDenied,
14    Backend,
15    BufferTooSmall,
16    Unsupported,
17}
18
19#[derive(Debug, Clone, Copy, Eq, PartialEq)]
20pub enum ConnectionState {
21    BannerExchange,
22    KeyExchange,
23    Authentication,
24    SessionOpen,
25    ExecRunning,
26    Closing,
27    Closed,
28}
29
30#[derive(Debug, Clone, Copy, Eq, PartialEq)]
31pub enum ChannelStream {
32    Stdin,
33    Stdout,
34    Stderr,
35}
36
37#[repr(C)]
38#[derive(Debug, Clone, Copy, Eq, PartialEq)]
39pub struct ExecSessionWiring {
40    pub session_id: u32,
41    pub stdin_ring: u32,
42    pub stdout_ring: u32,
43    pub stderr_ring: u32,
44}
45
46pub trait Transport {
47    /// Implements recv.
48    fn recv(&mut self, out: &mut [u8]) -> Result<usize>;
49    /// Implements send.
50    fn send(&mut self, data: &[u8]) -> Result<usize>;
51}
52
53pub trait HostKeyProvider {
54    /// Implements host public key.
55    fn host_public_key(&self) -> &[u8];
56    /// Implements sign exchange hash.
57    fn sign_exchange_hash(&mut self, exchange_hash: &[u8], out: &mut [u8]) -> Result<usize>;
58}
59
60pub trait AuthProvider {
61    /// Implements authorize public key.
62    fn authorize_public_key(
63        &mut self,
64        username: &[u8],
65        algorithm: &[u8],
66        public_key: &[u8],
67        signed_data: &[u8],
68        signature: &[u8],
69    ) -> Result<bool>;
70}
71
72pub trait ExecSessionProvider {
73    /// Implements spawn exec.
74    fn spawn_exec(&mut self, username: &[u8], command: &[u8]) -> Result<ExecSessionWiring>;
75    /// Closes exec.
76    fn close_exec(&mut self, wiring: &ExecSessionWiring) -> Result<()>;
77}
78
79pub enum ParsedPacket<'a> {
80    ClientBanner(&'a [u8]),
81    KexInit(&'a [u8]),
82    UserAuthPublicKey {
83        username: &'a [u8],
84        algorithm: &'a [u8],
85        public_key: &'a [u8],
86        signed_data: &'a [u8],
87        signature: &'a [u8],
88    },
89    ChannelOpenSession {
90        channel_id: u32,
91    },
92    ChannelExec {
93        channel_id: u32,
94        command: &'a [u8],
95    },
96    ChannelData {
97        channel_id: u32,
98        stream: ChannelStream,
99        data: &'a [u8],
100    },
101    ChannelEof {
102        channel_id: u32,
103    },
104    Disconnect,
105}
106
107pub trait SshBackend {
108    fn parse_packet<'a>(&mut self, packet: &'a [u8]) -> Result<ParsedPacket<'a>>;
109    /// Implements make server banner.
110    fn make_server_banner(&mut self) -> Result<Vec<u8>>;
111    /// Implements make kex reply.
112    fn make_kex_reply(
113        &mut self,
114        client_kex: &[u8],
115        host_keys: &mut dyn HostKeyProvider,
116    ) -> Result<Vec<u8>>;
117    /// Implements make auth reply.
118    fn make_auth_reply(&mut self, accepted: bool) -> Result<Vec<u8>>;
119    /// Implements make exec reply.
120    fn make_exec_reply(&mut self, channel_id: u32, accepted: bool) -> Result<Vec<u8>>;
121    /// Implements make disconnect.
122    fn make_disconnect(&mut self) -> Result<Vec<u8>>;
123}
124
125pub enum CoreDirective {
126    SendPacket(Vec<u8>),
127    AuthAccepted {
128        username: Vec<u8>,
129    },
130    AuthRejected,
131    ExecStarted {
132        channel_id: u32,
133        wiring: ExecSessionWiring,
134    },
135    StdinData {
136        channel_id: u32,
137        data: Vec<u8>,
138    },
139    CloseConnection,
140}
141
142pub struct SshCore<B, A, H, S>
143where
144    B: SshBackend,
145    A: AuthProvider,
146    H: HostKeyProvider,
147    S: ExecSessionProvider,
148{
149    backend: B,
150    auth: A,
151    host_keys: H,
152    sessions: S,
153    state: ConnectionState,
154    active_channel: Option<u32>,
155    active_user: Option<Vec<u8>>,
156    active_exec: Option<ExecSessionWiring>,
157}
158
159impl<B, A, H, S> SshCore<B, A, H, S>
160where
161    B: SshBackend,
162    A: AuthProvider,
163    H: HostKeyProvider,
164    S: ExecSessionProvider,
165{
166    /// Creates a new instance.
167    pub fn new(backend: B, auth: A, host_keys: H, sessions: S) -> Self {
168        Self {
169            backend,
170            auth,
171            host_keys,
172            sessions,
173            state: ConnectionState::BannerExchange,
174            active_channel: None,
175            active_user: None,
176            active_exec: None,
177        }
178    }
179
180    /// Implements state.
181    pub fn state(&self) -> ConnectionState {
182        self.state
183    }
184
185    /// Implements auth mut.
186    pub fn auth_mut(&mut self) -> &mut A {
187        &mut self.auth
188    }
189
190    /// Implements sessions mut.
191    pub fn sessions_mut(&mut self) -> &mut S {
192        &mut self.sessions
193    }
194
195    /// Implements ingest packet.
196    pub fn ingest_packet(&mut self, packet: &[u8]) -> Result<Vec<CoreDirective>> {
197        let event = self.backend.parse_packet(packet)?;
198        self.handle_event(event)
199    }
200
201    /// Implements handle event.
202    fn handle_event(&mut self, event: ParsedPacket<'_>) -> Result<Vec<CoreDirective>> {
203        let mut out = Vec::new();
204
205        match event {
206            ParsedPacket::ClientBanner(_banner) => {
207                if self.state != ConnectionState::BannerExchange {
208                    return Err(SshCoreError::InvalidState);
209                }
210                self.state = ConnectionState::KeyExchange;
211                out.push(CoreDirective::SendPacket(
212                    self.backend.make_server_banner()?,
213                ));
214            }
215            ParsedPacket::KexInit(client_kex) => {
216                if self.state != ConnectionState::KeyExchange {
217                    return Err(SshCoreError::InvalidState);
218                }
219                self.state = ConnectionState::Authentication;
220                out.push(CoreDirective::SendPacket(
221                    self.backend
222                        .make_kex_reply(client_kex, &mut self.host_keys)?,
223                ));
224            }
225            ParsedPacket::UserAuthPublicKey {
226                username,
227                algorithm,
228                public_key,
229                signed_data,
230                signature,
231            } => {
232                if self.state != ConnectionState::Authentication {
233                    return Err(SshCoreError::InvalidState);
234                }
235
236                let accepted = self.auth.authorize_public_key(
237                    username,
238                    algorithm,
239                    public_key,
240                    signed_data,
241                    signature,
242                )?;
243
244                out.push(CoreDirective::SendPacket(
245                    self.backend.make_auth_reply(accepted)?,
246                ));
247
248                if accepted {
249                    self.active_user = Some(username.to_vec());
250                    self.state = ConnectionState::SessionOpen;
251                    out.push(CoreDirective::AuthAccepted {
252                        username: username.to_vec(),
253                    });
254                } else {
255                    out.push(CoreDirective::AuthRejected);
256                }
257            }
258            ParsedPacket::ChannelOpenSession { channel_id } => {
259                if self.state != ConnectionState::SessionOpen {
260                    return Err(SshCoreError::InvalidState);
261                }
262                self.active_channel = Some(channel_id);
263            }
264            ParsedPacket::ChannelExec {
265                channel_id,
266                command,
267            } => {
268                if self.state != ConnectionState::SessionOpen {
269                    return Err(SshCoreError::InvalidState);
270                }
271                if self.active_channel != Some(channel_id) {
272                    return Err(SshCoreError::InvalidState);
273                }
274                let user = self
275                    .active_user
276                    .as_ref()
277                    .ok_or(SshCoreError::InvalidState)?;
278
279                let wiring = self.sessions.spawn_exec(user, command)?;
280                self.active_exec = Some(wiring);
281                self.state = ConnectionState::ExecRunning;
282
283                out.push(CoreDirective::SendPacket(
284                    self.backend.make_exec_reply(channel_id, true)?,
285                ));
286                out.push(CoreDirective::ExecStarted { channel_id, wiring });
287            }
288            ParsedPacket::ChannelData {
289                channel_id,
290                stream,
291                data,
292            } => {
293                if self.state != ConnectionState::ExecRunning {
294                    return Err(SshCoreError::InvalidState);
295                }
296                if self.active_channel != Some(channel_id) {
297                    return Err(SshCoreError::InvalidState);
298                }
299                if stream == ChannelStream::Stdin {
300                    out.push(CoreDirective::StdinData {
301                        channel_id,
302                        data: data.to_vec(),
303                    });
304                }
305            }
306            ParsedPacket::ChannelEof { channel_id } => {
307                if self.active_channel != Some(channel_id) {
308                    return Err(SshCoreError::InvalidState);
309                }
310                if let Some(wiring) = self.active_exec.take() {
311                    self.sessions.close_exec(&wiring)?;
312                }
313                self.state = ConnectionState::Closing;
314                out.push(CoreDirective::SendPacket(self.backend.make_disconnect()?));
315                out.push(CoreDirective::CloseConnection);
316                self.state = ConnectionState::Closed;
317            }
318            ParsedPacket::Disconnect => {
319                if let Some(wiring) = self.active_exec.take() {
320                    self.sessions.close_exec(&wiring)?;
321                }
322                self.state = ConnectionState::Closed;
323                out.push(CoreDirective::CloseConnection);
324            }
325        }
326
327        Ok(out)
328    }
329}
330
331#[derive(Default)]
332pub struct MinimalBackend;
333
334impl MinimalBackend {
335    /// Reads u16 be.
336    fn read_u16_be(input: &[u8], off: usize) -> Result<u16> {
337        if off + 2 > input.len() {
338            return Err(SshCoreError::InvalidPacket);
339        }
340        Ok(u16::from_be_bytes([input[off], input[off + 1]]))
341    }
342
343    /// Reads u32 be.
344    fn read_u32_be(input: &[u8], off: usize) -> Result<u32> {
345        if off + 4 > input.len() {
346            return Err(SshCoreError::InvalidPacket);
347        }
348        Ok(u32::from_be_bytes([
349            input[off],
350            input[off + 1],
351            input[off + 2],
352            input[off + 3],
353        ]))
354    }
355}
356
357impl SshBackend for MinimalBackend {
358    fn parse_packet<'a>(&mut self, packet: &'a [u8]) -> Result<ParsedPacket<'a>> {
359        if packet.is_empty() {
360            return Err(SshCoreError::InvalidPacket);
361        }
362
363        match packet[0] {
364            0x01 => Ok(ParsedPacket::ClientBanner(&packet[1..])),
365            0x02 => Ok(ParsedPacket::KexInit(&packet[1..])),
366            0x03 => {
367                if packet.len() < 3 {
368                    return Err(SshCoreError::InvalidPacket);
369                }
370
371                let mut off = 1;
372                let user_len = packet[off] as usize;
373                off += 1;
374                if off + user_len > packet.len() {
375                    return Err(SshCoreError::InvalidPacket);
376                }
377                let username = &packet[off..off + user_len];
378                off += user_len;
379
380                if off >= packet.len() {
381                    return Err(SshCoreError::InvalidPacket);
382                }
383                let algo_len = packet[off] as usize;
384                off += 1;
385                if off + algo_len > packet.len() {
386                    return Err(SshCoreError::InvalidPacket);
387                }
388                let algorithm = &packet[off..off + algo_len];
389                off += algo_len;
390
391                let key_len = Self::read_u16_be(packet, off)? as usize;
392                off += 2;
393                if off + key_len > packet.len() {
394                    return Err(SshCoreError::InvalidPacket);
395                }
396                let public_key = &packet[off..off + key_len];
397                off += key_len;
398
399                let sig_len = Self::read_u16_be(packet, off)? as usize;
400                off += 2;
401                if off + sig_len > packet.len() {
402                    return Err(SshCoreError::InvalidPacket);
403                }
404                let signature = &packet[off..off + sig_len];
405
406                Ok(ParsedPacket::UserAuthPublicKey {
407                    username,
408                    algorithm,
409                    public_key,
410                    signed_data: packet,
411                    signature,
412                })
413            }
414            0x04 => Ok(ParsedPacket::ChannelOpenSession {
415                channel_id: Self::read_u32_be(packet, 1)?,
416            }),
417            0x05 => {
418                let channel_id = Self::read_u32_be(packet, 1)?;
419                let cmd_len = Self::read_u16_be(packet, 5)? as usize;
420                if 7 + cmd_len > packet.len() {
421                    return Err(SshCoreError::InvalidPacket);
422                }
423                Ok(ParsedPacket::ChannelExec {
424                    channel_id,
425                    command: &packet[7..7 + cmd_len],
426                })
427            }
428            0x06 => {
429                let channel_id = Self::read_u32_be(packet, 1)?;
430                let data_len = Self::read_u16_be(packet, 5)? as usize;
431                if 7 + data_len > packet.len() {
432                    return Err(SshCoreError::InvalidPacket);
433                }
434                Ok(ParsedPacket::ChannelData {
435                    channel_id,
436                    stream: ChannelStream::Stdin,
437                    data: &packet[7..7 + data_len],
438                })
439            }
440            0x07 => Ok(ParsedPacket::ChannelEof {
441                channel_id: Self::read_u32_be(packet, 1)?,
442            }),
443            0x08 => Ok(ParsedPacket::Disconnect),
444            _ => Err(SshCoreError::Unsupported),
445        }
446    }
447
448    /// Implements make server banner.
449    fn make_server_banner(&mut self) -> Result<Vec<u8>> {
450        Ok(b"\x81SSH-2.0-Strat9\n".to_vec())
451    }
452
453    /// Implements make kex reply.
454    fn make_kex_reply(
455        &mut self,
456        client_kex: &[u8],
457        host_keys: &mut dyn HostKeyProvider,
458    ) -> Result<Vec<u8>> {
459        let mut sig = [0u8; 128];
460        let sig_len = host_keys.sign_exchange_hash(client_kex, &mut sig)?;
461        let host_key = host_keys.host_public_key();
462
463        let mut out = Vec::with_capacity(1 + 2 + host_key.len() + 2 + sig_len);
464        out.push(0x82);
465        out.extend_from_slice(&(host_key.len() as u16).to_be_bytes());
466        out.extend_from_slice(host_key);
467        out.extend_from_slice(&(sig_len as u16).to_be_bytes());
468        out.extend_from_slice(&sig[..sig_len]);
469        Ok(out)
470    }
471
472    /// Implements make auth reply.
473    fn make_auth_reply(&mut self, accepted: bool) -> Result<Vec<u8>> {
474        Ok(vec![0x83, u8::from(accepted)])
475    }
476
477    /// Implements make exec reply.
478    fn make_exec_reply(&mut self, channel_id: u32, accepted: bool) -> Result<Vec<u8>> {
479        let mut out = Vec::with_capacity(6);
480        out.push(0x84);
481        out.extend_from_slice(&channel_id.to_be_bytes());
482        out.push(u8::from(accepted));
483        Ok(out)
484    }
485
486    /// Implements make disconnect.
487    fn make_disconnect(&mut self) -> Result<Vec<u8>> {
488        Ok(vec![0x85])
489    }
490}
491
492#[cfg(feature = "backend-zssh")]
493pub struct ZsshBackend {
494    fallback: MinimalBackend,
495}
496
497#[cfg(feature = "backend-zssh")]
498impl Default for ZsshBackend {
499    /// Implements default.
500    fn default() -> Self {
501        use zssh as _;
502        Self {
503            fallback: MinimalBackend,
504        }
505    }
506}
507
508#[cfg(feature = "backend-zssh")]
509impl SshBackend for ZsshBackend {
510    fn parse_packet<'a>(&mut self, packet: &'a [u8]) -> Result<ParsedPacket<'a>> {
511        self.fallback.parse_packet(packet)
512    }
513
514    /// Implements make server banner.
515    fn make_server_banner(&mut self) -> Result<Vec<u8>> {
516        self.fallback.make_server_banner()
517    }
518
519    /// Implements make kex reply.
520    fn make_kex_reply(
521        &mut self,
522        client_kex: &[u8],
523        host_keys: &mut dyn HostKeyProvider,
524    ) -> Result<Vec<u8>> {
525        self.fallback.make_kex_reply(client_kex, host_keys)
526    }
527
528    /// Implements make auth reply.
529    fn make_auth_reply(&mut self, accepted: bool) -> Result<Vec<u8>> {
530        self.fallback.make_auth_reply(accepted)
531    }
532
533    /// Implements make exec reply.
534    fn make_exec_reply(&mut self, channel_id: u32, accepted: bool) -> Result<Vec<u8>> {
535        self.fallback.make_exec_reply(channel_id, accepted)
536    }
537
538    /// Implements make disconnect.
539    fn make_disconnect(&mut self) -> Result<Vec<u8>> {
540        self.fallback.make_disconnect()
541    }
542}
543
544#[cfg(feature = "backend-zssh")]
545pub type DefaultBackend = ZsshBackend;
546
547#[cfg(not(feature = "backend-zssh"))]
548pub type DefaultBackend = MinimalBackend;
549
550/// Implements default backend.
551pub fn default_backend() -> DefaultBackend {
552    Default::default()
553}