Phi-3 Mini Mixed Q8K / Q4K Quantization

Layer-aware mixed-precision compression for Microsoft's 3.8B-parameter model — CPU-only inference in Rust

Rust 1.75+ Candle 0.3.0 Q8K + Q4K Mixed Phi-3 Mini 3.8B SafeTensors Docker

1. Project Overview

This project implements a complete end-to-end pipeline that compresses Microsoft Phi-3 Mini 4K-Instruct from its original BF16 representation into a mixed Q8K / Q4K quantized format, then runs it as a streaming conversational chatbot — entirely on CPU, with no GPU required.

Unlike uniform quantization (applying the same bit-width everywhere), this implementation assigns different precisions to different layers based on their sensitivity to quantization noise. The result: near-lossless quality at roughly half the original model size.

The current Rust implementation, benchmark logs, and follow-up experiments live in the nibble GitHub repository, primarily under phi3_standalone. That repository now contains the mixed Q8K/Q4K pipeline, strict packer/loader code, perplexity tooling, and the later Q6K/Q8K128 experiments.

3.8B
Parameters (Phi-3 Mini)
1.85×
Overall Compression
4.1 GB
Packed Model Size
9.6
tok/s observed (CPU)
🎯 Key Innovation

By mixing Q8K (3.5× compression, <0.01% error) for attention layers with Q4K (7.1× compression) for the noise-tolerant MLP down-projection, the pipeline achieves aggressive compression on layers that can handle it while preserving full fidelity where it matters most.

Core Components

⚙️

Mixed Quantization Engine

Layer-aware Q8K/Q4K routing with optional block-wise column permutation for improved block scale utilization

📦

SafeTensors Packer

Merges quantized blocks, F32 norms, and embeddings into a single portable file with auto-detection metadata

🧠

Phi-3 Transformer

Full 32-layer attention + MLP pipeline with RoPE, GQA, KV caching, and RMS normalization

💬

Conversational Chat

Multi-turn streaming chat with sliding-window history, adaptive sampling, and Phi-3 chat template

2. Why Mix Two Quantization Formats?

Not all layers in a transformer contribute equally to output quality. Running everything at the same bit-width is either wasteful (8-bit everywhere when some layers could handle 4-bit) or destructive (4-bit everywhere when some layers cannot tolerate the noise).

Q8K for Attention & Gate/Up Projections

The query/key/value projections (qkv_proj), output projection (o_proj), and the gated MLP projection (gate_up_proj) operate close to the residual stream and directly shape the attention scores and gated activations. Quantization errors in these layers compound across all 32 blocks. Q8K provides ~3.5× compression with RMSE typically below 1×10-4, making it the safe default.

Q4K for the MLP Down-Projection

The down-projection (down_proj) takes the post-activation, post-gating MLP output and maps it back to the residual dimension. By the time data reaches down_proj, the SiLU gating has already killed most activations (many values are near zero), so the effective signal-to-noise ratio is inherently lower. This makes down_proj naturally more tolerant of quantization noise. Q4K achieves ~7.1× compression here with acceptable RMSE (typically <1×10-3).

F32 for Norms & Embeddings

Layer norms contain only hidden_size = 3072 floats each (12 KB per layer), so compressing them saves essentially nothing while risking numerical instability. Embeddings are sparse look-up tables — block-level quantization would degrade individual token representations unnecessarily.

📊 Compression Summary
Format Block Size Bytes / 256 Values Compression Bits / Weight
F32 (original) 1024 32.0
Q8K 292 bytes 292 3.51× 9.14
Q4K 144 bytes 144 7.11× 4.50

3. Block Format Anatomy — Q8K vs Q4K

