mistralrs_core/pipeline/
normal.rs

1use super::inputs_processor::DEFAULT_PROMPT_CHUNK_SIZE;
2use super::isq::ImatrixDataSource;
3use super::llg::build_llg_factory;
4use super::{
5    get_model_paths, get_xlora_paths, text_models_inputs_processor::ModelInputs, AdapterKind,
6    CacheManager, GeneralMetadata, Loader, ModelKind, ModelPaths, NormalModel, NormalModelLoader,
7    TokenSource,
8};
9use super::{
10    AnyMoePipelineMixin, CacheManagerMixin, EitherCache, ForwardInputsResult, IsqOrganization,
11    IsqPipelineMixin, MetadataMixin, ModelCategory, PreProcessingMixin,
12};
13use super::{
14    AutoNormalLoader, DeepSeekV2Loader, DeepSeekV3Loader, Gemma2Loader, GemmaLoader, LlamaLoader,
15    MistralLoader, MixtralLoader, NormalLoaderType, Phi2Loader, Phi3Loader, Phi3_5MoELoader,
16    Qwen2Loader, Qwen3Loader, Qwen3MoELoader, Starcoder2Loader,
17};
18use crate::amoe::AnyMoeExpertType;
19use crate::device_map::{self, DeviceMapper};
20use crate::distributed::{self, WorkerTransferData};
21use crate::kv_cache::{FullCacheManager, NormalCacheManager};
22use crate::lora::Ordering;
23use crate::paged_attention::{calculate_cache_config, AttentionImplementation, CacheEngine};
24use crate::pipeline::chat_template::{calculate_eos_tokens, GenerationConfig};
25use crate::pipeline::isq::UqffFullSer;
26use crate::pipeline::loaders::auto_device_map;
27use crate::pipeline::loaders::QuantizationConfigShim;
28use crate::pipeline::sampling::sample_and_add_toks;
29use crate::pipeline::text_models_inputs_processor::make_prompt_chunk;
30use crate::pipeline::{get_chat_template, Modalities, SupportedModality};
31use crate::pipeline::{ChatTemplate, LocalModelPaths};
32use crate::prefix_cacher::PrefixCacheManagerV2;
33use crate::sequence::Sequence;
34use crate::utils::tokenizer::get_tokenizer;
35use crate::utils::varbuilder_utils::DeviceForLoadTensor;
36use crate::utils::{tokens::get_token, varbuilder_utils::from_mmaped_safetensors};
37use crate::xlora_models::NonGranularState;
38use crate::{
39    api_dir_list, api_get_file, get_mut_arcmutex, get_paths, get_uqff_paths, lora_model_loader,
40    normal_model_loader, normal_model_loader_sharded, xlora_model_loader, DeviceMapSetting,
41    PagedAttentionConfig, Pipeline, Topology, TryIntoDType, GLOBAL_HF_CACHE,
42};
43use anyhow::Result;
44use candle_core::{Device, Tensor, Var};
45use hf_hub::Cache;
46use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
47use indicatif::MultiProgress;
48use mistralrs_quant::log::once_log_info;
49use mistralrs_quant::{AfqLayer, GgufMatMul, HqqLayer, IsqType, QuantizedSerdeType};
50use rand_isaac::Isaac64Rng;
51use regex_automata::meta::Regex;
52use std::any::Any;
53use std::borrow::Cow;
54use std::num::{NonZero, NonZeroUsize};
55use std::path::{Path, PathBuf};
56use std::str::FromStr;
57use std::sync::{Arc, RwLock};
58use std::time::Instant;
59use std::{env, fs};
60use tokenizers::Tokenizer;
61use tokio::sync::Mutex;
62use tracing::{info, warn};
63
64pub struct NormalPipeline {
65    model: Box<dyn NormalModel + Send + Sync>,
66    tokenizer: Arc<Tokenizer>,
67    no_kv_cache: bool,
68    chat_template: Arc<ChatTemplate>,
69    non_granular_state: Option<NonGranularState>,
70    model_id: String,
71    metadata: Arc<GeneralMetadata>,
72    topology: Option<Topology>,
73    silent: bool,
74    organization: IsqOrganization,
75    // For full UQFF serialization
76    template_filename: Option<PathBuf>,
77    generation_config: Option<PathBuf>,
78    config: String,
79    imatrix: Option<PathBuf>,
80    mapper: Box<dyn DeviceMapper + Send + Sync>,
81}
82
83/// A loader for a "normal" (non-quantized) model.
84pub struct NormalLoader {
85    inner: Box<dyn NormalModelLoader>,
86    model_id: String,
87    config: NormalSpecificConfig,
88    xlora_model_id: Option<String>,
89    lora_adapter_ids: Option<Vec<String>>,
90    kind: ModelKind,
91    xlora_order: Option<Ordering>,
92    no_kv_cache: bool,
93    chat_template: Option<String>,
94    tokenizer_json: Option<String>,
95    tgt_non_granular_index: Option<usize>,
96    token_source: RwLock<Option<TokenSource>>,
97    revision: RwLock<Option<String>>,
98    from_uqff: RwLock<Option<Vec<PathBuf>>>,
99    jinja_explicit: Option<String>,
100    hf_cache_path: Option<PathBuf>,
101}
102
103#[derive(Default)]
104/// A builder for a loader for a "normal" (non-quantized) model.
105pub struct NormalLoaderBuilder {
106    model_id: Option<String>,
107    config: NormalSpecificConfig,
108    xlora_model_id: Option<String>,
109    lora_adapter_ids: Option<Vec<String>>,
110    kind: ModelKind,
111    xlora_order: Option<Ordering>,
112    no_kv_cache: bool,
113    chat_template: Option<String>,
114    tokenizer_json: Option<String>,
115    tgt_non_granular_index: Option<usize>,
116    jinja_explicit: Option<String>,
117    hf_cache_path: Option<PathBuf>,
118}
119
120#[derive(Clone, Default)]
121/// Config specific to loading a normal model.
122pub struct NormalSpecificConfig {
123    pub prompt_chunksize: Option<NonZeroUsize>,
124    pub topology: Option<Topology>,
125    pub organization: IsqOrganization,
126    pub write_uqff: Option<PathBuf>,
127    pub from_uqff: Option<Vec<PathBuf>>,
128    pub imatrix: Option<PathBuf>,
129    pub calibration_file: Option<PathBuf>,
130    pub hf_cache_path: Option<PathBuf>,
131}
132
133impl NormalLoaderBuilder {
134    pub fn new(
135        config: NormalSpecificConfig,
136        chat_template: Option<String>,
137        tokenizer_json: Option<String>,
138        model_id: Option<String>,
139        no_kv_cache: bool,
140        jinja_explicit: Option<String>,
141    ) -> Self {
142        Self {
143            config,
144            chat_template,
145            tokenizer_json,
146            model_id,
147            kind: ModelKind::Normal,
148            jinja_explicit,
149            no_kv_cache,
150            ..Default::default()
151        }
152    }
153
154    fn with_adapter(
155        mut self,
156        xlora_model_id: String,
157        xlora_order: Ordering,
158        no_kv_cache: bool,
159        tgt_non_granular_index: Option<usize>,
160    ) -> Self {
161        self.xlora_model_id = Some(xlora_model_id);
162        self.xlora_order = Some(xlora_order);
163        self.no_kv_cache = no_kv_cache;
164        self.tgt_non_granular_index = tgt_non_granular_index;
165        self.model_id = if let Some(id) = self.model_id {
166            Some(id)
167        } else {
168            info!(
169                "Using adapter base model ID: `{}`",
170                self.xlora_order.as_ref().unwrap().base_model_id
171            );
172            Some(self.xlora_order.as_ref().unwrap().base_model_id.clone())
173        };
174        self
175    }
176
177    pub fn with_xlora(
178        mut self,
179        xlora_model_id: String,
180        xlora_order: Ordering,
181        no_kv_cache: bool,
182        tgt_non_granular_index: Option<usize>,
183    ) -> Self {
184        self.kind = ModelKind::Adapter {
185            adapter: AdapterKind::XLora,
186        };
187        self.with_adapter(
188            xlora_model_id,
189            xlora_order,
190            no_kv_cache,
191            tgt_non_granular_index,
192        )
193    }
194
195    pub fn with_lora(mut self, lora_adapter_ids: Vec<String>) -> Self {
196        self.kind = ModelKind::Adapter {
197            adapter: AdapterKind::Lora,
198        };
199        self.lora_adapter_ids = Some(lora_adapter_ids);
200        self
201    }
202
203    pub fn hf_cache_path(mut self, hf_cache_path: PathBuf) -> Self {
204        self.hf_cache_path = Some(hf_cache_path);
205        self
206    }
207
208    /// If the loader type is not specified, loader type is automatically determined from the
209    /// `architectures` array in the config.
210    pub fn build(self, loader_tp: Option<NormalLoaderType>) -> anyhow::Result<Box<dyn Loader>> {
211        let loader: Box<dyn NormalModelLoader> = match loader_tp {
212            Some(NormalLoaderType::Mistral) => Box::new(MistralLoader),
213            Some(NormalLoaderType::Gemma) => Box::new(GemmaLoader),
214            Some(NormalLoaderType::Llama) => Box::new(LlamaLoader),
215            Some(NormalLoaderType::Mixtral) => Box::new(MixtralLoader),
216            Some(NormalLoaderType::Phi2) => Box::new(Phi2Loader),
217            Some(NormalLoaderType::Phi3) => Box::new(Phi3Loader),
218            Some(NormalLoaderType::Qwen2) => Box::new(Qwen2Loader),
219            Some(NormalLoaderType::Gemma2) => Box::new(Gemma2Loader),
220            Some(NormalLoaderType::Starcoder2) => Box::new(Starcoder2Loader),
221            Some(NormalLoaderType::Phi3_5MoE) => Box::new(Phi3_5MoELoader),
222            Some(NormalLoaderType::DeepSeekV2) => Box::new(DeepSeekV2Loader),
223            Some(NormalLoaderType::DeepSeekV3) => Box::new(DeepSeekV3Loader),
224            Some(NormalLoaderType::Qwen3) => Box::new(Qwen3Loader),
225            Some(NormalLoaderType::Qwen3Moe) => Box::new(Qwen3MoELoader),
226            None => Box::new(AutoNormalLoader),
227        };
228        Ok(Box::new(NormalLoader {
229            inner: loader,
230            model_id: self.model_id.unwrap(),
231            config: self.config,
232            xlora_model_id: self.xlora_model_id,
233            lora_adapter_ids: self.lora_adapter_ids,
234            kind: self.kind,
235            xlora_order: self.xlora_order,
236            no_kv_cache: self.no_kv_cache,
237            chat_template: self.chat_template,
238            tokenizer_json: self.tokenizer_json,
239            tgt_non_granular_index: self.tgt_non_granular_index,
240            jinja_explicit: self.jinja_explicit,
241            token_source: RwLock::new(None),
242            revision: RwLock::new(None),
243            from_uqff: RwLock::new(None),
244            hf_cache_path: self.hf_cache_path,
245        }))
246    }
247}
248
249impl Loader for NormalLoader {
250    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
251    fn load_model_from_hf(
252        &self,
253        revision: Option<String>,
254        token_source: TokenSource,
255        dtype: &dyn TryIntoDType,
256        device: &Device,
257        silent: bool,
258        mapper: DeviceMapSetting,
259        in_situ_quant: Option<IsqType>,
260        paged_attn_config: Option<PagedAttentionConfig>,
261    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
262        let cache = self
263            .hf_cache_path
264            .clone()
265            .map(Cache::new)
266            .unwrap_or_default();
267        GLOBAL_HF_CACHE.get_or_init(|| cache);
268
269        let paths: anyhow::Result<Box<dyn ModelPaths>> = get_paths!(
270            LocalModelPaths,
271            &token_source,
272            revision.clone(),
273            self,
274            None,
275            None,
276            silent,
277            self.config.from_uqff.is_some()
278        );
279        if let Some(from_uqff) = self.config.from_uqff.clone() {
280            *self.from_uqff.write().unwrap() = Some(get_uqff_paths!(&from_uqff, self, silent));
281        }
282        *self
283            .token_source
284            .write()
285            .expect("Failed to write to token source") = Some(token_source);
286        *self.revision.write().expect("Failed to write to revision") = revision;
287        self.load_model_from_path(
288            &paths?,
289            dtype,
290            device,
291            silent,
292            mapper,
293            in_situ_quant,
294            paged_attn_config,
295        )
296    }
297
298    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
299    fn load_model_from_path(
300        &self,
301        paths: &Box<dyn ModelPaths>,
302        dtype: &dyn TryIntoDType,
303        device: &Device,
304        silent: bool,
305        mut mapper: DeviceMapSetting,
306        mut in_situ_quant: Option<IsqType>,
307        mut paged_attn_config: Option<PagedAttentionConfig>,
308    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
309        let config = std::fs::read_to_string(paths.get_config_filename())?;
310
311        if !self.inner.supports_paged_attention(&config)? {
312            paged_attn_config = None;
313        }
314
315        // Apply default prompt size here
316        let prompt_chunksize = self
317            .config
318            .prompt_chunksize
319            .unwrap_or(DEFAULT_PROMPT_CHUNK_SIZE.try_into().unwrap())
320            .get();
321
322        info!("Prompt chunk size is {prompt_chunksize}.",);
323
324        let use_nccl = mistralrs_quant::distributed::use_nccl();
325
326        let available_devices = if let Ok(payload) = env::var(distributed::IS_DAEMON_FLAG) {
327            let payload: WorkerTransferData = serde_json::from_str(&payload)?;
328            let WorkerTransferData::Init { id: _, worker_rank } = payload;
329            vec![candle_core::Device::new_cuda(worker_rank + 1)?]
330        } else if use_nccl {
331            vec![candle_core::Device::new_cuda(0)?]
332        } else {
333            device_map::get_all_similar_devices(device)?
334        };
335        let device = if use_nccl || cfg!(feature = "ring") {
336            available_devices[0].clone()
337        } else {
338            device.clone()
339        };
340
341        // If auto, convert to Map if not using nccl
342        if use_nccl || cfg!(feature = "ring") {
343            mapper = DeviceMapSetting::DummyNccl {
344                nm_device: available_devices[0].clone(),
345            };
346        } else if let DeviceMapSetting::Auto(params) = mapper.clone() {
347            // Initial dtype
348            let dtype = dtype.try_into_dtype(&available_devices.iter().collect::<Vec<_>>())?;
349
350            // Disable ISQ if we are loading a prequantized model.
351            if QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)? != 1 {
352                in_situ_quant = None;
353            }
354
355            // ISQ or UQFF: quantized path
356            // Match logic below where UQFF has priority
357            let (layer_sizes_in_bytes, non_mapped_size_in_bytes, total_model_size_in_bytes) =
358                if let Some(serialized) = &*self.from_uqff.read().unwrap() {
359                    let weight_pack_factor = {
360                        let ser_artifacts = unsafe {
361                            candle_core::safetensors::MmapedSafetensors::multi(serialized)?
362                        };
363                        let mut total_pack_factors = 0;
364                        let total_tensors = ser_artifacts.tensors().len();
365                        for (_, artifact) in ser_artifacts.tensors() {
366                            let artifact = artifact.data();
367                            // NOTE(EricLBuehler): isq type is ALWAYS byte 4 (5th) of the tensor.
368                            let isq_type = artifact[mistralrs_quant::UQFF_QUANT_TYPE_OFFSET];
369                            let pack_factor = match QuantizedSerdeType::try_from(isq_type as usize)?
370                            {
371                                QuantizedSerdeType::Hqq => {
372                                    HqqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
373                                        .pack_factor(dtype)
374                                }
375                                QuantizedSerdeType::Gguf => {
376                                    GgufMatMul::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
377                                        .pack_factor(dtype)
378                                }
379                                QuantizedSerdeType::Fp8 => IsqType::F8E4M3.pack_factor(dtype),
380                                QuantizedSerdeType::Unquant => 1,
381                                QuantizedSerdeType::Afq => {
382                                    AfqLayer::get_isq_type_from_uqff(Cow::Borrowed(artifact))?
383                                        .pack_factor(dtype)
384                                }
385                            };
386                            total_pack_factors += pack_factor;
387                        }
388
389                        total_pack_factors / total_tensors
390                    };
391
392                    let layer_sizes_in_bytes =
393                        self.inner
394                            .layer_sizes_in_bytes(&config, dtype, weight_pack_factor)?;
395                    let non_mapped_size_in_bytes =
396                        self.inner
397                            .non_mapped_size_in_bytes(&config, dtype, weight_pack_factor)?;
398                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
399                    (
400                        layer_sizes_in_bytes,
401                        non_mapped_size_in_bytes,
402                        layer_sizes_sum + non_mapped_size_in_bytes,
403                    )
404                } else if let Some(isq) = in_situ_quant {
405                    let weight_pack_factor = isq.pack_factor(dtype);
406                    let layer_sizes_in_bytes =
407                        self.inner
408                            .layer_sizes_in_bytes(&config, dtype, weight_pack_factor)?;
409                    let non_mapped_size_in_bytes =
410                        self.inner
411                            .non_mapped_size_in_bytes(&config, dtype, weight_pack_factor)?;
412                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
413                    (
414                        layer_sizes_in_bytes,
415                        non_mapped_size_in_bytes,
416                        layer_sizes_sum + non_mapped_size_in_bytes,
417                    )
418                } else {
419                    // Be sure to get the weight pack factor here; we might be loading a prequantized model.
420                    let weight_pack_factor =
421                        QuantizationConfigShim::get_quant_config_pack_factor(&config, dtype)?;
422                    let layer_sizes_in_bytes =
423                        self.inner
424                            .layer_sizes_in_bytes(&config, dtype, weight_pack_factor)?;
425                    let non_mapped_size_in_bytes =
426                        self.inner
427                            .non_mapped_size_in_bytes(&config, dtype, weight_pack_factor)?;
428                    let layer_sizes_sum = layer_sizes_in_bytes.iter().sum::<usize>();
429                    (
430                        layer_sizes_in_bytes,
431                        non_mapped_size_in_bytes,
432                        layer_sizes_sum + non_mapped_size_in_bytes,
433                    )
434                };
435
436            let new = auto_device_map::get_device_layers(
437                &*self.inner,
438                &config,
439                self.inner.num_layers(&config)?,
440                layer_sizes_in_bytes,
441                non_mapped_size_in_bytes,
442                total_model_size_in_bytes,
443                &available_devices,
444                dtype,
445                &params,
446                prompt_chunksize,
447                paged_attn_config.as_ref(),
448            )?;
449            mapper = DeviceMapSetting::Map(new);
450        }
451
452        let pipeline_mapper = mapper.into_mapper(
453            self.inner.num_layers(&config)?,
454            &device,
455            self.config.topology.as_ref(),
456        )?;
457        let mapper = mapper.into_mapper(
458            self.inner.num_layers(&config)?,
459            &device,
460            self.config.topology.as_ref(),
461        )?;
462        let mut layer_devices = Vec::new();
463        for layer in 0..self.inner.num_layers(&config)? {
464            let device = mapper.device_for(layer, false).cloned();
465            layer_devices.push(device);
466        }
467        let dtype = mapper.get_min_dtype(dtype)?;
468
469        // TODO: PagedAttention is not supported with CPU for now.
470        // This check is not really necessary because `get_device_layers` should prevent it.
471        let mapping_uses_cpu = mapper.get_unique_devices().iter().any(Device::is_cpu);
472        if mapping_uses_cpu {
473            warn!("Device mapping contains a mix of GPU and CPU. There is no CPU support for PagedAttention, disabling PagedAttention.");
474            paged_attn_config = None;
475        }
476
477        info!("Model config: {:?}", self.inner.get_config_repr(&config)?);
478        if crate::using_flash_attn() {
479            once_log_info("FlashAttention is enabled.");
480        }
481
482        // Logic for ISQ here: if no calibration (i.e imatrix), then allow immediate ISQ. Otherwise, back to normal.
483        let mut loading_isq = if self.config.imatrix.is_none()
484            && self.config.calibration_file.is_none()
485            && !device.is_cuda()
486            && self.config.write_uqff.is_none()
487            && in_situ_quant.is_some()
488        {
489            let predicates = if matches!(self.config.organization, IsqOrganization::MoeExpertsOnly)
490            {
491                self.inner.immediate_isq_predicates_moqe(&config)?
492            } else {
493                self.inner.immediate_isq_predicates(&config)?
494            };
495            info!("Applying ISQ to {in_situ_quant:?}");
496            if predicates.is_empty() {
497                warn!("No predicates for this model and ISQ setting detected. ISQ will not be applied to any weights!");
498            }
499            mistralrs_quant::set_immediate_isq(in_situ_quant, predicates);
500            false
501        } else {
502            in_situ_quant.is_some()
503        };
504
505        if let Some(ref topology) = self.config.topology {
506            loading_isq |= topology
507                .0
508                .iter()
509                .any(|layer| layer.as_ref().is_some_and(|layer| layer.isq.is_some()));
510        }
511
512        if self.config.imatrix.is_some() && self.config.calibration_file.is_some() {
513            anyhow::bail!(
514                "`imatrix` and `calibration_file` were both specified, this is not allowed."
515            );
516        }
517
518        // Load onto the regular device if not using isq or if the calibration file is specified
519        let load_device = if !loading_isq || self.config.calibration_file.is_some() {
520            loading_isq = false;
521            device.clone()
522        } else {
523            Device::Cpu
524        };
525
526        let is_xlora = self.kind.is_adapted_and(|a| a.is_x_lora());
527
528        let attention_mechanism = if paged_attn_config.is_some() {
529            AttentionImplementation::PagedAttention
530        } else {
531            AttentionImplementation::Eager
532        };
533
534        let multi_progress = Arc::new(MultiProgress::new());
535
536        let mut model = if use_nccl || cfg!(feature = "ring") {
537            let (mapper, sharded_vb) = distributed::prepare_distributed_mapper(
538                dtype,
539                &device,
540                &available_devices,
541                silent,
542                &config,
543                loading_isq,
544                self.config.from_uqff.is_some(),
545                self.config.organization,
546                &*self.inner,
547                paths.as_ref(),
548            )?;
549
550            // Special case for where things can be more optimially loaded.
551            match self.kind {
552                ModelKind::Normal => normal_model_loader_sharded!(
553                    sharded_vb,
554                    config,
555                    self.inner,
556                    mapper,
557                    loading_isq,
558                    device.clone(),
559                    attention_mechanism,
560                    multi_progress.clone(),
561                ),
562                ModelKind::Adapter {
563                    adapter: AdapterKind::XLora,
564                } => xlora_model_loader!(
565                    paths,
566                    Some(dtype),
567                    &load_device,
568                    layer_devices.clone(),
569                    config,
570                    self.inner,
571                    silent,
572                    mapper,
573                    loading_isq,
574                    device.clone(),
575                    multi_progress.clone(),
576                ),
577                ModelKind::Adapter {
578                    adapter: AdapterKind::Lora,
579                } => lora_model_loader!(
580                    paths,
581                    Some(dtype),
582                    &load_device,
583                    layer_devices.clone(),
584                    config,
585                    self.inner,
586                    silent,
587                    mapper,
588                    loading_isq,
589                    self.config.from_uqff.is_some(),
590                    device.clone(),
591                    attention_mechanism,
592                    matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
593                    multi_progress.clone(),
594                ),
595                _ => unreachable!(),
596            }
597        } else {
598            match self.kind {
599                ModelKind::Normal => normal_model_loader!(
600                    paths,
601                    Some(dtype),
602                    &load_device,
603                    layer_devices.clone(),
604                    config,
605                    self.inner,
606                    silent,
607                    mapper,
608                    loading_isq,
609                    self.config.from_uqff.is_some(),
610                    device.clone(),
611                    attention_mechanism,
612                    matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
613                    multi_progress.clone(),
614                ),
615                ModelKind::Adapter {
616                    adapter: AdapterKind::XLora,
617                } => xlora_model_loader!(
618                    paths,
619                    Some(dtype),
620                    &load_device,
621                    layer_devices.clone(),
622                    config,
623                    self.inner,
624                    silent,
625                    mapper,
626                    loading_isq,
627                    device.clone(),
628                    multi_progress.clone(),
629                ),
630                ModelKind::Adapter {
631                    adapter: AdapterKind::Lora,
632                } => lora_model_loader!(
633                    paths,
634                    Some(dtype),
635                    &load_device,
636                    layer_devices.clone(),
637                    config,
638                    self.inner,
639                    silent,
640                    mapper,
641                    loading_isq,
642                    self.config.from_uqff.is_some(),
643                    device.clone(),
644                    attention_mechanism,
645                    matches!(self.config.organization, IsqOrganization::MoeExpertsOnly),
646                    multi_progress.clone(),
647                ),
648                _ => unreachable!(),
649            }
650        };
651
652        let tokenizer = get_tokenizer(paths.get_tokenizer_filename(), None)?;
653        let gen_conf: Option<GenerationConfig> = paths.get_gen_conf_filename().map(|f| {
654            serde_json::from_str(&fs::read_to_string(f).unwrap())
655                .expect("bos_token_id/eos_token_id missing in generation_config.json")
656        });
657
658        let chat_template_explicit = paths
659            .get_chat_template_explicit()
660            .as_ref()
661            .map(|x| x.to_string_lossy().to_string());
662        let chat_template = get_chat_template(
663            paths,
664            self.jinja_explicit.as_ref(),
665            chat_template_explicit.as_ref(),
666            self.chat_template.as_ref(),
667            None,
668        );
669
670        if let Some(calibration_file) = &self.config.calibration_file {
671            let calibration_data = std::fs::read_to_string(calibration_file)?;
672            // Tokenize, don't add bos yet
673            let tokens = tokenizer
674                .encode_fast(calibration_data, false)
675                .map_err(anyhow::Error::msg)?
676                .get_ids()
677                .to_vec();
678            info!(
679                "Collecting imatrix from calibration file `{}` of {} tokens.",
680                calibration_file.display(),
681                tokens.len()
682            );
683            let bos_toks = chat_template.bos_tok().map(|b| vec![b]).unwrap_or_default();
684            let bos_tok_id = tokenizer
685                .token_to_id(&bos_toks[0])
686                .expect("Somehow the bos token is not present.");
687
688            match self.config.organization {
689                IsqOrganization::Default => model.begin_track_stats()?,
690                IsqOrganization::MoeExpertsOnly => model.begin_track_stats_moe_experts_only()?,
691            }
692
693            const CHUNK_SIZE: usize = 1024;
694            let n_chunks = tokens.len().div_ceil(CHUNK_SIZE);
695            let start = Instant::now();
696            for (i, chunk) in tokens.chunks(CHUNK_SIZE).enumerate() {
697                let chunk = [vec![bos_tok_id], chunk.to_vec()].concat();
698                let chunk_len = chunk.len();
699
700                let start = Instant::now();
701                let inputs = make_prompt_chunk(
702                    0,
703                    vec![&chunk],
704                    &[0],
705                    &load_device,
706                    None,
707                    false,
708                    None,
709                    Some(pipeline_mapper.as_ref()),
710                )?;
711
712                model.forward(
713                    &inputs.input.to_device(model.device())?,
714                    &inputs.positions,
715                    inputs.context_lens.clone(),
716                    inputs.position_ids.clone(),
717                    None,
718                    &inputs.flash_meta.clone(),
719                )?;
720
721                match model.cache_mut() {
722                    EitherCache::Full(full) => {
723                        for layer in &mut *full.lock() {
724                            *layer = None
725                        }
726                    }
727                    EitherCache::Normal(normal) => {
728                        for layer in &mut *normal.lock().unwrap().0 {
729                            layer.reset();
730                        }
731                    }
732                }
733
734                let end = Instant::now();
735                info!(
736                    "Processed chunk {}/{n_chunks} ({chunk_len} tokens), {:.2}s",
737                    i + 1,
738                    end.duration_since(start).as_secs_f32()
739                );
740            }
741            load_device.synchronize()?;
742            let end = Instant::now();
743            info!(
744                "Finished collecting imatrix in {:.2}s",
745                end.duration_since(start).as_secs_f32()
746            );
747        }
748
749        // Only if loading from UQFF
750        if (loading_isq || self.config.topology.is_some()) && self.config.from_uqff.is_none() {
751            let imatrix_source = match (
752                self.config.imatrix.as_ref(),
753                self.config.calibration_file.is_some(),
754            ) {
755                (None, false) => None,
756                (Some(file), false) => Some(ImatrixDataSource::File(file)),
757                (None, true) => Some(ImatrixDataSource::Collected),
758                (Some(_), true) => unreachable!(),
759            };
760
761            info!("Applying ISQ to all ranks.");
762
763            let multi_progress = Arc::new(MultiProgress::new());
764
765            model.quantize(
766                in_situ_quant,
767                model.device().clone(),
768                self.config.topology.as_ref(),
769                silent,
770                imatrix_source,
771                self.config.organization,
772                self.config.write_uqff.as_ref(),
773                UqffFullSer {
774                    tokenizer: &tokenizer,
775                    template_filename: paths.get_template_filename(),
776                    generation_config: paths.get_gen_conf_filename(),
777                    config: config.clone(),
778                    processor_filename: &None,
779                    preprocessor_filename: &None,
780                },
781                multi_progress.clone(),
782            )?;
783        } else if let Some(from_uqff) = &*self.from_uqff.read().unwrap() {
784            model.load_from_artifacts(
785                device.clone(),
786                self.config.topology.as_ref(),
787                silent,
788                from_uqff,
789            )?;
790        }
791
792        let paged_attn_config = if matches!(
793            self.kind,
794            ModelKind::Adapter {
795                adapter: AdapterKind::XLora
796            }
797        ) {
798            warn!(
799                "Adapter parallel_models do not currently support PagedAttention, running without"
800            );
801            None
802        } else {
803            paged_attn_config
804        };
805
806        let (cache_config, cache_engine) = if let Some(paged_attn_config) = paged_attn_config {
807            let cache_config = calculate_cache_config(
808                paged_attn_config.mem_gpu,
809                paged_attn_config.mem_cpu,
810                paged_attn_config.block_size,
811                dtype,
812                model.config(),
813                &device,
814                &pipeline_mapper
815                    .get_unique_devices()
816                    .into_iter()
817                    .map(Some)
818                    .collect::<Vec<_>>(),
819                silent,
820            )?;
821
822            let mut layer_devices = Vec::new();
823            for layer in 0..self.inner.num_layers(&config)? {
824                let device = model.get_layers().1.device_for(layer, false).cloned();
825                layer_devices.push(device);
826            }
827            let cache_engine = CacheEngine::new(
828                model.config(),
829                &cache_config,
830                dtype,
831                model.device(),
832                layer_devices.clone(),
833            )?;
834
835            (Some(cache_config), Some(cache_engine))
836        } else {
837            (None, None)
838        };
839
840        let max_seq_len = model.max_seq_len();
841        let llg_factory = build_llg_factory(tokenizer.clone())?;
842        let num_hidden_layers = match model.cache() {
843            EitherCache::Full(full) => full.lock().len(),
844            EitherCache::Normal(normal) => normal.lock().unwrap().0.len(),
845        };
846        let eos = calculate_eos_tokens(&chat_template, gen_conf, &tokenizer);
847        let sliding_window = model.config().sliding_window;
848        let model_metadata = Arc::new(model.config().clone());
849
850        Ok(Arc::new(Mutex::new(NormalPipeline {
851            model,
852            tokenizer: tokenizer.into(),
853            no_kv_cache: self.no_kv_cache,
854            chat_template: Arc::new(chat_template),
855            non_granular_state: self.tgt_non_granular_index.map(|tgt_non_granular_index| {
856                NonGranularState {
857                    non_granular_index: Arc::new(Mutex::new(0)),
858                    tgt_non_granular_index,
859                }
860            }),
861            model_id: self.model_id.clone(),
862            metadata: Arc::new(GeneralMetadata {
863                max_seq_len,
864                llg_factory: Some(llg_factory),
865                no_kv_cache: self.no_kv_cache,
866                no_prefix_cache: is_xlora,
867                num_hidden_layers,
868                eos_tok: eos,
869                kind: self.kind.clone(),
870                is_xlora,
871                activation_dtype: dtype,
872                sliding_window,
873                cache_config,
874                cache_engine,
875                prompt_chunksize: Some(NonZero::new(prompt_chunksize).unwrap()),
876                model_metadata: Some(model_metadata),
877                modalities: Modalities {
878                    input: vec![SupportedModality::Text],
879                    output: vec![SupportedModality::Text],
880                },
881            }),
882            topology: self.config.topology.clone(),
883            silent,
884            organization: self.config.organization,
885            template_filename: paths.get_template_filename().clone(),
886            generation_config: paths.get_gen_conf_filename().cloned(),
887            config,
888            imatrix: self.config.imatrix.clone(),
889            mapper: pipeline_mapper,
890        })))
891    }
892
893    fn get_id(&self) -> String {
894        self.model_id.clone()
895    }
896
897    fn get_kind(&self) -> ModelKind {
898        self.kind.clone()
899    }
900}
901
902impl PreProcessingMixin for NormalPipeline {
903    fn get_chat_template(&self) -> Option<Arc<ChatTemplate>> {
904        Some(self.chat_template.clone())
905    }
906    fn get_input_processor_config(&self) -> Option<Arc<dyn Any>> {
907        None
908    }
909}
910
911impl IsqPipelineMixin for NormalPipeline {
912    fn re_isq_model(&mut self, dtype: IsqType) -> Result<()> {
913        let device = self.device().clone();
914        let multi_progress = Arc::new(MultiProgress::new());
915        self.model.quantize(
916            Some(dtype),
917            device.clone(),
918            self.topology.as_ref(),
919            self.silent,
920            self.imatrix.as_ref().map(ImatrixDataSource::File),
921            self.organization,
922            None,
923            UqffFullSer {
924                tokenizer: &self.tokenizer,
925                template_filename: &self.template_filename,
926                generation_config: self.generation_config.as_ref(),
927                config: self.config.clone(),
928                processor_filename: &None,
929                preprocessor_filename: &None,
930            },
931            multi_progress.clone(),
932        )?;
933        Ok(())
934    }
935}
936
937impl CacheManagerMixin for NormalPipeline {
938    fn clone_in_cache(&self, seqs: &mut [&mut Sequence]) {
939        if matches!(self.model.cache(), EitherCache::Full(_)) {
940            FullCacheManager.clone_in_cache(self, seqs, false)
941        } else {
942            NormalCacheManager.clone_in_cache(self, seqs, false)
943        }
944    }
945    fn clone_out_cache(&self, seqs: &mut [&mut Sequence]) {
946        if matches!(self.model.cache(), EitherCache::Full(_)) {
947            FullCacheManager.clone_out_cache(self, seqs, false)
948        } else {
949            NormalCacheManager.clone_out_cache(self, seqs, false)
950        }
951    }
952    fn set_none_cache(
953        &self,
954        seqs: &mut [&mut Sequence],
955        reset_non_granular: bool,
956        modify_draft_cache: bool,
957        load_preallocated_cache: bool,
958    ) {
959        if matches!(self.model.cache(), EitherCache::Full(_)) {
960            FullCacheManager.set_none_cache(self, seqs, modify_draft_cache, false);
961        } else {
962            NormalCacheManager.set_none_cache(
963                self,
964                seqs,
965                modify_draft_cache,
966                load_preallocated_cache,
967            );
968        }
969        if reset_non_granular {
970            self.reset_non_granular_state()
971        }
972    }
973    fn cache(&self) -> &EitherCache {
974        self.model.cache()
975    }
976}
977
978impl MetadataMixin for NormalPipeline {
979    fn device(&self) -> Device {
980        self.model.device().clone()
981    }
982    fn tokenizer(&self) -> Option<Arc<Tokenizer>> {
983        Some(self.tokenizer.clone())
984    }
985    fn name(&self) -> String {
986        self.model_id.clone()
987    }
988    fn reset_non_granular_state(&self) {
989        if let Some(s) = self.non_granular_state.as_ref() {
990            *self.cache().full().get_scalings_cache() = None;
991            *get_mut_arcmutex!(s.non_granular_index) = 0;
992        }
993    }
994    fn get_metadata(&self) -> Arc<GeneralMetadata> {
995        self.metadata.clone()
996    }
997    fn device_mapper(&self) -> Option<&dyn DeviceMapper> {
998        Some(&*self.mapper)
999    }
1000}
1001
1002#[async_trait::async_trait]
1003impl Pipeline for NormalPipeline {
1004    fn forward_inputs(
1005        &mut self,
1006        inputs: Box<dyn Any>,
1007        return_raw_logits: bool,
1008    ) -> Result<ForwardInputsResult, candle_core::Error> {
1009        let ModelInputs {
1010            input_ids,
1011            input_ids_full,
1012            seqlen_offsets,
1013            seqlen_offsets_full,
1014            context_lens,
1015            position_ids,
1016            paged_attn_meta,
1017            flash_meta,
1018            flash_meta_full,
1019        } = *inputs.downcast().expect("Downcast failed.");
1020        let metadata = self.get_metadata();
1021        let paged_attn_meta = match (&metadata.cache_engine, &paged_attn_meta) {
1022            (Some(cache_engine), Some(meta)) => Some((cache_engine, meta)),
1023            (Some(_), None) => {
1024                // This can happen if Rust-side user code is wrong
1025                candle_core::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.")
1026            }
1027            (None, Some(_)) => {
1028                // This should never happen but we handle it anyway
1029                candle_core::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.")
1030            }
1031            (None, None) => None,
1032        };
1033        let logits = match self.model.is_xlora() {
1034            false => {
1035                let paged_attn_meta = paged_attn_meta
1036                    .as_ref()
1037                    .map(|meta| (meta.0.get_kv_cache().clone(), meta.1.clone()));
1038
1039                self.model.forward(
1040                    &input_ids,
1041                    &seqlen_offsets,
1042                    context_lens,
1043                    position_ids,
1044                    paged_attn_meta.as_ref().map(|(a, b)| (a.clone(), b)),
1045                    &flash_meta,
1046                )?
1047            }
1048            true => self.model.xlora_forward(
1049                &input_ids,
1050                input_ids_full.as_ref().unwrap_or(&input_ids),
1051                &seqlen_offsets,
1052                seqlen_offsets_full.as_ref().unwrap_or(&seqlen_offsets),
1053                self.no_kv_cache,
1054                &self.non_granular_state,
1055                context_lens,
1056                position_ids,
1057                &flash_meta,
1058                flash_meta_full.as_ref().unwrap_or(&flash_meta),
1059            )?,
1060        };
1061        if return_raw_logits {
1062            Ok(ForwardInputsResult::RawLogits { logits })
1063        } else {
1064            Ok(ForwardInputsResult::CausalGeneration { logits })
1065        }
1066    }
1067    async fn sample_causal_gen(
1068        &self,
1069        seqs: &mut [&mut Sequence],
1070        logits: Vec<Tensor>,
1071        prefix_cacher: &mut PrefixCacheManagerV2,
1072        disable_eos_stop: bool,
1073        rng: Arc<std::sync::Mutex<Isaac64Rng>>,
1074    ) -> Result<(), candle_core::Error> {
1075        sample_and_add_toks(self, seqs, logits, prefix_cacher, disable_eos_stop, rng).await
1076    }
1077    fn category(&self) -> ModelCategory {
1078        ModelCategory::Text
1079    }
1080}
1081
1082impl AnyMoePipelineMixin for NormalPipeline {
1083    fn amoe_finish_training(&mut self, gate_model_id: Option<String>) -> candle_core::Result<()> {
1084        self.model.finish_training(gate_model_id)
1085    }
1086    fn amoe_layer_vars(&self) -> Vec<Vec<Var>> {
1087        self.model.get_vars()
1088    }
1089    fn amoe_base_model_trainable_params(&self) -> usize {
1090        self.model.trainable_params()
1091    }
1092    fn amoe_take_cached_gating_outputs(&mut self) -> Vec<Tensor> {
1093        self.model.take_cached_gating_outputs()
1094    }
1095    fn amoe_create_layers(
1096        &mut self,
1097        model_ids: Vec<String>,
1098        token: &TokenSource,
1099        revision: Option<String>,
1100        match_regex: &str,
1101        config: crate::amoe::AnyMoeConfig,
1102        dtype: candle_core::DType,
1103        dev: &Device,
1104        (prefix, mlp): (String, String),
1105        layers: Vec<usize>,
1106        expert_type: AnyMoeExpertType,
1107        silent: bool,
1108        gate_model_id: Option<String>,
1109    ) -> candle_core::Result<()> {
1110        let mut vbs = Vec::new();
1111        // Precompile regex here
1112        let regex = Regex::new(match_regex).map_err(candle_core::Error::msg)?;
1113        for model_id in model_ids {
1114            let model_id_str = &model_id;
1115            let model_id = Path::new(&model_id);
1116
1117            let api = {
1118                let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1119                let mut api = ApiBuilder::from_cache(cache)
1120                    .with_progress(!silent)
1121                    .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1122                if let Ok(x) = std::env::var("HF_HUB_CACHE") {
1123                    api = api.with_cache_dir(x.into());
1124                }
1125                api.build().map_err(candle_core::Error::msg)?
1126            };
1127            let revision = revision.clone().unwrap_or("main".to_string());
1128            let api = api.repo(Repo::with_revision(
1129                model_id_str.clone(),
1130                RepoType::Model,
1131                revision.clone(),
1132            ));
1133
1134            let mut filenames = vec![];
1135            for rfilename in
1136                api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
1137            {
1138                filenames.push(api_get_file!(api, &rfilename, model_id));
1139            }
1140
1141            let regex = regex.clone();
1142            let match_regex_clone = match_regex.to_string();
1143            let layers_clone = layers.clone();
1144            let vb = from_mmaped_safetensors(
1145                filenames,
1146                vec![],
1147                Some(dtype),
1148                dev,
1149                vec![None],
1150                silent,
1151                None,
1152                move |key| {
1153                    if regex.is_match(&key) {
1154                        // Idx of the last char of the layer id, +1
1155                        // Assumes N.MLP
1156                        let last_layer_idx = key.find(&match_regex_clone).unwrap() - 1;
1157                        let first_layer_idx = key[..last_layer_idx].rfind('.').unwrap();
1158                        let layer_n = key[first_layer_idx + 1..last_layer_idx]
1159                            .parse::<usize>()
1160                            .unwrap();
1161                        layers_clone.contains(&layer_n) || layers_clone.is_empty()
1162                    } else {
1163                        false
1164                    }
1165                },
1166                Arc::new(|_| DeviceForLoadTensor::Base),
1167            )?;
1168            vbs.push(vb);
1169        }
1170
1171        let gate_vb = if let Some(gate_model_id) = gate_model_id {
1172            let model_id_str = &gate_model_id;
1173            let model_id = Path::new(&gate_model_id);
1174
1175            let api = {
1176                let cache = GLOBAL_HF_CACHE.get().cloned().unwrap_or_default();
1177                let mut api = ApiBuilder::from_cache(cache)
1178                    .with_progress(!silent)
1179                    .with_token(get_token(token).map_err(candle_core::Error::msg)?);
1180                if let Ok(x) = std::env::var("HF_HUB_CACHE") {
1181                    api = api.with_cache_dir(x.into());
1182                }
1183                api.build().map_err(candle_core::Error::msg)?
1184            };
1185            let revision = revision.clone().unwrap_or("main".to_string());
1186            let api = api.repo(Repo::with_revision(
1187                model_id_str.clone(),
1188                RepoType::Model,
1189                revision.clone(),
1190            ));
1191
1192            let mut gate_filenames = vec![];
1193            for rfilename in
1194                api_dir_list!(api, model_id, true).filter(|x| x.ends_with(".safetensors"))
1195            {
1196                gate_filenames.push(api_get_file!(api, &rfilename, model_id));
1197            }
1198            assert_eq!(
1199                gate_filenames.len(),
1200                1,
1201                "Gate model ID must contain only one .safetensors file"
1202            );
1203
1204            let vb = from_mmaped_safetensors(
1205                gate_filenames.clone(),
1206                vec![],
1207                Some(dtype),
1208                dev,
1209                vec![None],
1210                silent,
1211                None,
1212                |_| true,
1213                Arc::new(|_| DeviceForLoadTensor::Base),
1214            )?;
1215            info!(
1216                "Loaded gating layers from `{}`",
1217                gate_filenames[0].display()
1218            );
1219            Some(vb)
1220        } else {
1221            None
1222        };
1223
1224        self.model.create_anymoe_layers(
1225            vbs.clone(),
1226            config.clone(),
1227            (prefix.clone(), mlp.clone()),
1228            layers.clone(),
1229            expert_type.clone(),
1230            gate_vb.clone(),
1231        )?;
1232
1233        Ok(())
1234    }
1235    fn amoe_supported(&self) -> bool {
1236        self.model.amoe_supported()
1237    }
1238}