diff --git a/CMakeLists.txt b/CMakeLists.txt index a7070781..ff16f8c5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -98,6 +98,7 @@ set(SOURCES gemma/vit.h gemma/weights.cc gemma/weights.h + gemma/load_safetensors.cc io/blob_store.cc io/blob_store.h io/fields.cc @@ -105,6 +106,8 @@ set(SOURCES io/io_win.cc io/io.cc io/io.h + io/safetensors.cc + io/safetensors.h ops/dot-inl.h ops/matmul_static_bf16.cc ops/matmul_static_f32.cc @@ -160,7 +163,7 @@ set_property(TARGET libgemma PROPERTY CXX_STANDARD 17) set_target_properties(libgemma PROPERTIES PREFIX "") set_property(TARGET libgemma PROPERTY POSITION_INDEPENDENT_CODE ON) target_include_directories(libgemma PUBLIC ./) -target_link_libraries(libgemma hwy hwy_contrib sentencepiece-static) +target_link_libraries(libgemma hwy hwy_contrib sentencepiece-static nlohmann_json::nlohmann_json) target_include_directories(libgemma PUBLIC ${sentencepiece_SOURCE_DIR}) target_compile_definitions(libgemma PRIVATE $<$:_CRT_SECURE_NO_WARNINGS NOMINMAX>) target_compile_options(libgemma PRIVATE $<$:-Wno-deprecated-declarations>) diff --git a/gemma/attention.cc b/gemma/attention.cc index 117b533e..e36575f5 100644 --- a/gemma/attention.cc +++ b/gemma/attention.cc @@ -183,7 +183,12 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx, // heads that share the same key and value heads. const size_t kHeadGroups = layer_config.heads / layer_config.kv_heads; - const size_t cache_layer_size = layer_config.CacheLayerSize(); + // Cumulative KV cache column offset for this layer. For models with uniform + // qkv_dim this equals layer_idx * CacheLayerSize(), but for Gemma 4 + // dual-attention each layer may have a different qkv_dim and thus different + // CacheLayerSize(), so we must sum previous layers explicitly. + const size_t layer_kv_offset = + activations.config.KVCacheLayerOffset(layer_idx); const size_t seq_len = static_cast(activations.div_seq_len.GetDivisor()); // All layers should have the same number of heads. @@ -218,7 +223,7 @@ void DotSoftmaxWeightedSum(const size_t num_tokens, const size_t layer_idx, // Make strided read-only views into the kv cache for // this query and head. const size_t head_offset = (head / kHeadGroups) * qkv_dim * 2; - const size_t kv_head_offset = layer_idx * cache_layer_size + head_offset; + const size_t kv_head_offset = layer_kv_offset + head_offset; MatPtrT k("k_view", Extents2D(seq_len, qkv_dim)); k.SetPtr(kv_cache.Row(0) + kv_head_offset, kv_cache.Stride()); MatPtrT v("v_view", Extents2D(seq_len, qkv_dim)); @@ -257,7 +262,10 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, const LayerConfig& layer_config = layer.layer_config; const size_t qkv_dim = layer_config.qkv_dim; const size_t kv_heads = layer_config.kv_heads; - const size_t cache_layer_size = layer_config.CacheLayerSize(); + // Cumulative KV cache column offset for this layer (handles variable + // qkv_dim per layer in Gemma 4 dual-attention). + const size_t layer_kv_offset = + activations.config.KVCacheLayerOffset(layer_idx); // The original qkv_einsum_w has shape [(heads + kv_heads * 2), qkv_dim, // model_dim], which we reshaped to (heads + kv_heads * 2) * qkv_dim rows. @@ -276,7 +284,7 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, const size_t cache_pos = activations.div_seq_len.Remainder(qbatch.Pos(qi) + batch_idx); env.row_ptrs[0][interleaved_idx] = reinterpret_cast( - qbatch.KV(qi).kv_cache.Row(cache_pos) + layer_idx * cache_layer_size); + qbatch.KV(qi).kv_cache.Row(cache_pos) + layer_kv_offset); } kv_rows.AttachRowPtrs(env.row_ptrs[0].get()); CallMatMul(activations.pre_att_rms_out, layer.qkv_einsum_w2, @@ -297,7 +305,7 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx, const size_t cache_pos = activations.div_seq_len.Remainder(pos); auto& kv_cache = qbatch.KV(qi).kv_cache; KV_t* HWY_RESTRICT kv = kv_cache.Row(cache_pos) + - layer_idx * cache_layer_size + + layer_kv_offset + head * qkv_dim * 2; HWY_ALIGN float kv_f32[2 * kMaxQKVDim]; diff --git a/gemma/configs.cc b/gemma/configs.cc index 8856203a..10392219 100644 --- a/gemma/configs.cc +++ b/gemma/configs.cc @@ -430,6 +430,93 @@ static ModelConfig ConfigGemma3_270M() { return config; } +static ModelConfig ConfigBaseGemmaV4() { + ModelConfig config = ConfigNoSSM(); + config.att_cap = 0.0f; + config.final_cap = 30.0f; // Gemma 4 uses final logit softcapping. + config.eos_id = 1; + config.secondary_eos_id = 106; + return config; +} + +// Builds per-layer configs for Gemma 4, which has two distinct qkv_dims: +// SWA layers use kQKVDimSWA=256, full-attention layers use kQKVDimFull=512. +// The sliding_window_pattern is all models: every (pattern_period-1)th layer +// (0-indexed) is full-attention (False), the rest are SWA (True). +static void BuildGemma4LayerConfigs(ModelConfig& config, + uint32_t num_layers, + uint32_t pattern_period, + uint32_t heads, + uint32_t kv_heads, + const std::vector& ff_per_layer, + uint32_t swa_window) { + static constexpr uint32_t kQKVDimSWA = 256; + static constexpr uint32_t kQKVDimFull = 512; + config.num_layers = num_layers; + config.attention_window_sizes.resize(num_layers); + config.layer_configs.reserve(num_layers); + for (uint32_t i = 0; i < num_layers; ++i) { + const bool is_full = ((i % pattern_period) == (pattern_period - 1)); + LayerConfig lc; + lc.model_dim = config.model_dim; + lc.ff_hidden_dim = ff_per_layer[i]; + lc.heads = heads; + lc.kv_heads = kv_heads; + lc.qkv_dim = is_full ? kQKVDimFull : kQKVDimSWA; + lc.optimized_gating = true; + lc.post_norm = PostNormType::Scale; + lc.use_qk_norm = true; + config.layer_configs.push_back(lc); + config.attention_window_sizes[i] = + is_full ? config.max_seq_len : swa_window; + } +} + +// Gemma 4 E2B ("nano"): 35 layers, model_dim=1536. +// Sliding window pattern TTTTF repeated 7 times (7 full-att, 28 SWA). +// FFN: first 15 layers=6144, last 20 layers=12288. +static ModelConfig ConfigGemma4_E2B() { + ModelConfig config = ConfigBaseGemmaV4(); + config.display_name = "Gemma4_E2B"; + config.model = Model::GEMMA4_E2B; + config.wrapping = PromptWrapping::GEMMA_VLM; + config.model_dim = 1536; + config.vocab_size = kGemmaV3VocabSize; + config.max_seq_len = 128 * 1024; // 131072 + config.per_layer_embd_dim = 256; + config.query_scale = QueryScaleType::SqrtKeySize; + + // Per-layer FFN: layers 0-14 use 6144, layers 15-34 use 12288. + std::vector ff(35); + for (uint32_t i = 0; i < 35; ++i) ff[i] = (i < 15) ? 6144 : 12288; + + // Pattern: TTTTF (period=5), SWA window=512 tokens. + BuildGemma4LayerConfigs(config, 35, 5, 8, 1, ff, 512); + return config; +} + +// Gemma 4 E4B ("turbo"): 42 layers, model_dim=2560. +// Sliding window pattern TTTTTF repeated 7 times (7 full-att, 35 SWA). +// FFN: uniform 10240 across all layers. +static ModelConfig ConfigGemma4_E4B() { + ModelConfig config = ConfigBaseGemmaV4(); + config.display_name = "Gemma4_E4B"; + config.model = Model::GEMMA4_E4B; + config.wrapping = PromptWrapping::GEMMA_VLM; + config.model_dim = 2560; + config.vocab_size = kGemmaV3VocabSize; + config.max_seq_len = 128 * 1024; // 131072 + config.per_layer_embd_dim = 256; + config.query_scale = QueryScaleType::SqrtKeySize; + + // Per-layer FFN: uniform 10240 across all 42 layers. + std::vector ff(42, 10240); + + // Pattern: TTTTTF (period=6), SWA window=512 tokens. + BuildGemma4LayerConfigs(config, 42, 6, 8, 2, ff, 512); + return config; +} + static ModelConfig ConfigFromModel(Model model) { switch (model) { case Model::GEMMA2_2B: @@ -456,6 +543,10 @@ static ModelConfig ConfigFromModel(Model model) { return ConfigGemma3_27B(); case Model::GEMMA3_270M: return ConfigGemma3_270M(); + case Model::GEMMA4_E2B: + return ConfigGemma4_E2B(); + case Model::GEMMA4_E4B: + return ConfigGemma4_E4B(); default: HWY_ABORT("Model type %d unknown.", static_cast(model)); } @@ -489,6 +580,10 @@ const char* ModelPrefix(Model model) { return "gemma3-27b"; case Model::GEMMA3_270M: return "gemma3-270m"; + case Model::GEMMA4_E2B: + return "gemma4-e2b"; + case Model::GEMMA4_E4B: + return "gemma4-e4b"; default: HWY_ABORT("Model type %d unknown.", static_cast(model)); } diff --git a/gemma/configs.h b/gemma/configs.h index 275f3744..13c243c0 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -177,6 +177,8 @@ enum class Model { GEMMA3_12B, GEMMA3_27B, GEMMA3_270M, + GEMMA4_E2B, + GEMMA4_E4B, kSentinel, }; @@ -188,7 +190,8 @@ const char* ModelPrefix(Model model); // This is used for deducing the PromptWrapping for pre-2025 BlobStore. static inline bool IsVLM(Model model) { return model == Model::GEMMA3_4B || model == Model::GEMMA3_1B || - model == Model::GEMMA3_12B || model == Model::GEMMA3_27B; + model == Model::GEMMA3_12B || model == Model::GEMMA3_27B || + model == Model::GEMMA4_E2B || model == Model::GEMMA4_E4B; } static inline bool IsPaliGemma(Model model) { @@ -383,6 +386,7 @@ struct ModelConfig : public IFields { internal.VisitFields(visitor); + visitor(per_layer_embd_dim); // Append new fields here, then update `python/configs.cc`. } @@ -431,8 +435,21 @@ struct ModelConfig : public IFields { } size_t KVCacheCols() const { - const size_t num_layers = layer_configs.size(); - return num_layers * layer_configs[0].CacheLayerSize(); + size_t total = 0; + for (const auto& lc : layer_configs) { + total += lc.CacheLayerSize(); + } + return total; + } + + // Returns the column offset into the KV cache row for the given layer. + // Handles models with variable qkv_dim per layer (e.g. Gemma 4 dual-attention). + size_t KVCacheLayerOffset(size_t layer_idx) const { + size_t offset = 0; + for (size_t i = 0; i < layer_idx && i < layer_configs.size(); ++i) { + offset += layer_configs[i].CacheLayerSize(); + } + return offset; } bool IsEOS(int id) const { return (id == eos_id || id == secondary_eos_id); } @@ -453,6 +470,8 @@ struct ModelConfig : public IFields { uint32_t model_dim = 0; uint32_t vocab_size = 0; uint32_t max_seq_len = 0; + // Per-layer input embedding dimension (Gemma 4+). 0 means not used. + uint32_t per_layer_embd_dim = 0; // We no longer set nor use this: config_converter is not able to set this, // and only pre-2025 format stores scales, and we do not require advance diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 7991c354..120b1768 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -637,17 +637,29 @@ HWY_EXPORT(GenerateImageTokensT); Gemma::Gemma(const LoaderArgs& loader, const InferenceArgs& inference, ThreadingContext& ctx) - : reader_(loader.weights), - model_(reader_, loader.tokenizer, loader.wrapping), + : reader_(std::make_unique(loader.weights)), + model_(*reader_, loader.tokenizer, loader.wrapping), weights_(model_.Config()), chat_template_(model_.Tokenizer(), model_.Config().model), inference_(inference), aes_ctr_engine_(inference.deterministic) { // Negligible CPU time in the ctor body (except ReadFromBlobs). - weight_read_mode_ = weights_.ReadFromBlobs(model_, reader_, loader, inference, + weight_read_mode_ = weights_.ReadFromBlobs(model_, *reader_, loader, inference, mat_owners_, ctx); // Read everything into memory, or `weights_.mapped_` keeps the mapping alive. - reader_.CloseFile(); + reader_->CloseFile(); +} + +Gemma::Gemma(const ModelConfig& config, const Path& tokenizer_path, + const Path& safetensors_dir, const InferenceArgs& inference, + ThreadingContext& ctx) + : model_(config, tokenizer_path), + weights_(model_.Config()), + chat_template_(model_.Tokenizer(), model_.Config().model), + inference_(inference), + aes_ctr_engine_(inference.deterministic) { + weights_.LoadFromSafetensors(safetensors_dir.path, mat_owners_, ctx); + weight_read_mode_ = WeightsPtrs::Mode::kRead; } Gemma::~Gemma() = default; diff --git a/gemma/gemma.h b/gemma/gemma.h index 771cd1cb..105538f1 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -16,6 +16,7 @@ #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_H_ +#include #include #include @@ -237,6 +238,12 @@ class Gemma { // may reference the same, or other `ThreadingContext` via `MatMulEnv`. Gemma(const LoaderArgs& loader, const InferenceArgs& inference, ThreadingContext& ctx); + + // Loads BF16 weights directly from a HuggingFace safetensors directory, + // bypassing any BlobStore conversion step (no precision loss from conversion). + Gemma(const ModelConfig& config, const Path& tokenizer_path, + const Path& safetensors_dir, const InferenceArgs& inference, + ThreadingContext& ctx); ~Gemma(); const ModelConfig& Config() const { return model_.Config(); } @@ -273,7 +280,7 @@ class Gemma { MatMulEnv& env) const; private: - BlobReader reader_; + std::unique_ptr reader_; ModelStore model_; std::vector mat_owners_; WeightsPtrs weights_; diff --git a/gemma/gemma4_test.cc b/gemma/gemma4_test.cc new file mode 100644 index 00000000..ff38072f --- /dev/null +++ b/gemma/gemma4_test.cc @@ -0,0 +1,197 @@ +// Standalone test for Gemma 4 model configs. +// Validates layer counts, qkv_dim per layer, KV cache layout, and attention +// window sizes against the values observed in the GGUF metadata. +#include +#include +#include + +#include "gemma/configs.h" +#include "compression/types.h" + +namespace gcpp { + +static void PrintFail(const char* msg) { + fprintf(stderr, "FAIL: %s\n", msg); +} + +static bool TestGemma4E2B() { + fprintf(stderr, "\n=== Gemma 4 E2B (nano) ===\n"); + ModelConfig cfg(Model::GEMMA4_E2B, Type::kSFP, ChooseWrapping(Model::GEMMA4_E2B)); + + bool ok = true; + + // Basic dimensions (from GGUF: block_count=35, embedding_length=1536) + if (cfg.num_layers != 35) { PrintFail("E2B num_layers != 35"); ok = false; } + if (cfg.model_dim != 1536) { PrintFail("E2B model_dim != 1536"); ok = false; } + if (cfg.vocab_size != 262144) { PrintFail("E2B vocab_size != 262144"); ok = false; } + if (cfg.per_layer_embd_dim != 256) { PrintFail("E2B per_layer_embd_dim != 256"); ok = false; } + fprintf(stderr, " layers=%u model_dim=%u vocab=%u per_layer_embd=%u\n", + cfg.num_layers, cfg.model_dim, cfg.vocab_size, cfg.per_layer_embd_dim); + + // Pattern TTTTF x7: SWA layers at indices NOT in {4,9,14,19,24,29,34} + // Full-att (global) layers: indices where i%5==4 + size_t swa_count = 0, full_count = 0; + for (size_t i = 0; i < cfg.num_layers; ++i) { + bool is_global = cfg.IsGlobalLayer(i); + if (is_global) { + ++full_count; + if (cfg.layer_configs[i].qkv_dim != 512) { + PrintFail("E2B full-att layer qkv_dim != 512"); ok = false; + } + } else { + ++swa_count; + if (cfg.layer_configs[i].qkv_dim != 256) { + PrintFail("E2B SWA layer qkv_dim != 256"); ok = false; + } + } + } + fprintf(stderr, " SWA layers=%zu full-att layers=%zu\n", swa_count, full_count); + if (swa_count != 28) { PrintFail("E2B SWA count != 28"); ok = false; } + if (full_count != 7) { PrintFail("E2B full-att count != 7"); ok = false; } + + // FFN: first 15 layers = 6144, last 20 = 12288 + for (size_t i = 0; i < 15; ++i) { + if (cfg.layer_configs[i].ff_hidden_dim != 6144) { + fprintf(stderr, " FAIL: E2B layer %zu ff_hidden_dim=%u (want 6144)\n", + i, cfg.layer_configs[i].ff_hidden_dim); + ok = false; + } + } + for (size_t i = 15; i < 35; ++i) { + if (cfg.layer_configs[i].ff_hidden_dim != 12288) { + fprintf(stderr, " FAIL: E2B layer %zu ff_hidden_dim=%u (want 12288)\n", + i, cfg.layer_configs[i].ff_hidden_dim); + ok = false; + } + } + fprintf(stderr, " FFN: layers 0-14=%u, layers 15-34=%u\n", + cfg.layer_configs[0].ff_hidden_dim, cfg.layer_configs[15].ff_hidden_dim); + + // KV cache: verify KVCacheLayerOffset is non-uniform (SWA < full-att layers) + size_t offset_layer4 = cfg.KVCacheLayerOffset(4); + size_t size_layer4 = cfg.layer_configs[4].CacheLayerSize(); + // Layers 0-3 are SWA (qkv_dim=256, kv_heads=1 → size=512 each) + // So offset for layer 4 should be 4*512=2048, not 4*1024=4096 + fprintf(stderr, " KV offset[4]=%zu layer[4].CacheLayerSize=%zu\n", + offset_layer4, size_layer4); + if (offset_layer4 != 4 * 512u) { + fprintf(stderr, " FAIL: E2B KVCacheLayerOffset(4)=%zu (want %u)\n", + offset_layer4, 4*512u); + ok = false; + } + if (size_layer4 != 1024) { + fprintf(stderr, " FAIL: E2B layer[4].CacheLayerSize=%zu (want 1024)\n", size_layer4); + ok = false; + } + + // Total KV cache cols + size_t kv_cols = cfg.KVCacheCols(); + // 28 SWA * 512 + 7 full-att * 1024 = 14336 + 7168 = 21504 + fprintf(stderr, " KVCacheCols=%zu (expected 21504)\n", kv_cols); + if (kv_cols != 21504) { PrintFail("E2B KVCacheCols wrong"); ok = false; } + + fprintf(stderr, " E2B: %s\n", ok ? "PASS" : "FAIL"); + return ok; +} + +static bool TestGemma4E4B() { + fprintf(stderr, "\n=== Gemma 4 E4B (turbo) ===\n"); + ModelConfig cfg(Model::GEMMA4_E4B, Type::kSFP, ChooseWrapping(Model::GEMMA4_E4B)); + + bool ok = true; + + // Basic dimensions (from GGUF: block_count=42, embedding_length=2560) + if (cfg.num_layers != 42) { PrintFail("E4B num_layers != 42"); ok = false; } + if (cfg.model_dim != 2560) { PrintFail("E4B model_dim != 2560"); ok = false; } + if (cfg.vocab_size != 262144) { PrintFail("E4B vocab_size != 262144"); ok = false; } + fprintf(stderr, " layers=%u model_dim=%u vocab=%u\n", + cfg.num_layers, cfg.model_dim, cfg.vocab_size); + + // Pattern TTTTTF x7: full-att at indices 5,11,17,23,29,35,41 + size_t swa_count = 0, full_count = 0; + for (size_t i = 0; i < cfg.num_layers; ++i) { + bool is_global = cfg.IsGlobalLayer(i); + if (is_global) { + ++full_count; + if (cfg.layer_configs[i].qkv_dim != 512) { + PrintFail("E4B full-att layer qkv_dim != 512"); ok = false; + } + } else { + ++swa_count; + if (cfg.layer_configs[i].qkv_dim != 256) { + PrintFail("E4B SWA layer qkv_dim != 256"); ok = false; + } + } + } + fprintf(stderr, " SWA layers=%zu full-att layers=%zu\n", swa_count, full_count); + if (swa_count != 35) { PrintFail("E4B SWA count != 35"); ok = false; } + if (full_count != 7) { PrintFail("E4B full-att count != 7"); ok = false; } + + // FFN: uniform 10240 + for (size_t i = 0; i < 42; ++i) { + if (cfg.layer_configs[i].ff_hidden_dim != 10240) { + fprintf(stderr, " FAIL: E4B layer %zu ff_hidden_dim=%u (want 10240)\n", + i, cfg.layer_configs[i].ff_hidden_dim); + ok = false; + break; + } + } + fprintf(stderr, " FFN (uniform): %u\n", cfg.layer_configs[0].ff_hidden_dim); + + // KV cache: E4B has kv_heads=2 + // SWA: kv_heads=2, qkv_dim=256 → CacheLayerSize = 2*256*2 = 1024 + // Full-att: kv_heads=2, qkv_dim=512 → CacheLayerSize = 2*512*2 = 2048 + size_t offset_layer5 = cfg.KVCacheLayerOffset(5); + size_t size_layer5 = cfg.layer_configs[5].CacheLayerSize(); + // Layers 0-4 are SWA → offset = 5*1024 = 5120 + fprintf(stderr, " KV offset[5]=%zu layer[5].CacheLayerSize=%zu\n", + offset_layer5, size_layer5); + if (offset_layer5 != 5 * 1024u) { + fprintf(stderr, " FAIL: E4B KVCacheLayerOffset(5)=%zu (want %u)\n", + offset_layer5, 5*1024u); + ok = false; + } + if (size_layer5 != 2048) { + fprintf(stderr, " FAIL: E4B layer[5].CacheLayerSize=%zu (want 2048)\n", size_layer5); + ok = false; + } + + // Total KV: 35 SWA * 1024 + 7 full-att * 2048 = 35840 + 14336 = 50176 + size_t kv_cols = cfg.KVCacheCols(); + fprintf(stderr, " KVCacheCols=%zu (expected 50176)\n", kv_cols); + if (kv_cols != 50176) { PrintFail("E4B KVCacheCols wrong"); ok = false; } + + fprintf(stderr, " E4B: %s\n", ok ? "PASS" : "FAIL"); + return ok; +} + +static bool TestSerializeRoundTrip(Model model) { + ModelConfig cfg(model, Type::kSFP, ChooseWrapping(model)); + const std::vector serialized = cfg.Write(); + ModelConfig deserialized; + const IFields::ReadResult result = + deserialized.Read(hwy::Span(serialized), /*pos=*/0); + bool ok = true; + if (result.pos != serialized.size()) { PrintFail("serialized size mismatch"); ok = false; } + if (!deserialized.TestEqual(cfg, /*print=*/true)) { PrintFail("TestEqual failed"); ok = false; } + if (deserialized.model != model) { PrintFail("model mismatch after deserialize"); ok = false; } + return ok; +} + +} // namespace gcpp + +int main() { + bool all_ok = true; + all_ok &= gcpp::TestGemma4E2B(); + all_ok &= gcpp::TestGemma4E4B(); + + fprintf(stderr, "\n=== Serialize round-trip ===\n"); + bool rt_e2b = gcpp::TestSerializeRoundTrip(gcpp::Model::GEMMA4_E2B); + bool rt_e4b = gcpp::TestSerializeRoundTrip(gcpp::Model::GEMMA4_E4B); + fprintf(stderr, " E2B round-trip: %s\n", rt_e2b ? "PASS" : "FAIL"); + fprintf(stderr, " E4B round-trip: %s\n", rt_e4b ? "PASS" : "FAIL"); + all_ok &= rt_e2b && rt_e4b; + + fprintf(stderr, "\n=== Overall: %s ===\n\n", all_ok ? "PASS" : "FAIL"); + return all_ok ? 0 : 1; +} diff --git a/gemma/load_safetensors.cc b/gemma/load_safetensors.cc new file mode 100644 index 00000000..39f60aac --- /dev/null +++ b/gemma/load_safetensors.cc @@ -0,0 +1,263 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Implementation of WeightsPtrs::LoadFromSafetensors. +// Loads HuggingFace safetensors weights for Gemma 4 (E2B / E4B) directly +// into gemma.cpp's weight layout, with no BlobStore conversion step. + +#include "gemma/weights.h" + +#include +#include +#include +#include +#include + +#include "gemma/configs.h" +#include "gemma/tensor_info.h" +#include "io/safetensors.h" +#include "util/allocator.h" +#include "util/mat.h" +#include "util/threading_context.h" +#include "compression/types.h" +#include "hwy/aligned_allocator.h" +#include "hwy/base.h" + +namespace gcpp { + +namespace { + +// Returns a HF layer tensor name for layer `i`. +// Gemma 4 multimodal checkpoints nest the LM under model.language_model. +static inline std::string LN(const char* tail, size_t i) { + return "model.language_model.layers." + std::to_string(i) + "." + tail; +} + +// Validates that the safetensor entry has the expected element count. +static void ValidateShape(const SafetensorEntry& e, const char* hf_name, + size_t expected_elems) { + const uint64_t actual = SafetensorNumElems(e.shape); + if (actual != expected_elems) { + HWY_ABORT("safetensors: '%s' expected %zu elems, got %zu", hf_name, + expected_elems, static_cast(actual)); + } + if (e.dtype != "BF16") { + HWY_ABORT("safetensors: '%s' dtype '%s' not supported (need BF16)", + hf_name, e.dtype.c_str()); + } +} + +// Looks up `hf_name`, validates shape, allocates `mat` as BF16/kPacked, and +// reads the contiguous BF16 bytes directly into the allocated memory. +static void AllocAndReadDirect(MatPtr& mat, std::vector& owners, + const Allocator& alloc, + const SafetensorsIndex& idx, + const char* hf_name) { + const SafetensorEntry* e = idx.Find(hf_name); + if (!e) { + HWY_ABORT("safetensors: required tensor '%s' not found", hf_name); + } + const size_t expected = mat.Rows() * mat.Cols(); + ValidateShape(*e, hf_name, expected); + mat.SetType(Type::kBF16); + owners.emplace_back(); + owners.back().AllocateFor(mat, alloc, MatPadding::kPacked); + if (!idx.ReadTensor(*e, mat.Packed())) { + HWY_ABORT("safetensors: failed to read '%s'", hf_name); + } +} + +// Reads two contiguous HF tensors (rows × cols each) and concatenates them +// vertically into `mat` (2*rows × cols). +static void AllocAndReadConcat2(MatPtr& mat, std::vector& owners, + const Allocator& alloc, + const SafetensorsIndex& idx, + const char* hf_a, const char* hf_b) { + const SafetensorEntry* ea = idx.Find(hf_a); + const SafetensorEntry* eb = idx.Find(hf_b); + if (!ea) HWY_ABORT("safetensors: tensor '%s' not found", hf_a); + if (!eb) HWY_ABORT("safetensors: tensor '%s' not found", hf_b); + + const size_t rows_a = ea->num_bytes / (mat.Cols() * 2); // 2 bytes/BF16 + const size_t rows_b = eb->num_bytes / (mat.Cols() * 2); + if (rows_a + rows_b != mat.Rows()) { + HWY_ABORT( + "safetensors: concat '%s'(%zu) + '%s'(%zu) rows != expected %zu", + hf_a, rows_a, hf_b, rows_b, mat.Rows()); + } + if (ea->dtype != "BF16") + HWY_ABORT("safetensors: '%s' dtype '%s' not BF16", hf_a, ea->dtype.c_str()); + if (eb->dtype != "BF16") + HWY_ABORT("safetensors: '%s' dtype '%s' not BF16", hf_b, eb->dtype.c_str()); + + mat.SetType(Type::kBF16); + owners.emplace_back(); + owners.back().AllocateFor(mat, alloc, MatPadding::kPacked); + + uint8_t* dst = static_cast(mat.Packed()); + if (!idx.ReadTensor(*ea, dst)) HWY_ABORT("safetensors: read failed: '%s'", hf_a); + if (!idx.ReadTensor(*eb, dst + ea->num_bytes)) + HWY_ABORT("safetensors: read failed: '%s'", hf_b); +} + +// Reads three contiguous HF tensors and concatenates vertically into `mat`. +static void AllocAndReadConcat3(MatPtr& mat, std::vector& owners, + const Allocator& alloc, + const SafetensorsIndex& idx, + const char* hf_a, const char* hf_b, + const char* hf_c) { + const SafetensorEntry* ea = idx.Find(hf_a); + const SafetensorEntry* eb = idx.Find(hf_b); + const SafetensorEntry* ec = idx.Find(hf_c); + if (!ea) HWY_ABORT("safetensors: tensor '%s' not found", hf_a); + if (!eb) HWY_ABORT("safetensors: tensor '%s' not found", hf_b); + if (!ec) HWY_ABORT("safetensors: tensor '%s' not found", hf_c); + + const uint64_t total_bytes = ea->num_bytes + eb->num_bytes + ec->num_bytes; + const uint64_t expected_bytes = mat.Rows() * mat.Cols() * 2; + if (total_bytes != expected_bytes) { + HWY_ABORT( + "safetensors: concat3 '%s'+'%s'+'%s' bytes %" PRIu64 + " != expected %" PRIu64, + hf_a, hf_b, hf_c, total_bytes, expected_bytes); + } + if (ea->dtype != "BF16" || eb->dtype != "BF16" || ec->dtype != "BF16") { + HWY_ABORT("safetensors: expected BF16 for Q/K/V projections"); + } + + mat.SetType(Type::kBF16); + owners.emplace_back(); + owners.back().AllocateFor(mat, alloc, MatPadding::kPacked); + + uint8_t* dst = static_cast(mat.Packed()); + if (!idx.ReadTensor(*ea, dst)) HWY_ABORT("safetensors: read failed: '%s'", hf_a); + dst += ea->num_bytes; + if (!idx.ReadTensor(*eb, dst)) HWY_ABORT("safetensors: read failed: '%s'", hf_b); + dst += eb->num_bytes; + if (!idx.ReadTensor(*ec, dst)) HWY_ABORT("safetensors: read failed: '%s'", hf_c); +} + +// Loads per-layer token embedding from HF shape [V, L*D] into gemma shape +// [L*D, V] by transposing (simple matrix transpose of a [V, L*D] matrix). +// HF: hf[v, l*D+d] → gemma row (l*D+d), col v. +static void LoadPerLayerEmbd(MatPtr& mat, std::vector& owners, + const Allocator& alloc, + const SafetensorsIndex& idx, + size_t num_layers, size_t vocab_size, + size_t embd_dim) { + // Gemma 4 multimodal: "model.language_model.embed_tokens_per_layer.weight" + // Shape [V, L*D] in HF; gemma stores it as [L*D, V]. + const char* hf_name = + "model.language_model.embed_tokens_per_layer.weight"; + const SafetensorEntry* e = idx.Find(hf_name); + if (!e) HWY_ABORT("safetensors: '%s' not found", hf_name); + ValidateShape(*e, hf_name, num_layers * vocab_size * embd_dim); + + // Read HF [V, L*D] into a temp buffer. + const size_t total_elems = vocab_size * num_layers * embd_dim; + auto tmp = hwy::AllocateAligned(total_elems); // BF16 = uint16 + if (!idx.ReadTensor(*e, tmp.get())) { + HWY_ABORT("safetensors: read failed: '%s'", hf_name); + } + + // Allocate gemma [L*D, V] (rows=L*D, cols=V) as packed BF16. + mat.SetType(Type::kBF16); + owners.emplace_back(); + owners.back().AllocateFor(mat, alloc, MatPadding::kPacked); + + // Transpose: gemma[l*D+d, v] = HF[v, l*D+d] + uint16_t* dst = static_cast(mat.Packed()); + const uint16_t* src = tmp.get(); + const size_t LD = num_layers * embd_dim; + for (size_t v = 0; v < vocab_size; ++v) { + for (size_t ld = 0; ld < LD; ++ld) { + dst[ld * vocab_size + v] = src[v * LD + ld]; + } + } +} + +} // namespace + +void WeightsPtrs::LoadFromSafetensors(const std::string& dir, + std::vector& mat_owners, + ThreadingContext& ctx) { + const Allocator& alloc = ctx.allocator; + SafetensorsIndex idx(dir); + const ModelConfig& cfg = config_; + + // ── Global tensors ──────────────────────────────────────────────────────── + AllocAndReadDirect(embedder_input_embedding, mat_owners, alloc, idx, + "model.language_model.embed_tokens.weight"); + AllocAndReadDirect(final_norm_scale, mat_owners, alloc, idx, + "model.language_model.norm.weight"); + + if (cfg.per_layer_embd_dim > 0) { + LoadPerLayerEmbd(per_layer_input_embedding, mat_owners, alloc, idx, + cfg.num_layers, cfg.vocab_size, cfg.per_layer_embd_dim); + } + + // ── Per-layer tensors ───────────────────────────────────────────────────── + for (size_t i = 0; i < cfg.num_layers; ++i) { + const LayerConfig& lc = cfg.layer_configs[i]; + LayerWeightsPtrs& lw = *GetLayer(i); + + // Norm scales (shape: [model_dim] → rows=1, cols=model_dim). + AllocAndReadDirect(lw.pre_attention_norm_scale, mat_owners, alloc, idx, + LN("input_layernorm.weight", i).c_str()); + AllocAndReadDirect(lw.post_attention_norm_scale, mat_owners, alloc, idx, + LN("post_attention_layernorm.weight", i).c_str()); + AllocAndReadDirect(lw.pre_ffw_norm_scale, mat_owners, alloc, idx, + LN("pre_feedforward_layernorm.weight", i).c_str()); + AllocAndReadDirect(lw.post_ffw_norm_scale, mat_owners, alloc, idx, + LN("post_feedforward_layernorm.weight", i).c_str()); + + if (lc.use_qk_norm) { + AllocAndReadDirect(lw.query_norm_scale, mat_owners, alloc, idx, + LN("self_attn.q_norm.weight", i).c_str()); + AllocAndReadDirect(lw.key_norm_scale, mat_owners, alloc, idx, + LN("self_attn.k_norm.weight", i).c_str()); + } + + // Attention: Q + K + V → qkv_einsum_w [(heads+2*kv_heads)*qkv_dim, model_dim] + AllocAndReadConcat3( + lw.qkv_einsum_w, mat_owners, alloc, idx, + LN("self_attn.q_proj.weight", i).c_str(), + LN("self_attn.k_proj.weight", i).c_str(), + LN("self_attn.v_proj.weight", i).c_str()); + + // Output projection: att_weights [model_dim, heads*qkv_dim] (direct). + // HF o_proj.weight shape is already [model_dim, heads*qkv_dim]. ✓ + AllocAndReadDirect(lw.att_weights, mat_owners, alloc, idx, + LN("self_attn.o_proj.weight", i).c_str()); + + // FFN: gate + up → gating_einsum_w [2*ff_hidden_dim, model_dim] + AllocAndReadConcat2( + lw.gating_einsum_w, mat_owners, alloc, idx, + LN("mlp.gate_proj.weight", i).c_str(), + LN("mlp.up_proj.weight", i).c_str()); + + // FFN down: linear_w [model_dim, ff_hidden_dim] (direct). + AllocAndReadDirect(lw.linear_w, mat_owners, alloc, idx, + LN("mlp.down_proj.weight", i).c_str()); + } + + // ── Fixup (splits qkv/gating, verifies att_weights) ───────────────────── + Fixup(mat_owners, ctx); + + fprintf(stderr, "[safetensors] loaded %zu layers from %s\n", + cfg.num_layers, dir.c_str()); +} + +} // namespace gcpp diff --git a/gemma/model_store.cc b/gemma/model_store.cc index 2f3e1ecb..b4beff65 100644 --- a/gemma/model_store.cc +++ b/gemma/model_store.cc @@ -394,6 +394,11 @@ ModelStore::ModelStore(BlobReader& reader, const Path& tokenizer_path, HWY_ASSERT(key_idx_.size() == mat_ptrs_.size()); } +ModelStore::ModelStore(const ModelConfig& config, const Path& tokenizer_path) + : config_(config), + tokenizer_(tokenizer_path.Empty() ? kMockTokenizer + : ReadFileToString(tokenizer_path)) {} + ModelStore::~ModelStore() { // Sanity check: ensure all scales were consumed. HWY_ASSERT(scales_consumed_ == scales_.size()); diff --git a/gemma/model_store.h b/gemma/model_store.h index b4d63ad2..a1458f4e 100644 --- a/gemma/model_store.h +++ b/gemma/model_store.h @@ -53,6 +53,10 @@ class ModelStore { // used for pre-2025 files. ModelStore(BlobReader& reader, const Path& tokenizer_path = Path(), Tristate wrapping = Tristate::kDefault); + + // Constructs from an in-memory config and a tokenizer file path, without + // requiring a BlobStore. Used for direct safetensors loading. + ModelStore(const ModelConfig& config, const Path& tokenizer_path); ~ModelStore(); const ModelConfig& Config() const { diff --git a/gemma/run.cc b/gemma/run.cc index 7e2059fd..0a642647 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -18,6 +18,7 @@ #include #include +#include #include #include #include @@ -255,7 +256,24 @@ void Run(const LoaderArgs& loader, const ThreadingArgs& threading, ThreadingContext ctx(threading); MatMulEnv env(ctx); if (inference.verbosity >= 3) env.print_best = true; - const Gemma gemma(loader, inference, ctx); + + // Two load paths: safetensors directory or BlobStore .sbs file. + const bool use_safetensors = !loader.safetensors.Empty(); + if (use_safetensors && loader.model_spec.empty()) { + HWY_ABORT( + "--safetensors requires --model_spec, e.g. 'gemma4-e2b-BF16-it'."); + } + + std::unique_ptr gemma_ptr; + if (use_safetensors) { + const ModelConfig st_config(loader.model_spec); + gemma_ptr = std::make_unique(st_config, loader.tokenizer, + loader.safetensors, inference, ctx); + } else { + gemma_ptr = std::make_unique(loader, inference, ctx); + } + const Gemma& gemma = *gemma_ptr; + KVCache kv_cache(gemma.Config(), inference, ctx.allocator); if (inference.verbosity >= 1) { diff --git a/gemma/tensor_info.cc b/gemma/tensor_info.cc index 05f829bb..d7f30107 100644 --- a/gemma/tensor_info.cc +++ b/gemma/tensor_info.cc @@ -42,6 +42,18 @@ void TensorInfoRegistry::AddModelTensors(const ModelConfig& config) { .shape = {config.model_dim}, .min_size = Type::kBF16, }); + // Per-layer token embeddings (Gemma 4+). Shape: [num_layers * embd_dim, vocab]. + if (config.per_layer_embd_dim > 0) { + Add(no_suffix, + { + .base_name = "per_layer_embd", + .source_names = {"per_layer_token_embd.weight"}, + .axes = {0, 1}, + .shape = {config.num_layers * config.per_layer_embd_dim, + config.vocab_size}, + .min_size = Type::kBF16, + }); + } Add(no_suffix, { .base_name = "enc_norm_bias", .source_names = {"img/Transformer/encoder_norm/bias"}, diff --git a/gemma/weights.h b/gemma/weights.h index 36618694..8985a5e9 100644 --- a/gemma/weights.h +++ b/gemma/weights.h @@ -276,6 +276,9 @@ struct WeightsPtrs { tensors_(config_), finder_("", tensors_), // no suffix because these are per-model. embedder_input_embedding(finder_("c_embedding")), + per_layer_input_embedding(config.per_layer_embd_dim > 0 + ? finder_("per_layer_embd") + : MatPtr()), final_norm_scale(finder_("c_final_norm")), vit_encoder_norm_bias(finder_("enc_norm_bias")), vit_encoder_norm_scale(finder_("enc_norm_scale")), @@ -306,6 +309,8 @@ struct WeightsPtrs { // TODO: switch to SFP? MatPtr embedder_input_embedding; + // Per-layer input embeddings (Gemma 4+). Shape: [num_layers * embd_dim, vocab]. + MatPtr per_layer_input_embedding; MatPtr final_norm_scale; // at least BF16. // Vit parts. @@ -341,6 +346,9 @@ struct WeightsPtrs { LayerWeightsPtrs* other_layer1 = nullptr; LayerWeightsPtrs* other_layer2 = nullptr; func(TENSOR_ARGS(embedder_input_embedding, kMustRead)); + if (config_.per_layer_embd_dim > 0) { + func(TENSOR_ARGS(per_layer_input_embedding, kMustRead)); + } func(TENSOR_ARGS(final_norm_scale, kMustRead)); if (!config_.vit_config.layer_configs.empty()) { // Vit parts. @@ -409,6 +417,12 @@ struct WeightsPtrs { const LoaderArgs& loader, const InferenceArgs& inference, std::vector& mat_owners, ThreadingContext& ctx); + // Loads BF16 tensor data directly from a HuggingFace safetensors directory + // (e.g. google/gemma-4-e2b-it) without any BlobStore conversion step. + void LoadFromSafetensors(const std::string& dir, + std::vector& mat_owners, + ThreadingContext& ctx); + // Adds one blob for each tensor's data and returns all serialized MatPtr. std::vector AddTensorDataToWriter(BlobWriter& writer) const; diff --git a/io/safetensors.cc b/io/safetensors.cc new file mode 100644 index 00000000..dffb14e1 --- /dev/null +++ b/io/safetensors.cc @@ -0,0 +1,141 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "io/safetensors.h" + +#include +#include + +#include +#include +#include +#include + +#include "io/io.h" +#include "hwy/base.h" +#include + +namespace gcpp { + +namespace { + +// Reads a uint64_t from a little-endian byte buffer. +inline uint64_t ReadLE64(const uint8_t* p) { + uint64_t v = 0; + for (int i = 0; i < 8; ++i) v |= static_cast(p[i]) << (8 * i); + return v; +} + +// Returns sorted list of *.safetensors files in a directory. +std::vector FindSafetensorsFiles(const std::string& dir) { + std::vector paths; + namespace fs = std::filesystem; + if (!fs::is_directory(dir)) { + HWY_ABORT("safetensors: '%s' is not a directory", dir.c_str()); + } + for (const auto& entry : fs::directory_iterator(dir)) { + if (!entry.is_regular_file()) continue; + const std::string ext = entry.path().extension().string(); + if (ext == ".safetensors") { + paths.push_back(entry.path().string()); + } + } + if (paths.empty()) { + HWY_ABORT("safetensors: no *.safetensors files found in '%s'", dir.c_str()); + } + std::sort(paths.begin(), paths.end()); + return paths; +} + +// Parses a single safetensors header JSON and populates `entries`. +// `shard_idx` is the index of this file in SafetensorsIndex::shards_. +// Returns the data_offset (= 8 + header_size). +uint64_t ParseSafetensorsHeader( + const File& file, size_t shard_idx, + std::unordered_map* entries, + std::vector* names) { + // Read 8-byte header size. + uint8_t size_buf[8]; + if (!file.Read(0, 8, size_buf)) { + HWY_ABORT("safetensors: failed to read header size"); + } + const uint64_t header_size = ReadLE64(size_buf); + if (header_size == 0 || header_size > 100 * 1024 * 1024) { + HWY_ABORT("safetensors: implausible header_size %" PRIu64, header_size); + } + + // Read JSON header. + std::string header_json(static_cast(header_size), '\0'); + if (!file.Read(8, header_size, &header_json[0])) { + HWY_ABORT("safetensors: failed to read header JSON"); + } + + const uint64_t data_offset = 8 + header_size; + + nlohmann::json j = nlohmann::json::parse(header_json); + for (auto& [name, val] : j.items()) { + if (name == "__metadata__") continue; + SafetensorEntry entry; + entry.dtype = val["dtype"].get(); + for (const auto& d : val["shape"]) { + entry.shape.push_back(d.get()); + } + const auto& offsets = val["data_offsets"]; + const uint64_t rel_start = offsets[0].get(); + const uint64_t rel_end = offsets[1].get(); + entry.file_offset = data_offset + rel_start; + entry.num_bytes = rel_end - rel_start; + entry.shard_idx = shard_idx; + if (entries->find(name) != entries->end()) { + HWY_ABORT("safetensors: duplicate tensor name '%s'", name.c_str()); + } + names->push_back(name); + (*entries)[name] = std::move(entry); + } + return data_offset; +} + +} // namespace + +SafetensorsIndex::SafetensorsIndex(const std::string& dir) { + const std::vector paths = FindSafetensorsFiles(dir); + shards_.resize(paths.size()); + for (size_t i = 0; i < paths.size(); ++i) { + shards_[i].file = OpenFileOrAbort(Path(paths[i]), "r"); + shards_[i].data_offset = ParseSafetensorsHeader( + *shards_[i].file, i, &entries_, &names_); + } + fprintf(stderr, "[safetensors] indexed %zu tensors from %zu shard(s) in %s\n", + entries_.size(), shards_.size(), dir.c_str()); +} + +const SafetensorEntry* SafetensorsIndex::Find(const std::string& name) const { + const auto it = entries_.find(name); + return it == entries_.end() ? nullptr : &it->second; +} + +bool SafetensorsIndex::ReadTensor(const SafetensorEntry& entry, + void* out) const { + const Shard& shard = shards_[entry.shard_idx]; + return shard.file->Read(entry.file_offset, entry.num_bytes, out); +} + +void SafetensorsIndex::PrintNames() const { + for (const auto& n : names_) { + fprintf(stderr, " %s\n", n.c_str()); + } +} + +} // namespace gcpp diff --git a/io/safetensors.h b/io/safetensors.h new file mode 100644 index 00000000..3676cf86 --- /dev/null +++ b/io/safetensors.h @@ -0,0 +1,96 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Parses the safetensors format used by HuggingFace models. +// Format: [8-byte header_size LE][header_size bytes JSON][tensor data ...] +// Multiple sharded files (model-NNNNN-of-MMMMM.safetensors) are supported. + +#ifndef THIRD_PARTY_GEMMA_CPP_IO_SAFETENSORS_H_ +#define THIRD_PARTY_GEMMA_CPP_IO_SAFETENSORS_H_ + +#include +#include + +#include +#include +#include +#include + +#include "io/io.h" // File, Path +#include "hwy/base.h" + +namespace gcpp { + +// Metadata for a single tensor in a safetensors file. +struct SafetensorEntry { + std::string dtype; // "BF16", "F32", "F16", "I8", etc. + std::vector shape; // dimension sizes (may be empty for scalars) + uint64_t file_offset; // absolute byte offset in shard file + uint64_t num_bytes; // total byte size of tensor data + size_t shard_idx; // index into SafetensorsIndex::shards_ +}; + +// Opens one or more *.safetensors files and provides random tensor access. +// Supports sharded models (model-00001-of-00002.safetensors, etc.). +class SafetensorsIndex { + public: + // Scans `dir` for all *.safetensors files, parses their headers, and builds + // a unified tensor index. Aborts if no files are found or parsing fails. + explicit SafetensorsIndex(const std::string& dir); + + // Returns nullptr if `name` is not found in any shard. + const SafetensorEntry* Find(const std::string& name) const; + + // All tensor names across all shards. + const std::vector& Names() const { return names_; } + + // Reads `entry.num_bytes` bytes into `out`. Returns true on success. + bool ReadTensor(const SafetensorEntry& entry, void* out) const; + + // For debugging: prints all indexed tensor names to stderr. + void PrintNames() const; + + struct Shard { + std::unique_ptr file; + uint64_t data_offset; // = 8 + header_size + }; + + private: + std::vector shards_; + std::unordered_map entries_; + std::vector names_; +}; + +// Returns total number of elements given a shape vector. +inline uint64_t SafetensorNumElems(const std::vector& shape) { + if (shape.empty()) return 1; + uint64_t n = 1; + for (uint64_t d : shape) n *= d; + return n; +} + +// Returns bytes per element for a safetensors dtype string. +inline size_t SafetensorDtypeBytes(const std::string& dtype) { + if (dtype == "BF16" || dtype == "F16") return 2; + if (dtype == "F32" || dtype == "I32" || dtype == "U32") return 4; + if (dtype == "F64" || dtype == "I64" || dtype == "U64") return 8; + if (dtype == "I8" || dtype == "U8" || dtype == "BOOL") return 1; + if (dtype == "I16" || dtype == "U16") return 2; + return 0; +} + +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_IO_SAFETENSORS_H_