Both formats operate on blocks of 256 values (QK_K = 256 in candle's GGML implementation). Per-block scaling dramatically reduces error compared to per-tensor scaling, because the scale adapts to the local magnitude range of each 256-value chunk.

Q8K Block (292 bytes → 256 weights)

┌─────────────────────────────────────────────────────────────┐ │ BlockQ8K (292 bytes → 256 values) │ │ │ │ d : f32 (4 bytes) — block scale │ │ qs : [i8; 256] — 256 quantized signed integers │ │ bsums : [i16; 16] — partial sums for 16 sub-groups │ └─────────────────────────────────────────────────────────────┘

Encoding algorithm:

  1. Find amax = max(|value|) across the 256 values in the block.
  2. Compute scale: d = amax / 127.0.
  3. Each float w is stored as round(w / d) clamped to [-127, 127].
  4. 16 bsums entries pre-compute partial dot-product sums for the GGML matmul kernel.

Decoding: w_original ≈ qs[i] × d

Q4K Block (144 bytes → 256 weights)

┌─────────────────────────────────────────────────────────────┐ │ BlockQ4K (144 bytes → 256 values) │ │ │ │ d : f16 (2 bytes) — main block scale │ │ dmin : f16 (2 bytes) — minimum offset scale │ │ scales : [u8; 12] — 8 sub-block (scale, min) pairs │ │ packed as 6-bit values │ │ qs : [u8; 128] — 256 values packed 2-per-byte │ └─────────────────────────────────────────────────────────────┘

Encoding algorithm:

  1. Split 256 values into 8 sub-blocks of 32.
  2. Each sub-block gets its own (scale, min) quantization pair.
  3. Each weight is stored as a 4-bit unsigned integer (0–15): qs[i] = clamp(round((w - sub_min) / sub_scale), 0, 15)
  4. Two 4-bit values are packed into each byte of qs.

Decoding: w_original ≈ qs[i] × sub_scale + sub_min

⚠️ Why Per-Block Scaling Matters

A single weight row in qkv_proj is 3072 values wide. If you used one scale for the entire row, an outlier at position 500 would determine the scale for all 3072 values, wasting precision on the majority. With QK_K = 256, you get 12 independent scales per row — each tuned to its local value range.

4. Layer Assignment Policy

For Phi-3 Mini (32 transformer layers, hidden_size=3072, intermediate_size=8192):

Tensor Pattern Count Format Rationale
*.self_attn.qkv_proj.weight 32 Q8K Fused QKV — central to attention quality
*.self_attn.o_proj.weight 32 Q8K Attention output → residual stream
*.mlp.gate_up_proj.weight 32 Q8K Fused gate+up, pre-activation signal
*.mlp.down_proj.weight 32 Q4K Post-gating; naturally noise-tolerant
lm_head.weight 1 Q8K Logits — errors affect token selection
model.embed_tokens.weight 1 F32 Sparse look-up; no block benefit
*.layernorm.weight, model.norm 65 F32 Tiny (12 KB each); high sensitivity

Totals: 97 Q8K tensors + 32 Q4K tensors + ~66 F32 tensors

quantize_q8k.rs — Routing Logic
// Returns true if this weight should be quantized at all
fn is_target_weight(name: &str) -> bool {
    name.ends_with(".weight")
        && !name.contains("embed_tokens")
        && !name.contains("norm")
}

// Returns true if this layer goes into Q4K (all others → Q8K)
fn is_q4k_layer(name: &str) -> bool {
    name.contains("mlp.down_proj")
}

5. The Three-Stage Pipeline

End-to-End Quantization & Inference Pipeline

Original Model
BF16 SafeTensors
~7.6 GB
Stage 1
quantize_q8k.rs
Q8K + Q4K files
Stage 2
pack_q8k_safetensors.rs
Single .safetensors
Stage 3
loader.rs → model.rs
Streaming Chat
Stage 1 Output: Stage 2 Output (packed-model.safetensors): quantized/ ┌──────────────────────────────────────────┐ ├── *.q8k (header + blocks) │ model.embed_tokens.weight (F32) │ ├── *.q4k (header + blocks) │ model.norm.weight (F32) │ ├── *.q8k_meta │ layers.0.qkv_proj.weight.q8k (U8 blob)│ ├── *.q4k_meta │ layers.0.qkv_proj.weight.q8k_meta (I32)│ └── *.perm (optional) │ layers.0.down_proj.weight.q4k (U8 blob)│ │ layers.0.down_proj.weight.q4k_meta (I32)│ │ ... (layers 1-31) │ │ lm_head.weight.q8k (U8 blob)│ └──────────────────────────────────────────┘

6. Stage 1 — Quantization

The quantizer reads the original HuggingFace model, converts weights to F32, routes each layer to Q8K or Q4K, and writes individual binary files with validation metrics.

quantize_q8k.rs — Q8K Row-by-Row Quantization
fn quantize_rows_q8k(rows: usize, k: usize, data: &[f32]) -> Result<Vec<BlockQ8K>> {
    let blocks_per_row = k / QK_K;  // e.g. 3072/256 = 12 blocks per row
    let mut blocks = vec![BlockQ8K::zeros(); rows * blocks_per_row];

    for r in 0..rows {
        let row = &data[r * k..(r + 1) * k];          // one output row
        let dst = &mut blocks[r * blocks_per_row..];   // destination blocks
        BlockQ8K::from_float(row, dst);                // GGML kernel: F32 → Q8K
    }

    Ok(blocks)
}

BlockQ8K::from_float is GGML's own kernel (via candle's k_quants module). For each 256-value block it scans for amax, sets d = amax / 127.0, rounds each value to i8, and computes the 16 bsums sub-group sums needed by the optimized matmul kernel.

