mistralrs_server_core/lib.rs
1//! > **mistral.rs server core**
2//!
3//! ## About
4//!
5//! This crate powers mistral.rs server. It exposes the underlying functionality
6//! allowing others to implement and extend the server implementation.
7//!
8//! ### Features
9//! 1. Incorporate mistral.rs server into another axum.rs project.
10//! 2. Hook into the mistral.rs server lifecycle.
11//!
12//! ### Example
13//! ```ignore
14//! use std::sync::Arc;
15//!
16//! use axum::{
17//! Json, Router,
18//! extract::State,
19//! routing::{get, post},
20//! };
21//! use utoipa::OpenApi;
22//! use utoipa_swagger_ui::SwaggerUi;
23//!
24//! use mistralrs::{
25//! AutoDeviceMapParams, ChatCompletionChunkResponse, ModelDType, ModelSelected, initialize_logging,
26//! };
27//! use mistralrs_server_core::{
28//! chat_completion::{
29//! ChatCompletionResponder, OnChunkCallback, OnDoneCallback, create_chat_streamer,
30//! create_response_channel, handle_chat_completion_error, parse_request,
31//! process_non_streaming_chat_response, send_request,
32//! },
33//! mistralrs_for_server_builder::MistralRsForServerBuilder,
34//! mistralrs_server_router_builder::MistralRsServerRouterBuilder,
35//! openai::ChatCompletionRequest,
36//! openapi_doc::get_openapi_doc,
37//! types::SharedMistralRsState,
38//! };
39//!
40//! #[derive(OpenApi)]
41//! #[openapi(
42//! paths(root, custom_chat),
43//! tags(
44//! (name = "hello", description = "Hello world endpoints")
45//! ),
46//! info(
47//! title = "Hello World API",
48//! version = "1.0.0",
49//! description = "A simple API that responds with a greeting"
50//! )
51//! )]
52//! struct ApiDoc;
53//!
54//! #[derive(Clone)]
55//! pub struct AppState {
56//! pub mistralrs_state: SharedMistralRsState,
57//! pub db_create: fn(),
58//! }
59//!
60//! #[tokio::main]
61//! async fn main() {
62//! initialize_logging();
63//!
64//! let plain_model_id = String::from("meta-llama/Llama-3.2-1B-Instruct");
65//! let tokenizer_json = None;
66//! let arch = None;
67//! let organization = None;
68//! let write_uqff = None;
69//! let from_uqff = None;
70//! let imatrix = None;
71//! let calibration_file = None;
72//! let hf_cache_path = None;
73//!
74//! let dtype = ModelDType::Auto;
75//! let topology = None;
76//! let max_seq_len = AutoDeviceMapParams::DEFAULT_MAX_SEQ_LEN;
77//! let max_batch_size = AutoDeviceMapParams::DEFAULT_MAX_BATCH_SIZE;
78//!
79//! let model = ModelSelected::Plain {
80//! model_id: plain_model_id,
81//! tokenizer_json,
82//! arch,
83//! dtype,
84//! topology,
85//! organization,
86//! write_uqff,
87//! from_uqff,
88//! imatrix,
89//! calibration_file,
90//! max_seq_len,
91//! max_batch_size,
92//! hf_cache_path,
93//! };
94//!
95//! let shared_mistralrs = MistralRsForServerBuilder::new()
96//! .with_model(model)
97//! .with_in_situ_quant("8".to_string())
98//! .with_paged_attn(true)
99//! .build()
100//! .await
101//! .unwrap();
102//!
103//! let mistralrs_base_path = "/api/mistral";
104//!
105//! let mistralrs_routes = MistralRsServerRouterBuilder::new()
106//! .with_mistralrs(shared_mistralrs.clone())
107//! .with_include_swagger_routes(false)
108//! .with_base_path(mistralrs_base_path)
109//! .build()
110//! .await
111//! .unwrap();
112//!
113//! let mistralrs_doc = get_openapi_doc(Some(mistralrs_base_path));
114//! let mut api_docs = ApiDoc::openapi();
115//! api_docs.merge(mistralrs_doc);
116//!
117//! let app_state = Arc::new(AppState {
118//! mistralrs_state: shared_mistralrs,
119//! db_create: mock_db_call,
120//! });
121//!
122//! let app = Router::new()
123//! .route("/", get(root))
124//! .route("/chat", post(custom_chat))
125//! .with_state(app_state.clone())
126//! .nest(mistralrs_base_path, mistralrs_routes)
127//! .merge(SwaggerUi::new("/api-docs").url("/api-docs/openapi.json", api_docs));
128//!
129//! let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
130//! axum::serve(listener, app).await.unwrap();
131//!
132//! println!("Listening on 0.0.0.0:3000");
133//! }
134//!
135//! #[utoipa::path(
136//! get,
137//! path = "/",
138//! tag = "hello",
139//! responses(
140//! (status = 200, description = "Successful response with greeting message", body = String)
141//! )
142//! )]
143//! async fn root() -> &'static str {
144//! "Hello, World!"
145//! }
146//!
147//! #[utoipa::path(
148//! post,
149//! tag = "Custom",
150//! path = "/chat",
151//! request_body = ChatCompletionRequest,
152//! responses((status = 200, description = "Chat completions"))
153//! )]
154//! pub async fn custom_chat(
155//! State(state): State<Arc<AppState>>,
156//! Json(oai_request): Json<ChatCompletionRequest>,
157//! ) -> ChatCompletionResponder {
158//! let mistralrs_state = state.mistralrs_state.clone();
159//! let (tx, mut rx) = create_response_channel(None);
160//!
161//! let (request, is_streaming) = match parse_request(oai_request, mistralrs_state.clone(), tx).await
162//! {
163//! Ok(x) => x,
164//! Err(e) => return handle_chat_completion_error(mistralrs_state, e.into()),
165//! };
166//!
167//! dbg!(request.clone());
168//!
169//! if let Err(e) = send_request(&mistralrs_state, request).await {
170//! return handle_chat_completion_error(mistralrs_state, e.into());
171//! }
172//!
173//! if is_streaming {
174//! let db_fn = state.db_create;
175//!
176//! let on_chunk: OnChunkCallback = Box::new(move |mut chunk: ChatCompletionChunkResponse| {
177//! dbg!(&chunk);
178//!
179//! if let Some(original_content) = &chunk.choices[0].delta.content {
180//! chunk.choices[0].delta.content = Some(format!("CHANGED! {}", original_content));
181//! }
182//!
183//! chunk.clone()
184//! });
185//!
186//! let on_done: OnDoneCallback = Box::new(move |chunks: &[ChatCompletionChunkResponse]| {
187//! dbg!(chunks);
188//! (db_fn)();
189//! });
190//!
191//! let streamer =
192//! create_chat_streamer(rx, mistralrs_state.clone(), Some(on_chunk), Some(on_done));
193//!
194//! ChatCompletionResponder::Sse(streamer)
195//! } else {
196//! let response = process_non_streaming_chat_response(&mut rx, mistralrs_state.clone()).await;
197//!
198//! match &response {
199//! ChatCompletionResponder::Json(json_response) => {
200//! dbg!(json_response);
201//! (state.db_create)();
202//! }
203//! _ => {
204//! //
205//! }
206//! }
207//!
208//! response
209//! }
210//! }
211//!
212//! pub fn mock_db_call() {
213//! println!("Saving to DB");
214//! }
215//! ```
216
217pub mod chat_completion;
218mod completions;
219mod handlers;
220mod image_generation;
221pub mod mistralrs_for_server_builder;
222pub mod mistralrs_server_router_builder;
223pub mod openai;
224pub mod openapi_doc;
225mod speech_generation;
226pub mod types;
227pub mod util;