1use std::{
2 collections::HashMap,
3 fmt::{Debug, Display},
4 str::FromStr,
5 sync::Arc,
6};
7
8use crate::{
9 amoe::AnyMoeBaseModelMixin,
10 device_map::DeviceMapper,
11 lora::{LoraConfig, Ordering},
12 paged_attention::{AttentionImplementation, ModelConfigLike, ModelConfigMetadata},
13 pipeline::{
14 isq::IsqModelLoader,
15 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
16 EitherCache, IsqModel,
17 },
18 utils::varbuilder_utils::DeviceForLoadTensor,
19 xlora_models::NonGranularState,
20};
21use anyhow::Result;
22use candle_core::{DType, Device, Tensor};
23use mistralrs_quant::log::once_log_info;
24
25use indicatif::MultiProgress;
26use mistralrs_quant::ShardedVarBuilder;
27#[cfg(feature = "pyo3_macros")]
28use pyo3::pyclass;
29
30use regex::Regex;
31use serde::Deserialize;
32
33use crate::{
34 models,
35 xlora_models::{self, XLoraConfig},
36};
37
38use super::{AutoDeviceMapParams, DeviceMappedModelLoader};
39
40pub trait NormalModel: IsqModel + AnyMoeBaseModelMixin {
41 #[allow(clippy::too_many_arguments)]
42 fn forward(
43 &self,
44 input_ids: &Tensor,
45 seqlen_offsets: &[usize],
46 context_lens: Vec<(usize, usize)>,
47 position_ids: Vec<usize>,
48 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
49 flash_params: &FlashParams,
50 ) -> candle_core::Result<Tensor>;
51 #[allow(clippy::too_many_arguments)]
52 fn xlora_forward(
53 &self,
54 input_ids: &Tensor,
55 input_ids_full: &Tensor,
56 seqlen_offsets: &[usize],
57 seqlen_offsets_full: &[usize],
58 no_kv_cache: bool,
59 non_granular_state: &Option<NonGranularState>,
60 context_lens: Vec<(usize, usize)>,
61 position_ids: Vec<usize>,
62 flash_params: &FlashParams,
63 flash_params_full: &FlashParams,
64 ) -> candle_core::Result<Tensor>;
65 fn is_xlora(&self) -> bool;
66 fn device(&self) -> &Device;
67 fn cache(&self) -> &EitherCache;
68 fn cache_mut(&mut self) -> &mut EitherCache;
69 fn max_seq_len(&self) -> usize;
70 fn config(&self) -> &ModelConfigMetadata;
71}
72
73pub struct NormalLoadingMetadata {
75 pub mapper: Box<dyn DeviceMapper + Send + Sync>,
77 pub loading_isq: bool,
79 pub real_device: Device,
81 pub multi_progress: Arc<MultiProgress>,
83}
84
85pub trait NormalModelLoader: IsqModelLoader + Send + Sync + DeviceMappedModelLoader {
86 fn load(
87 &self,
88 config: &str,
89 vb: ShardedVarBuilder,
90 normal_loading_metadata: NormalLoadingMetadata,
91 attention_mechanism: AttentionImplementation,
92 ) -> Result<Box<dyn NormalModel + Send + Sync>>;
93 #[allow(clippy::too_many_arguments)]
94 fn load_xlora(
95 &self,
96 config: &str,
97 vb: ShardedVarBuilder,
98 lora_config: &[((String, String), LoraConfig)],
99 xlora_config: Option<XLoraConfig>,
100 xlora_ordering: Ordering,
101 normal_loading_metadata: NormalLoadingMetadata,
102 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
103 ) -> Result<Box<dyn NormalModel + Send + Sync>>;
104 fn is_gptx(&self, config: &str) -> Result<bool>;
105 fn supports_paged_attention(&self, _config: &str) -> Result<bool> {
106 Ok(true)
107 }
108 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>>;
109 fn get_device_for_tensor(
110 &self,
111 config: &str,
112 _mapper: &dyn DeviceMapper,
113 loading_isq: bool,
114 ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
115 if loading_isq {
116 Ok(Arc::new(|_| DeviceForLoadTensor::Base))
117 } else {
118 let re = Regex::new(r"\.layers\.(\d+)\.").unwrap();
119 let num_layers = self.model_config(config)?.num_layers();
120 let closure = move |name: String| {
121 if let Some(captures) = re.captures(&name) {
122 captures
123 .get(1)
124 .and_then(|m| m.as_str().parse::<usize>().ok())
125 .map(|l| l.min(num_layers))
126 .map(DeviceForLoadTensor::Idx)
127 .unwrap_or(DeviceForLoadTensor::Base)
128 } else {
129 DeviceForLoadTensor::Base
130 }
131 };
132
133 Ok(Arc::new(closure))
134 }
135 }
136}
137
138#[cfg_attr(feature = "pyo3_macros", pyclass(eq, eq_int))]
139#[derive(Clone, Debug, Deserialize, PartialEq)]
140pub enum NormalLoaderType {
142 #[serde(rename = "mistral")]
143 Mistral,
144 #[serde(rename = "gemma")]
145 Gemma,
146 #[serde(rename = "mixtral")]
147 Mixtral,
148 #[serde(rename = "llama")]
149 Llama,
150 #[serde(rename = "phi2")]
151 Phi2,
152 #[serde(rename = "phi3")]
153 Phi3,
154 #[serde(rename = "qwen2")]
155 Qwen2,
156 #[serde(rename = "gemma2")]
157 Gemma2,
158 #[serde(rename = "starcoder2")]
159 Starcoder2,
160 #[serde(rename = "phi3.5moe")]
161 Phi3_5MoE,
162 #[serde(rename = "deepseekv2")]
163 DeepSeekV2,
164 #[serde(rename = "deepseekv3")]
165 DeepSeekV3,
166 #[serde(rename = "qwen3")]
167 Qwen3,
168 #[serde(rename = "glm4")]
169 GLM4,
170 #[serde(rename = "qwen3moe")]
171 Qwen3Moe,
172}
173
174impl NormalLoaderType {
176 pub fn from_causal_lm_name(name: &str) -> Result<Self> {
177 match name {
178 "MistralForCausalLM" => Ok(Self::Mistral),
179 "MixtralForCausalLM" => Ok(Self::Mixtral),
180 "GemmaForCausalLM" => Ok(Self::Gemma),
181 "Gemma2ForCausalLM" => Ok(Self::Gemma2),
182 "PhiForCausalLM" => Ok(Self::Phi2),
183 "Phi3ForCausalLM" => Ok(Self::Phi3),
184 "LlamaForCausalLM" => Ok(Self::Llama),
185 "Qwen2ForCausalLM" => Ok(Self::Qwen2),
186 "Starcoder2ForCausalLM" => Ok(Self::Starcoder2),
187 "PhiMoEForCausalLM" => Ok(Self::Phi3_5MoE),
188 "DeepseekV2ForCausalLM" => Ok(Self::DeepSeekV2),
189 "DeepseekV3ForCausalLM" => Ok(Self::DeepSeekV3),
190 "Qwen3ForCausalLM" => Ok(Self::Qwen3),
191 "Glm4ForCausalLM" => Ok(Self::GLM4),
192 "Qwen3MoeForCausalLM" => Ok(Self::Qwen3Moe),
193 other => anyhow::bail!(
194 "Unsupported Hugging Face Transformers -CausalLM model class `{other}`. Please raise an issue."
195 ),
196 }
197 }
198}
199
200impl FromStr for NormalLoaderType {
201 type Err = String;
202 fn from_str(s: &str) -> Result<Self, Self::Err> {
203 match s {
204 "mistral" => Ok(Self::Mistral),
205 "gemma" => Ok(Self::Gemma),
206 "mixtral" => Ok(Self::Mixtral),
207 "llama" => Ok(Self::Llama),
208 "phi2" => Ok(Self::Phi2),
209 "phi3" => Ok(Self::Phi3),
210 "qwen2" => Ok(Self::Qwen2),
211 "gemma2" => Ok(Self::Gemma2),
212 "starcoder2" => Ok(Self::Starcoder2),
213 "phi3.5moe" => Ok(Self::Phi3_5MoE),
214 "deepseekv2" => Ok(Self::DeepSeekV2),
215 "deepseekv3" => Ok(Self::DeepSeekV3),
216 "qwen3" => Ok(Self::Qwen3),
217 "glm4" => Ok(Self::GLM4),
218 "qwen3moe" => Ok(Self::Qwen3Moe),
219 a => Err(format!("Unknown architecture `{a}`. Possible architectures: `mistral`, `gemma`, `mixtral`, `llama`, `phi2`, `phi3`, `qwen2`, `gemma2`, `starcoder2`, `phi3.5moe`, `deepseekv2`, `deepseekv3`, `qwen3`, `glm4`, `qwen3moe`.")),
220 }
221 }
222}
223
224impl Display for NormalLoaderType {
225 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
226 match self {
227 Self::Gemma => write!(f, "gemma"),
228 Self::Gemma2 => write!(f, "gemma2"),
229 Self::Llama => write!(f, "llama"),
230 Self::Mistral => write!(f, "mistral"),
231 Self::Mixtral => write!(f, "mixtral"),
232 Self::Phi2 => write!(f, "phi2"),
233 Self::Phi3 => write!(f, "phi3"),
234 Self::Phi3_5MoE => write!(f, "phi3.5moe"),
235 Self::Qwen2 => write!(f, "qwen2"),
236 Self::Starcoder2 => write!(f, "starcoder2"),
237 Self::DeepSeekV2 => write!(f, "deepseekv2"),
238 Self::DeepSeekV3 => write!(f, "deepseekv3"),
239 Self::Qwen3 => write!(f, "qwen3"),
240 Self::GLM4 => write!(f, "glm4"),
241 Self::Qwen3Moe => write!(f, "qwen3moe"),
242 }
243 }
244}
245
246macro_rules! bias_if {
247 ($cond:expr, $size:expr) => {
248 if $cond {
249 $size
250 } else {
251 0
252 }
253 };
254}
255
256pub struct AutoNormalLoader;
258
259#[derive(Deserialize)]
260struct AutoNormalLoaderConfig {
261 architectures: Vec<String>,
262}
263
264impl AutoNormalLoader {
265 fn get_loader(config: &str) -> Result<Box<dyn NormalModelLoader>> {
266 let auto_cfg: AutoNormalLoaderConfig = serde_json::from_str(config)?;
267 if auto_cfg.architectures.len() != 1 {
268 anyhow::bail!("Expected to have one name for `architectures` config field.")
269 }
270
271 let name = &auto_cfg.architectures[0];
272
273 let tp = NormalLoaderType::from_causal_lm_name(name)?;
274
275 once_log_info(format!("Automatic loader type determined to be `{tp}`"));
276
277 match tp {
278 NormalLoaderType::Mistral => Ok(Box::new(MistralLoader)),
279 NormalLoaderType::Gemma => Ok(Box::new(GemmaLoader)),
280 NormalLoaderType::Llama => Ok(Box::new(LlamaLoader)),
281 NormalLoaderType::Mixtral => Ok(Box::new(MixtralLoader)),
282 NormalLoaderType::Phi2 => Ok(Box::new(Phi2Loader)),
283 NormalLoaderType::Phi3 => Ok(Box::new(Phi3Loader)),
284 NormalLoaderType::Qwen2 => Ok(Box::new(Qwen2Loader)),
285 NormalLoaderType::Gemma2 => Ok(Box::new(Gemma2Loader)),
286 NormalLoaderType::Starcoder2 => Ok(Box::new(Starcoder2Loader)),
287 NormalLoaderType::Phi3_5MoE => Ok(Box::new(Phi3_5MoELoader)),
288 NormalLoaderType::DeepSeekV2 => Ok(Box::new(DeepSeekV2Loader)),
289 NormalLoaderType::DeepSeekV3 => Ok(Box::new(DeepSeekV3Loader)),
290 NormalLoaderType::Qwen3 => Ok(Box::new(Qwen3Loader)),
291 NormalLoaderType::GLM4 => Ok(Box::new(GLM4Loader)),
292 NormalLoaderType::Qwen3Moe => Ok(Box::new(Qwen3MoELoader)),
293 }
294 }
295}
296
297impl NormalModelLoader for AutoNormalLoader {
298 fn load(
299 &self,
300 config: &str,
301 vb: ShardedVarBuilder,
302 normal_loading_metadata: NormalLoadingMetadata,
303 attention_mechanism: AttentionImplementation,
304 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
305 Self::get_loader(config)?.load(config, vb, normal_loading_metadata, attention_mechanism)
306 }
307 fn load_xlora(
308 &self,
309 config: &str,
310 vb: ShardedVarBuilder,
311 lora_config: &[((String, String), LoraConfig)],
312 xlora_config: Option<XLoraConfig>,
313 xlora_ordering: Ordering,
314 normal_loading_metadata: NormalLoadingMetadata,
315 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
316 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
317 Self::get_loader(config)?.load_xlora(
318 config,
319 vb,
320 lora_config,
321 xlora_config,
322 xlora_ordering,
323 normal_loading_metadata,
324 preload_adapters,
325 )
326 }
327 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
328 Self::get_loader(config)?.get_config_repr(config)
329 }
330 fn supports_paged_attention(&self, config: &str) -> Result<bool> {
331 Self::get_loader(config)?.supports_paged_attention(config)
332 }
333 fn is_gptx(&self, config: &str) -> Result<bool> {
334 Self::get_loader(config)?.is_gptx(config)
335 }
336}
337
338impl IsqModelLoader for AutoNormalLoader {
339 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
340 Self::get_loader(config)?.immediate_isq_predicates(config)
341 }
342 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
343 Self::get_loader(config)?.immediate_isq_predicates_moqe(config)
344 }
345 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
346 Self::get_loader(config)?.isq_layer_regexes(config)
347 }
348 fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
349 Self::get_loader(config)?.isq_layer_regexes_moqe(config)
350 }
351}
352
353impl DeviceMappedModelLoader for AutoNormalLoader {
354 fn non_mapped_size_in_bytes(
355 &self,
356 config: &str,
357 dtype: DType,
358 weight_pack_factor: usize,
359 ) -> Result<usize> {
360 Self::get_loader(config)?.non_mapped_size_in_bytes(config, dtype, weight_pack_factor)
361 }
362 fn num_layers(&self, config: &str) -> Result<usize> {
363 Self::get_loader(config)?.num_layers(config)
364 }
365 fn layer_sizes_in_bytes(
366 &self,
367 config: &str,
368 dtype: DType,
369 weight_pack_factor: usize,
370 ) -> Result<Vec<usize>> {
371 Self::get_loader(config)?.layer_sizes_in_bytes(config, dtype, weight_pack_factor)
372 }
373 fn mapped_max_act_size_elems(
374 &self,
375 config: &str,
376 params: &super::AutoDeviceMapParams,
377 prompt_chunksize: usize,
378 ) -> Result<usize> {
379 Self::get_loader(config)?.mapped_max_act_size_elems(config, params, prompt_chunksize)
380 }
381 fn non_mapped_max_act_size_elems(
382 &self,
383 _config: &str,
384 _params: &AutoDeviceMapParams,
385 ) -> Result<usize> {
386 Ok(0)
387 }
388 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
389 Self::get_loader(config)?.model_config(config)
390 }
391}
392
393pub struct MistralLoader;
396
397impl NormalModelLoader for MistralLoader {
398 fn load(
399 &self,
400 config: &str,
401 vb: ShardedVarBuilder,
402 normal_loading_metadata: NormalLoadingMetadata,
403 attention_mechanism: AttentionImplementation,
404 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
405 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
406 Ok(Box::new(models::mistral::Model::new(
407 &cfg,
408 vb,
409 self.is_gptx(config)?,
410 normal_loading_metadata,
411 attention_mechanism,
412 )?))
413 }
414 fn load_xlora(
415 &self,
416 config: &str,
417 vb: ShardedVarBuilder,
418 lora_config: &[((String, String), LoraConfig)],
419 xlora_config: Option<XLoraConfig>,
420 xlora_ordering: Ordering,
421 normal_loading_metadata: NormalLoadingMetadata,
422 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
423 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
424 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
425 Ok(Box::new(xlora_models::XLoraMistral::new(
426 &cfg,
427 vb,
428 lora_config,
429 xlora_config,
430 xlora_ordering,
431 self.is_gptx(config)?,
432 normal_loading_metadata,
433 preload_adapters,
434 )?))
435 }
436 fn is_gptx(&self, _: &str) -> Result<bool> {
437 Ok(true)
438 }
439 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
440 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
441 Ok(Box::new(cfg))
442 }
443}
444
445impl IsqModelLoader for MistralLoader {
446 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
447 Ok(vec![
448 Regex::new(r"lm_head\.(weight|bias)$")?,
449 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
451 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
452 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
453 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
454 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
456 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
457 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
458 ])
459 }
460 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
461 self.isq_layer_regexes(config)
462 }
463}
464
465impl DeviceMappedModelLoader for MistralLoader {
466 fn mapped_max_act_size_elems(
467 &self,
468 config: &str,
469 params: &AutoDeviceMapParams,
470 prompt_chunksize: usize,
471 ) -> Result<usize> {
472 let AutoDeviceMapParams::Text {
473 max_seq_len: _,
474 max_batch_size,
475 } = params
476 else {
477 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
478 };
479
480 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
481
482 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
483 }
484 fn non_mapped_max_act_size_elems(
485 &self,
486 _config: &str,
487 _params: &AutoDeviceMapParams,
488 ) -> Result<usize> {
489 Ok(0)
490 }
491
492 fn non_mapped_size_in_bytes(
493 &self,
494 config: &str,
495 dtype: DType,
496 weight_pack_factor: usize,
497 ) -> Result<usize> {
498 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
499
500 let elems = {
501 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
502 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
504 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
505 } else {
506 0
507 };
508 let norm = cfg.hidden_size;
509 embed_tokens + lm_head + norm
510 };
511 Ok(elems * dtype.size_in_bytes())
512 }
513
514 fn layer_sizes_in_bytes(
515 &self,
516 config: &str,
517 dtype: DType,
518 weight_pack_factor: usize,
519 ) -> Result<Vec<usize>> {
520 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
521
522 let per_layer_elems = {
523 let input_layernorm = cfg.hidden_size;
524 let post_attention_layernorm = cfg.hidden_size;
525
526 let size_in = cfg.hidden_size;
527 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
528 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
529 let q_proj = size_in * size_q / weight_pack_factor;
530 let k_proj = size_in * size_kv / weight_pack_factor;
531 let v_proj = size_in * size_kv / weight_pack_factor;
532 let o_proj = size_q * size_in / weight_pack_factor;
533
534 let h_size = cfg.hidden_size;
535 let i_size = cfg.intermediate_size;
536 let gate_proj = h_size * i_size / weight_pack_factor;
537 let up_proj = h_size * i_size / weight_pack_factor;
538 let down_proj = i_size * h_size / weight_pack_factor;
539
540 input_layernorm
541 + post_attention_layernorm
542 + q_proj
543 + k_proj
544 + v_proj
545 + o_proj
546 + gate_proj
547 + up_proj
548 + down_proj
549 };
550 Ok(vec![
551 per_layer_elems * dtype.size_in_bytes();
552 cfg.num_hidden_layers
553 ])
554 }
555
556 fn num_layers(&self, config: &str) -> Result<usize> {
557 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
558 Ok(cfg.num_hidden_layers)
559 }
560
561 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
562 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
563
564 let cfg = ModelConfigMetadata {
565 max_seq_len: cfg.max_position_embeddings,
566 num_layers: cfg.num_hidden_layers,
567 hidden_size: cfg.hidden_size,
568 num_kv_heads: cfg.num_key_value_heads,
569 num_attn_heads: cfg.num_attention_heads,
570 sliding_window: cfg.sliding_window,
571 k_head_dim: cfg.head_dim(),
572 v_head_dim: cfg.head_dim(),
573 };
574
575 Ok(Box::new(cfg))
576 }
577}
578
579pub struct GemmaLoader;
585
586impl NormalModelLoader for GemmaLoader {
587 fn load(
588 &self,
589 config: &str,
590 vb: ShardedVarBuilder,
591 normal_loading_metadata: NormalLoadingMetadata,
592 attention_mechanism: AttentionImplementation,
593 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
594 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
595
596 Ok(Box::new(models::gemma::Model::new(
597 &cfg,
598 vb,
599 self.is_gptx(config)?,
600 normal_loading_metadata,
601 attention_mechanism,
602 )?))
603 }
604 fn load_xlora(
605 &self,
606 config: &str,
607 vb: ShardedVarBuilder,
608 lora_config: &[((String, String), LoraConfig)],
609 xlora_config: Option<XLoraConfig>,
610 xlora_ordering: Ordering,
611 normal_loading_metadata: NormalLoadingMetadata,
612 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
613 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
614 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
615
616 Ok(Box::new(xlora_models::XLoraGemma::new(
617 &cfg,
618 vb,
619 lora_config,
620 xlora_config,
621 xlora_ordering,
622 self.is_gptx(config)?,
623 normal_loading_metadata,
624 preload_adapters,
625 )?))
626 }
627 fn is_gptx(&self, _: &str) -> Result<bool> {
628 Ok(true)
629 }
630 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
631 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
632 Ok(Box::new(cfg))
633 }
634}
635
636impl IsqModelLoader for GemmaLoader {
637 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
638 Ok(vec![
639 Regex::new(r"lm_head\.(weight|bias)$")?,
640 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
642 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
643 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
644 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
645 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
647 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
648 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
649 ])
650 }
651 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
652 self.isq_layer_regexes(config)
653 }
654}
655
656impl DeviceMappedModelLoader for GemmaLoader {
657 fn mapped_max_act_size_elems(
658 &self,
659 config: &str,
660 params: &AutoDeviceMapParams,
661 prompt_chunksize: usize,
662 ) -> Result<usize> {
663 let AutoDeviceMapParams::Text {
664 max_seq_len: _,
665 max_batch_size,
666 } = params
667 else {
668 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
669 };
670
671 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
672
673 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
674 }
675 fn non_mapped_max_act_size_elems(
676 &self,
677 _config: &str,
678 _params: &AutoDeviceMapParams,
679 ) -> Result<usize> {
680 Ok(0)
681 }
682
683 fn non_mapped_size_in_bytes(
684 &self,
685 config: &str,
686 dtype: DType,
687 weight_pack_factor: usize,
688 ) -> Result<usize> {
689 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
690
691 let elems = {
692 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
693 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
695 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
696 } else {
697 0
698 };
699 let norm = cfg.hidden_size;
700 embed_tokens + lm_head + norm
701 };
702 Ok(elems * dtype.size_in_bytes())
703 }
704
705 fn layer_sizes_in_bytes(
706 &self,
707 config: &str,
708 dtype: DType,
709 weight_pack_factor: usize,
710 ) -> Result<Vec<usize>> {
711 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
712
713 let per_layer_elems = {
714 let input_layernorm = cfg.hidden_size;
715 let post_attention_layernorm = cfg.hidden_size;
716
717 let size_in = cfg.hidden_size;
718 let size_q = cfg.head_dim * cfg.num_attention_heads;
719 let size_kv = cfg.head_dim * cfg.num_key_value_heads;
720 let q_proj =
721 size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
722 let k_proj =
723 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
724 let v_proj =
725 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
726 let o_proj =
727 size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
728
729 let h_size = cfg.hidden_size;
730 let i_size = cfg.intermediate_size;
731 let gate_proj = h_size * i_size / weight_pack_factor;
732 let up_proj = h_size * i_size / weight_pack_factor;
733 let down_proj = i_size * h_size / weight_pack_factor;
734
735 input_layernorm
736 + post_attention_layernorm
737 + q_proj
738 + k_proj
739 + v_proj
740 + o_proj
741 + gate_proj
742 + up_proj
743 + down_proj
744 };
745 Ok(vec![
746 per_layer_elems * dtype.size_in_bytes();
747 cfg.num_hidden_layers
748 ])
749 }
750
751 fn num_layers(&self, config: &str) -> Result<usize> {
752 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
753 Ok(cfg.num_hidden_layers)
754 }
755
756 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
757 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
758
759 let cfg = ModelConfigMetadata {
760 max_seq_len: cfg.max_position_embeddings,
761 num_layers: cfg.num_hidden_layers,
762 hidden_size: cfg.hidden_size,
763 num_kv_heads: cfg.num_key_value_heads,
764 num_attn_heads: cfg.num_attention_heads,
765 sliding_window: None,
766 k_head_dim: cfg.head_dim,
767 v_head_dim: cfg.head_dim,
768 };
769
770 Ok(Box::new(cfg))
771 }
772}
773
774pub struct LlamaLoader;
780
781impl NormalModelLoader for LlamaLoader {
782 fn load(
783 &self,
784 config: &str,
785 vb: ShardedVarBuilder,
786 normal_loading_metadata: NormalLoadingMetadata,
787 attention_mechanism: AttentionImplementation,
788 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
789 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
790
791 Ok(Box::new(models::llama::Llama::new(
792 &cfg,
793 vb,
794 self.is_gptx(config)?,
795 normal_loading_metadata,
796 attention_mechanism,
797 )?))
798 }
799 fn load_xlora(
800 &self,
801 config: &str,
802 vb: ShardedVarBuilder,
803 lora_config: &[((String, String), LoraConfig)],
804 xlora_config: Option<XLoraConfig>,
805 xlora_ordering: Ordering,
806 normal_loading_metadata: NormalLoadingMetadata,
807 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
808 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
809 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
810
811 Ok(Box::new(xlora_models::XLoraLlama::new(
812 &cfg,
813 vb,
814 lora_config,
815 xlora_config,
816 xlora_ordering,
817 self.is_gptx(config)?,
818 normal_loading_metadata,
819 preload_adapters,
820 )?))
821 }
822 fn is_gptx(&self, _: &str) -> Result<bool> {
823 Ok(true)
824 }
825 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
826 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
827 Ok(Box::new(cfg))
828 }
829}
830
831impl IsqModelLoader for LlamaLoader {
832 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
833 Ok(vec![
834 Regex::new(r"lm_head\.(weight|bias)$")?,
835 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
837 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
838 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
839 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
840 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
842 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
843 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
844 ])
845 }
846 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
847 self.isq_layer_regexes(config)
848 }
849}
850
851impl DeviceMappedModelLoader for LlamaLoader {
852 fn mapped_max_act_size_elems(
853 &self,
854 config: &str,
855 params: &AutoDeviceMapParams,
856 prompt_chunksize: usize,
857 ) -> Result<usize> {
858 let AutoDeviceMapParams::Text {
859 max_seq_len: _,
860 max_batch_size,
861 } = params
862 else {
863 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
864 };
865
866 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
867
868 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
869 }
870 fn non_mapped_max_act_size_elems(
871 &self,
872 _config: &str,
873 _params: &AutoDeviceMapParams,
874 ) -> Result<usize> {
875 Ok(0)
876 }
877
878 fn non_mapped_size_in_bytes(
879 &self,
880 config: &str,
881 dtype: DType,
882 weight_pack_factor: usize,
883 ) -> Result<usize> {
884 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
885
886 let elems = {
887 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
888 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
890 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
891 } else {
892 0
893 };
894 let norm = cfg.hidden_size;
895 embed_tokens + lm_head + norm
896 };
897 Ok(elems * dtype.size_in_bytes())
898 }
899
900 fn layer_sizes_in_bytes(
901 &self,
902 config: &str,
903 dtype: DType,
904 weight_pack_factor: usize,
905 ) -> Result<Vec<usize>> {
906 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
907
908 let per_layer_elems = {
909 let input_layernorm = cfg.hidden_size;
910 let post_attention_layernorm = cfg.hidden_size;
911
912 let size_in = cfg.hidden_size;
913 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
914 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
915 let q_proj = size_in * size_q / weight_pack_factor;
916 let k_proj = size_in * size_kv / weight_pack_factor;
917 let v_proj = size_in * size_kv / weight_pack_factor;
918 let o_proj = size_q * size_in / weight_pack_factor;
919
920 let h_size = cfg.hidden_size;
921 let i_size = cfg.intermediate_size;
922 let gate_proj = h_size * i_size / weight_pack_factor;
923 let up_proj = h_size * i_size / weight_pack_factor;
924 let down_proj = i_size * h_size / weight_pack_factor;
925
926 input_layernorm
927 + post_attention_layernorm
928 + q_proj
929 + k_proj
930 + v_proj
931 + o_proj
932 + gate_proj
933 + up_proj
934 + down_proj
935 };
936 Ok(vec![
937 per_layer_elems * dtype.size_in_bytes();
938 cfg.num_hidden_layers
939 ])
940 }
941
942 fn num_layers(&self, config: &str) -> Result<usize> {
943 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
944
945 Ok(cfg.num_hidden_layers)
946 }
947 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
948 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
949
950 let cfg = ModelConfigMetadata {
951 max_seq_len: cfg.max_position_embeddings,
952 num_layers: cfg.num_hidden_layers,
953 hidden_size: cfg.hidden_size,
954 num_kv_heads: cfg.num_key_value_heads,
955 num_attn_heads: cfg.num_attention_heads,
956 sliding_window: None,
957 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
958 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
959 };
960
961 Ok(Box::new(cfg))
962 }
963}
964
965pub struct MixtralLoader;
968
969impl NormalModelLoader for MixtralLoader {
970 fn load(
971 &self,
972 config: &str,
973 vb: ShardedVarBuilder,
974 normal_loading_metadata: NormalLoadingMetadata,
975 attention_mechanism: AttentionImplementation,
976 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
977 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
978
979 Ok(Box::new(models::mixtral::Model::new(
980 &cfg,
981 vb,
982 self.is_gptx(config)?,
983 normal_loading_metadata,
984 attention_mechanism,
985 )?))
986 }
987 fn load_xlora(
988 &self,
989 config: &str,
990 vb: ShardedVarBuilder,
991 lora_config: &[((String, String), LoraConfig)],
992 xlora_config: Option<XLoraConfig>,
993 xlora_ordering: Ordering,
994 normal_loading_metadata: NormalLoadingMetadata,
995 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
996 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
997 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
998
999 Ok(Box::new(xlora_models::XLoraMixtral::new(
1000 &cfg,
1001 vb,
1002 lora_config,
1003 xlora_config,
1004 xlora_ordering,
1005 self.is_gptx(config)?,
1006 normal_loading_metadata,
1007 preload_adapters,
1008 )?))
1009 }
1010 fn is_gptx(&self, _: &str) -> Result<bool> {
1011 Ok(true)
1012 }
1013 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1014 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1015
1016 Ok(Box::new(cfg))
1017 }
1018}
1019
1020impl IsqModelLoader for MixtralLoader {
1021 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1022 Ok(vec![
1023 Regex::new(r"lm_head\.(weight|bias)$")?,
1024 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1026 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1027 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1028 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1029 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.gate\.(weight|bias)$")?,
1031 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
1032 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
1033 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
1034 ])
1035 }
1036 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1037 self.isq_layer_regexes(config)
1038 }
1039}
1040
1041impl DeviceMappedModelLoader for MixtralLoader {
1042 fn mapped_max_act_size_elems(
1043 &self,
1044 config: &str,
1045 params: &AutoDeviceMapParams,
1046 prompt_chunksize: usize,
1047 ) -> Result<usize> {
1048 let AutoDeviceMapParams::Text {
1049 max_seq_len: _,
1050 max_batch_size,
1051 } = params
1052 else {
1053 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1054 };
1055
1056 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1057
1058 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
1059 }
1060 fn non_mapped_max_act_size_elems(
1061 &self,
1062 _config: &str,
1063 _params: &AutoDeviceMapParams,
1064 ) -> Result<usize> {
1065 Ok(0)
1066 }
1067
1068 fn non_mapped_size_in_bytes(
1069 &self,
1070 config: &str,
1071 dtype: DType,
1072 weight_pack_factor: usize,
1073 ) -> Result<usize> {
1074 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1075
1076 let elems = {
1077 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1078 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1080 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1081 } else {
1082 0
1083 };
1084 let norm = cfg.hidden_size;
1085 embed_tokens + lm_head + norm
1086 };
1087 Ok(elems * dtype.size_in_bytes())
1088 }
1089
1090 fn layer_sizes_in_bytes(
1091 &self,
1092 config: &str,
1093 dtype: DType,
1094 weight_pack_factor: usize,
1095 ) -> Result<Vec<usize>> {
1096 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1097
1098 let per_layer_elems = {
1099 let input_layernorm = cfg.hidden_size;
1100 let post_attention_layernorm = cfg.hidden_size;
1101
1102 let size_in = cfg.hidden_size;
1103 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1104 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1105 let q_proj = size_in * size_q / weight_pack_factor;
1106 let k_proj = size_in * size_kv / weight_pack_factor;
1107 let v_proj = size_in * size_kv / weight_pack_factor;
1108 let o_proj = size_q * size_in / weight_pack_factor;
1109
1110 let moe_block = {
1111 let gate = cfg.hidden_size * cfg.num_local_experts;
1112 let w1 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1114 let w2 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1115 let w3 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1116 gate + cfg.num_local_experts * w1
1117 + cfg.num_local_experts * w2
1118 + cfg.num_local_experts * w3
1119 };
1120
1121 input_layernorm
1122 + post_attention_layernorm
1123 + q_proj
1124 + k_proj
1125 + v_proj
1126 + o_proj
1127 + moe_block
1128 };
1129 Ok(vec![
1130 per_layer_elems * dtype.size_in_bytes();
1131 cfg.num_hidden_layers
1132 ])
1133 }
1134
1135 fn num_layers(&self, config: &str) -> Result<usize> {
1136 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1137
1138 Ok(cfg.num_hidden_layers)
1139 }
1140
1141 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1142 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1143
1144 let cfg = ModelConfigMetadata {
1145 max_seq_len: cfg.max_position_embeddings,
1146 num_layers: cfg.num_hidden_layers,
1147 hidden_size: cfg.hidden_size,
1148 num_kv_heads: cfg.num_key_value_heads,
1149 num_attn_heads: cfg.num_attention_heads,
1150 sliding_window: cfg.sliding_window,
1151 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1152 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1153 };
1154
1155 Ok(Box::new(cfg))
1156 }
1157}
1158
1159pub struct Phi2Loader;
1165
1166impl NormalModelLoader for Phi2Loader {
1167 fn load(
1168 &self,
1169 config: &str,
1170 vb: ShardedVarBuilder,
1171 normal_loading_metadata: NormalLoadingMetadata,
1172 attention_mechanism: AttentionImplementation,
1173 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1174 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1175
1176 Ok(Box::new(models::phi2::Model::new(
1177 &cfg,
1178 vb,
1179 self.is_gptx(config)?,
1180 normal_loading_metadata,
1181 attention_mechanism,
1182 )?))
1183 }
1184 fn load_xlora(
1185 &self,
1186 config: &str,
1187 vb: ShardedVarBuilder,
1188 lora_config: &[((String, String), LoraConfig)],
1189 xlora_config: Option<XLoraConfig>,
1190 xlora_ordering: Ordering,
1191 normal_loading_metadata: NormalLoadingMetadata,
1192 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1193 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1194 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1195
1196 Ok(Box::new(xlora_models::XLoraPhi2::new(
1197 &cfg,
1198 vb,
1199 lora_config,
1200 xlora_config,
1201 xlora_ordering,
1202 self.is_gptx(config)?,
1203 normal_loading_metadata,
1204 preload_adapters,
1205 )?))
1206 }
1207 fn is_gptx(&self, _: &str) -> Result<bool> {
1208 Ok(true)
1209 }
1210 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1211 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1212
1213 Ok(Box::new(cfg))
1214 }
1215}
1216
1217impl IsqModelLoader for Phi2Loader {
1218 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1219 Ok(vec![
1220 Regex::new(r"lm_head\.(weight|bias)$")?,
1221 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1223 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1224 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1225 Regex::new(r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?,
1226 Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
1228 Regex::new(r"layers\.(\d+)\.mlp\.fc2\.(weight|bias)$")?,
1229 ])
1230 }
1231 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1232 self.isq_layer_regexes(config)
1233 }
1234}
1235
1236impl DeviceMappedModelLoader for Phi2Loader {
1237 fn mapped_max_act_size_elems(
1238 &self,
1239 config: &str,
1240 params: &AutoDeviceMapParams,
1241 prompt_chunksize: usize,
1242 ) -> Result<usize> {
1243 let AutoDeviceMapParams::Text {
1244 max_seq_len: _,
1245 max_batch_size,
1246 } = params
1247 else {
1248 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1249 };
1250
1251 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1252
1253 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
1254 }
1255 fn non_mapped_max_act_size_elems(
1256 &self,
1257 _config: &str,
1258 _params: &AutoDeviceMapParams,
1259 ) -> Result<usize> {
1260 Ok(0)
1261 }
1262
1263 fn non_mapped_size_in_bytes(
1264 &self,
1265 config: &str,
1266 dtype: DType,
1267 weight_pack_factor: usize,
1268 ) -> Result<usize> {
1269 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1270
1271 let elems = {
1272 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1273 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1275 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1276 } else {
1277 0
1278 };
1279 let norm = cfg.hidden_size;
1280 embed_tokens + lm_head + norm
1281 };
1282 Ok(elems * dtype.size_in_bytes())
1283 }
1284
1285 fn layer_sizes_in_bytes(
1286 &self,
1287 config: &str,
1288 dtype: DType,
1289 weight_pack_factor: usize,
1290 ) -> Result<Vec<usize>> {
1291 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1292
1293 let per_layer_elems = {
1294 let input_layernorm = cfg.hidden_size + cfg.hidden_size;
1295
1296 let size_in = cfg.hidden_size;
1297 let size_q = cfg.head_dim() * cfg.num_attention_heads;
1298 let size_kv = cfg.head_dim() * cfg.num_key_value_heads();
1299 let q_proj = size_in * size_q / weight_pack_factor + size_q;
1300 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
1301 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
1302 let o_proj = size_q * size_in / weight_pack_factor + size_in;
1303 let (q_norm, k_norm) = if cfg.qk_layernorm {
1304 (cfg.head_dim(), cfg.head_dim())
1305 } else {
1306 (0, 0)
1307 };
1308
1309 let h_size = cfg.hidden_size;
1310 let i_size = cfg.intermediate_size;
1311 let fc1 = h_size * i_size / weight_pack_factor;
1312 let fc2 = h_size * i_size / weight_pack_factor;
1313
1314 input_layernorm + q_proj + k_proj + v_proj + o_proj + q_norm + k_norm + fc1 + fc2
1315 };
1316 Ok(vec![
1317 per_layer_elems * dtype.size_in_bytes();
1318 cfg.num_hidden_layers
1319 ])
1320 }
1321
1322 fn num_layers(&self, config: &str) -> Result<usize> {
1323 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1324
1325 Ok(cfg.num_hidden_layers)
1326 }
1327
1328 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1329 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1330
1331 let cfg = ModelConfigMetadata {
1332 max_seq_len: cfg.max_position_embeddings,
1333 num_layers: cfg.num_hidden_layers,
1334 hidden_size: cfg.hidden_size,
1335 num_kv_heads: cfg.num_key_value_heads(),
1336 num_attn_heads: cfg.num_attention_heads,
1337 sliding_window: None,
1338 k_head_dim: cfg.head_dim(),
1339 v_head_dim: cfg.head_dim(),
1340 };
1341
1342 Ok(Box::new(cfg))
1343 }
1344}
1345
1346pub struct Phi3Loader;
1352
1353impl NormalModelLoader for Phi3Loader {
1354 fn load(
1355 &self,
1356 config: &str,
1357 vb: ShardedVarBuilder,
1358 normal_loading_metadata: NormalLoadingMetadata,
1359 attention_mechanism: AttentionImplementation,
1360 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1361 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1362
1363 Ok(Box::new(models::phi3::Model::new(
1364 &cfg,
1365 vb,
1366 self.is_gptx(config)?,
1367 normal_loading_metadata,
1368 attention_mechanism,
1369 )?))
1370 }
1371 fn load_xlora(
1372 &self,
1373 config: &str,
1374 vb: ShardedVarBuilder,
1375 lora_config: &[((String, String), LoraConfig)],
1376 xlora_config: Option<XLoraConfig>,
1377 xlora_ordering: Ordering,
1378 normal_loading_metadata: NormalLoadingMetadata,
1379 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1380 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1381 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1382
1383 Ok(Box::new(xlora_models::XLoraPhi3::new(
1384 &cfg,
1385 vb,
1386 lora_config,
1387 xlora_config,
1388 xlora_ordering,
1389 self.is_gptx(config)?,
1390 normal_loading_metadata,
1391 preload_adapters,
1392 )?))
1393 }
1394 fn is_gptx(&self, _: &str) -> Result<bool> {
1395 Ok(true)
1396 }
1397 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1398 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1399
1400 Ok(Box::new(cfg))
1401 }
1402}
1403
1404impl IsqModelLoader for Phi3Loader {
1405 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1406 Ok(vec![
1407 Regex::new(r"lm_head\.(weight|bias)$")?,
1408 Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
1410 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1411 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1413 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1414 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1415 ])
1416 }
1417 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1418 self.isq_layer_regexes(config)
1419 }
1420}
1421
1422impl DeviceMappedModelLoader for Phi3Loader {
1423 fn mapped_max_act_size_elems(
1424 &self,
1425 config: &str,
1426 params: &AutoDeviceMapParams,
1427 prompt_chunksize: usize,
1428 ) -> Result<usize> {
1429 let AutoDeviceMapParams::Text {
1430 max_seq_len: _,
1431 max_batch_size,
1432 } = params
1433 else {
1434 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1435 };
1436
1437 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1438
1439 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
1440 }
1441 fn non_mapped_max_act_size_elems(
1442 &self,
1443 _config: &str,
1444 _params: &AutoDeviceMapParams,
1445 ) -> Result<usize> {
1446 Ok(0)
1447 }
1448
1449 fn non_mapped_size_in_bytes(
1450 &self,
1451 config: &str,
1452 dtype: DType,
1453 weight_pack_factor: usize,
1454 ) -> Result<usize> {
1455 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1456
1457 let elems = {
1458 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1459 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1461 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1462 } else {
1463 0
1464 };
1465 let norm = cfg.hidden_size;
1466 embed_tokens + lm_head + norm
1467 };
1468 Ok(elems * dtype.size_in_bytes())
1469 }
1470
1471 fn layer_sizes_in_bytes(
1472 &self,
1473 config: &str,
1474 dtype: DType,
1475 weight_pack_factor: usize,
1476 ) -> Result<Vec<usize>> {
1477 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1478
1479 let per_layer_elems = {
1480 let input_layernorm = cfg.hidden_size;
1481 let post_attention_layernorm = cfg.hidden_size;
1482
1483 let size_in = cfg.hidden_size;
1484 let head_dim = cfg.head_dim();
1485 let op_size =
1486 cfg.num_attention_heads * head_dim + 2 * cfg.num_key_value_heads * head_dim;
1487 let qkv_proj = size_in * op_size / weight_pack_factor;
1488 let o_proj =
1489 (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor + size_in;
1490
1491 let h_size = cfg.hidden_size;
1492 let i_size = cfg.intermediate_size;
1493 let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
1494 let down_proj = h_size * i_size / weight_pack_factor;
1495
1496 input_layernorm
1497 + post_attention_layernorm
1498 + qkv_proj
1499 + o_proj
1500 + gate_up_proj
1501 + down_proj
1502 };
1503 Ok(vec![
1504 per_layer_elems * dtype.size_in_bytes();
1505 cfg.num_hidden_layers
1506 ])
1507 }
1508
1509 fn num_layers(&self, config: &str) -> Result<usize> {
1510 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1511
1512 Ok(cfg.num_hidden_layers)
1513 }
1514
1515 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1516 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1517
1518 let cfg = ModelConfigMetadata {
1519 max_seq_len: cfg.max_position_embeddings,
1520 num_layers: cfg.num_hidden_layers,
1521 hidden_size: cfg.hidden_size,
1522 num_kv_heads: cfg.num_key_value_heads,
1523 num_attn_heads: cfg.num_attention_heads,
1524 sliding_window: cfg.sliding_window,
1525 k_head_dim: cfg.head_dim(),
1526 v_head_dim: cfg.head_dim(),
1527 };
1528
1529 Ok(Box::new(cfg))
1530 }
1531}
1532
1533pub struct Qwen2Loader;
1539
1540impl NormalModelLoader for Qwen2Loader {
1541 fn load(
1542 &self,
1543 config: &str,
1544 vb: ShardedVarBuilder,
1545 normal_loading_metadata: NormalLoadingMetadata,
1546 attention_mechanism: AttentionImplementation,
1547 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1548 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1549
1550 Ok(Box::new(models::qwen2::Model::new(
1551 &cfg,
1552 vb,
1553 self.is_gptx(config)?,
1554 normal_loading_metadata,
1555 attention_mechanism,
1556 )?))
1557 }
1558 fn load_xlora(
1559 &self,
1560 _config: &str,
1561 _vb: ShardedVarBuilder,
1562 _lora_config: &[((String, String), LoraConfig)],
1563 _xlora_config: Option<XLoraConfig>,
1564 _xlora_ordering: Ordering,
1565 _normal_loading_metadata: NormalLoadingMetadata,
1566 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1567 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1568 todo!()
1569 }
1570 fn is_gptx(&self, _: &str) -> Result<bool> {
1571 Ok(true)
1572 }
1573 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1574 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1575
1576 Ok(Box::new(cfg))
1577 }
1578}
1579
1580impl IsqModelLoader for Qwen2Loader {
1581 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1582 Ok(vec![
1583 Regex::new(r"lm_head\.(weight|bias)$")?,
1584 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1586 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1587 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1588 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1589 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1591 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1592 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1593 ])
1594 }
1595 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1596 self.isq_layer_regexes(config)
1597 }
1598}
1599
1600impl DeviceMappedModelLoader for Qwen2Loader {
1601 fn mapped_max_act_size_elems(
1602 &self,
1603 config: &str,
1604 params: &AutoDeviceMapParams,
1605 prompt_chunksize: usize,
1606 ) -> Result<usize> {
1607 let AutoDeviceMapParams::Text {
1608 max_seq_len: _,
1609 max_batch_size,
1610 } = params
1611 else {
1612 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1613 };
1614
1615 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1616
1617 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
1618 }
1619 fn non_mapped_max_act_size_elems(
1620 &self,
1621 _config: &str,
1622 _params: &AutoDeviceMapParams,
1623 ) -> Result<usize> {
1624 Ok(0)
1625 }
1626
1627 fn non_mapped_size_in_bytes(
1628 &self,
1629 config: &str,
1630 dtype: DType,
1631 weight_pack_factor: usize,
1632 ) -> Result<usize> {
1633 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1634
1635 let elems = {
1636 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1637 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1639 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1640 } else {
1641 0
1642 };
1643 let norm = cfg.hidden_size;
1644 embed_tokens + lm_head + norm
1645 };
1646 Ok(elems * dtype.size_in_bytes())
1647 }
1648
1649 fn layer_sizes_in_bytes(
1650 &self,
1651 config: &str,
1652 dtype: DType,
1653 weight_pack_factor: usize,
1654 ) -> Result<Vec<usize>> {
1655 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1656
1657 let per_layer_elems = {
1658 let input_layernorm = cfg.hidden_size;
1659 let post_attention_layernorm = cfg.hidden_size;
1660
1661 let size_in = cfg.hidden_size;
1662 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1663 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1664 let q_proj = size_in * size_q / weight_pack_factor + size_q;
1665 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
1666 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
1667 let o_proj = size_q * size_in / weight_pack_factor;
1668
1669 let h_size = cfg.hidden_size;
1670 let i_size = cfg.intermediate_size;
1671 let gate_proj = h_size * i_size / weight_pack_factor;
1672 let up_proj = h_size * i_size / weight_pack_factor;
1673 let down_proj = i_size * h_size / weight_pack_factor;
1674
1675 input_layernorm
1676 + post_attention_layernorm
1677 + q_proj
1678 + k_proj
1679 + v_proj
1680 + o_proj
1681 + gate_proj
1682 + up_proj
1683 + down_proj
1684 };
1685 Ok(vec![
1686 per_layer_elems * dtype.size_in_bytes();
1687 cfg.num_hidden_layers
1688 ])
1689 }
1690
1691 fn num_layers(&self, config: &str) -> Result<usize> {
1692 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1693
1694 Ok(cfg.num_hidden_layers)
1695 }
1696
1697 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1698 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1699
1700 let cfg = ModelConfigMetadata {
1701 max_seq_len: cfg.max_position_embeddings,
1702 num_layers: cfg.num_hidden_layers,
1703 hidden_size: cfg.hidden_size,
1704 num_kv_heads: cfg.num_key_value_heads,
1705 num_attn_heads: cfg.num_attention_heads,
1706 sliding_window: cfg.sliding_window,
1707 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1708 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1709 };
1710
1711 Ok(Box::new(cfg))
1712 }
1713}
1714
1715pub struct Gemma2Loader;
1721
1722impl NormalModelLoader for Gemma2Loader {
1723 fn load(
1724 &self,
1725 config: &str,
1726 vb: ShardedVarBuilder,
1727 normal_loading_metadata: NormalLoadingMetadata,
1728 attention_mechanism: AttentionImplementation,
1729 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1730 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1731
1732 Ok(Box::new(models::gemma2::Model::new(
1733 &cfg,
1734 vb,
1735 self.is_gptx(config)?,
1736 normal_loading_metadata,
1737 attention_mechanism,
1738 )?))
1739 }
1740 fn load_xlora(
1741 &self,
1742 config: &str,
1743 vb: ShardedVarBuilder,
1744 lora_config: &[((String, String), LoraConfig)],
1745 xlora_config: Option<XLoraConfig>,
1746 xlora_ordering: Ordering,
1747 normal_loading_metadata: NormalLoadingMetadata,
1748 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1749 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1750 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1751
1752 Ok(Box::new(xlora_models::XLoraGemma2::new(
1753 &cfg,
1754 vb,
1755 lora_config,
1756 xlora_config,
1757 xlora_ordering,
1758 self.is_gptx(config)?,
1759 normal_loading_metadata,
1760 preload_adapters,
1761 )?))
1762 }
1763 fn is_gptx(&self, _: &str) -> Result<bool> {
1764 Ok(true)
1765 }
1766 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1767 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1768
1769 Ok(Box::new(cfg))
1770 }
1771}
1772
1773impl IsqModelLoader for Gemma2Loader {
1774 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1775 Ok(vec![
1776 Regex::new(r"lm_head\.(weight|bias)$")?,
1777 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1779 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1780 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1781 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1782 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1784 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1785 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1786 ])
1787 }
1788 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1789 self.isq_layer_regexes(config)
1790 }
1791}
1792
1793impl DeviceMappedModelLoader for Gemma2Loader {
1794 fn mapped_max_act_size_elems(
1795 &self,
1796 config: &str,
1797 params: &AutoDeviceMapParams,
1798 prompt_chunksize: usize,
1799 ) -> Result<usize> {
1800 let AutoDeviceMapParams::Text {
1801 max_seq_len: _,
1802 max_batch_size,
1803 } = params
1804 else {
1805 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1806 };
1807
1808 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1809
1810 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
1811 }
1812 fn non_mapped_max_act_size_elems(
1813 &self,
1814 _config: &str,
1815 _params: &AutoDeviceMapParams,
1816 ) -> Result<usize> {
1817 Ok(0)
1818 }
1819
1820 fn non_mapped_size_in_bytes(
1821 &self,
1822 config: &str,
1823 dtype: DType,
1824 weight_pack_factor: usize,
1825 ) -> Result<usize> {
1826 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1827
1828 let elems = {
1829 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1830 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1832 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1833 } else {
1834 0
1835 };
1836 let norm = cfg.hidden_size;
1837 embed_tokens + lm_head + norm
1838 };
1839 Ok(elems * dtype.size_in_bytes())
1840 }
1841
1842 fn layer_sizes_in_bytes(
1843 &self,
1844 config: &str,
1845 dtype: DType,
1846 weight_pack_factor: usize,
1847 ) -> Result<Vec<usize>> {
1848 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1849
1850 let per_layer_elems = {
1851 let input_layernorm = cfg.hidden_size;
1852 let post_attention_layernorm = cfg.hidden_size;
1853
1854 let size_in = cfg.hidden_size;
1855 let size_q = cfg.head_dim * cfg.num_attention_heads;
1856 let size_kv = cfg.head_dim * cfg.num_key_value_heads;
1857 let q_proj =
1858 size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
1859 let k_proj =
1860 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
1861 let v_proj =
1862 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
1863 let o_proj =
1864 size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
1865
1866 let h_size = cfg.hidden_size;
1867 let i_size = cfg.intermediate_size;
1868 let gate_proj = h_size * i_size / weight_pack_factor;
1869 let up_proj = h_size * i_size / weight_pack_factor;
1870 let down_proj = i_size * h_size / weight_pack_factor;
1871
1872 input_layernorm
1873 + post_attention_layernorm
1874 + q_proj
1875 + k_proj
1876 + v_proj
1877 + o_proj
1878 + gate_proj
1879 + up_proj
1880 + down_proj
1881 };
1882 Ok(vec![
1883 per_layer_elems * dtype.size_in_bytes();
1884 cfg.num_hidden_layers
1885 ])
1886 }
1887
1888 fn num_layers(&self, config: &str) -> Result<usize> {
1889 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1890
1891 Ok(cfg.num_hidden_layers)
1892 }
1893 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1894 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1895
1896 let cfg = ModelConfigMetadata {
1897 max_seq_len: cfg.max_position_embeddings,
1898 num_layers: cfg.num_hidden_layers,
1899 hidden_size: cfg.hidden_size,
1900 num_kv_heads: cfg.num_key_value_heads,
1901 num_attn_heads: cfg.num_attention_heads,
1902 sliding_window: None, k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1904 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1905 };
1906
1907 Ok(Box::new(cfg))
1908 }
1909}
1910
1911pub struct Starcoder2Loader;
1917
1918impl NormalModelLoader for Starcoder2Loader {
1919 fn load(
1920 &self,
1921 config: &str,
1922 vb: ShardedVarBuilder,
1923 normal_loading_metadata: NormalLoadingMetadata,
1924 attention_mechanism: AttentionImplementation,
1925 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1926 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
1927
1928 Ok(Box::new(models::starcoder2::Model::new(
1929 &cfg,
1930 vb,
1931 self.is_gptx(config)?,
1932 normal_loading_metadata,
1933 attention_mechanism,
1934 )?))
1935 }
1936 fn load_xlora(
1937 &self,
1938 config: &str,
1939 vb: ShardedVarBuilder,
1940 lora_config: &[((String, String), LoraConfig)],
1941 xlora_config: Option<XLoraConfig>,
1942 xlora_ordering: Ordering,
1943 normal_loading_metadata: NormalLoadingMetadata,
1944 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1945 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1946 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
1947
1948 Ok(Box::new(xlora_models::XLoraStarcoder2::new(
1949 &cfg,
1950 vb,
1951 lora_config,
1952 xlora_config,
1953 xlora_ordering,
1954 self.is_gptx(config)?,
1955 normal_loading_metadata,
1956 preload_adapters,
1957 )?))
1958 }
1959 fn is_gptx(&self, _: &str) -> Result<bool> {
1960 Ok(true)
1961 }
1962 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1963 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
1964
1965 Ok(Box::new(cfg))
1966 }
1967}
1968
1969impl IsqModelLoader for Starcoder2Loader {
1970 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1971 Ok(vec![
1972 Regex::new(r"lm_head\.(weight|bias)$")?,
1973 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1975 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1976 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1977 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1978 Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
1980 Regex::new(r"layers\.(\d+)\.mlp\.c_proj\.(weight|bias)$")?,
1981 ])
1982 }
1983 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1984 self.isq_layer_regexes(config)
1985 }
1986}
1987
1988impl DeviceMappedModelLoader for Starcoder2Loader {
1989 fn mapped_max_act_size_elems(
1990 &self,
1991 config: &str,
1992 params: &AutoDeviceMapParams,
1993 prompt_chunksize: usize,
1994 ) -> Result<usize> {
1995 let AutoDeviceMapParams::Text {
1996 max_seq_len: _,
1997 max_batch_size,
1998 } = params
1999 else {
2000 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2001 };
2002
2003 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2004
2005 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
2006 }
2007 fn non_mapped_max_act_size_elems(
2008 &self,
2009 _config: &str,
2010 _params: &AutoDeviceMapParams,
2011 ) -> Result<usize> {
2012 Ok(0)
2013 }
2014
2015 fn non_mapped_size_in_bytes(
2016 &self,
2017 config: &str,
2018 dtype: DType,
2019 weight_pack_factor: usize,
2020 ) -> Result<usize> {
2021 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2022
2023 let elems = {
2024 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2025 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2027 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2028 } else {
2029 0
2030 };
2031 let norm = cfg.hidden_size + cfg.hidden_size;
2032 embed_tokens + lm_head + norm
2033 };
2034 Ok(elems * dtype.size_in_bytes())
2035 }
2036
2037 fn layer_sizes_in_bytes(
2038 &self,
2039 config: &str,
2040 dtype: DType,
2041 weight_pack_factor: usize,
2042 ) -> Result<Vec<usize>> {
2043 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2044
2045 let per_layer_elems = {
2046 let input_layernorm = cfg.hidden_size + cfg.hidden_size;
2047 let post_attention_layernorm = cfg.hidden_size + cfg.hidden_size;
2048
2049 let size_in = cfg.hidden_size;
2050 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2051 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2052 let q_proj = size_in * size_q / weight_pack_factor + bias_if!(cfg.use_bias, size_q);
2053 let k_proj = size_in * size_kv / weight_pack_factor + bias_if!(cfg.use_bias, size_kv);
2054 let v_proj = size_in * size_kv / weight_pack_factor + bias_if!(cfg.use_bias, size_kv);
2055 let o_proj = size_q * size_in / weight_pack_factor + bias_if!(cfg.use_bias, size_in);
2056
2057 let h_size = cfg.hidden_size;
2058 let i_size = cfg.intermediate_size;
2059 let fc1 = h_size * i_size / weight_pack_factor + bias_if!(cfg.use_bias, i_size);
2060 let fc2 = h_size * i_size / weight_pack_factor + bias_if!(cfg.use_bias, h_size);
2061
2062 input_layernorm
2063 + post_attention_layernorm
2064 + q_proj
2065 + k_proj
2066 + v_proj
2067 + o_proj
2068 + fc1
2069 + fc2
2070 };
2071 Ok(vec![
2072 per_layer_elems * dtype.size_in_bytes();
2073 cfg.num_hidden_layers
2074 ])
2075 }
2076
2077 fn num_layers(&self, config: &str) -> Result<usize> {
2078 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2079
2080 Ok(cfg.num_hidden_layers)
2081 }
2082
2083 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2084 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2085
2086 let cfg = ModelConfigMetadata {
2087 max_seq_len: cfg.max_position_embeddings,
2088 num_layers: cfg.num_hidden_layers,
2089 hidden_size: cfg.hidden_size,
2090 num_kv_heads: cfg.num_key_value_heads,
2091 num_attn_heads: cfg.num_attention_heads,
2092 sliding_window: cfg.sliding_window,
2093 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2094 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2095 };
2096
2097 Ok(Box::new(cfg))
2098 }
2099}
2100
2101pub struct Phi3_5MoELoader;
2107
2108impl NormalModelLoader for Phi3_5MoELoader {
2109 fn load(
2110 &self,
2111 config: &str,
2112 vb: ShardedVarBuilder,
2113 normal_loading_metadata: NormalLoadingMetadata,
2114 attention_mechanism: AttentionImplementation,
2115 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2116 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2117
2118 Ok(Box::new(models::phi3_5_moe::Model::new(
2119 &cfg,
2120 vb,
2121 self.is_gptx(config)?,
2122 normal_loading_metadata,
2123 attention_mechanism,
2124 )?))
2125 }
2126 fn load_xlora(
2127 &self,
2128 config: &str,
2129 vb: ShardedVarBuilder,
2130 lora_config: &[((String, String), LoraConfig)],
2131 xlora_config: Option<XLoraConfig>,
2132 xlora_ordering: Ordering,
2133 normal_loading_metadata: NormalLoadingMetadata,
2134 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2135 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2136 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
2137
2138 Ok(Box::new(xlora_models::XLoraPhi3::new(
2139 &cfg,
2140 vb,
2141 lora_config,
2142 xlora_config,
2143 xlora_ordering,
2144 self.is_gptx(config)?,
2145 normal_loading_metadata,
2146 preload_adapters,
2147 )?))
2148 }
2149 fn is_gptx(&self, _: &str) -> Result<bool> {
2150 Ok(true)
2151 }
2152 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2153 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2154
2155 Ok(Box::new(cfg))
2156 }
2157}
2158
2159impl IsqModelLoader for Phi3_5MoELoader {
2160 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2161 Ok(vec![
2162 Regex::new(r"lm_head\.(weight|bias)$")?,
2163 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2165 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2166 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2167 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2168 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
2170 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
2171 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
2172 ])
2173 }
2174 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2175 self.isq_layer_regexes(config)
2176 }
2177
2178 fn isq_layer_regexes_moqe(&self, _config: &str) -> Result<Vec<Regex>> {
2179 Ok(vec![
2180 Regex::new(r"lm_head\.(weight|bias)$")?,
2181 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
2183 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
2184 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
2185 ])
2186 }
2187 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2188 self.isq_layer_regexes_moqe(config)
2189 }
2190}
2191
2192impl DeviceMappedModelLoader for Phi3_5MoELoader {
2193 fn mapped_max_act_size_elems(
2194 &self,
2195 config: &str,
2196 params: &AutoDeviceMapParams,
2197 prompt_chunksize: usize,
2198 ) -> Result<usize> {
2199 let AutoDeviceMapParams::Text {
2200 max_seq_len: _,
2201 max_batch_size,
2202 } = params
2203 else {
2204 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2205 };
2206
2207 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2208
2209 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
2210 }
2211 fn non_mapped_max_act_size_elems(
2212 &self,
2213 _config: &str,
2214 _params: &AutoDeviceMapParams,
2215 ) -> Result<usize> {
2216 Ok(0)
2217 }
2218
2219 fn non_mapped_size_in_bytes(
2220 &self,
2221 config: &str,
2222 dtype: DType,
2223 weight_pack_factor: usize,
2224 ) -> Result<usize> {
2225 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2226
2227 let elems = {
2228 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2229 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2231 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2232 } else {
2233 0
2234 };
2235 let norm = cfg.hidden_size;
2236 embed_tokens + lm_head + norm
2237 };
2238 Ok(elems * dtype.size_in_bytes())
2239 }
2240
2241 fn layer_sizes_in_bytes(
2242 &self,
2243 config: &str,
2244 dtype: DType,
2245 weight_pack_factor: usize,
2246 ) -> Result<Vec<usize>> {
2247 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2248
2249 let per_layer_elems = {
2250 let input_layernorm = cfg.hidden_size;
2251 let post_attention_layernorm = cfg.hidden_size;
2252
2253 let size_in = cfg.hidden_size;
2254 let size_q = cfg.head_dim() * cfg.num_attention_heads;
2255 let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
2256 let q_proj =
2257 size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
2258 let k_proj =
2259 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
2260 let v_proj =
2261 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
2262 let o_proj =
2263 size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
2264
2265 let moe_block = {
2266 let gate = cfg.hidden_size * cfg.num_local_experts;
2267 let w1 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2269 let w2 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2270 let w3 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2271 gate + cfg.num_local_experts * w1
2272 + cfg.num_local_experts * w2
2273 + cfg.num_local_experts * w3
2274 };
2275
2276 input_layernorm
2277 + post_attention_layernorm
2278 + q_proj
2279 + k_proj
2280 + v_proj
2281 + o_proj
2282 + moe_block
2283 };
2284 Ok(vec![
2285 per_layer_elems * dtype.size_in_bytes();
2286 cfg.num_hidden_layers
2287 ])
2288 }
2289
2290 fn num_layers(&self, config: &str) -> Result<usize> {
2291 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2292
2293 Ok(cfg.num_hidden_layers)
2294 }
2295
2296 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2297 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2298
2299 let cfg = ModelConfigMetadata {
2300 max_seq_len: cfg.max_position_embeddings,
2301 num_layers: cfg.num_hidden_layers,
2302 hidden_size: cfg.hidden_size,
2303 num_kv_heads: cfg.num_key_value_heads,
2304 num_attn_heads: cfg.num_attention_heads,
2305 sliding_window: cfg.sliding_window,
2306 k_head_dim: cfg.head_dim(),
2307 v_head_dim: cfg.head_dim(),
2308 };
2309
2310 Ok(Box::new(cfg))
2311 }
2312}
2313
2314pub struct DeepSeekV2Loader;
2318
2319impl NormalModelLoader for DeepSeekV2Loader {
2320 fn load(
2321 &self,
2322 config: &str,
2323 vb: ShardedVarBuilder,
2324 normal_loading_metadata: NormalLoadingMetadata,
2325 attention_mechanism: AttentionImplementation,
2326 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2327 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2328
2329 Ok(Box::new(models::deepseek2::DeepSeekV2::new(
2330 &cfg,
2331 vb,
2332 self.is_gptx(config)?,
2333 normal_loading_metadata,
2334 attention_mechanism,
2335 )?))
2336 }
2337 fn load_xlora(
2338 &self,
2339 _config: &str,
2340 _vb: ShardedVarBuilder,
2341 _lora_config: &[((String, String), LoraConfig)],
2342 _xlora_config: Option<XLoraConfig>,
2343 _xlora_ordering: Ordering,
2344 _normal_loading_metadata: NormalLoadingMetadata,
2345 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2346 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2347 todo!()
2348 }
2349 fn is_gptx(&self, _: &str) -> Result<bool> {
2350 Ok(true)
2351 }
2352 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2353 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2354 Ok(Box::new(cfg))
2355 }
2356}
2357
2358impl IsqModelLoader for DeepSeekV2Loader {
2359 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
2360 let mut data = vec![
2361 Regex::new(r"lm_head\.(weight|bias)$")?,
2362 Regex::new(r"layers\.(\d+)\.self_attn\.kv_a_proj_with_mqa\.(weight|bias)$")?,
2364 Regex::new(r"layers\.(\d+)\.self_attn\.kv_b_proj\.(weight|bias)$")?,
2365 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2366 ];
2367 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2368 if cfg.q_lora_rank.is_some() {
2369 data.extend(vec![
2370 Regex::new(r"layers\.(\d+)\.self_attn\.q_a_proj\.(weight|bias)$")?,
2371 Regex::new(r"layers\.(\d+)\.self_attn\.q_b_proj\.(weight|bias)$")?,
2372 ]);
2373 } else {
2374 data.push(Regex::new(
2375 r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
2376 )?);
2377 }
2378 for layer_idx in 0..cfg.num_hidden_layers {
2379 if cfg.n_routed_experts.is_some()
2380 && layer_idx >= cfg.first_k_dense_replace
2381 && layer_idx % cfg.moe_layer_freq == 0
2382 {
2383 for i in 0..cfg.n_routed_experts.unwrap() {
2384 data.extend(vec![
2385 Regex::new(&format!(
2386 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2387 ))?,
2388 Regex::new(&format!(
2389 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2390 ))?,
2391 Regex::new(&format!(
2392 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2393 ))?,
2394 ]);
2395 }
2396 if cfg.n_shared_experts.is_some() {
2397 data.extend(vec![
2398 Regex::new(&format!(
2399 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2400 ))?,
2401 Regex::new(&format!(
2402 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2403 ))?,
2404 Regex::new(&format!(
2405 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2406 ))?,
2407 ]);
2408 }
2409 } else {
2410 data.extend(vec![
2411 Regex::new(&format!(
2412 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2413 ))?,
2414 Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2415 Regex::new(&format!(
2416 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2417 ))?,
2418 ]);
2419 };
2420 }
2421 Ok(data)
2422 }
2423 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2424 self.isq_layer_regexes(config)
2425 }
2426
2427 fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2428 let mut data = vec![Regex::new(r"lm_head\.(weight|bias)$")?];
2429 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2430 for layer_idx in 0..cfg.num_hidden_layers {
2431 if cfg.n_routed_experts.is_some()
2432 && layer_idx >= cfg.first_k_dense_replace
2433 && layer_idx % cfg.moe_layer_freq == 0
2434 {
2435 for i in 0..cfg.n_routed_experts.unwrap() {
2436 data.extend(vec![
2437 Regex::new(&format!(
2438 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2439 ))?,
2440 Regex::new(&format!(
2441 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2442 ))?,
2443 Regex::new(&format!(
2444 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2445 ))?,
2446 ]);
2447 }
2448 if cfg.n_shared_experts.is_some() {
2449 data.extend(vec![
2450 Regex::new(&format!(
2451 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2452 ))?,
2453 Regex::new(&format!(
2454 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2455 ))?,
2456 Regex::new(&format!(
2457 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2458 ))?,
2459 ]);
2460 }
2461 } else {
2462 data.extend(vec![
2463 Regex::new(&format!(
2464 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2465 ))?,
2466 Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2467 Regex::new(&format!(
2468 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2469 ))?,
2470 ]);
2471 };
2472 }
2473 Ok(data)
2474 }
2475 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2476 self.isq_layer_regexes_moqe(config)
2477 }
2478}
2479
2480impl DeviceMappedModelLoader for DeepSeekV2Loader {
2481 fn mapped_max_act_size_elems(
2482 &self,
2483 config: &str,
2484 params: &AutoDeviceMapParams,
2485 prompt_chunksize: usize,
2486 ) -> Result<usize> {
2487 let AutoDeviceMapParams::Text {
2488 max_seq_len: _,
2489 max_batch_size,
2490 } = params
2491 else {
2492 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2493 };
2494
2495 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2496
2497 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
2498 }
2499 fn non_mapped_max_act_size_elems(
2500 &self,
2501 _config: &str,
2502 _params: &AutoDeviceMapParams,
2503 ) -> Result<usize> {
2504 Ok(0)
2505 }
2506
2507 fn non_mapped_size_in_bytes(
2508 &self,
2509 config: &str,
2510 dtype: DType,
2511 weight_pack_factor: usize,
2512 ) -> Result<usize> {
2513 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2514 let elems = {
2515 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2516 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2518 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2519 } else {
2520 0
2521 };
2522 let norm = cfg.hidden_size;
2523 embed_tokens + lm_head + norm
2524 };
2525 Ok(elems * dtype.size_in_bytes())
2526 }
2527
2528 fn layer_sizes_in_bytes(
2529 &self,
2530 config: &str,
2531 dtype: DType,
2532 weight_pack_factor: usize,
2533 ) -> Result<Vec<usize>> {
2534 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2535 let mut per_layer_elems = Vec::new();
2536
2537 for layer_idx in 0..cfg.num_hidden_layers {
2538 let input_layernorm = cfg.hidden_size;
2539 let post_attention_layernorm = cfg.hidden_size;
2540
2541 let q_proj = match cfg.q_lora_rank {
2542 Some(lora_rank) => {
2543 let a = cfg.hidden_size * lora_rank;
2544 let norm = lora_rank;
2545 let b = (cfg.num_attention_heads * cfg.q_head_dim()) * lora_rank;
2546 a + norm + b
2547 }
2548 None => (cfg.num_attention_heads * cfg.q_head_dim()) * cfg.hidden_size,
2549 };
2550 let kv_a_proj_with_mqa = cfg.hidden_size * (cfg.kv_lora_rank + cfg.qk_rope_head_dim)
2551 / weight_pack_factor
2552 + bias_if!(cfg.attention_bias, cfg.kv_lora_rank + cfg.qk_rope_head_dim);
2553 let kv_a_layernorm = cfg.kv_lora_rank;
2554 let kv_b_proj = cfg.kv_lora_rank
2555 * cfg.num_attention_heads
2556 * (cfg.q_head_dim() - cfg.qk_rope_head_dim + cfg.v_head_dim)
2557 / weight_pack_factor;
2558 let o_proj = cfg.num_attention_heads * cfg.v_head_dim * cfg.hidden_size
2559 / weight_pack_factor
2560 + bias_if!(cfg.attention_bias, cfg.hidden_size);
2561
2562 let moe_block = {
2563 let mut sum = 0;
2564 if cfg.n_routed_experts.is_some()
2565 && layer_idx >= cfg.first_k_dense_replace
2566 && layer_idx % cfg.moe_layer_freq == 0
2567 {
2568 let h_size = cfg.hidden_size;
2569 let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
2570 * cfg.n_routed_experts.unwrap();
2571 let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
2572 * cfg.n_routed_experts.unwrap();
2573 let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor
2574 * cfg.n_routed_experts.unwrap();
2575 let shared_experts = if let Some(n_shared_experts) = cfg.n_shared_experts {
2576 let gate_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2577 / weight_pack_factor;
2578 let up_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2579 / weight_pack_factor;
2580 let down_proj = (cfg.intermediate_size * n_shared_experts) * h_size
2581 / weight_pack_factor;
2582 gate_proj + up_proj + down_proj
2583 } else {
2584 0
2585 };
2586 let gate_weight = cfg.n_routed_experts.unwrap() * cfg.hidden_size;
2587 sum += gate_proj + up_proj + down_proj + shared_experts + gate_weight;
2588 } else {
2589 let h_size = cfg.hidden_size;
2590 let i_size = cfg.intermediate_size;
2591 let gate_proj = h_size * i_size / weight_pack_factor;
2592 let up_proj = h_size * i_size / weight_pack_factor;
2593 let down_proj = i_size * h_size / weight_pack_factor;
2594 sum += gate_proj + up_proj + down_proj;
2595 }
2596 sum
2597 };
2598
2599 per_layer_elems.push(
2600 input_layernorm
2601 + post_attention_layernorm
2602 + q_proj
2603 + kv_a_layernorm
2604 + kv_a_proj_with_mqa
2605 + kv_b_proj
2606 + o_proj
2607 + moe_block,
2608 );
2609 }
2610
2611 Ok(per_layer_elems
2612 .into_iter()
2613 .map(|x| x * dtype.size_in_bytes())
2614 .collect())
2615 }
2616
2617 fn num_layers(&self, config: &str) -> Result<usize> {
2618 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2619 Ok(cfg.num_hidden_layers)
2620 }
2621
2622 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2623 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2624
2625 let cfg = ModelConfigMetadata {
2626 max_seq_len: cfg.max_position_embeddings,
2627 num_layers: cfg.num_hidden_layers,
2628 hidden_size: cfg.hidden_size,
2629 num_kv_heads: cfg.num_attention_heads,
2630 num_attn_heads: cfg.num_attention_heads,
2631 sliding_window: None,
2632 k_head_dim: cfg.qk_rope_head_dim + cfg.qk_nope_head_dim,
2633 v_head_dim: cfg.v_head_dim,
2634 };
2635
2636 Ok(Box::new(cfg))
2637 }
2638}
2639
2640pub struct DeepSeekV3Loader;
2644
2645impl NormalModelLoader for DeepSeekV3Loader {
2646 fn load(
2647 &self,
2648 config: &str,
2649 vb: ShardedVarBuilder,
2650 normal_loading_metadata: NormalLoadingMetadata,
2651 attention_mechanism: AttentionImplementation,
2652 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2653 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2654 Ok(Box::new(models::deepseek3::DeepSeekV3::new(
2655 &cfg,
2656 vb,
2657 self.is_gptx(config)?,
2658 normal_loading_metadata,
2659 attention_mechanism,
2660 )?))
2661 }
2662 fn load_xlora(
2663 &self,
2664 _config: &str,
2665 _vb: ShardedVarBuilder,
2666 _lora_config: &[((String, String), LoraConfig)],
2667 _xlora_config: Option<XLoraConfig>,
2668 _xlora_ordering: Ordering,
2669 _normal_loading_metadata: NormalLoadingMetadata,
2670 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2671 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2672 todo!()
2673 }
2674 fn is_gptx(&self, _: &str) -> Result<bool> {
2675 Ok(true)
2676 }
2677 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2678 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2679 Ok(Box::new(cfg))
2680 }
2681}
2682
2683impl IsqModelLoader for DeepSeekV3Loader {
2684 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
2685 let mut data = vec![
2686 Regex::new(r"lm_head\.(weight|bias)$")?,
2687 Regex::new(r"layers\.(\d+)\.self_attn\.kv_a_proj_with_mqa\.(weight|bias)$")?,
2689 Regex::new(r"layers\.(\d+)\.self_attn\.kv_b_proj\.(weight|bias)$")?,
2690 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2691 ];
2692 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2693 if cfg.q_lora_rank.is_some() {
2694 data.extend(vec![
2695 Regex::new(r"layers\.(\d+)\.self_attn\.q_a_proj\.(weight|bias)$")?,
2696 Regex::new(r"layers\.(\d+)\.self_attn\.q_b_proj\.(weight|bias)$")?,
2697 ]);
2698 } else {
2699 data.push(Regex::new(
2700 r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
2701 )?);
2702 }
2703 for layer_idx in 0..cfg.num_hidden_layers {
2704 if cfg.n_routed_experts.is_some()
2705 && layer_idx >= cfg.first_k_dense_replace
2706 && layer_idx % cfg.moe_layer_freq == 0
2707 {
2708 for i in 0..cfg.n_routed_experts.unwrap() {
2709 data.extend(vec![
2710 Regex::new(&format!(
2711 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2712 ))?,
2713 Regex::new(&format!(
2714 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2715 ))?,
2716 Regex::new(&format!(
2717 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2718 ))?,
2719 ]);
2720 }
2721 if cfg.n_shared_experts.is_some() {
2722 data.extend(vec![
2723 Regex::new(&format!(
2724 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2725 ))?,
2726 Regex::new(&format!(
2727 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2728 ))?,
2729 Regex::new(&format!(
2730 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2731 ))?,
2732 ]);
2733 }
2734 } else {
2735 data.extend(vec![
2736 Regex::new(&format!(
2737 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2738 ))?,
2739 Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2740 Regex::new(&format!(
2741 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2742 ))?,
2743 ]);
2744 };
2745 }
2746 Ok(data)
2747 }
2748 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2749 self.isq_layer_regexes(config)
2750 }
2751
2752 fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2753 let mut data = vec![Regex::new(r"lm_head\.(weight|bias)$")?];
2754 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2755 for layer_idx in 0..cfg.num_hidden_layers {
2756 if cfg.n_routed_experts.is_some()
2757 && layer_idx >= cfg.first_k_dense_replace
2758 && layer_idx % cfg.moe_layer_freq == 0
2759 {
2760 for i in 0..cfg.n_routed_experts.unwrap() {
2761 data.extend(vec![
2762 Regex::new(&format!(
2763 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2764 ))?,
2765 Regex::new(&format!(
2766 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2767 ))?,
2768 Regex::new(&format!(
2769 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2770 ))?,
2771 ]);
2772 }
2773 if cfg.n_shared_experts.is_some() {
2774 data.extend(vec![
2775 Regex::new(&format!(
2776 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2777 ))?,
2778 Regex::new(&format!(
2779 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2780 ))?,
2781 Regex::new(&format!(
2782 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2783 ))?,
2784 ]);
2785 }
2786 } else {
2787 data.extend(vec![
2788 Regex::new(&format!(
2789 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2790 ))?,
2791 Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2792 Regex::new(&format!(
2793 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2794 ))?,
2795 ]);
2796 };
2797 }
2798 Ok(data)
2799 }
2800 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2801 self.isq_layer_regexes_moqe(config)
2802 }
2803}
2804
2805impl DeviceMappedModelLoader for DeepSeekV3Loader {
2806 fn mapped_max_act_size_elems(
2807 &self,
2808 config: &str,
2809 params: &AutoDeviceMapParams,
2810 prompt_chunksize: usize,
2811 ) -> Result<usize> {
2812 let AutoDeviceMapParams::Text {
2813 max_seq_len: _,
2814 max_batch_size,
2815 } = params
2816 else {
2817 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2818 };
2819
2820 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2821
2822 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
2823 }
2824 fn non_mapped_max_act_size_elems(
2825 &self,
2826 _config: &str,
2827 _params: &AutoDeviceMapParams,
2828 ) -> Result<usize> {
2829 Ok(0)
2830 }
2831
2832 fn non_mapped_size_in_bytes(
2833 &self,
2834 config: &str,
2835 dtype: DType,
2836 weight_pack_factor: usize,
2837 ) -> Result<usize> {
2838 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2839 let elems = {
2840 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2841 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2843 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2844 } else {
2845 0
2846 };
2847 let norm = cfg.hidden_size;
2848 embed_tokens + lm_head + norm
2849 };
2850 Ok(elems * dtype.size_in_bytes())
2851 }
2852
2853 fn layer_sizes_in_bytes(
2854 &self,
2855 config: &str,
2856 dtype: DType,
2857 weight_pack_factor: usize,
2858 ) -> Result<Vec<usize>> {
2859 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2860 let mut per_layer_elems = Vec::new();
2861
2862 for layer_idx in 0..cfg.num_hidden_layers {
2863 let input_layernorm = cfg.hidden_size;
2864 let post_attention_layernorm = cfg.hidden_size;
2865
2866 let q_proj = match cfg.q_lora_rank {
2867 Some(lora_rank) => {
2868 let a = cfg.hidden_size * lora_rank;
2869 let norm = lora_rank;
2870 let b = (cfg.num_attention_heads * cfg.q_head_dim()) * lora_rank;
2871 a + norm + b
2872 }
2873 None => (cfg.num_attention_heads * cfg.q_head_dim()) * cfg.hidden_size,
2874 };
2875 let kv_a_proj_with_mqa = cfg.hidden_size * (cfg.kv_lora_rank + cfg.qk_rope_head_dim)
2876 / weight_pack_factor
2877 + bias_if!(cfg.attention_bias, cfg.kv_lora_rank + cfg.qk_rope_head_dim);
2878 let kv_a_layernorm = cfg.kv_lora_rank;
2879 let kv_b_proj = cfg.kv_lora_rank
2880 * cfg.num_attention_heads
2881 * (cfg.q_head_dim() - cfg.qk_rope_head_dim + cfg.v_head_dim)
2882 / weight_pack_factor;
2883 let o_proj = cfg.num_attention_heads * cfg.v_head_dim * cfg.hidden_size
2884 / weight_pack_factor
2885 + bias_if!(cfg.attention_bias, cfg.hidden_size);
2886
2887 let moe_block = {
2888 let mut sum = 0;
2889 if cfg.n_routed_experts.is_some()
2890 && layer_idx >= cfg.first_k_dense_replace
2891 && layer_idx % cfg.moe_layer_freq == 0
2892 {
2893 let h_size = cfg.hidden_size;
2894 let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
2895 * cfg.n_routed_experts.unwrap();
2896 let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
2897 * cfg.n_routed_experts.unwrap();
2898 let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor
2899 * cfg.n_routed_experts.unwrap();
2900 let shared_experts = if let Some(n_shared_experts) = cfg.n_shared_experts {
2901 let gate_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2902 / weight_pack_factor;
2903 let up_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2904 / weight_pack_factor;
2905 let down_proj = (cfg.intermediate_size * n_shared_experts) * h_size
2906 / weight_pack_factor;
2907 gate_proj + up_proj + down_proj
2908 } else {
2909 0
2910 };
2911 let gate_weight = cfg.n_routed_experts.unwrap() * cfg.hidden_size;
2912 sum += gate_proj + up_proj + down_proj + shared_experts + gate_weight;
2913 } else {
2914 let h_size = cfg.hidden_size;
2915 let i_size = cfg.intermediate_size;
2916 let gate_proj = h_size * i_size / weight_pack_factor;
2917 let up_proj = h_size * i_size / weight_pack_factor;
2918 let down_proj = i_size * h_size / weight_pack_factor;
2919 sum += gate_proj + up_proj + down_proj;
2920 }
2921 sum
2922 };
2923
2924 per_layer_elems.push(
2925 input_layernorm
2926 + post_attention_layernorm
2927 + q_proj
2928 + kv_a_layernorm
2929 + kv_a_proj_with_mqa
2930 + kv_b_proj
2931 + o_proj
2932 + moe_block,
2933 );
2934 }
2935
2936 Ok(per_layer_elems
2937 .into_iter()
2938 .map(|x| x * dtype.size_in_bytes())
2939 .collect())
2940 }
2941
2942 fn num_layers(&self, config: &str) -> Result<usize> {
2943 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2944 Ok(cfg.num_hidden_layers)
2945 }
2946
2947 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2948 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2949
2950 let cfg = ModelConfigMetadata {
2951 max_seq_len: cfg.max_position_embeddings,
2952 num_layers: cfg.num_hidden_layers,
2953 hidden_size: cfg.hidden_size,
2954 num_kv_heads: cfg.num_attention_heads,
2955 num_attn_heads: cfg.num_attention_heads,
2956 sliding_window: None,
2957 k_head_dim: cfg.qk_rope_head_dim + cfg.qk_nope_head_dim,
2958 v_head_dim: cfg.v_head_dim,
2959 };
2960
2961 Ok(Box::new(cfg))
2962 }
2963}
2964
2965pub struct Qwen3Loader;
2969
2970impl NormalModelLoader for Qwen3Loader {
2971 fn load(
2972 &self,
2973 config: &str,
2974 vb: ShardedVarBuilder,
2975 normal_loading_metadata: NormalLoadingMetadata,
2976 attention_mechanism: AttentionImplementation,
2977 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2978 let cfg: crate::models::qwen3::Config = serde_json::from_str(config)?;
2979
2980 Ok(Box::new(models::qwen3::Model::new(
2981 &cfg,
2982 vb,
2983 self.is_gptx(config)?,
2984 normal_loading_metadata,
2985 attention_mechanism,
2986 )?))
2987 }
2988 fn load_xlora(
2989 &self,
2990 _config: &str,
2991 _vb: ShardedVarBuilder,
2992 _lora_config: &[((String, String), LoraConfig)],
2993 _xlora_config: Option<XLoraConfig>,
2994 _xlora_ordering: Ordering,
2995 _normal_loading_metadata: NormalLoadingMetadata,
2996 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2997 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2998 todo!()
2999 }
3000 fn is_gptx(&self, _: &str) -> Result<bool> {
3001 Ok(true)
3002 }
3003 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3004 let cfg: crate::models::qwen3::Config = serde_json::from_str(config)?;
3005
3006 Ok(Box::new(cfg))
3007 }
3008}
3009
3010impl IsqModelLoader for Qwen3Loader {
3011 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3012 Ok(vec![
3013 Regex::new(r"lm_head\.(weight|bias)$")?,
3014 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3016 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3017 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3018 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3019 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3021 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3022 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3023 ])
3024 }
3025 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3026 self.isq_layer_regexes(config)
3027 }
3028}
3029
3030impl DeviceMappedModelLoader for Qwen3Loader {
3031 fn mapped_max_act_size_elems(
3032 &self,
3033 config: &str,
3034 params: &AutoDeviceMapParams,
3035 prompt_chunksize: usize,
3036 ) -> Result<usize> {
3037 let AutoDeviceMapParams::Text {
3038 max_seq_len: _,
3039 max_batch_size,
3040 } = params
3041 else {
3042 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3043 };
3044
3045 let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3046
3047 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
3048 }
3049 fn non_mapped_max_act_size_elems(
3050 &self,
3051 _config: &str,
3052 _params: &AutoDeviceMapParams,
3053 ) -> Result<usize> {
3054 Ok(0)
3055 }
3056
3057 fn non_mapped_size_in_bytes(
3058 &self,
3059 config: &str,
3060 dtype: DType,
3061 weight_pack_factor: usize,
3062 ) -> Result<usize> {
3063 let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3064 let elems = {
3065 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3066 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3068 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3069 } else {
3070 0
3071 };
3072 let norm = cfg.hidden_size;
3073 embed_tokens + lm_head + norm
3074 };
3075 Ok(elems * dtype.size_in_bytes())
3076 }
3077
3078 fn layer_sizes_in_bytes(
3079 &self,
3080 config: &str,
3081 dtype: DType,
3082 weight_pack_factor: usize,
3083 ) -> Result<Vec<usize>> {
3084 let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3085 let per_layer_elems = {
3086 let input_layernorm = cfg.hidden_size;
3087 let post_attention_layernorm = cfg.hidden_size;
3088
3089 let size_in = cfg.hidden_size;
3090 let size_q = cfg.head_dim() * cfg.num_attention_heads;
3091 let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
3092 let q_proj = size_in * size_q / weight_pack_factor + size_q;
3093 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3094 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3095 let o_proj = size_q * size_in / weight_pack_factor;
3096
3097 let h_size = cfg.hidden_size;
3098 let i_size = cfg.intermediate_size;
3099 let gate_proj = h_size * i_size / weight_pack_factor;
3100 let up_proj = h_size * i_size / weight_pack_factor;
3101 let down_proj = i_size * h_size / weight_pack_factor;
3102
3103 let q_norm = cfg.head_dim();
3104 let k_norm = cfg.head_dim();
3105
3106 input_layernorm
3107 + post_attention_layernorm
3108 + q_proj
3109 + k_proj
3110 + v_proj
3111 + o_proj
3112 + gate_proj
3113 + up_proj
3114 + down_proj
3115 + q_norm
3116 + k_norm
3117 };
3118 Ok(vec![
3119 per_layer_elems * dtype.size_in_bytes();
3120 cfg.num_hidden_layers
3121 ])
3122 }
3123
3124 fn num_layers(&self, config: &str) -> Result<usize> {
3125 let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3126 Ok(cfg.num_hidden_layers)
3127 }
3128
3129 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3130 let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3131
3132 let cfg = ModelConfigMetadata {
3133 max_seq_len: cfg.max_position_embeddings,
3134 num_layers: cfg.num_hidden_layers,
3135 hidden_size: cfg.hidden_size,
3136 num_kv_heads: cfg.num_key_value_heads,
3137 num_attn_heads: cfg.num_attention_heads,
3138 sliding_window: cfg.sliding_window,
3139 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3140 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3141 };
3142
3143 Ok(Box::new(cfg))
3144 }
3145}
3146
3147pub struct GLM4Loader;
3151
3152impl NormalModelLoader for GLM4Loader {
3153 fn load(
3154 &self,
3155 config: &str,
3156 vb: ShardedVarBuilder,
3157 normal_loading_metadata: NormalLoadingMetadata,
3158 attention_mechanism: AttentionImplementation,
3159 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3160 let cfg: crate::models::glm4::Config = serde_json::from_str(config)?;
3161
3162 Ok(Box::new(models::glm4::Model::new(
3163 &cfg,
3164 vb,
3165 self.is_gptx(config)?,
3166 normal_loading_metadata,
3167 attention_mechanism,
3168 )?))
3169 }
3170 fn load_xlora(
3171 &self,
3172 _config: &str,
3173 _vb: ShardedVarBuilder,
3174 _lora_config: &[((String, String), LoraConfig)],
3175 _xlora_config: Option<XLoraConfig>,
3176 _xlora_ordering: Ordering,
3177 _normal_loading_metadata: NormalLoadingMetadata,
3178 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3179 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3180 todo!()
3181 }
3182 fn is_gptx(&self, _: &str) -> Result<bool> {
3183 Ok(true)
3184 }
3185 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3186 let cfg: crate::models::glm4::Config = serde_json::from_str(config)?;
3187
3188 Ok(Box::new(cfg))
3189 }
3190}
3191
3192impl IsqModelLoader for GLM4Loader {
3193 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3194 Ok(vec![
3195 Regex::new(r"lm_head\.(weight|bias)$")?,
3196 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3198 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3199 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3200 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3201 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3203 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3204 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3205 ])
3206 }
3207 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3208 self.isq_layer_regexes(config)
3209 }
3210}
3211
3212impl DeviceMappedModelLoader for GLM4Loader {
3213 fn mapped_max_act_size_elems(
3214 &self,
3215 config: &str,
3216 params: &AutoDeviceMapParams,
3217 prompt_chunksize: usize,
3218 ) -> Result<usize> {
3219 let AutoDeviceMapParams::Text {
3220 max_seq_len: _,
3221 max_batch_size,
3222 } = params
3223 else {
3224 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3225 };
3226
3227 let cfg: models::glm4::Config = serde_json::from_str(config)?;
3228
3229 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
3230 }
3231 fn non_mapped_max_act_size_elems(
3232 &self,
3233 _config: &str,
3234 _params: &AutoDeviceMapParams,
3235 ) -> Result<usize> {
3236 Ok(0)
3237 }
3238
3239 fn non_mapped_size_in_bytes(
3240 &self,
3241 config: &str,
3242 dtype: DType,
3243 weight_pack_factor: usize,
3244 ) -> Result<usize> {
3245 let cfg: models::glm4::Config = serde_json::from_str(config)?;
3246 let elems = {
3247 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3248 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3250 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3251 } else {
3252 0
3253 };
3254 let norm = cfg.hidden_size;
3255 embed_tokens + lm_head + norm
3256 };
3257 Ok(elems * dtype.size_in_bytes())
3258 }
3259
3260 fn layer_sizes_in_bytes(
3261 &self,
3262 config: &str,
3263 dtype: DType,
3264 weight_pack_factor: usize,
3265 ) -> Result<Vec<usize>> {
3266 let cfg: models::glm4::Config = serde_json::from_str(config)?;
3267 let per_layer_elems = {
3268 let input_layernorm = cfg.hidden_size;
3269 let post_attention_layernorm = cfg.hidden_size * 3; let size_in = cfg.hidden_size;
3272 let size_q = cfg.head_dim() * cfg.num_attention_heads;
3273 let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
3274 let q_proj = size_in * size_q / weight_pack_factor + size_q;
3275 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3276 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3277 let o_proj = size_q * size_in / weight_pack_factor;
3278
3279 let h_size = cfg.hidden_size;
3280 let i_size = cfg.intermediate_size;
3281 let gate_proj = h_size * i_size / weight_pack_factor;
3282 let up_proj = h_size * i_size / weight_pack_factor;
3283 let down_proj = i_size * h_size / weight_pack_factor;
3284
3285 input_layernorm
3286 + post_attention_layernorm
3287 + q_proj
3288 + k_proj
3289 + v_proj
3290 + o_proj
3291 + gate_proj
3292 + up_proj
3293 + down_proj
3294 };
3295 Ok(vec![
3296 per_layer_elems * dtype.size_in_bytes();
3297 cfg.num_hidden_layers
3298 ])
3299 }
3300
3301 fn num_layers(&self, config: &str) -> Result<usize> {
3302 let cfg: models::glm4::Config = serde_json::from_str(config)?;
3303 Ok(cfg.num_hidden_layers)
3304 }
3305
3306 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3307 let cfg: models::glm4::Config = serde_json::from_str(config)?;
3308
3309 let cfg = ModelConfigMetadata {
3310 max_seq_len: cfg.max_position_embeddings,
3311 num_layers: cfg.num_hidden_layers,
3312 hidden_size: cfg.hidden_size,
3313 num_kv_heads: cfg.num_key_value_heads,
3314 num_attn_heads: cfg.num_attention_heads,
3315 sliding_window: cfg.sliding_window,
3316 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3317 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3318 };
3319
3320 Ok(Box::new(cfg))
3321 }
3322}
3323
3324pub struct Qwen3MoELoader;
3328
3329impl NormalModelLoader for Qwen3MoELoader {
3330 fn load(
3331 &self,
3332 config: &str,
3333 vb: ShardedVarBuilder,
3334 normal_loading_metadata: NormalLoadingMetadata,
3335 attention_mechanism: AttentionImplementation,
3336 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3337 let cfg: crate::models::qwen3_moe::Config = serde_json::from_str(config)?;
3338
3339 Ok(Box::new(models::qwen3_moe::Model::new(
3340 &cfg,
3341 vb,
3342 self.is_gptx(config)?,
3343 normal_loading_metadata,
3344 attention_mechanism,
3345 )?))
3346 }
3347 fn load_xlora(
3348 &self,
3349 _config: &str,
3350 _vb: ShardedVarBuilder,
3351 _lora_config: &[((String, String), LoraConfig)],
3352 _xlora_config: Option<XLoraConfig>,
3353 _xlora_ordering: Ordering,
3354 _normal_loading_metadata: NormalLoadingMetadata,
3355 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3356 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3357 todo!()
3358 }
3359 fn is_gptx(&self, _: &str) -> Result<bool> {
3360 Ok(true)
3361 }
3362 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3363 let cfg: crate::models::qwen3_moe::Config = serde_json::from_str(config)?;
3364
3365 Ok(Box::new(cfg))
3366 }
3367}
3368
3369impl IsqModelLoader for Qwen3MoELoader {
3370 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3371 Ok(vec![
3372 Regex::new(r"lm_head\.(weight|bias)$")?,
3373 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3375 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3376 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3377 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3378 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3380 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3381 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3382 Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.gate_proj\.(weight|bias)$")?,
3384 Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.up_proj\.(weight|bias)$")?,
3385 Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.down_proj\.(weight|bias)$")?,
3386 ])
3387 }
3388 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3389 self.isq_layer_regexes(config)
3390 }
3391 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
3392 self.isq_layer_regexes_moqe(config)
3393 }
3394}
3395
3396impl DeviceMappedModelLoader for Qwen3MoELoader {
3397 fn mapped_max_act_size_elems(
3398 &self,
3399 config: &str,
3400 params: &AutoDeviceMapParams,
3401 prompt_chunksize: usize,
3402 ) -> Result<usize> {
3403 let AutoDeviceMapParams::Text {
3404 max_seq_len: _,
3405 max_batch_size,
3406 } = params
3407 else {
3408 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3409 };
3410
3411 let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3412
3413 Ok(max_batch_size * cfg.num_attention_heads * prompt_chunksize * prompt_chunksize)
3414 }
3415 fn non_mapped_max_act_size_elems(
3416 &self,
3417 _config: &str,
3418 _params: &AutoDeviceMapParams,
3419 ) -> Result<usize> {
3420 Ok(0)
3421 }
3422
3423 fn non_mapped_size_in_bytes(
3424 &self,
3425 config: &str,
3426 dtype: DType,
3427 weight_pack_factor: usize,
3428 ) -> Result<usize> {
3429 let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3430 let elems = {
3431 let embed_tokens = cfg.hidden_size * cfg.vocab_size;
3432 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3434 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3435 } else {
3436 0
3437 };
3438 let norm = cfg.hidden_size;
3439 embed_tokens + lm_head + norm
3440 };
3441 Ok(elems * dtype.size_in_bytes())
3442 }
3443
3444 fn layer_sizes_in_bytes(
3445 &self,
3446 config: &str,
3447 dtype: DType,
3448 weight_pack_factor: usize,
3449 ) -> Result<Vec<usize>> {
3450 let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3451
3452 let mut layer_sizes_in_bytes = Vec::new();
3453 for layer_idx in 0..cfg.num_hidden_layers {
3454 let input_layernorm = cfg.hidden_size;
3455 let post_attention_layernorm = cfg.hidden_size;
3456
3457 let size_in = cfg.hidden_size;
3458 let size_q = cfg.head_dim() * cfg.num_attention_heads;
3459 let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
3460 let q_proj = size_in * size_q / weight_pack_factor;
3461 let k_proj = size_in * size_kv / weight_pack_factor;
3462 let v_proj = size_in * size_kv / weight_pack_factor;
3463 let o_proj = size_q * size_in / weight_pack_factor;
3464
3465 let mlp_size = if !cfg.mlp_only_layers.contains(&layer_idx)
3466 && (cfg.num_experts > 0 && (layer_idx + 1) % cfg.decoder_sparse_step == 0)
3467 {
3468 let gate_size = cfg.hidden_size * cfg.num_experts;
3469 let expert_size = {
3470 let h_size = cfg.hidden_size;
3471 let i_size = cfg.moe_intermediate_size;
3472 let gate_proj = h_size * i_size / weight_pack_factor;
3473 let up_proj = h_size * i_size / weight_pack_factor;
3474 let down_proj = i_size * h_size / weight_pack_factor;
3475 gate_proj + up_proj + down_proj
3476 };
3477 expert_size * cfg.num_experts + gate_size
3478 } else {
3479 let h_size = cfg.hidden_size;
3480 let i_size = cfg.intermediate_size;
3481 let gate_proj = h_size * i_size / weight_pack_factor;
3482 let up_proj = h_size * i_size / weight_pack_factor;
3483 let down_proj = i_size * h_size / weight_pack_factor;
3484 gate_proj + up_proj + down_proj
3485 };
3486
3487 let q_norm = cfg.head_dim();
3488 let k_norm = cfg.head_dim();
3489
3490 let size_elems = input_layernorm
3491 + post_attention_layernorm
3492 + q_proj
3493 + k_proj
3494 + v_proj
3495 + o_proj
3496 + mlp_size
3497 + q_norm
3498 + k_norm;
3499
3500 let size_in_bytes = size_elems * dtype.size_in_bytes();
3501 layer_sizes_in_bytes.push(size_in_bytes);
3502 }
3503
3504 Ok(layer_sizes_in_bytes)
3505 }
3506
3507 fn num_layers(&self, config: &str) -> Result<usize> {
3508 let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3509 Ok(cfg.num_hidden_layers)
3510 }
3511
3512 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3513 let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
3514
3515 let cfg = ModelConfigMetadata {
3516 max_seq_len: cfg.max_position_embeddings,
3517 num_layers: cfg.num_hidden_layers,
3518 hidden_size: cfg.hidden_size,
3519 num_kv_heads: cfg.num_key_value_heads,
3520 num_attn_heads: cfg.num_attention_heads,
3521 sliding_window: cfg.sliding_window,
3522 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3523 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3524 };
3525
3526 Ok(Box::new(cfg))
3527 }
3528}