mistralrs_server_core/
lib.rs

1//! > **mistral.rs server core**
2//!
3//! ## About
4//!
5//! This crate powers mistral.rs server. It exposes the underlying functionality
6//! allowing others to implement and extend the server implementation.
7//!
8//! ### Features
9//! 1. Incorporate mistral.rs server into another axum.rs project.
10//! 2. Hook into the mistral.rs server lifecycle.
11//!
12//! ### Example
13//! ```ignore
14//! use std::sync::Arc;
15//!
16//! use axum::{
17//!     Json, Router,
18//!     extract::State,
19//!     routing::{get, post},
20//! };
21//! use utoipa::OpenApi;
22//! use utoipa_swagger_ui::SwaggerUi;
23//!
24//! use mistralrs::{
25//!    AutoDeviceMapParams, ChatCompletionChunkResponse, ModelDType, ModelSelected, initialize_logging,
26//! };
27//! use mistralrs_server_core::{
28//!     chat_completion::{
29//!         ChatCompletionResponder, OnChunkCallback, OnDoneCallback, create_chat_streamer,
30//!         create_response_channel, handle_chat_completion_error, parse_request,
31//!         process_non_streaming_chat_response, send_request,
32//!     },
33//!     mistralrs_for_server_builder::MistralRsForServerBuilder,
34//!     mistralrs_server_router_builder::MistralRsServerRouterBuilder,
35//!     openai::ChatCompletionRequest,
36//!     openapi_doc::get_openapi_doc,
37//!     types::SharedMistralRsState,
38//! };
39//!
40//! #[derive(OpenApi)]
41//! #[openapi(
42//!     paths(root, custom_chat),
43//!     tags(
44//!         (name = "hello", description = "Hello world endpoints")
45//!     ),
46//!     info(
47//!         title = "Hello World API",
48//!         version = "1.0.0",
49//!         description = "A simple API that responds with a greeting"
50//!     )
51//! )]
52//! struct ApiDoc;
53//!
54//! #[derive(Clone)]
55//! pub struct AppState {
56//!     pub mistralrs_state: SharedMistralRsState,
57//!     pub db_create: fn(),
58//! }
59//!
60//! #[tokio::main]
61//! async fn main() {
62//!     initialize_logging();
63//!     
64//!     let plain_model_id = String::from("meta-llama/Llama-3.2-1B-Instruct");
65//!     let tokenizer_json = None;
66//!     let arch = None;
67//!     let organization = None;
68//!     let write_uqff = None;
69//!     let from_uqff = None;
70//!     let imatrix = None;
71//!     let calibration_file = None;
72//!     let hf_cache_path = None;
73//!
74//!     let dtype = ModelDType::Auto;
75//!     let topology = None;
76//!     let max_seq_len = AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN;
77//!     let max_batch_size = AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE;
78//!
79//!     let model = ModelSelected::Plain {
80//!         model_id: plain_model_id,
81//!         tokenizer_json,
82//!         arch,
83//!         dtype,
84//!         topology,
85//!         organization,
86//!         write_uqff,
87//!         from_uqff,
88//!         imatrix,
89//!         calibration_file,
90//!         max_seq_len,
91//!         max_batch_size,
92//!         hf_cache_path,
93//!     };
94//!
95//!     let shared_mistralrs = MistralRsForServerBuilder::new()
96//!         .with_model(model)
97//!         .with_in_situ_quant("8".to_string())
98//!         .with_paged_attn(true)
99//!         .build()
100//!         .await
101//!         .unwrap();
102//!
103//!     let mistralrs_base_path = "/api/mistral";
104//!
105//!     let mistralrs_routes = MistralRsServerRouterBuilder::new()
106//!         .with_mistralrs(shared_mistralrs.clone())
107//!         .with_include_swagger_routes(false)
108//!         .with_base_path(mistralrs_base_path)
109//!         .build()
110//!         .await
111//!         .unwrap();
112//!
113//!     let mistralrs_doc = get_openapi_doc(Some(mistralrs_base_path));
114//!     let mut api_docs = ApiDoc::openapi();
115//!     api_docs.merge(mistralrs_doc);
116//!
117//!     let app_state = Arc::new(AppState {
118//!         mistralrs_state: shared_mistralrs,
119//!         db_create: mock_db_call,
120//!     });
121//!
122//!     let app = Router::new()
123//!         .route("/", get(root))
124//!         .route("/chat", post(custom_chat))
125//!         .with_state(app_state.clone())
126//!         .nest(mistralrs_base_path, mistralrs_routes)
127//!         .merge(SwaggerUi::new("/api-docs").url("/api-docs/openapi.json", api_docs));
128//!
129//!     let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
130//!     axum::serve(listener, app).await.unwrap();
131//!
132//!     println!("Listening on 0.0.0.0:3000");
133//! }
134//!
135//! #[utoipa::path(
136//!     get,
137//!     path = "/",
138//!     tag = "hello",
139//!     responses(
140//!         (status = 200, description = "Successful response with greeting message", body = String)
141//!     )
142//! )]
143//! async fn root() -> &'static str {
144//!     "Hello, World!"
145//! }
146//!
147//! #[utoipa::path(
148//!   post,
149//!   tag = "Custom",
150//!   path = "/chat",
151//!   request_body = ChatCompletionRequest,
152//!   responses((status = 200, description = "Chat completions"))
153//! )]
154//! pub async fn custom_chat(
155//!     State(state): State<Arc<AppState>>,
156//!     Json(oai_request): Json<ChatCompletionRequest>,
157//! ) -> ChatCompletionResponder {
158//!     let mistralrs_state = state.mistralrs_state.clone();
159//!     let (tx, mut rx) = create_response_channel(None);
160//!
161//!     let (request, is_streaming) = match parse_request(oai_request, mistralrs_state.clone(), tx).await
162//!     {
163//!         Ok(x) => x,
164//!         Err(e) => return handle_chat_completion_error(mistralrs_state, e.into()),
165//!     };
166//!
167//!     dbg!(request.clone());
168//!
169//!     if let Err(e) = send_request(&mistralrs_state, request).await {
170//!         return handle_chat_completion_error(mistralrs_state, e.into());
171//!     }
172//!
173//!     if is_streaming {
174//!         let db_fn = state.db_create;
175//!
176//!         let on_chunk: OnChunkCallback = Box::new(move |mut chunk: ChatCompletionChunkResponse| {
177//!             dbg!(&chunk);
178//!
179//!             if let Some(original_content) = &chunk.choices[0].delta.content {
180//!                 chunk.choices[0].delta.content = Some(format!("CHANGED! {}", original_content));
181//!             }
182//!
183//!             chunk.clone()
184//!         });
185//!
186//!         let on_done: OnDoneCallback = Box::new(move |chunks: &[ChatCompletionChunkResponse]| {
187//!             dbg!(chunks);
188//!             (db_fn)();
189//!         });
190//!
191//!         let streamer =
192//!             create_chat_streamer(rx, mistralrs_state.clone(), Some(on_chunk), Some(on_done));
193//!
194//!         ChatCompletionResponder::Sse(streamer)
195//!     } else {
196//!         let response = process_non_streaming_chat_response(&mut rx, mistralrs_state.clone()).await;
197//!
198//!         match &response {
199//!             ChatCompletionResponder::Json(json_response) => {
200//!                 dbg!(json_response);
201//!                 (state.db_create)();
202//!             }
203//!             _ => {
204//!                 //
205//!             }
206//!         }
207//!
208//!         response
209//!     }
210//! }
211//!
212//! pub fn mock_db_call() {
213//!     println!("Saving to DB");
214//! }
215//! ```
216
217pub mod chat_completion;
218mod completions;
219mod handlers;
220mod image_generation;
221pub mod mistralrs_for_server_builder;
222pub mod mistralrs_server_router_builder;
223pub mod openai;
224pub mod openapi_doc;
225mod speech_generation;
226pub mod types;
227pub mod util;