1use super::inputs_processor::DEFAULT_PROMPT_CHUNK_SIZE;
2use super::isq::ImatrixDataSource;
3use super::llg::build_llg_factory;
4use super::{
5 get_model_paths, get_xlora_paths, text_models_inputs_processor::ModelInputs, AdapterKind,
6 CacheManager, GeneralMetadata, Loader, ModelKind, ModelPaths, NormalModel, NormalModelLoader,
7 TokenSource,
8};
9use super::{
10 AnyMoePipelineMixin, CacheManagerMixin, EitherCache, ForwardInputsResult, IsqOrganization,
11 IsqPipelineMixin, MetadataMixin, ModelCategory, PreProcessingMixin,
12};
13use super::{
14 AutoNormalLoader, DeepSeekV2Loader, DeepSeekV3Loader, Gemma2Loader, GemmaLoader, LlamaLoader,
15 MistralLoader, MixtralLoader, NormalLoaderType, Phi2Loader, Phi3Loader, Phi3_5MoELoader,
16 Qwen2Loader, Qwen3Loader, Qwen3MoELoader, Starcoder2Loader,
17};
18use crate::amoe::AnyMoeExpertType;
19use crate::device_map::{self, DeviceMapper};
20use crate::distributed::{self, WorkerTransferData};
21use crate::kv_cache::{FullCacheManager, NormalCacheManager};
22use crate::lora::Ordering;
23use crate::paged_attention::{calculate_cache_config, AttentionImplementation, CacheEngine};
24use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig};
25use crate::pipeline::isq::UqffFullSer;
26use crate::pipeline::loaders::auto_device_map;
27use crate::pipeline::loaders::QuantizationConfigShim;
28use crate::pipeline::sampling::sample_and_add_toks;
29use crate::pipeline::text_models_inputs_processor::make_prompt_chunk;
30use crate::pipeline::{get_chat_template, Modalities, SupportedModality};
31use crate::pipeline::{ChatTemplate, LocalModelPaths};
32use crate::prefix_cacher::PrefixCacheManagerV2;
33use crate::sequence::Sequence;
34use crate::utils::tokenizer::get_tokenizer;
35use crate::utils::varbuilder_utils::DeviceForLoadTensor;
36use crate::utils::{tokens::get_token, varbuilder_utils::from_mmaped_safetensors};
37use crate::xlora_models::NonGranularState;
38use crate::{
39 api_dir_list, api_get_file, get_mut_arcmutex, get_paths, get_uqff_paths, lora_model_loader,
40 normal_model_loader, normal_model_loader_sharded, xlora_model_loader, DeviceMapSetting,
41 PagedAttentionConfig, Pipeline, Topology, TryIntoDType, GLOBAL_HF_CACHE,
42};
43use anyhow::Result;
44use candle_core::{Device, Tensor, Var};
45use hf_hub::Cache;
46use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
47use indicatif::MultiProgress;
48use mistralrs_quant::log::once_log_info;
49use mistralrs_quant::{AfqLayer, GgufMatMul, HqqLayer, IsqType, QuantizedSerdeType};
50use rand_isaac::Isaac64Rng;
51use regex_automata::meta::Regex;
52use std::any::Any;
53use std::borrow::Cow;
54use std::num::{NonZero, NonZeroUsize};
55use std::path::{Path, PathBuf};
56use std::str::FromStr;
57use std::sync::{Arc, RwLock};
58use std::time::Instant;
59use std::{env, fs};
60use tokenizers::Tokenizer;
61use tokio::sync::Mutex;
62use tracing::{info, warn};
63
64pub struct NormalPipeline {
65 model: Box<dyn NormalModel + Send + Sync>,
66 tokenizer: Arc<Tokenizer>,
67 no_kv_cache: bool,
68 chat_template: Arc<ChatTemplate>,
69 non_granular_state: Option<NonGranularState>,
70 model_id: String,
71 metadata: Arc<GeneralMetadata>,
72 topology: Option<Topology>,
73 silent: bool,
74 organization: IsqOrganization,
75 template_filename: Option<PathBuf>,
77 generation_config: Option<PathBuf>,
78 config: String,
79 imatrix: Option<PathBuf>,
80 mapper: Box<dyn DeviceMapper + Send + Sync>,
81}
82
83pub struct NormalLoader {
85 inner: Box<dyn NormalModelLoader>,
86 model_id: String,
87 config: NormalSpecificConfig,
88 xlora_model_id: Option<String>,
89 lora_adapter_ids: Option<Vec<String>>,
90 kind: ModelKind,
91 xlora_order: Option<Ordering>,
92 no_kv_cache: bool,
93 chat_template: Option<String>,
94 tokenizer_json: Option<String>,
95 tgt_non_granular_index: Option<usize>,
96 token_source: RwLock<Option<TokenSource>>,
97 revision: RwLock<Option<String>>,
98 from_uqff: RwLock<Option<Vec<PathBuf>>>,
99 jinja_explicit: Option<String>,
100 hf_cache_path: Option<PathBuf>,
101}
102
103#[derive(Default)]
104pub struct NormalLoaderBuilder {
106 model_id: Option<String>,
107 config: NormalSpecificConfig,
108 xlora_model_id: Option<String>,
109 lora_adapter_ids: Option<Vec<String>>,
110 kind: ModelKind,
111 xlora_order: Option<Ordering>,
112 no_kv_cache: bool,
113 chat_template: Option<String>,
114 tokenizer_json: Option<String>,
115 tgt_non_granular_index: Option<usize>,
116 jinja_explicit: Option<String>,
117 hf_cache_path: Option<PathBuf>,
118}
119
120#[derive(Clone, Default)]
121pub struct NormalSpecificConfig {
123 pub prompt_chunksize: Option<NonZeroUsize>,
124 pub topology: Option<Topology>,
125 pub organization: IsqOrganization,
126 pub write_uqff: Option<PathBuf>,
127 pub from_uqff: Option<Vec<PathBuf>>,
128 pub imatrix: Option<PathBuf>,
129 pub calibration_file: Option<PathBuf>,
130 pub hf_cache_path: Option<PathBuf>,
131}
132
133impl NormalLoaderBuilder {
134 pub fn new(
135 config: NormalSpecificConfig,
136 chat_template: Option<String>,
137 tokenizer_json: Option<String>,
138 model_id: Option<String>,
139 no_kv_cache: bool,
140 jinja_explicit: Option<String>,
141 ) -> Self {
142 Self {
143 config,
144 chat_template,
145 tokenizer_json,
146 model_id,
147 kind: ModelKind::Normal,
148 jinja_explicit,
149 no_kv_cache,
150 ..Default::default()
151 }
152 }
153
154 fn with_adapter(
155 mut self,
156 xlora_model_id: String,
157 xlora_order: Ordering,
158 no_kv_cache: bool,
159 tgt_non_granular_index: Option<usize>,
160 ) -> Self {
161 self.xlora_model_id = Some(xlora_model_id);
162 self.xlora_order = Some(xlora_order);
163 self.no_kv_cache = no_kv_cache;
164 self.tgt_non_granular_index = tgt_non_granular_index;
165 self.model_id = if let Some(id) = self.model_id {
166 Some(id)
167 } else {
168 info!(
169 "Using adapter base model ID: `{}`",
170 self.xlora_order.as_ref().unwrap().base_model_id
171 );
172 Some(self.xlora_order.as_ref().unwrap().base_model_id.clone())
173 };
174 self
175 }
176
177 pub fn with_xlora(
178 mut self,
179 xlora_model_id: String,
180 xlora_order: Ordering,
181 no_kv_cache: bool,
182 tgt_non_granular_index: Option<usize>,
183 ) -> Self {
184 self.kind = ModelKind::Adapter {
185 adapter: AdapterKind::XLora,
186 };
187 self.with_adapter(
188 xlora_model_id,
189 xlora_order,
190 no_kv_cache,
191 tgt_non_granular_index,
192 )
193 }
194
195 pub fn with_lora(mut self, lora_adapter_ids: Vec<String>) -> Self {
196 self.kind = ModelKind::Adapter {
197 adapter: AdapterKind::Lora,
198 };
199 self.lora_adapter_ids = Some(lora_adapter_ids);
200 self
201 }
202
203 pub fn hf_cache_path(mut self, hf_cache_path: PathBuf) -> Self {
204 self.hf_cache_path = Some(hf_cache_path);
205 self
206 }
207
208 pub fn build(self, loader_tp: Option<NormalLoaderType>) -> anyhow::Result<Box<dyn Loader>> {
211 let loader: Box<dyn NormalModelLoader> = match loader_tp {
212 Some(NormalLoaderType::Mistral) => Box::new(MistralLoader),
213 Some(NormalLoaderType::Gemma) => Box::new(GemmaLoader),
214 Some(NormalLoaderType::Llama) => Box::new(LlamaLoader),
215 Some(NormalLoaderType::Mixtral) => Box::new(MixtralLoader),
216 Some(NormalLoaderType::Phi2) => Box::new(Phi2Loader),
217 Some(NormalLoaderType::Phi3) => Box::new(Phi3Loader),
218 Some(NormalLoaderType::Qwen2) => Box::new(Qwen2Loader),
219 Some(NormalLoaderType::Gemma2) => Box::new(Gemma2Loader),
220 Some(NormalLoaderType::Starcoder2) => Box::new(Starcoder2Loader),
221 Some(NormalLoaderType::Phi3_5MoE) => Box::new(Phi3_5MoELoader),
222 Some(NormalLoaderType::DeepSeekV2) => Box::new(DeepSeekV2Loader),
223 Some(NormalLoaderType::DeepSeekV3) => Box::new(DeepSeekV3Loader),
224 Some(NormalLoaderType::Qwen3) => Box::new(Qwen3Loader),
225 Some(NormalLoaderType::Qwen3Moe) => Box::new(Qwen3MoELoader),
226 None => Box::new(AutoNormalLoader),
227 };
228 Ok(Box::new(NormalLoader {
229 inner: loader,
230 model_id: self.model_id.unwrap(),
231 config: self.config,
232 xlora_model_id: self.xlora_model_id,
233 lora_adapter_ids: self.lora_adapter_ids,
234 kind: self.kind,
235 xlora_order: self.xlora_order,
236 no_kv_cache: self.no_kv_cache,
237 chat_template: self.chat_template,
238 tokenizer_json: self.tokenizer_json,
239 tgt_non_granular_index: self.tgt_non_granular_index,
240 jinja_explicit: self.jinja_explicit,
241 token_source: RwLock::new(None),
242 revision: RwLock::new(None),
243 from_uqff: RwLock::new(None),
244 hf_cache_path: self.hf_cache_path,
245 }))
246 }
247}
248
249impl Loader for NormalLoader {
250 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
251 fn load_model_from_hf(
252 &self,
253 revision: Option<String>,
254 token_source: TokenSource,
255 dtype: &dyn TryIntoDType,
256 device: &Device,
257 silent: bool,
258 mapper: DeviceMapSetting,
259 in_situ_quant: Option<IsqType>,
260 paged_attn_config: Option<PagedAttentionConfig>,
261 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
262 let cache = self
263 .hf_cache_path
264 .clone()
265 .map(Cache::new)
266 .unwrap_or_default();
267 GLOBAL_HF_CACHE.get_or_init(|| cache);
268
269 let paths: anyhow::Result<Box<dyn ModelPaths>> = get_paths!(
270 LocalModelPaths,
271 &token_source,
272 revision.clone(),
273 self,
274 None,
275 None,
276 silent,
277 self.config.from_uqff.is_some()
278 );
279 if let Some(from_uqff) = self.config.from_uqff.clone() {
280 *self.from_uqff.write().unwrap() = Some(get_uqff_paths!(&from_uqff, self, silent));
281 }
282 *self
283 .token_source
284 .write()
285 .expect("Failed to write to token source") = Some(token_source);
286 *self.revision.write().expect("Failed to write to revision") = revision;
287 self.load_model_from_path(
288 &paths?,
289 dtype,
290 device,
291 silent,
292 mapper,
293 in_situ_quant,
294 paged_attn_config,
295 )
296 }
297
298 #[allow(clippy::type_complexity, clippy::too_many_arguments)]
299 fn load_model_from_path(
300 &self,
301 paths: &Box<dyn ModelPaths>,
302 dtype: &dyn TryIntoDType,
303 device: &Device,
304 silent: bool,
305 mut mapper: DeviceMapSetting,
306 mut in_situ_quant: Option<IsqType>,
307 mut paged_attn_config: Option<PagedAttentionConfig>,
308 ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
309 let config = std::fs::read_to_string(paths.get_config_filename())?;
310
311 if !self.inner.supports_paged_attention(&config)? {
312 paged_attn_config = None;
313 }
314
315 let prompt_chunksize = self
317 .config
318 .prompt_chunksize
319 .unwrap_or(DEFAULT_PROMPT_CHUNK_SIZE.try_into().unwrap())
320 .get();
321
322 info!("Prompt chunk size is {prompt_chunksize}.",);
323
324 let use_nccl = mistralrs_quant::distributed::use_nccl();
325
326 let available_devices = if let Ok(payload) = env::var(distributed::IS_DAEMON_FLAG) {
327 let payload: WorkerTransferData = serde_json::from_str(&payload)?;
328 let WorkerTransferData::Init { id: _, worker_rank } = payload;
329 vec![candle_core::Device::new_cuda(worker_rank + 1)?]
330 } else if use_nccl {
331 vec![candle_core::Device::new_cuda(0)?]
332 } else {
333 device_map::get_all_similar_devices(device)?
334 };
335 let device = if use_nccl || cfg!(feature = "ring") {
336 available_devices[0].clone()
337 } else {
338 device.clone()
339 };
340
341 if use_nccl || cfg!(feature = "ring") {
343 mapper = DeviceMapSetting::DummyNccl {
344 nm_device: available_devices[0].clone(),
345 };
346 } else if let DeviceMapSetting::Auto(params) = mapper.clone() {
347 let dtype = dtype.try_into_dtype(&available_devices.iter().collect::<Vec<_>>())?;
349
350 if QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)? != 1 {
352 in_situ_quant = None;
353 }
354
355 let (layer_sizes_in_bytes, non_mapped_size_in_bytes, total_model_size_in_bytes) =
358 if let Some(serialized) = &*self.from_uqff.read().unwrap() {
359 let weight_pack_factor = {
360 let ser_artifacts = unsafe {
361 candle_core::safetensors::MmapedSafetensors::multi(serialized)?
362 };
363 let mut total_pack_factors = 0;
364 let total_tensors = ser_artifacts.tensors().len();
365 for (_, artifact) in ser_artifacts.tensors() {
366 let artifact = artifact.data();
367 let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
369 let pack_factor = match QuantizedSerdeType::try_from(isq_type as usize)?
370 {
371 QuantizedSerdeType::Hqq => {
372 HqqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
373 .pack_factor(dtype)
374 }
375 QuantizedSerdeType::Gguf => {
376 GgufMatMul::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
377 .pack_factor(dtype)
378 }
379 QuantizedSerdeType::Fp8 => IsqType::F8E4M3.pack_factor(dtype),
380 QuantizedSerdeType::Unquant => 1,
381 QuantizedSerdeType::Afq => {
382 AfqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
383 .pack_factor(dtype)
384 }
385 };
386 total_pack_factors += pack_factor;
387 }
388
389 total_pack_factors / total_tensors
390 };
391
392 let layer_sizes_in_bytes =
393 self.inner
394 .layer_sizes_in_bytes(&config, dtype, weight_pack_factor)?;
395 let non_mapped_size_in_bytes =
396 self.inner
397 .non_mapped_size_in_bytes(&config, dtype, weight_pack_factor)?;
398 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
399 (
400 layer_sizes_in_bytes,
401 non_mapped_size_in_bytes,
402 layer_sizes_sum + non_mapped_size_in_bytes,
403 )
404 } else if let Some(isq) = in_situ_quant {
405 let weight_pack_factor = isq.pack_factor(dtype);
406 let layer_sizes_in_bytes =
407 self.inner
408 .layer_sizes_in_bytes(&config, dtype, weight_pack_factor)?;
409 let non_mapped_size_in_bytes =
410 self.inner
411 .non_mapped_size_in_bytes(&config, dtype, weight_pack_factor)?;
412 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
413 (
414 layer_sizes_in_bytes,
415 non_mapped_size_in_bytes,
416 layer_sizes_sum + non_mapped_size_in_bytes,
417 )
418 } else {
419 let weight_pack_factor =
421 QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)?;
422 let layer_sizes_in_bytes =
423 self.inner
424 .layer_sizes_in_bytes(&config, dtype, weight_pack_factor)?;
425 let non_mapped_size_in_bytes =
426 self.inner
427 .non_mapped_size_in_bytes(&config, dtype, weight_pack_factor)?;
428 let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
429 (
430 layer_sizes_in_bytes,
431 non_mapped_size_in_bytes,
432 layer_sizes_sum + non_mapped_size_in_bytes,
433 )
434 };
435
436 let new = auto_device_map::get_device_layers(
437 &*self.inner,
438 &config,
439 self.inner.num_layers(&config)?,
440 layer_sizes_in_bytes,
441 non_mapped_size_in_bytes,
442 total_model_size_in_bytes,
443 &available_devices,
444 dtype,
445 ¶ms,
446 prompt_chunksize,
447 paged_attn_config.as_ref(),
448 )?;
449 mapper = DeviceMapSetting::Map(new);
450 }
451
452 let pipeline_mapper = mapper.into_mapper(
453 self.inner.num_layers(&config)?,
454 &device,
455 self.config.topology.as_ref(),
456 )?;
457 let mapper = mapper.into_mapper(
458 self.inner.num_layers(&config)?,
459 &device,
460 self.config.topology.as_ref(),
461 )?;
462 let mut layer_devices = Vec::new();
463 for layer in 0..self.inner.num_layers(&config)? {
464 let device = mapper.device_for(layer, false).cloned();
465 layer_devices.push(device);
466 }
467 let dtype = mapper.get_min_dtype(dtype)?;
468
469 let mapping_uses_cpu = mapper.get_unique_devices().iter().any(Device::is_cpu);
472 if mapping_uses_cpu {
473 warn!("Device mapping contains a mix of GPU and CPU. There is no CPU support for PagedAttention, disabling PagedAttention.");
474 paged_attn_config = None;
475 }
476
477 info!("Model config: {:?}", self.inner.get_config_repr(&config)?);
478 if crate::using_flash_attn() {
479 once_log_info("FlashAttention is enabled.");
480 }
481
482 let mut loading_isq = if self.config.imatrix.is_none()
484 && self.config.calibration_file.is_none()
485 && !device.is_cuda()
486 && self.config.write_uqff.is_none()
487 && in_situ_quant.is_some()
488 {
489 let predicates = if matches!(self.config.organization, IsqOrganization::MoeExpertsOnly)
490 {
491 self.inner.immediate_isq_predicates_moqe(&config)?
492 } else {
493 self.inner.immediate_isq_predicates(&config)?
494 };
495 info!("Applying ISQ to {in_situ_quant:?}");
496 if predicates.is_empty() {
497 warn!("No predicates for this model and ISQ setting detected. ISQ will not be applied to any weights!");
498 }
499 mistralrs_quant::set_immediate_isq(in_situ_quant, predicates);
500 false
501 } else {
502 in_situ_quant.is_some()
503 };
504
505 if let Some(ref topology) = self.config.topology {
506 loading_isq |= topology
507 .0
508 .iter()
509 .any(|layer| layer.as_ref().is_some_and(|layer| layer.isq.is_some()));
510 }
511
512 if self.config.imatrix.is_some() && self.config.calibration_file.is_some() {
513 anyhow::bail!(
514 "`imatrix` and `calibration_file` were both specified, this is not allowed."
515 );
516 }
517
518 let load_device = if !loading_isq || self.config.calibration_file.is_some() {
520 loading_isq = false;
521 device.clone()
522 } else {
523 Device::Cpu
524 };
525
526 let is_xlora = self.kind.is_adapted_and(|a| a.is_x_lora());
527
528 let attention_mechanism = if paged_attn_config.is_some() {
529 AttentionImplementation::PagedAttention
530 } else {
531 AttentionImplementation::Eager
532 };
533
534 let multi_progress = Arc::new(MultiProgress::new());
535
536 let mut model = if use_nccl || cfg!(feature = "ring") {
537 let (mapper, sharded_vb) = distributed::prepare_distributed_mapper(
538 dtype,
539 &device,
540 &available_devices,
541 silent,
542 &config,
543 loading_isq,
544 self.config.from_uqff.is_some(),
545 self.config.organization,
546 &*self.inner,
547 paths.as_ref(),
548 )?;
549
550 match self.kind {
552 ModelKind::Normal => normal_model_loader_sharded!(
553 sharded_vb,
554 config,
555 self.inner,
556 mapper,
557 loading_isq,
558 device.clone(),
559 attention_mechanism,
560 multi_progress.clone(),
561 ),
562 ModelKind::Adapter {
563 adapter: AdapterKind::XLora,
564 } => xlora_model_loader!(
565 paths,
566 Some(dtype),
567 &load_device,
568 layer_devices.clone(),
569 config,
570 self.inner,
571 silent,
572 mapper,
573 loading_isq,
574 device.clone(),
575 multi_progress.clone(),
576 ),
577 ModelKind::Adapter {
578 adapter: AdapterKind::Lora,
579 } => lora_model_loader!(
580 paths,
581 Some(dtype),
582 &load_device,
583 layer_devices.clone(),
584 config,
585 self.inner,
586 silent,
587 mapper,
588 loading_isq,
589 self.config.from_uqff.is_some(),
590 device.clone(),
591 attention_mechanism,
592 matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
593 multi_progress.clone(),
594 ),
595 _ => unreachable!(),
596 }
597 } else {
598 match self.kind {
599 ModelKind::Normal => normal_model_loader!(
600 paths,
601 Some(dtype),
602 &load_device,
603 layer_devices.clone(),
604 config,
605 self.inner,
606 silent,
607 mapper,
608 loading_isq,
609 self.config.from_uqff.is_some(),
610 device.clone(),
611 attention_mechanism,
612 matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
613 multi_progress.clone(),
614 ),
615 ModelKind::Adapter {
616 adapter: AdapterKind::XLora,
617 } => xlora_model_loader!(
618 paths,
619 Some(dtype),
620 &load_device,
621 layer_devices.clone(),
622 config,
623 self.inner,
624 silent,
625 mapper,
626 loading_isq,
627 device.clone(),
628 multi_progress.clone(),
629 ),
630 ModelKind::Adapter {
631 adapter: AdapterKind::Lora,
632 } => lora_model_loader!(
633 paths,
634 Some(dtype),
635 &load_device,
636 layer_devices.clone(),
637 config,
638 self.inner,
639 silent,
640 mapper,
641 loading_isq,
642 self.config.from_uqff.is_some(),
643 device.clone(),
644 attention_mechanism,
645 matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
646 multi_progress.clone(),
647 ),
648 _ => unreachable!(),
649 }
650 };
651
652 let tokenizer = get_tokenizer(paths.get_tokenizer_filename(), None)?;
653 let gen_conf: Option<GenerationConfig> = paths.get_gen_conf_filename().map(|f| {
654 serde_json::from_str(&fs::read_to_string(f).unwrap())
655 .expect("bos_token_id/eos_token_id missing in generation_config.json")
656 });
657
658 let chat_template_explicit = paths
659 .get_chat_template_explicit()
660 .as_ref()
661 .map(|x| x.to_string_lossy().to_string());
662 let chat_template = get_chat_template(
663 paths,
664 self.jinja_explicit.as_ref(),
665 chat_template_explicit.as_ref(),
666 self.chat_template.as_ref(),
667 None,
668 );
669
670 if let Some(calibration_file) = &self.config.calibration_file {
671 let calibration_data = std::fs::read_to_string(calibration_file)?;
672 let tokens = tokenizer
674 .encode_fast(calibration_data, false)
675 .map_err(anyhow::Error::msg)?
676 .get_ids()
677 .to_vec();
678 info!(
679 "Collecting imatrix from calibration file `{}` of {} tokens.",
680 calibration_file.display(),
681 tokens.len()
682 );
683 let bos_toks = chat_template.bos_tok().map(|b| vec![b]).unwrap_or_default();
684 let bos_tok_id = tokenizer
685 .token_to_id(&bos_toks[0])
686 .expect("Somehow the bos token is not present.");
687
688 match self.config.organization {
689 IsqOrganization::Default => model.begin_track_stats()?,
690 IsqOrganization::MoeExpertsOnly => model.begin_track_stats_moe_experts_only()?,
691 }
692
693 const CHUNK_SIZE: usize = 1024;
694 let n_chunks = tokens.len().div_ceil(CHUNK_SIZE);
695 let start = Instant::now();
696 for (i, chunk) in tokens.chunks(CHUNK_SIZE).enumerate() {
697 let chunk = [vec![bos_tok_id], chunk.to_vec()].concat();
698 let chunk_len = chunk.len();
699
700 let start = Instant::now();
701 let inputs = make_prompt_chunk(
702 0,
703 vec![&chunk],
704 &[0],
705 &load_device,
706 None,
707 false,
708 None,
709 Some(pipeline_mapper.as_ref()),
710 )?;
711
712 model.forward(
713 &inputs.input.to_device(model.device())?,
714 &inputs.positions,
715 inputs.context_lens.clone(),
716 inputs.position_ids.clone(),
717 None,
718 &inputs.flash_meta.clone(),
719 )?;
720
721 match model.cache_mut() {
722 EitherCache::Full(full) => {
723 for layer in &mut *full.lock() {
724 *layer = None
725 }
726 }
727 EitherCache::Normal(normal) => {
728 for layer in &mut *normal.lock().unwrap().0 {
729 layer.reset();
730 }
731 }
732 }
733
734 let end = Instant::now();
735 info!(
736 "Processed chunk {}/{n_chunks} ({chunk_len} tokens), {:.2}s",
737 i + 1,
738 end.duration_since(start).as_secs_f32()
739 );
740 }
741 load_device.synchronize()?;
742 let end = Instant::now();
743 info!(
744 "Finished collecting imatrix in {:.2}s",
745 end.duration_since(start).as_secs_f32()
746 );
747 }
748
749 if (loading_isq || self.config.topology.is_some()) && self.config.from_uqff.is_none() {
751 let imatrix_source = match (
752 self.config.imatrix.as_ref(),
753 self.config.calibration_file.is_some(),
754 ) {
755 (None, false) => None,
756 (Some(file), false) => Some(ImatrixDataSource::File(file)),
757 (None, true) => Some(ImatrixDataSource::Collected),
758 (Some(_), true) => unreachable!(),
759 };
760
761 info!("Applying ISQ to all ranks.");
762
763 let multi_progress = Arc::new(MultiProgress::new());
764
765 model.quantize(
766 in_situ_quant,
767 model.device().clone(),
768 self.config.topology.as_ref(),
769 silent,
770 imatrix_source,
771 self.config.organization,
772 self.config.write_uqff.as_ref(),
773 UqffFullSer {
774 tokenizer: &tokenizer,
775 template_filename: paths.get_template_filename(),
776 generation_config: paths.get_gen_conf_filename(),
777 config: config.clone(),
778 processor_filename: &None,
779 preprocessor_filename: &None,
780 },
781 multi_progress.clone(),
782 )?;
783 } else if let Some(from_uqff) = &*self.from_uqff.read().unwrap() {
784 model.load_from_artifacts(
785 device.clone(),
786 self.config.topology.as_ref(),
787 silent,
788 from_uqff,
789 )?;
790 }
791
792 let paged_attn_config = if matches!(
793 self.kind,
794 ModelKind::Adapter {
795 adapter: AdapterKind::XLora
796 }
797 ) {
798 warn!(
799 "Adapter parallel_models do not currently support PagedAttention, running without"
800 );
801 None
802 } else {
803 paged_attn_config
804 };
805
806 let (cache_config, cache_engine) = if let Some(paged_attn_config) = paged_attn_config {
807 let cache_config = calculate_cache_config(
808 paged_attn_config.mem_gpu,
809 paged_attn_config.mem_cpu,
810 paged_attn_config.block_size,
811 dtype,
812 model.config(),
813 &device,
814 &pipeline_mapper
815 .get_unique_devices()
816 .into_iter()
817 .map(Some)
818 .collect::<Vec<_>>(),
819 silent,
820 )?;
821
822 let mut layer_devices = Vec::new();
823 for layer in 0..self.inner.num_layers(&config)? {
824 let device = model.get_layers().1.device_for(layer, false).cloned();
825 layer_devices.push(device);
826 }
827 let cache_engine = CacheEngine::new(
828 model.config(),
829 &cache_config,
830 dtype,
831 model.device(),
832 layer_devices.clone(),
833 )?;
834
835 (Some(cache_config), Some(cache_engine))
836 } else {
837 (None, None)
838 };
839
840 let max_seq_len = model.max_seq_len();
841 let llg_factory = build_llg_factory(tokenizer.clone())?;
842 let num_hidden_layers = match model.cache() {
843 EitherCache::Full(full) => full.lock().len(),
844 EitherCache::Normal(normal) => normal.lock().unwrap().0.len(),
845 };
846 let eos = calculate_eos_tokens(&chat_template, gen_conf, &tokenizer);
847 let sliding_window = model.config().sliding_window;
848 let model_metadata = Arc::new(model.config().clone());
849
850 Ok(Arc::new(Mutex::new(NormalPipeline {
851 model,
852 tokenizer: tokenizer.into(),
853 no_kv_cache: self.no_kv_cache,
854 chat_template: Arc::new(chat_template),
855 non_granular_state: self.tgt_non_granular_index.map(|tgt_non_granular_index| {
856 NonGranularState {
857 non_granular_index: Arc::new(Mutex::new(0)),
858 tgt_non_granular_index,
859 }
860 }),
861 model_id: self.model_id.clone(),
862 metadata: Arc::new(GeneralMetadata {
863 max_seq_len,
864 llg_factory: Some(llg_factory),
865 no_kv_cache: self.no_kv_cache,
866 no_prefix_cache: is_xlora,
867 num_hidden_layers,
868 eos_tok: eos,
869 kind: self.kind.clone(),
870 is_xlora,
871 activation_dtype: dtype,
872 sliding_window,
873 cache_config,
874 cache_engine,
875 prompt_chunksize: Some(NonZero::new(prompt_chunksize).unwrap()),
876 model_metadata: Some(model_metadata),
877 modalities: Modalities {
878 input: vec![SupportedModality::Text],
879 output: vec![SupportedModality::Text],
880 },
881 }),
882 topology: self.config.topology.clone(),
883 silent,
884 organization: self.config.organization,
885 template_filename: paths.get_template_filename().clone(),
886 generation_config: paths.get_gen_conf_filename().cloned(),
887 config,
888 imatrix: self.config.imatrix.clone(),
889 mapper: pipeline_mapper,
890 })))
891 }
892
893 fn get_id(&self) -> String {
894 self.model_id.clone()
895 }
896
897 fn get_kind(&self) -> ModelKind {
898 self.kind.clone()
899 }
900}
901
902impl PreProcessingMixin for NormalPipeline {
903 fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
904 Some(self.chat_template.clone())
905 }
906 fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
907 None
908 }
909}
910
911impl IsqPipelineMixin for NormalPipeline {
912 fn re_isq_model(&mut self, dtype: IsqType) -> Result<()> {
913 let device = self.device().clone();
914 let multi_progress = Arc::new(MultiProgress::new());
915 self.model.quantize(
916 Some(dtype),
917 device.clone(),
918 self.topology.as_ref(),
919 self.silent,
920 self.imatrix.as_ref().map(ImatrixDataSource::File),
921 self.organization,
922 None,
923 UqffFullSer {
924 tokenizer: &self.tokenizer,
925 template_filename: &self.template_filename,
926 generation_config: self.generation_config.as_ref(),
927 config: self.config.clone(),
928 processor_filename: &None,
929 preprocessor_filename: &None,
930 },
931 multi_progress.clone(),
932 )?;
933 Ok(())
934 }
935}
936
937impl CacheManagerMixin for NormalPipeline {
938 fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
939 if matches!(self.model.cache(), EitherCache::Full(_)) {
940 FullCacheManager.clone_in_cache(self, seqs, false)
941 } else {
942 NormalCacheManager.clone_in_cache(self, seqs, false)
943 }
944 }
945 fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
946 if matches!(self.model.cache(), EitherCache::Full(_)) {
947 FullCacheManager.clone_out_cache(self, seqs, false)
948 } else {
949 NormalCacheManager.clone_out_cache(self, seqs, false)
950 }
951 }
952 fn set_none_cache(
953 &self,
954 seqs: &mut [&mut Sequence],
955 reset_non_granular: bool,
956 modify_draft_cache: bool,
957 load_preallocated_cache: bool,
958 ) {
959 if matches!(self.model.cache(), EitherCache::Full(_)) {
960 FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, false);
961 } else {
962 NormalCacheManager.set_none_cache(
963 self,
964 seqs,
965 modify_draft_cache,
966 load_preallocated_cache,
967 );
968 }
969 if reset_non_granular {
970 self.reset_non_granular_state()
971 }
972 }
973 fn cache(&self) -> &EitherCache {
974 self.model.cache()
975 }
976}
977
978impl MetadataMixin for NormalPipeline {
979 fn device(&self) -> Device {
980 self.model.device().clone()
981 }
982 fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
983 Some(self.tokenizer.clone())
984 }
985 fn name(&self) -> String {
986 self.model_id.clone()
987 }
988 fn reset_non_granular_state(&self) {
989 if let Some(s) = self.non_granular_state.as_ref() {
990 *self.cache().full().get_scalings_cache() = None;
991 *get_mut_arcmutex!(s.non_granular_index) = 0;
992 }
993 }
994 fn get_metadata(&self) -> Arc<GeneralMetadata> {
995 self.metadata.clone()
996 }
997 fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
998 Some(&*self.mapper)
999 }
1000}
1001
1002#[async_trait::async_trait]
1003impl Pipeline for NormalPipeline {
1004 fn forward_inputs(
1005 &mut self,
1006 inputs: Box<dyn Any>,
1007 return_raw_logits: bool,
1008 ) -> Result<ForwardInputsResult, candle_core::Error> {
1009 let ModelInputs {
1010 input_ids,
1011 input_ids_full,
1012 seqlen_offsets,
1013 seqlen_offsets_full,
1014 context_lens,
1015 position_ids,
1016 paged_attn_meta,
1017 flash_meta,
1018 flash_meta_full,
1019 } = *inputs.downcast().expect("Downcast failed.");
1020 let metadata = self.get_metadata();
1021 let paged_attn_meta = match (&metadata.cache_engine, &paged_attn_meta) {
1022 (Some(cache_engine), Some(meta)) => Some((cache_engine, meta)),
1023 (Some(_), None) => {
1024 candle_core::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.")
1026 }
1027 (None, Some(_)) => {
1028 candle_core::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.")
1030 }
1031 (None, None) => None,
1032 };
1033 let logits = match self.model.is_xlora() {
1034 false => {
1035 let paged_attn_meta = paged_attn_meta
1036 .as_ref()
1037 .map(|meta| (meta.0.get_kv_cache().clone(), meta.1.clone()));
1038
1039 self.model.forward(
1040 &input_ids,
1041 &seqlen_offsets,
1042 context_lens,
1043 position_ids,
1044 paged_attn_meta.as_ref().map(|(a, b)| (a.clone(), b)),
1045 &flash_meta,
1046 )?
1047 }
1048 true => self.model.xlora_forward(
1049 &input_ids,
1050 input_ids_full.as_ref().unwrap_or(&input_ids),
1051 &seqlen_offsets,
1052 seqlen_offsets_full.as_ref().unwrap_or(&seqlen_offsets),
1053 self.no_kv_cache,
1054 &self.non_granular_state,
1055 context_lens,
1056 position_ids,
1057 &flash_meta,
1058 flash_meta_full.as_ref().unwrap_or(&flash_meta),
1059 )?,
1060 };
1061 if return_raw_logits {
1062 Ok(ForwardInputsResult::RawLogits { logits })
1063 } else {
1064 Ok(ForwardInputsResult::CausalGeneration { logits })
1065 }
1066 }
1067 async fn sample_causal_gen(
1068 &self,
1069 seqs: &mut [&mut Sequence],
1070 logits: Vec<Tensor>,
1071 prefix_cacher: &mut PrefixCacheManagerV2,
1072 disable_eos_stop: bool,
1073 rng: Arc<std::sync::Mutex<Isaac64Rng>>,
1074 ) -> Result<(), candle_core::Error> {
1075 sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
1076 }
1077 fn category(&self) -> ModelCategory {
1078 ModelCategory::Text
1079 }
1080}
1081
1082impl AnyMoePipelineMixin for NormalPipeline {
1083 fn amoe_finish_training(&mut self, gate_model_id: Option<String>) -> candle_core::Result<()> {
1084 self.model.finish_training(gate_model_id)
1085 }
1086 fn amoe_layer_vars(&self) -> Vec<Vec<Var>> {
1087 self.model.get_vars()
1088 }
1089 fn amoe_base_model_trainable_params(&self) -> usize {
1090 self.model.trainable_params()
1091 }
1092 fn amoe_take_cached_gating_outputs(&mut self) -> Vec<Tensor> {
1093 self.model.take_cached_gating_outputs()
1094 }
1095 fn amoe_create_layers(
1096 &mut self,
1097 model_ids: Vec<String>,
1098 token: &TokenSource,
1099 revision: Option<String>,
1100 match_regex: &str,
1101 config: crate::amoe::AnyMoeConfig,
1102 dtype: candle_core::DType,
1103 dev: &Device,
1104 (prefix, mlp): (String, String),
1105 layers: Vec<usize>,
1106 expert_type: AnyMoeExpertType,
1107 silent: bool,
1108 gate_model_id: Option<String>,
1109 ) -> candle_core::Result<()> {
1110 let mut vbs = Vec::new();
1111 let regex = Regex::new(match_regex).map_err(candle_core::Error::msg)?;
1113 for model_id in model_ids {
1114 let model_id_str = &model_id;
1115 let model_id = Path::new(&model_id);
1116
1117 let api = {
1118 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1119 let mut api = ApiBuilder::from_cache(cache)
1120 .with_progress(!silent)
1121 .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1122 if let Ok(x) = std::env::var("HF_HUB_CACHE") {
1123 api = api.with_cache_dir(x.into());
1124 }
1125 api.build().map_err(candle_core::Error::msg)?
1126 };
1127 let revision = revision.clone().unwrap_or("main".to_string());
1128 let api = api.repo(Repo::with_revision(
1129 model_id_str.clone(),
1130 RepoType::Model,
1131 revision.clone(),
1132 ));
1133
1134 let mut filenames = vec![];
1135 for rfilename in
1136 api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
1137 {
1138 filenames.push(api_get_file!(api, &rfilename, model_id));
1139 }
1140
1141 let regex = regex.clone();
1142 let match_regex_clone = match_regex.to_string();
1143 let layers_clone = layers.clone();
1144 let vb = from_mmaped_safetensors(
1145 filenames,
1146 vec![],
1147 Some(dtype),
1148 dev,
1149 vec![None],
1150 silent,
1151 None,
1152 move |key| {
1153 if regex.is_match(&key) {
1154 let last_layer_idx = key.find(&match_regex_clone).unwrap() - 1;
1157 let first_layer_idx = key[..last_layer_idx].rfind('.').unwrap();
1158 let layer_n = key[first_layer_idx + 1..last_layer_idx]
1159 .parse::<usize>()
1160 .unwrap();
1161 layers_clone.contains(&layer_n) || layers_clone.is_empty()
1162 } else {
1163 false
1164 }
1165 },
1166 Arc::new(|_| DeviceForLoadTensor::Base),
1167 )?;
1168 vbs.push(vb);
1169 }
1170
1171 let gate_vb = if let Some(gate_model_id) = gate_model_id {
1172 let model_id_str = &gate_model_id;
1173 let model_id = Path::new(&gate_model_id);
1174
1175 let api = {
1176 let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1177 let mut api = ApiBuilder::from_cache(cache)
1178 .with_progress(!silent)
1179 .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1180 if let Ok(x) = std::env::var("HF_HUB_CACHE") {
1181 api = api.with_cache_dir(x.into());
1182 }
1183 api.build().map_err(candle_core::Error::msg)?
1184 };
1185 let revision = revision.clone().unwrap_or("main".to_string());
1186 let api = api.repo(Repo::with_revision(
1187 model_id_str.clone(),
1188 RepoType::Model,
1189 revision.clone(),
1190 ));
1191
1192 let mut gate_filenames = vec![];
1193 for rfilename in
1194 api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
1195 {
1196 gate_filenames.push(api_get_file!(api, &rfilename, model_id));
1197 }
1198 assert_eq!(
1199 gate_filenames.len(),
1200 1,
1201 "Gate model ID must contain only one .safetensors file"
1202 );
1203
1204 let vb = from_mmaped_safetensors(
1205 gate_filenames.clone(),
1206 vec![],
1207 Some(dtype),
1208 dev,
1209 vec![None],
1210 silent,
1211 None,
1212 |_| true,
1213 Arc::new(|_| DeviceForLoadTensor::Base),
1214 )?;
1215 info!(
1216 "Loaded gating layers from `{}`",
1217 gate_filenames[0].display()
1218 );
1219 Some(vb)
1220 } else {
1221 None
1222 };
1223
1224 self.model.create_anymoe_layers(
1225 vbs.clone(),
1226 config.clone(),
1227 (prefix.clone(), mlp.clone()),
1228 layers.clone(),
1229 expert_type.clone(),
1230 gate_vb.clone(),
1231 )?;
1232
1233 Ok(())
1234 }
1235 fn amoe_supported(&self) -> bool {
1236 self.model.amoe_supported()
1237 }
1238}