| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276 |
- # Copyright 2022 EleutherAI and The HuggingFace Inc. team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import argparse
- import json
- import os
- import shutil
- import torch
- """
- Sample usage:
- ```
- python src/transformers/models/llama/convert_llama_weights_to_hf.py \
- --input_dir /path/to/downloaded/llama/weights --model_size 7B --output_dir /output/path
- ```
- Thereafter, models can be loaded via:
- ```
- tokenizer = transformers.LLaMATokenizer.from_pretrained("/output/path/tokenizer/")
- model = transformers.LLaMAForCausalLM.from_pretrained("/output/path/llama-7b/")
- ```
- """
- INTERMEDIATE_SIZE_MAP = {
- "7B": 11008,
- "13B": 13824,
- "30B": 17920,
- "65B": 22016,
- }
- NUM_SHARDS = {
- "7B": 1,
- "13B": 2,
- "30B": 4,
- "65B": 8,
- }
- def read_json(path):
- with open(path, "r") as f:
- return json.load(f)
- def write_json(text, path):
- with open(path, "w") as f:
- json.dump(text, f)
- def write_model(model_path, input_base_path, model_size):
- assert model_size in INTERMEDIATE_SIZE_MAP
- os.makedirs(model_path, exist_ok=True)
- params = read_json(os.path.join(input_base_path, "params.json"))
- num_shards = NUM_SHARDS[model_size]
- n_layers = params["n_layers"]
- n_heads = params["n_heads"]
- n_heads_per_shard = n_heads // num_shards
- dim = params["dim"]
- dims_per_head = dim // n_heads
- base = 10000.0
- inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
- # permute for sliced rotary
- def permute(w):
- return w.view(n_heads, dim // n_heads // 2, 2, dim).transpose(1, 2).reshape(dim, dim)
- # Load weights
- if model_size == "7B":
- # Not shared
- # (The sharded implementation would also work, but this is simpler.)
- loaded = torch.load(os.path.join(input_base_path, "consolidated.00.pth"), map_location="cpu")
- else:
- # Sharded
- loaded = [
- torch.load(os.path.join(input_base_path, f"consolidated.{i:02d}.pth"), map_location="cpu")
- for i in range(num_shards)
- ]
- param_count = 0
- index_dict = {"weight_map": {}}
- for layer_i in range(n_layers):
- filename = "pytorch_model-{:05d}-of-{:05d}.bin".format(
- layer_i + 1,
- n_layers + 1,
- )
- if model_size == "7B":
- # Unsharded
- state_dict = {
- f"model.layers.{layer_i}.self_attn.q_proj.weight": permute(
- loaded[f"layers.{layer_i}.attention.wq.weight"]
- ),
- f"model.layers.{layer_i}.self_attn.k_proj.weight": permute(
- loaded[f"layers.{layer_i}.attention.wk.weight"]
- ),
- f"model.layers.{layer_i}.self_attn.v_proj.weight": loaded[f"layers.{layer_i}.attention.wv.weight"],
- f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"layers.{layer_i}.attention.wo.weight"],
- f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w1.weight"],
- f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w2.weight"],
- f"model.layers.{layer_i}.mlp.up_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w3.weight"],
- f"model.layers.{layer_i}.input_layernorm.weight": loaded[f"layers.{layer_i}.attention_norm.weight"],
- f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[f"layers.{layer_i}.ffn_norm.weight"],
- }
- else:
- # Sharded
- state_dict = {
- f"model.layers.{layer_i}.input_layernorm.weight": loaded[0][f"layers.{layer_i}.attention_norm.weight"],
- f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[0][
- f"layers.{layer_i}.ffn_norm.weight"
- ],
- }
- state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = permute(
- torch.cat(
- [
- loaded[i][f"layers.{layer_i}.attention.wq.weight"].view(n_heads_per_shard, dims_per_head, dim)
- for i in range(num_shards)
- ],
- dim=0,
- ).reshape(dim, dim)
- )
- state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = permute(
- torch.cat(
- [
- loaded[i][f"layers.{layer_i}.attention.wk.weight"].view(n_heads_per_shard, dims_per_head, dim)
- for i in range(num_shards)
- ],
- dim=0,
- ).reshape(dim, dim)
- )
- state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat(
- [
- loaded[i][f"layers.{layer_i}.attention.wv.weight"].view(n_heads_per_shard, dims_per_head, dim)
- for i in range(num_shards)
- ],
- dim=0,
- ).reshape(dim, dim)
- state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat(
- [loaded[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(num_shards)], dim=1
- )
- state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat(
- [loaded[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(num_shards)], dim=0
- )
- state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat(
- [loaded[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(num_shards)], dim=1
- )
- state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat(
- [loaded[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(num_shards)], dim=0
- )
- state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq
- for k, v in state_dict.items():
- index_dict["weight_map"][k] = filename
- param_count += v.numel()
- torch.save(state_dict, os.path.join(model_path, filename))
- filename = "pytorch_model-{:05d}-of-{:05d}.bin".format(
- n_layers + 1,
- n_layers + 1,
- )
- if model_size == "7B":
- # Unsharded
- state_dict = {
- "model.embed_tokens.weight": loaded["tok_embeddings.weight"],
- "model.norm.weight": loaded["norm.weight"],
- "lm_head.weight": loaded["output.weight"],
- }
- else:
- state_dict = {
- "model.norm.weight": loaded[0]["norm.weight"],
- "model.embed_tokens.weight": torch.cat(
- [loaded[i]["tok_embeddings.weight"] for i in range(num_shards)], dim=1
- ),
- "lm_head.weight": torch.cat([loaded[i]["output.weight"] for i in range(num_shards)], dim=0),
- }
- for k, v in state_dict.items():
- index_dict["weight_map"][k] = filename
- param_count += v.numel()
- torch.save(state_dict, os.path.join(model_path, filename))
- # Write configs
- index_dict["metadata"] = {"total_size": param_count * 2}
- write_json(index_dict, os.path.join(model_path, "pytorch_model.bin.index.json"))
- config_out = {
- "architectures": ["LLaMAForCausalLM"],
- "bos_token_id": 0,
- "eos_token_id": 1,
- "hidden_act": "silu",
- "hidden_size": params["dim"],
- "intermediate_size": INTERMEDIATE_SIZE_MAP[model_size],
- "initializer_range": 0.02,
- "max_sequence_length": 2048,
- "model_type": "llama",
- "num_attention_heads": params["n_heads"],
- "num_hidden_layers": params["n_layers"],
- "pad_token_id": -1,
- "rms_norm_eps": params["norm_eps"],
- "torch_dtype": "float16",
- "transformers_version": "4.27.0.dev0",
- "use_cache": True,
- "vocab_size": 32000,
- }
- write_json(
- config_out,
- os.path.join(model_path, "config.json"),
- )
- generation_config = {
- "_from_model_config": True,
- "bos_token_id": 0,
- "eos_token_id": 1,
- "pad_token_id": 0,
- "transformers_version": "4.27.0.dev0",
- }
- write_json(
- generation_config,
- os.path.join(model_path, "generation_config.json"),
- )
- def write_tokenizer(tokenizer_path, input_tokenizer_path):
- os.makedirs(tokenizer_path, exist_ok=True)
- write_json({}, os.path.join(tokenizer_path, "special_tokens_map.json"))
- write_json(
- {
- "bos_token": "",
- "eos_token": "",
- "model_max_length": int(1e30),
- "tokenizer_class": "LLaMATokenizer",
- "unk_token": "",
- },
- os.path.join(tokenizer_path, "tokenizer_config.json"),
- )
- shutil.copyfile(input_tokenizer_path, os.path.join(tokenizer_path, "tokenizer.model"))
- def main():
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--input_dir",
- help="Location of LLaMA weights, which contains tokenizer.model and model folders",
- )
- parser.add_argument(
- "--model_size",
- choices=["7B", "13B", "30B", "65B"],
- )
- parser.add_argument(
- "--output_dir",
- help="Location to write HF model and tokenizer",
- )
- args = parser.parse_args()
- write_model(
- model_path=os.path.join(args.output_dir, "llama-{}".format(args.model_size).lower()),
- input_base_path=os.path.join(args.input_dir, args.model_size),
- model_size=args.model_size,
- )
- write_tokenizer(
- tokenizer_path=os.path.join(args.output_dir, "tokenizer"),
- input_tokenizer_path=os.path.join(args.input_dir, "tokenizer.model"),
- )
- if __name__ == "__main__":
- main()
|