소스 검색

Fix linters (#185)

* install isort

* isort .

* whoops

* fix black
Eric J. Wang 3 년 전
부모
커밋
dbd04f3560
2개의 변경된 파일13개의 추가작업 그리고 13개의 파일을 삭제
  1. 3 3
      .github/workflows/lint.yml
  2. 10 10
      finetune.py

+ 3 - 3
.github/workflows/lint.yml

@@ -25,10 +25,10 @@ jobs:
           python-version: 3.8
 
       - name: Install Python dependencies
-        run: pip install black black[jupyter] flake8
+        run: pip install black black[jupyter] flake8 isort
 
       - name: lint isort
-        run: isort --check --diff
+        run: isort --check --diff .
 
       - name: lint black
-        run: black --check --diff
+        run: black --check --diff .

+ 10 - 10
finetune.py

@@ -54,8 +54,8 @@ def train(
     # wandb params
     wandb_project: str = "",
     wandb_run_name: str = "",
-    wandb_watch: str = "", # options: false | gradients | all
-    wandb_log_model: str = "", # options: false | true
+    wandb_watch: str = "",  # options: false | gradients | all
+    wandb_log_model: str = "",  # options: false | true
     resume_from_checkpoint: str = None,  # either training checkpoint or final adapter
 ):
     print(
@@ -94,16 +94,16 @@ def train(
         gradient_accumulation_steps = gradient_accumulation_steps // world_size
 
     # Check if parameter passed or if set within environ
-    use_wandb = len(wandb_project) > 0 or \
-                ("WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0)
+    use_wandb = len(wandb_project) > 0 or (
+        "WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0
+    )
     # Only overwrite environ if wandb param passed
-    if len(wandb_project) > 0: 
-        os.environ['WANDB_PROJECT'] = wandb_project
+    if len(wandb_project) > 0:
+        os.environ["WANDB_PROJECT"] = wandb_project
     if len(wandb_watch) > 0:
-        os.environ['WANDB_WATCH'] = wandb_watch
+        os.environ["WANDB_WATCH"] = wandb_watch
     if len(wandb_log_model) > 0:
-        os.environ['WANDB_LOG_MODEL'] = wandb_log_model
-
+        os.environ["WANDB_LOG_MODEL"] = wandb_log_model
 
     model = LlamaForCausalLM.from_pretrained(
         base_model,
@@ -231,7 +231,7 @@ def train(
             ddp_find_unused_parameters=False if ddp else None,
             group_by_length=group_by_length,
             report_to="wandb" if use_wandb else None,
-            run_name=wandb_run_name if use_wandb else None
+            run_name=wandb_run_name if use_wandb else None,
         ),
         data_collator=transformers.DataCollatorForSeq2Seq(
             tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True