spin_factor_outbound_mqtt/
host.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
use std::{sync::Arc, time::Duration};

use anyhow::Result;
use spin_core::{async_trait, wasmtime::component::Resource};
use spin_factor_outbound_networking::OutboundAllowedHosts;
use spin_world::v2::mqtt::{self as v2, Connection, Error, Qos};
use tracing::{instrument, Level};

use crate::ClientCreator;

pub struct InstanceState {
    allowed_hosts: OutboundAllowedHosts,
    connections: spin_resource_table::Table<Arc<dyn MqttClient>>,
    create_client: Arc<dyn ClientCreator>,
}

impl InstanceState {
    pub fn new(allowed_hosts: OutboundAllowedHosts, create_client: Arc<dyn ClientCreator>) -> Self {
        Self {
            allowed_hosts,
            create_client,
            connections: spin_resource_table::Table::new(1024),
        }
    }
}

#[async_trait]
pub trait MqttClient: Send + Sync {
    async fn publish_bytes(&self, topic: String, qos: Qos, payload: Vec<u8>) -> Result<(), Error>;
}

impl InstanceState {
    async fn is_address_allowed(&self, address: &str) -> Result<bool> {
        self.allowed_hosts.check_url(address, "mqtt").await
    }

    async fn establish_connection(
        &mut self,
        address: String,
        username: String,
        password: String,
        keep_alive_interval: Duration,
    ) -> Result<Resource<Connection>, Error> {
        self.connections
            .push((self.create_client).create(address, username, password, keep_alive_interval)?)
            .map(Resource::new_own)
            .map_err(|_| Error::TooManyConnections)
    }

    async fn get_conn(&self, connection: Resource<Connection>) -> Result<&dyn MqttClient, Error> {
        self.connections
            .get(connection.rep())
            .ok_or(Error::Other(
                "could not find connection for resource".into(),
            ))
            .map(|c| c.as_ref())
    }
}

impl v2::Host for InstanceState {
    fn convert_error(&mut self, error: Error) -> Result<Error> {
        Ok(error)
    }
}

#[async_trait]
impl v2::HostConnection for InstanceState {
    #[instrument(name = "spin_outbound_mqtt.open_connection", skip(self, password), err(level = Level::INFO), fields(otel.kind = "client"))]
    async fn open(
        &mut self,
        address: String,
        username: String,
        password: String,
        keep_alive_interval: u64,
    ) -> Result<Resource<Connection>, Error> {
        if !self
            .is_address_allowed(&address)
            .await
            .map_err(|e| v2::Error::Other(e.to_string()))?
        {
            return Err(v2::Error::ConnectionFailed(format!(
                "address {address} is not permitted"
            )));
        }
        self.establish_connection(
            address,
            username,
            password,
            Duration::from_secs(keep_alive_interval),
        )
        .await
    }

    /// Publish a message to the MQTT broker.
    ///
    /// OTEL trace propagation is not directly supported in MQTT V3. You will need to embed the
    /// current trace context into the payload yourself.
    /// https://w3c.github.io/trace-context-mqtt/#mqtt-v3-recommendation.
    #[instrument(name = "spin_outbound_mqtt.publish", skip(self, connection, payload), err(level = Level::INFO),
        fields(otel.kind = "producer", otel.name = format!("{} publish", topic), messaging.operation = "publish",
        messaging.system = "mqtt"))]
    async fn publish(
        &mut self,
        connection: Resource<Connection>,
        topic: String,
        payload: Vec<u8>,
        qos: Qos,
    ) -> Result<(), Error> {
        let conn = self.get_conn(connection).await.map_err(other_error)?;

        conn.publish_bytes(topic, qos, payload).await?;

        Ok(())
    }

    async fn drop(&mut self, connection: Resource<Connection>) -> anyhow::Result<()> {
        self.connections.remove(connection.rep());
        Ok(())
    }
}

pub fn other_error(e: impl std::fmt::Display) -> Error {
    Error::Other(e.to_string())
}