quantize_q8k.rs — Q4K Row-by-Row Quantization
fn quantize_rows_q4k(rows: usize, k: usize, data: &[f32]) -> Result<Vec<BlockQ4K>> {
    let blocks_per_row = k / QK_K;
    let mut blocks = vec![BlockQ4K::zeros(); rows * blocks_per_row];

    for r in 0..rows {
        let row = &data[r * k..(r + 1) * k];
        let dst = &mut blocks[r * blocks_per_row..];
        BlockQ4K::from_float(row, dst);   // GGML kernel: F32 → Q4K
    }

    Ok(blocks)
}

Optional Column Permutation

An optional preprocessing step reorders the columns of a weight matrix before quantizing, placing the highest-magnitude columns at the front of each 256-value block. This improves quantization quality because the block scale is determined by the largest absolute value — when large and small values share a block, the small values lose precision.

Block-Wise Permutation Algorithm
fn build_block_wise_permutation(rows: usize, k: usize, data: &[f32]) -> Vec<usize> {
    // Step 1: Compute L2 norm of each column across all rows
    let mut col_norms = vec![0f64; k];
    for r in 0..rows {
        for c in 0..k {
            let val = data[r * k + c] as f64;
            col_norms[c] += val * val;
        }
    }
    for norm in &mut col_norms { *norm = norm.sqrt(); }

    // Step 2: Divide into blocks of 64 columns
    let block_size = 64;
    let mut global_perm = vec![0usize; k];

    // Step 3: Sort within each block by descending L2 norm
    for block_start in (0..k).step_by(block_size) {
        let block_end = (block_start + block_size).min(k);
        let mut indices: Vec<usize> = (block_start..block_end).collect();
        indices.sort_by(|&a, &b|
            col_norms[b].partial_cmp(&col_norms[a]).unwrap()
        );
        for (dst, &src) in indices.iter().enumerate() {
            global_perm[block_start + dst] = src;
        }
    }

    global_perm  // perm[dst] = src
}
💡 Permutation Scope

Columns never cross 64-column block boundaries. This preserves cache locality during the matmul and limits the overhead of the per-inference permutation pass. For attention layers, Q/K/V share the same permutation (computed from q_proj) to keep attention scores consistent.

Validation Pipeline

After quantizing each layer, three validation metrics are computed by dequantizing back to F32 via BlockQ8K::to_float (or BlockQ4K::to_float):

Validation Metrics
fn compute_quantization_error_detailed(
    original: &[f32],
    blocks: &[BlockQ8K],
    rows: usize,
    k: usize,
) -> Result<(f64, f64, f64)> {
    let mut dequantized = vec![0f32; rows * k];
    BlockQ8K::to_float(blocks, &mut dequantized);

    let mut l2_error = 0f64;
    let mut max_error = 0f64;
    let mut relative_error_sum = 0f64;

    for (&orig, &deq) in original.iter().zip(dequantized.iter()) {
        let err = (orig - deq).abs() as f64;
        l2_error += err * err;
        if err > max_error { max_error = err; }
        if orig.abs() > 1e-10 {
            relative_error_sum += err / orig.abs() as f64;
        }
    }

    let rmse = (l2_error / (rows * k) as f64).sqrt();
    let mean_relative_error = relative_error_sum / (rows * k) as f64;

    Ok((rmse, max_error, mean_relative_error))
}

7. Stage 2 — Packing

The packer merges all quantized block files, metadata, and unquantized F32 tensors into a single .safetensors file. This simplifies deployment — the inference binary needs only one (or two, for sharded models) file.

Q8K Packing Layout

quantized/name.q8k (24-byte Q8KHeader + raw BlockQ8K bytes) ↓ Packed model keys: name.q8k → U8 blob (header + blocks, both present) name.q8k_meta → I32[3] = [out, k, header_size_in_bytes] name.perm → U8 blob (optional, if permutation was used)

