conversion.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. # Copyright 2022 EleutherAI and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import argparse
  15. import json
  16. import os
  17. import shutil
  18. import torch
  19. """
  20. Sample usage:
  21. ```
  22. python src/transformers/models/llama/convert_llama_weights_to_hf.py \
  23. --input_dir /path/to/downloaded/llama/weights --model_size 7B --output_dir /output/path
  24. ```
  25. Thereafter, models can be loaded via:
  26. ```
  27. tokenizer = transformers.LLaMATokenizer.from_pretrained("/output/path/tokenizer/")
  28. model = transformers.LLaMAForCausalLM.from_pretrained("/output/path/llama-7b/")
  29. ```
  30. """
  31. INTERMEDIATE_SIZE_MAP = {
  32. "7B": 11008,
  33. "13B": 13824,
  34. "30B": 17920,
  35. "65B": 22016,
  36. }
  37. NUM_SHARDS = {
  38. "7B": 1,
  39. "13B": 2,
  40. "30B": 4,
  41. "65B": 8,
  42. }
  43. def read_json(path):
  44. with open(path, "r") as f:
  45. return json.load(f)
  46. def write_json(text, path):
  47. with open(path, "w") as f:
  48. json.dump(text, f)
  49. def write_model(model_path, input_base_path, model_size):
  50. assert model_size in INTERMEDIATE_SIZE_MAP
  51. os.makedirs(model_path, exist_ok=True)
  52. params = read_json(os.path.join(input_base_path, "params.json"))
  53. num_shards = NUM_SHARDS[model_size]
  54. n_layers = params["n_layers"]
  55. n_heads = params["n_heads"]
  56. n_heads_per_shard = n_heads // num_shards
  57. dim = params["dim"]
  58. dims_per_head = dim // n_heads
  59. base = 10000.0
  60. inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
  61. # permute for sliced rotary
  62. def permute(w):
  63. return w.view(n_heads, dim // n_heads // 2, 2, dim).transpose(1, 2).reshape(dim, dim)
  64. # Load weights
  65. if model_size == "7B":
  66. # Not shared
  67. # (The sharded implementation would also work, but this is simpler.)
  68. loaded = torch.load(os.path.join(input_base_path, "consolidated.00.pth"), map_location="cpu")
  69. else:
  70. # Sharded
  71. loaded = [
  72. torch.load(os.path.join(input_base_path, f"consolidated.{i:02d}.pth"), map_location="cpu")
  73. for i in range(num_shards)
  74. ]
  75. param_count = 0
  76. index_dict = {"weight_map": {}}
  77. for layer_i in range(n_layers):
  78. filename = "pytorch_model-{:05d}-of-{:05d}.bin".format(
  79. layer_i + 1,
  80. n_layers + 1,
  81. )
  82. if model_size == "7B":
  83. # Unsharded
  84. state_dict = {
  85. f"model.layers.{layer_i}.self_attn.q_proj.weight": permute(
  86. loaded[f"layers.{layer_i}.attention.wq.weight"]
  87. ),
  88. f"model.layers.{layer_i}.self_attn.k_proj.weight": permute(
  89. loaded[f"layers.{layer_i}.attention.wk.weight"]
  90. ),
  91. f"model.layers.{layer_i}.self_attn.v_proj.weight": loaded[f"layers.{layer_i}.attention.wv.weight"],
  92. f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"layers.{layer_i}.attention.wo.weight"],
  93. f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w1.weight"],
  94. f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w2.weight"],
  95. f"model.layers.{layer_i}.mlp.up_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w3.weight"],
  96. f"model.layers.{layer_i}.input_layernorm.weight": loaded[f"layers.{layer_i}.attention_norm.weight"],
  97. f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[f"layers.{layer_i}.ffn_norm.weight"],
  98. }
  99. else:
  100. # Sharded
  101. state_dict = {
  102. f"model.layers.{layer_i}.input_layernorm.weight": loaded[0][f"layers.{layer_i}.attention_norm.weight"],
  103. f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[0][
  104. f"layers.{layer_i}.ffn_norm.weight"
  105. ],
  106. }
  107. state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = permute(
  108. torch.cat(
  109. [
  110. loaded[i][f"layers.{layer_i}.attention.wq.weight"].view(n_heads_per_shard, dims_per_head, dim)
  111. for i in range(num_shards)
  112. ],
  113. dim=0,
  114. ).reshape(dim, dim)
  115. )
  116. state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = permute(
  117. torch.cat(
  118. [
  119. loaded[i][f"layers.{layer_i}.attention.wk.weight"].view(n_heads_per_shard, dims_per_head, dim)
  120. for i in range(num_shards)
  121. ],
  122. dim=0,
  123. ).reshape(dim, dim)
  124. )
  125. state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat(
  126. [
  127. loaded[i][f"layers.{layer_i}.attention.wv.weight"].view(n_heads_per_shard, dims_per_head, dim)
  128. for i in range(num_shards)
  129. ],
  130. dim=0,
  131. ).reshape(dim, dim)
  132. state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat(
  133. [loaded[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(num_shards)], dim=1
  134. )
  135. state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat(
  136. [loaded[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(num_shards)], dim=0
  137. )
  138. state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat(
  139. [loaded[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(num_shards)], dim=1
  140. )
  141. state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat(
  142. [loaded[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(num_shards)], dim=0
  143. )
  144. state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq
  145. for k, v in state_dict.items():
  146. index_dict["weight_map"][k] = filename
  147. param_count += v.numel()
  148. torch.save(state_dict, os.path.join(model_path, filename))
  149. filename = "pytorch_model-{:05d}-of-{:05d}.bin".format(
  150. n_layers + 1,
  151. n_layers + 1,
  152. )
  153. if model_size == "7B":
  154. # Unsharded
  155. state_dict = {
  156. "model.embed_tokens.weight": loaded["tok_embeddings.weight"],
  157. "model.norm.weight": loaded["norm.weight"],
  158. "lm_head.weight": loaded["output.weight"],
  159. }
  160. else:
  161. state_dict = {
  162. "model.norm.weight": loaded[0]["norm.weight"],
  163. "model.embed_tokens.weight": torch.cat(
  164. [loaded[i]["tok_embeddings.weight"] for i in range(num_shards)], dim=1
  165. ),
  166. "lm_head.weight": torch.cat([loaded[i]["output.weight"] for i in range(num_shards)], dim=0),
  167. }
  168. for k, v in state_dict.items():
  169. index_dict["weight_map"][k] = filename
  170. param_count += v.numel()
  171. torch.save(state_dict, os.path.join(model_path, filename))
  172. # Write configs
  173. index_dict["metadata"] = {"total_size": param_count * 2}
  174. write_json(index_dict, os.path.join(model_path, "pytorch_model.bin.index.json"))
  175. config_out = {
  176. "architectures": ["LLaMAForCausalLM"],
  177. "bos_token_id": 0,
  178. "eos_token_id": 1,
  179. "hidden_act": "silu",
  180. "hidden_size": params["dim"],
  181. "intermediate_size": INTERMEDIATE_SIZE_MAP[model_size],
  182. "initializer_range": 0.02,
  183. "max_sequence_length": 2048,
  184. "model_type": "llama",
  185. "num_attention_heads": params["n_heads"],
  186. "num_hidden_layers": params["n_layers"],
  187. "pad_token_id": -1,
  188. "rms_norm_eps": params["norm_eps"],
  189. "torch_dtype": "float16",
  190. "transformers_version": "4.27.0.dev0",
  191. "use_cache": True,
  192. "vocab_size": 32000,
  193. }
  194. write_json(
  195. config_out,
  196. os.path.join(model_path, "config.json"),
  197. )
  198. generation_config = {
  199. "_from_model_config": True,
  200. "bos_token_id": 0,
  201. "eos_token_id": 1,
  202. "pad_token_id": 0,
  203. "transformers_version": "4.27.0.dev0",
  204. }
  205. write_json(
  206. generation_config,
  207. os.path.join(model_path, "generation_config.json"),
  208. )
  209. def write_tokenizer(tokenizer_path, input_tokenizer_path):
  210. os.makedirs(tokenizer_path, exist_ok=True)
  211. write_json({}, os.path.join(tokenizer_path, "special_tokens_map.json"))
  212. write_json(
  213. {
  214. "bos_token": "",
  215. "eos_token": "",
  216. "model_max_length": int(1e30),
  217. "tokenizer_class": "LLaMATokenizer",
  218. "unk_token": "",
  219. },
  220. os.path.join(tokenizer_path, "tokenizer_config.json"),
  221. )
  222. shutil.copyfile(input_tokenizer_path, os.path.join(tokenizer_path, "tokenizer.model"))
  223. def main():
  224. parser = argparse.ArgumentParser()
  225. parser.add_argument(
  226. "--input_dir",
  227. help="Location of LLaMA weights, which contains tokenizer.model and model folders",
  228. )
  229. parser.add_argument(
  230. "--model_size",
  231. choices=["7B", "13B", "30B", "65B"],
  232. )
  233. parser.add_argument(
  234. "--output_dir",
  235. help="Location to write HF model and tokenizer",
  236. )
  237. args = parser.parse_args()
  238. write_model(
  239. model_path=os.path.join(args.output_dir, "llama-{}".format(args.model_size).lower()),
  240. input_base_path=os.path.join(args.input_dir, args.model_size),
  241. model_size=args.model_size,
  242. )
  243. write_tokenizer(
  244. tokenizer_path=os.path.join(args.output_dir, "tokenizer"),
  245. input_tokenizer_path=os.path.join(args.input_dir, "tokenizer.model"),
  246. )
  247. if __name__ == "__main__":
  248. main()