|
@@ -62,8 +62,8 @@ def generate_prompt(data_point):
|
|
|
MICRO_BATCH_SIZE = 4 # this could actually be 5 but i like powers of 2
|
|
MICRO_BATCH_SIZE = 4 # this could actually be 5 but i like powers of 2
|
|
|
BATCH_SIZE = 128
|
|
BATCH_SIZE = 128
|
|
|
GRADIENT_ACCUMULATION_STEPS = BATCH_SIZE // MICRO_BATCH_SIZE
|
|
GRADIENT_ACCUMULATION_STEPS = BATCH_SIZE // MICRO_BATCH_SIZE
|
|
|
-EPOCHS = 3 # from the result
|
|
|
|
|
-LEARNING_RATE = 2e-5 # also from the result
|
|
|
|
|
|
|
+EPOCHS = 1 # we don't need 3 tbh
|
|
|
|
|
+LEARNING_RATE = 3e-4 # the karpathy constant
|
|
|
CUTOFF_LEN = 256 # 256 accounts for about 96% of the data
|
|
CUTOFF_LEN = 256 # 256 accounts for about 96% of the data
|
|
|
|
|
|
|
|
data = data.shuffle().map(
|
|
data = data.shuffle().map(
|