Q4K Packing Layout

quantized/name.q4k (24-byte header + raw BlockQ4K bytes) ↓ Packed model keys: name.q4k → U8 blob (raw BlockQ4K bytes only, header stripped) name.q4k_meta → I32[2] = [out, k]

The header is stripped for Q4K because the SafeTensors format already encodes byte length via its own shape metadata. The name.q4k_meta tensor provides out and k so the loader can reconstruct the block layout.

Auto-Detection at Load Time

loader.rs — Format Detection
let load_q = |name: &str| -> candle::Result<QuantLinear> {
    let q4k_meta = format!("{}.q4k_meta", name);
    // Try Q4K first (down_proj layers)
    if st.names().iter().any(|n| **n == q4k_meta) {
        return QuantLinear::load_q4k_from_packed_safetensors(&st, name);
    }
    // Fall back to Q8K
    QuantLinear::load_from_packed_safetensors(&st, name)
};

The decision is made per-layer by probing for the presence of .q4k_meta. Multi-shard loading searches both shard files in sequence until the tensor is found.

8. Stage 3 — Inference & On-the-Fly Dequantization

The QuantLinear struct is the central runtime data structure. One instance exists for every quantized weight matrix. It dispatches to the correct GGML matmul kernel based on whether the layer is Q8K or Q4K.

quant_linear.rs — Core Data Structures
pub enum QuantBlocks {
    Q8K(Vec<BlockQ8K>),
    Q4K(Vec<BlockQ4K>),
}

pub struct QuantLinear {
    pub(crate) blocks: QuantBlocks,        // Q8K or Q4K block data
    pub(crate) out:    usize,              // number of output features
    pub(crate) k:      usize,              // number of input features
    pub(crate) name:   String,             // tensor name (for diagnostics)
    pub(crate) perm:   Option<Vec<usize>>, // column permutation (or None)
}

The Forward Pass

quant_linear.rs — On-the-Fly Dequantization
fn forward_2d(&self, x2d: &Tensor, b: usize, k_in: usize) -> candle::Result<Tensor> {
    let x_f32 = x2d.to_dtype(DType::F32)?.contiguous()?;
    let mut x_vec = x_f32.flatten_all()?.to_vec1::<f32>()?;

    // Apply column permutation if this layer was permutation-quantized
    if let Some(ref perm) = self.perm {
        x_vec = self.apply_permutation_to_input(&x_vec, b, perm);
    }

    let mut out_buf = vec![0f32; b * self.out];
    match &self.blocks {
        QuantBlocks::Q8K(blocks) => {
            matmul::<BlockQ8K>((b, self.k, self.out), &x_vec, blocks, &mut out_buf)?
        }
        QuantBlocks::Q4K(blocks) => {
            matmul::<BlockQ4K>((b, self.k, self.out), &x_vec, blocks, &mut out_buf)?
        }
    }

    Tensor::from_vec(out_buf, (b, self.out), x2d.device())
}
🚀 Critical Insight: Weights Are Never Fully Dequantized

The GGML matmul kernel processes one block at a time. For each block, it reads the scale d and the 256 quantized values, computes qs[i] × d as a temporary register value, and immediately multiplies it into the accumulator. The dequantized value never touches RAM as a full F32 matrix. Peak overhead per token: <24 KB across all layers.

9. Phi-3 Transformer Architecture

The model implements the full Phi-3 Mini 4K-Instruct transformer: 32 identical blocks, each containing causal self-attention with RoPE and a gated MLP with SiLU activation. All computation runs at F32 precision, with quantized weights decoded on-the-fly during each matrix multiplication.

model.rs — Transformer Block
pub struct Block {
    pub rms_1: Tensor,              // input layernorm weights (F32)
    pub attn:  CausalSelfAttention, // Q/K/V proj + O proj (Q8K)
    pub rms_2: Tensor,              // post-attention layernorm (F32)
    pub mlp:   Mlp,                 // gate_up (Q8K) + down (Q4K)
}

impl Block {
    pub fn forward(&self, x: &Tensor, index_pos: usize,
                   block_idx: usize, cache: &mut Cache) -> candle::Result<Tensor> {
        let residual = x;
        let x = rms_norm(x, &self.rms_1, 1e-5)?;
        let x = (self.attn.forward(&x, index_pos, block_idx, cache)? + residual)?;
        let residual = &x;
        let x = rms_norm(&x, &self.rms_2, 1e-5)?;
        let x = (self.mlp.forward(&x)? + residual)?;
        Ok(x)
    }
}

