Skip to content

Minor fixes to tools (prepare_data validators) #26

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 33 additions & 16 deletions openai/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def common_prompt_suffix_validator(df):
if suffix_option == " ->":
if df.prompt.str.contains("\n").any():
continue
if df.prompt.str.contains(suffix_option).any():
if df.prompt.str.contains(suffix_option, regex=False).any():
continue
suggested_suffix = suffix_option
break
Expand All @@ -202,7 +202,11 @@ def add_suffix(x, suffix):
)
if len(common_suffix) > 10:
immediate_msg += f". This suffix seems very long. Consider replacing with a shorter suffix, such as `{display_suggested_suffix}`"
if df.prompt.str[: -len(common_suffix)].str.contains(common_suffix).any():
if (
df.prompt.str[: -len(common_suffix)]
.str.contains(common_suffix, regex=False)
.any()
):
immediate_msg += f"\n WARNING: Some of your prompts contain the suffix `{common_suffix}` more than once. We strongly suggest that you review your prompts and add a unique suffix"

else:
Expand Down Expand Up @@ -271,11 +275,15 @@ def common_completion_prefix_validator(df):
MAX_PREFIX_LEN = 5

common_prefix = get_common_xfix(df.completion, xfix="prefix")
ws_prefix = len(common_prefix) > 0 and common_prefix[0] == " "
if len(common_prefix) < MAX_PREFIX_LEN:
return Remediation(name="common_prefix")

def remove_common_prefix(x, prefix):
def remove_common_prefix(x, prefix, ws_prefix):
x["completion"] = x["completion"].str[len(prefix) :]
if ws_prefix:
# keep the single whitespace as prefix
x["completion"] = " " + x["completion"]
return x

if (df.completion == common_prefix).all():
Expand All @@ -286,7 +294,7 @@ def remove_common_prefix(x, prefix):
optional_msg = f"Remove prefix `{common_prefix}` from all completions"

def optional_fn(x):
return remove_common_prefix(x, common_prefix)
return remove_common_prefix(x, common_prefix, ws_prefix)

return Remediation(
name="common_completion_prefix",
Expand All @@ -305,6 +313,15 @@ def common_completion_suffix_validator(df):
optional_msg = None
optional_fn = None

ft_type = infer_task_type(df)
if ft_type == "open-ended generation" or ft_type == "classification":
return Remediation(name="common_suffix")

common_suffix = get_common_xfix(df.completion, xfix="suffix")
if (df.completion == common_suffix).all():
error_msg = f"All completions are identical: `{common_suffix}`\nEnsure completions are different, otherwise the model will just repeat `{common_suffix}`"
return Remediation(name="common_suffix", error_msg=error_msg)

# Find a suffix which is not contained within the completion otherwise
suggested_suffix = " [END]"
suffix_options = [
Expand All @@ -319,33 +336,28 @@ def common_completion_suffix_validator(df):
"%%%",
]
for suffix_option in suffix_options:
if df.completion.str.contains(suffix_option).any():
if df.completion.str.contains(suffix_option, regex=False).any():
continue
suggested_suffix = suffix_option
break
display_suggested_suffix = suggested_suffix.replace("\n", "\\n")

ft_type = infer_task_type(df)
if ft_type == "open-ended generation" or ft_type == "classification":
return Remediation(name="common_suffix")

def add_suffix(x, suffix):
x["completion"] += suffix
return x

common_suffix = get_common_xfix(df.completion, xfix="suffix")
if (df.completion == common_suffix).all():
error_msg = f"All completions are identical: `{common_suffix}`\nEnsure completions are different, otherwise the model will just repeat `{common_suffix}`"
return Remediation(name="common_suffix", error_msg=error_msg)

if common_suffix != "":
common_suffix_new_line_handled = common_suffix.replace("\n", "\\n")
immediate_msg = (
f"\n- All completions end with suffix `{common_suffix_new_line_handled}`"
)
if len(common_suffix) > 10:
immediate_msg += f". This suffix seems very long. Consider replacing with a shorter suffix, such as `{display_suggested_suffix}`"
if df.completion.str[: -len(common_suffix)].str.contains(common_suffix).any():
if (
df.completion.str[: -len(common_suffix)]
.str.contains(common_suffix, regex=False)
.any()
):
immediate_msg += f"\n WARNING: Some of your completions contain the suffix `{common_suffix}` more than once. We suggest that you review your completions and add a unique ending"

else:
Expand Down Expand Up @@ -617,8 +629,13 @@ def write_out_file(df, fname, any_remediations):
# Add -v VALID_FILE if we split the file into train / valid
files_string = ("s" if split else "") + " to `" + ("` and `".join(outfnames))
valid_string = f' -v "{outfnames[1]}"' if split else ""
separator_reminder = (
""
if len(common_prompt_suffix_new_line_handled) == 0
else f"After you’ve fine-tuned a model, remember that your prompt has to end with the indicator string `{common_prompt_suffix_new_line_handled}` for the model to start generating completions, rather than continuing with the prompt."
)
sys.stdout.write(
f'\nWrote modified file{files_string}`\nFeel free to take a look!\n\nNow use that file when fine-tuning:\n> openai api fine_tunes.create -t "{outfnames[0]}"{valid_string}{packing_param}\n\nAfter you’ve fine-tuned a model, remember that your prompt has to end with the indicator string `{common_prompt_suffix_new_line_handled}` for the model to start generating completions, rather than continuing with the prompt.{optional_ending_string}\n'
f'\nWrote modified file{files_string}`\nFeel free to take a look!\n\nNow use that file when fine-tuning:\n> openai api fine_tunes.create -t "{outfnames[0]}"{valid_string}{packing_param}\n\n{separator_reminder}{optional_ending_string}\n'
)
else:
sys.stdout.write("Aborting... did not write the file\n")
Expand Down
2 changes: 1 addition & 1 deletion openai/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
VERSION = "0.9.3"
VERSION = "0.9.4"