spin_factor_llm/
host.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
use async_trait::async_trait;
use spin_world::v1::llm::{self as v1};
use spin_world::v2::llm::{self as v2};
use tracing::field::Empty;
use tracing::{instrument, Level};

use crate::InstanceState;

#[async_trait]
impl v2::Host for InstanceState {
    #[instrument(name = "spin_llm.infer", skip(self, prompt), err(level = Level::INFO), fields(otel.kind = "client", llm.backend = Empty))]
    async fn infer(
        &mut self,
        model: v2::InferencingModel,
        prompt: String,
        params: Option<v2::InferencingParams>,
    ) -> Result<v2::InferencingResult, v2::Error> {
        if !self.allowed_models.contains(&model) {
            return Err(access_denied_error(&model));
        }
        let mut engine = self.engine.lock().await;
        tracing::Span::current().record("llm.backend", engine.summary());
        engine
            .infer(
                model,
                prompt,
                params.unwrap_or(v2::InferencingParams {
                    max_tokens: 100,
                    repeat_penalty: 1.1,
                    repeat_penalty_last_n_token_count: 64,
                    temperature: 0.8,
                    top_k: 40,
                    top_p: 0.9,
                }),
            )
            .await
    }

    #[instrument(name = "spin_llm.generate_embeddings", skip(self, data), err(level = Level::INFO), fields(otel.kind = "client", llm.backend = Empty))]
    async fn generate_embeddings(
        &mut self,
        model: v1::EmbeddingModel,
        data: Vec<String>,
    ) -> Result<v2::EmbeddingsResult, v2::Error> {
        if !self.allowed_models.contains(&model) {
            return Err(access_denied_error(&model));
        }
        let mut engine = self.engine.lock().await;
        tracing::Span::current().record("llm.backend", engine.summary());
        engine.generate_embeddings(model, data).await
    }

    fn convert_error(&mut self, error: v2::Error) -> anyhow::Result<v2::Error> {
        Ok(error)
    }
}

#[async_trait]
impl v1::Host for InstanceState {
    async fn infer(
        &mut self,
        model: v1::InferencingModel,
        prompt: String,
        params: Option<v1::InferencingParams>,
    ) -> Result<v1::InferencingResult, v1::Error> {
        <Self as v2::Host>::infer(self, model, prompt, params.map(Into::into))
            .await
            .map(Into::into)
            .map_err(Into::into)
    }

    async fn generate_embeddings(
        &mut self,
        model: v1::EmbeddingModel,
        data: Vec<String>,
    ) -> Result<v1::EmbeddingsResult, v1::Error> {
        <Self as v2::Host>::generate_embeddings(self, model, data)
            .await
            .map(Into::into)
            .map_err(Into::into)
    }

    fn convert_error(&mut self, error: v1::Error) -> anyhow::Result<v1::Error> {
        Ok(error)
    }
}

fn access_denied_error(model: &str) -> v2::Error {
    v2::Error::InvalidInput(format!(
        "The component does not have access to use '{model}'. To give the component access, add '{model}' to the 'ai_models' key for the component in your spin.toml manifest"
    ))
}