Gated MLP with Mixed Precision

model.rs — MLP (Q8K gate_up + Q4K down)
pub struct Mlp {
    pub gate_up_proj: QuantLinear,  // Q8K — fused gate + up projection
    pub down_proj:    QuantLinear,  // Q4K — down projection
    pub hidden_size:  usize,
}

impl Mlp {
    pub fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
        let gate_up = self.gate_up_proj.forward(x)?;           // Q8K matmul
        let gate = gate_up.narrow(D::Minus1, 0, self.hidden_size)?;
        let up   = gate_up.narrow(D::Minus1, self.hidden_size, self.hidden_size)?;
        let x = (silu(&gate)? * up)?;                          // SiLU gating
        self.down_proj.forward(&x)                              // Q4K matmul
    }
}

The SiLU gating between gate_up_proj (Q8K) and down_proj (Q4K) acts as a natural smoother — zeroing out most activations and reducing the effective dynamic range of the input to down_proj. This is precisely why Q4K works well here despite its lower precision.

Causal Self-Attention with RoPE

model.rs — Attention Mechanism
pub fn forward(&self, x: &Tensor, index_pos: usize,
               block_idx: usize, cache: &mut Cache) -> candle::Result<Tensor> {
    let (b_sz, seq_len, hidden_size) = x.dims3()?;

    // Fused QKV projection (single Q8K matmul)
    let qkv = self.qkv_proj.forward(x)?;

    // Split into Q, K, V
    let q = qkv.narrow(D::Minus1, 0, q_size)?;
    let k = qkv.narrow(D::Minus1, q_size, kv_size)?;
    let v = qkv.narrow(D::Minus1, q_size + kv_size, kv_size)?;

    // Apply Rotary Position Embeddings
    let q = self.apply_rotary_emb(&q, index_pos, cache)?;
    let k = self.apply_rotary_emb(&k, index_pos, cache)?;

    // Concatenate with KV cache
    if cache.use_kv_cache {
        if let Some((cache_k, cache_v)) = &cache.kvs[block_idx] {
            k = Tensor::cat(&[cache_k, &k], 2)?.contiguous()?;
            v = Tensor::cat(&[cache_v, &v], 2)?.contiguous()?;
        }
        cache.kvs[block_idx] = Some((k.clone(), v.clone()));
    }

    // Scaled dot-product attention with causal mask
    let att = q.matmul(&k.t()?)? / (head_dim as f64).sqrt();
    // ... causal masking + softmax + V projection ...
    self.o_proj.forward(&y)  // Q8K output projection
}

10. KV Cache & Rotary Position Embeddings

The Cache struct serves three purposes: storing past Key/Value tensors, pre-computing RoPE sin/cos tables, and generating causal attention masks.

cache.rs — KV Cache Structure
#[derive(Debug, Clone)]
pub struct Cache {
    pub masks: HashMap<(usize, usize), Tensor>,  // cached causal masks
    pub use_kv_cache: bool,
    pub kvs: Vec<Option<(Tensor, Tensor)>>,       // 32 slots: one (K,V) per layer
    pub cos: Tensor,                               // [4096, 48] RoPE cosine table
    pub sin: Tensor,                               // [4096, 48] RoPE sine table
    pub device: Device,
}

RoPE Pre-computation

Rotary Position Embeddings encode position by rotating query/key vectors in 2D subspaces. The frequency for dimension pair i at position m is:

θi(m) = m × 10000-2i/d

The cache pre-computes cos(θ) and sin(θ) for all 4096 positions and all 48 dimension pairs at startup:

cache.rs — RoPE Table Construction
let theta = 10000.0f32;
let inv_freq: Vec<f32> = (0..head_dim)
    .step_by(2)
    .map(|i| 1.0 / theta.powf(i as f32 / head_dim as f32))
    .collect();

let inv_freq = Tensor::from_vec(inv_freq, (inv_freq_len,), device)?;
let t = Tensor::arange(0u32, max_seq_len as u32, device)?
    .to_dtype(DType::F32)?
    .reshape((max_seq_len, 1))?;

let freqs = t.matmul(&inv_freq.reshape((1, inv_freq.elem_count()))?)?;
let cos = freqs.cos()?.to_dtype(dtype)?;  // [4096, 48]
let sin = freqs.sin()?.to_dtype(dtype)?;  // [4096, 48]

