Model Loading: From File to Device
Updated 2026-04-15
Series context: This is article #2 in the llama.cpp source code deep-dive series, focusing on the model file loading pipeline and the backend device assignment mechanism. If you haven’t read the series overview and #1 Tool Panorama & GGUF Binary Parsing, we recommend building the big picture first before diving into this chapter.
Part A: Model Loading
With the knowledge of GGUF format in hand, let’s see how llama.cpp turns a GGUF file into a model ready for inference.
Loading Overview
Stage 1: no_alloc Parsing of the GGUF Header
// llama-model-loader.cpp constructor
struct gguf_init_params params = {
/*.no_alloc = */ true, // Key: don't allocate tensor data memory
/*.ctx = */ &ctx,
};
metadata = gguf_init_from_file(fname.c_str(), params);
no_alloc = true tells the GGUF parser: only create the ggml_tensor structs — don’t read or allocate data memory. At this point, memory overhead is minimal:
// gguf.cpp — no_alloc mode
if (params.no_alloc) {
// Each tensor only needs struct overhead (~400 bytes)
const size_t overhead = n_tensors * ggml_tensor_overhead();
mem_size = overhead; // 1000 tensors ≈ 400 KB
}
Compare this to no_alloc = false mode, which requires overhead + data_size (several GB) — the difference is enormous.
Stage 2: Building weights_map
After parsing the header, all tensors are iterated to build a name-to-tensor index:
// llama-model-loader.cpp constructor
for (ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) {
std::string tensor_name = std::string(cur->name);
// Check for duplicate names
if (weights_map.find(tensor_name) != weights_map.end()) {
throw std::runtime_error("invalid model: tensor is duplicated");
}
n_elements += ggml_nelements(cur);
n_bytes += ggml_nbytes(cur);
// Record the file, file index, metadata context, and tensor pointer
weights_map.emplace(tensor_name, llama_tensor_weight(file, 0, metadata, cur));
}
At this point, each tensor object’s state looks like this:
| Field | Value | Status |
|---|---|---|
type | Q4_K | Set, read from the GGUF header |
ne[4] | shape | Set |
nb[4] | stride | Computed from type and shape |
name | "blk.0.attn_q.weight" | Set |
data | NULL | Data not yet loaded |
Stage 3: Architecture Identification
The model architecture is read from KV metadata:
// llama-model-loader.cpp
get_key(llm_kv(LLM_KV_GENERAL_ARCHITECTURE), arch_name, false);
// arch_name = "llama" or "qwen3" or "gpt2" etc.
llm_kv = LLM_KV(llm_arch_from_string(arch_name));
llm_arch_from_string() maps the string to an llm_arch enum value (e.g., LLM_ARCH_QWEN3). All subsequent hyperparameter reading and compute graph construction dispatch based on this enum value.
Stage 4: mmap vs read
There are two ways to load model weights:
mmap Mode (Default)
// llama-model-loader.cpp — init_mappings()
std::unique_ptr<llama_mmap> mapping = std::make_unique<llama_mmap>(file, prefetch ? -1 : 0, is_numa);
mmap maps the file’s data blob into the process’s virtual address space without immediately reading data into physical memory. The OS loads data from disk only when the corresponding page is first accessed (demand paging).
// load_data_for() — mmap mode
if (use_mmap) {
if (cur->data == nullptr) {
// Point tensor->data directly to the mapped address + offset
cur->data = (uint8_t *)mapping->addr() + w.offs;
} else {
// If the tensor already has a buffer (e.g., GPU buffer), memcpy
memcpy(cur->data, (uint8_t *)mapping->addr() + w.offs, ggml_nbytes(cur));
}
}
Advantages of mmap:
- Zero-copy: CPU tensors reference the mapped address directly, no extra memory allocation needed
- On-demand loading: Pages that are never accessed don’t consume physical memory
- Multi-process sharing: When multiple llama.cpp instances load the same model, the OS shares physical pages
read Mode (--no-mmap)
// load_data_for() — read mode
GGML_ASSERT(cur->data != nullptr); // Caller must pre-allocate the buffer
file->seek(w.offs, SEEK_SET); // Seek to the offset in the file
file->read_raw(cur->data, ggml_nbytes(cur)); // Read directly into the buffer
read mode is appropriate when:
- The system doesn’t support mmap (certain embedded platforms)
- The model is larger than physical RAM — mmap may cause frequent page faults; read mode combined with
--mlockcan avoid this
Stage 5: Sharded Models
For very large models (e.g., 70B+), the GGUF file is split into multiple shard files. llama_model_loader detects the number of shards via the LLM_KV_SPLIT_COUNT KV value, then loads all shards sequentially:
model-00001-of-00003.gguf ← Primary file (full GGUF: header + KV metadata + tensor info + tensor data)
model-00002-of-00003.gguf ← Additional file (also a full GGUF, with its own header + tensor info + tensor data)
model-00003-of-00003.gguf ← Additional file (same as above)
All tensors from every shard are registered into the same weights_map, making the split completely transparent to higher layers.
Part B: Backend Initialization and Device Assignment
Once the model weights are loaded, the next step is to initialize the compute backends and decide which device each layer’s tensors reside on.
Backend Discovery and Initialization
llama.cpp’s backend initialization happens in the llama_context constructor, following this priority order:
// llama-context.cpp — constructor
// 1. GPU backends
for (const auto & dev : model.devices) {
ggml_backend_t backend = ggml_backend_dev_init(dev.dev, nullptr);
backends.emplace_back(backend);
}
// 2. ACCEL backends (e.g., BLAS)
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_ACCEL) {
backends.emplace_back(ggml_backend_dev_init(dev, nullptr));
}
}
// 3. CPU backend (always present as fallback)
backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
backends.emplace_back(backend_cpu);
The CPU backend is always present — even if all layers are offloaded to GPU, certain operations (such as tokenization and sampling) still run on the CPU.
n_gpu_layers Layer Assignment
The -ngl (--n-gpu-layers) parameter determines how many layers are offloaded to the GPU. The core calculation is in load_tensors():
// llama-model.cpp — load_tensors()
const int i_gpu_start = std::max(int(n_layer) + 1 - n_gpu_layers, 0);
This means: the last n_gpu_layers layers go to the GPU, the earlier layers stay on the CPU.
For example, with n_layer=32, n_gpu_layers=24:
i_gpu_start = 32 + 1 - 24 = 9- Layers 0-8 on CPU (9 layers)
- Layers 9-31 + Output on GPU (23 layers + output = 24)
When n_gpu_layers=33 (greater than or equal to n_layer+1):
i_gpu_start = 0- All Layers 0-31 + Output on GPU
Note: The input layer (embedding) always stays on the CPU — offloading the embedding yields very little benefit, not worth the GPU VRAM:
// llama-model.cpp
// there is very little benefit to offloading the input layer, so always keep it on the CPU
pimpl->dev_input = { cpu_dev, &pimpl->cpu_buft_list };
Drag the slider below to get an intuitive feel for how n_gpu_layers and multi-GPU configurations affect layer assignment:
GPU Layer Split Visualization
Multi-GPU Assignment (tensor_split)
When multiple GPUs are available, the --tensor-split parameter controls the layer distribution ratio. If the user doesn’t specify it, the default allocation is proportional to each GPU’s free VRAM:
// llama-model.cpp — load_tensors()
if (all_zero) {
// Default: proportional to free VRAM
for (size_t i = 0; i < n_devices(); ++i) {
size_t total, free;
ggml_backend_dev_memory(devices[i].dev, &free, &total);
splits[i] = free;
}
} else {
// Use user-specified ratios
std::copy(tensor_split, tensor_split + n_devices(), splits.begin());
}
// Normalize to a cumulative distribution
float split_sum = 0.0f;
for (auto & s : splits) { split_sum += s; s = split_sum; }
for (auto & s : splits) { s /= split_sum; }
Then upper_bound binary search maps each layer to the corresponding GPU:
// Look up the layer's position in the cumulative distribution to determine which GPU it belongs to
const int layer_gpu = std::upper_bound(
splits.begin(), splits.begin() + n_devices(),
float(il - i_gpu_start) / act_gpu_layers
) - splits.begin();
For example, two GPUs with 12 GB and 8 GB of free VRAM, 24 layers to offload:
Normalized splits = [0.6, 1.0]
Layers 0-13 → GPU 0 (first 60%)
Layers 14-23 → GPU 1 (last 40%)
In the interactive component above, set the GPU count to 2 or more to see the tensor_split algorithm in action.
Buffer Type Selection
Each device may support multiple buffer types (e.g., CUDA device memory, host pinned memory, etc.). select_buft() iterates through the candidate list, testing whether each buffer type can execute the target operation:
// llama-model.cpp
static ggml_backend_buffer_type_t select_buft(const buft_list_t & buft_list, const F & fn) {
for (const auto & cur : buft_list) {
ggml_backend_dev_t cur_dev = cur.first;
ggml_backend_buffer_type_t cur_buft = cur.second;
if (buft_supported(cur_buft, cur_dev, fn)) {
return cur_buft; // Return the first supported buffer type
}
}
throw std::runtime_error("no suitable buffer type found");
}
When building a GPU’s buft_list, the CPU buffer type is appended at the end as a fallback:
buft_list_t buft_list = make_gpu_buft_list(dev, split_mode, tensor_split);
// Append CPU buffer type as fallback
buft_list.insert(buft_list.end(), cpu_buft_list.begin(), cpu_buft_list.end());
This means: if the GPU doesn’t support a particular operation, it automatically falls back to CPU execution.
Auto-fit
When -ngl auto (the default) is specified, llama.cpp automatically estimates VRAM requirements and adjusts n_gpu_layers and ctx_size to fit the model within available VRAM:
The auto-fit logic is implemented in llama_params_fit_impl() (src/llama.cpp), ensuring users don’t need to manually calculate VRAM to load models correctly.
Summary
This article traced the complete path from opening a GGUF file to having a model ready for inference in llama.cpp:
- no_alloc parsing: Only reads the header without allocating tensor data memory — minimal overhead
- weights_map construction: Builds a name index for each tensor; data pointers remain NULL
- Architecture identification: Reads the model architecture string from KV metadata and maps it to an enum value
- mmap vs read: Default mmap provides zero-copy, on-demand loading;
--no-mmapserves special scenarios - Sharded models: Multiple split files are unified into a single weights_map
- Backend initialization: GPU → ACCEL → CPU priority order, with CPU always present as fallback
- n_gpu_layers:
i_gpu_start = max(n_layer+1 - ngl, 0)determines the CPU/GPU boundary - Multi-GPU tensor_split: Normalized cumulative distribution + upper_bound binary search
- Buffer type selection: Iterates the candidate list; falls back to CPU when the GPU doesn’t support an operation
- Auto-fit: Automatically adjusts n_gpu_layers and ctx_size to fit available VRAM
The next article, #3 Warm-up and Tokenization, covers the first step of the inference pipeline: how input text becomes a token sequence.