At inference time, the model slices the relevant rows via cache.cos.narrow(0, index_pos, seq_len) — no trigonometry is computed during generation.

Why the Cache Matters for Performance

💾 KV Cache Memory Per Token

num_layers × 2 × num_kv_heads × head_dim × 4 bytes
= 32 × 2 × 32 × 96 × 4 = 786,432 bytes ≈ 0.75 MB / token

At full context (4096 tokens): ~3.1 GB. The sliding window trimmer in conversation.rs caps history at 3072 tokens to prevent exhausting RAM.

11. Conversational Loop & Sampling

The main loop reads user input from stdin, formats it with Phi-3's chat template (<|user|>\n...\n<|end|>\n<|assistant|>\n), tokenizes, runs autoregressive generation, and decodes the output token-by-token with streaming.

Adaptive Sampling Strategy

Per-turn sampling parameters are chosen based on heuristic keyword detection:

main.rs — Sampling Configuration
let is_code_question = user_input.contains("code")
    || user_input.contains("example")
    || user_input.contains("function")
    || user_input.contains("loop");

let (sampling, repeat_pen, rep_last_n) = if is_code_question {
    // Low temperature → deterministic, precise code
    (Sampling::TopKThenTopP { k: 25, p: 0.95, temperature: 0.35 }, 1.3, 128)
} else {
    // Higher temperature → natural conversation
    (Sampling::TopKThenTopP { k: 50, p: 0.9,  temperature: 0.6 },  1.35, 256)
};
Parameter Code Mode Normal Mode Effect
k (top-k) 25 50 Vocab narrowed to top K tokens
p (nucleus) 0.95 0.90 Cumulative probability cutoff
temperature 0.35 0.60 Lower = more deterministic
repeat_pen 1.30 1.35 Penalty on recently generated tokens

First Turn vs. Subsequent Turns

A crucial optimization: on the first turn, the full prompt (system + user message) is fed through all 32 layers. On subsequent turns, only the new user tokens are processed — the KV cache already holds all previous key/value projections. This is why global_position keeps growing across turns:

Turn 1: prompt = 150 tokens → global_position = 150 generate 200 tokens → global_position = 350 Turn 2: user msg = 20 tokens → global_position = 370 generate 180 tokens → global_position = 550 Turn 3: user msg = 25 tokens → global_position = 575 generate 250 tokens → global_position = 825

Sliding Window History

conversation.rs — Context Trimming
pub fn apply_sliding_window(&mut self, tokenizer: &Tokenizer) -> candle::Result<()> {
    let tokens = tokenizer.encode(self.format_prompt(tokenizer)?, false)?
        .get_ids().len();

    if tokens <= self.max_history_tokens { return Ok(()); }

    // Keep system prompt, then add recent messages from newest to oldest
    let mut kept = vec![self.messages[0].clone()];
    let mut current_tokens = /* system prompt token count */;

    for msg in self.messages.iter().skip(1).rev() {
        let msg_tokens = tokenizer.encode(msg.content.clone(), false)?
            .get_ids().len();
        if current_tokens + msg_tokens > self.max_history_tokens { break; }
        kept.insert(1, msg.clone());
        current_tokens += msg_tokens;
    }

    self.messages = kept;
    Ok(())
}

12. Memory & Performance Numbers

Model File Sizes

Format File Size Compression Notes
Original BF16 ~7.6 GB HuggingFace download
Mixed Q8K/Q4K (this project) ~4.1 GB 1.85× Deployed as artemr87/phi3-mixed
Pure GGUF Q4_K_M ~2.2 GB 3.45× Smaller but lower quality

Per-Layer Weight Memory

Layer Shape Format Memory
qkv_proj × 32 [9216, 3072] Q8K ~1.0 GB
o_proj × 32 [3072, 3072] Q8K ~346 MB
gate_up_proj × 32 [16384, 3072] Q8K ~1.8 GB
down_proj × 32 [3072, 8192] Q4K ~388 MB
lm_head [32000, 3072] Q8K 337 MB
embed_tokens [32000, 3072] F32 375 MB

Runtime Memory Growth

Action Position Cache Tokens Memory ────────────────────────────────────────────────────────────────────── Startup 0 0 0.0 MB User turn 1 (prefill 150 tok) 150 150 112.5 MB + generate 200 tokens 350 350 262.5 MB User turn 2 (send 20 tok) 370 370 277.5 MB + generate 180 tokens 550 550 412.5 MB User turn 3 (send 25 tok) 575 575 431.3 MB + generate 250 tokens 825 825 618.8 MB User types "reset" 0 0 0.0 MB

Inference Speed (Observed CPU Run)

Run Speed Notes
Short interactive prompt 9.6 tok/s Baseline Q8K/Q4K shards; prompt: What is the size of the Sun?
Longer conversations varies KV cache and prompt length grow, so per-response speed is workload-dependent

13. Deployment & Usage

Build from Source

The source is maintained in github.com/artem1984A/nibble. The Phi-3 mixed-quantization binaries are in the phi3_standalone crate.

Step 1 — Quantize the Weights
git clone https://github.com/artem1984A/nibble.git
cd nibble/phi3_standalone

cargo build --release \
  --bin quantize_q8k \
  --bin pack_q8k_safetensors \
  --bin phi3-mixed-quant

SNAP=~/.cache/huggingface/hub/models--microsoft--Phi-3-mini-4k-instruct/snapshots/<rev>

# Standard quantization:
./target/release/quantize_q8k \
    "$SNAP/model-00001-of-00002.safetensors" \
    ./quantized/
./target/release/quantize_q8k \
    "$SNAP/model-00002-of-00002.safetensors" \
    ./quantized/

# Optional experimental column permutation:
cargo build --release --features experimental-perm --bin quantize_q8k
CANDLE_Q8K_PERMUTE=1 ./target/release/quantize_q8k \
    "$SNAP/model-00001-of-00002.safetensors" \
    ./quantized-perm/
CANDLE_Q8K_PERMUTE=1 ./target/release/quantize_q8k \
    "$SNAP/model-00002-of-00002.safetensors" \
    ./quantized-perm/
Step 2 — Pack into a Single File
./target/release/pack_q8k_safetensors \
    "$SNAP/model-00001-of-00002.safetensors" \
    ./quantized/ \
    ./packed-shard1.safetensors

./target/release/pack_q8k_safetensors \
    "$SNAP/model-00002-of-00002.safetensors" \
    ./quantized/ \
    ./packed-shard2.safetensors
Step 3 — Run Inference
./target/release/phi3-mixed-quant \
    ./packed-shard1.safetensors \
    ./packed-shard2.safetensors

Docker Deployment

Docker Run
docker pull artemr87/phi3-mixed:1.0.1
docker run -it --rm \
    -v /path/to/phi3-mixed-shard1.safetensors:/app/shard1.safetensors:ro \
    -v /path/to/phi3-mixed-shard2.safetensors:/app/shard2.safetensors:ro \
    -v /path/to/tokenizer.json:/app/tokenizer.json:ro \
    artemr87/phi3-mixed:1.0.1

Chat Commands

Input Effect
(any text) Send a message; model streams a response
reset Clear conversation history and KV cache
exit Quit the program
📊 Stats Line

After each response, the system prints diagnostics:
[Pos: 329 | Cache: 329 tok (246.8 MB) | Speed: 9.6 t/s | Hist: 3 msgs]

File Reference

File Role
quantize_q8k.rs Stage 1: converts BF16 weights to Q8K/Q4K block files
pack_q8k_safetensors.rs Stage 2: merges quantized files + F32 into one SafeTensors
types.rs Shared types: Q8KHeader, Phi3Config, conversions
quant_linear.rs Runtime QuantLinear: on-the-fly matmul dispatch
loader.rs Builds the Phi3 model from packed SafeTensors
model.rs Transformer: attention, RoPE, MLP, RMS norm, blocks
cache.rs KV cache, RoPE tables, causal mask generation
conversation.rs Sliding-window history, Phi-3 chat template
main.rs Entry point: CLI, model load, chat loop, streaming

🙏 Acknowledgments

This project builds on the work of the open-source community:

  • Hugging Face Candle — Minimalist ML framework for Rust with optimized GGML quantization kernels
  • Microsoft Phi-3 Team — Compact yet powerful 3.8B instruction-tuned language model
  • llama.cpp — Inspiration for Q8K/Q4K block quantization schemes and GGML format
  • SafeTensors — Secure, zero-copy tensor serialization format
  • Rust Language — Memory-safe systems programming without garbage collection

📚 Explore More

Dive deeper into the implementation, inspect the Rust source and benchmark logs, or read the follow-up Q8K128 experiment.