diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs index d55d2897c..729be53a5 100644 --- a/pgml-extension/src/api.rs +++ b/pgml-extension/src/api.rs @@ -968,6 +968,501 @@ mod tests { use crate::orm::sampling::Sampling; use crate::orm::Hyperparams; + #[pg_test] + #[ignore = "requires model download"] + fn readme_intro_translation() { + let sql = "SELECT pgml.transform( + 'translation_en_to_fr', + inputs => ARRAY[ + 'Welcome to the future!', + 'Where have you been all this time?' + ] + ) AS french;"; + let got = Spi::get_one::(sql).unwrap().unwrap().0; + let want = serde_json::json!([ + {"translation_text": "Bienvenue à l'avenir!"}, + {"translation_text": "Où êtes-vous allé tout ce temps?"} + ]); + assert_eq!(got, want); + } + + #[pg_test] + #[ignore = "requires model download"] + fn readme_intro_sentiment_analysis() { + let sql = "SELECT pgml.transform( + task => 'text-classification', + inputs => ARRAY[ + 'I love how amazingly simple ML has become!', + 'I hate doing mundane and thankless tasks. ☹️' + ] + ) AS positivity;"; + let got = Spi::get_one::(sql).unwrap().unwrap().0; + let want = serde_json::json!([ + {"label": "POSITIVE", "score": 0.9995759129524232}, + {"label": "NEGATIVE", "score": 0.9903519749641418} + ]); + assert_eq!(got, want); + } + + #[pg_test] + #[ignore = "requires model download"] + fn readme_nlp_sentiment_analysis_specific_model() { + let sql = r#"SELECT pgml.transform( + inputs => ARRAY[ + 'I love how amazingly simple ML has become!', + 'I hate doing mundane and thankless tasks. ☹️' + ], + task => '{"task": "text-classification", + "model": "finiteautomata/bertweet-base-sentiment-analysis" + }'::JSONB + ) AS positivity;"#; + let got = Spi::get_one::(sql).unwrap().unwrap().0; + let want = serde_json::json!([ + {"label": "POS", "score": 0.992932200431826}, + {"label": "NEG", "score": 0.975599765777588} + ]); + assert_eq!(got, want); + } + + #[pg_test] + #[ignore = "requires model download"] + fn readme_nlp_sentiment_analysis_industry_specific_model() { + let sql = r#"SELECT pgml.transform( + inputs => ARRAY[ + 'Stocks rallied and the British pound gained.', + 'Stocks making the biggest moves midday: Nvidia, Palantir and more' + ], + task => '{"task": "text-classification", + "model": "ProsusAI/finbert" + }'::JSONB + ) AS market_sentiment;"#; + let got = Spi::get_one::(sql).unwrap().unwrap().0; + let want = serde_json::json!([ + {"label": "positive", "score": 0.8983612656593323}, + {"label": "neutral", "score": 0.8062630891799927} + ]); + assert_eq!(got, want); + } + + #[pg_test] + #[ignore = "requires model download"] + fn readme_nlp_nli() { + let sql = r#"SELECT pgml.transform( + inputs => ARRAY[ + 'A soccer game with multiple males playing. Some men are playing a sport.' + ], + task => '{"task": "text-classification", + "model": "roberta-large-mnli" + }'::JSONB + ) AS nli;"#; + let got = Spi::get_one::(sql).unwrap().unwrap().0; + let want = serde_json::json!([ + {"label": "ENTAILMENT", "score": 0.98837411403656} + ]); + assert_eq!(got, want); + } + + #[pg_test] + #[ignore = "requires model download"] + fn readme_nlp_qnli() { + let sql = r#"SELECT pgml.transform( + inputs => ARRAY[ + 'Where is the capital of France?, Paris is the capital of France.' + ], + task => '{"task": "text-classification", + "model": "cross-encoder/qnli-electra-base" + }'::JSONB + ) AS qnli;"#; + let got = Spi::get_one::(sql).unwrap().unwrap().0; + let want = serde_json::json!([ + {"label": "LABEL_0", "score": 0.9978110194206238} + ]); + assert_eq!(got, want); + } + + #[pg_test] + #[ignore = "requires model download"] + fn readme_nlp_qqp() { + let sql = r#"SELECT pgml.transform( + inputs => ARRAY[ + 'Which city is the capital of France?, Where is the capital of France?' + ], + task => '{"task": "text-classification", + "model": "textattack/bert-base-uncased-QQP" + }'::JSONB + ) AS qqp;"#; + let got = Spi::get_one::(sql).unwrap().unwrap().0; + let want = serde_json::json!([ + {"label": "LABEL_0", "score": 0.9988721013069152} + ]); + assert_eq!(got, want); + } + + #[pg_test] + #[ignore = "requires model download"] + fn readme_nlp_grammatical_correctness() { + let sql = r#"SELECT pgml.transform( + inputs => ARRAY[ + 'I will walk to home when I went through the bus.' + ], + task => '{"task": "text-classification", + "model": "textattack/distilbert-base-uncased-CoLA" + }'::JSONB + ) AS grammatical_correctness;"#; + let got = Spi::get_one::(sql).unwrap().unwrap().0; + let want = serde_json::json!([ + {"label": "LABEL_1", "score": 0.9576480388641356} + ]); + assert_eq!(got, want); + } + + #[pg_test] + #[ignore = "requires model download"] + fn readme_nlp_zeroshot_classification() { + let sql = r#"SELECT pgml.transform( + inputs => ARRAY[ + 'I have a problem with my iphone that needs to be resolved asap!!' + ], + task => '{ + "task": "zero-shot-classification", + "model": "facebook/bart-large-mnli" + }'::JSONB, + args => '{ + "candidate_labels": ["urgent", "not urgent", "phone", "tablet", "computer"] + }'::JSONB + ) AS zero_shot;"#; + let got = Spi::get_one::(sql).unwrap().unwrap().0; + let want = serde_json::json!([ + { + "labels": ["urgent", "phone", "computer", "not urgent", "tablet"], + "scores": [0.503635, 0.47879, 0.012600, 0.002655, 0.002308], + "sequence": "I have a problem with my iphone that needs to be resolved asap!!" + } + ]); + assert_eq!(got, want); + } + + #[pg_test] + #[ignore = "requires model download"] + fn readme_nlp_token_classification_ner() { + let sql = r#"SELECT pgml.transform( + inputs => ARRAY[ + 'I am Omar and I live in New York City.' + ], + task => 'token-classification' + ) as ner;"#; + let got = Spi::get_one::(sql).unwrap().unwrap().0; + let want = serde_json::json!([[ + {"end": 9, "word": "Omar", "index": 3, "score": 0.997110, "start": 5, "entity": "I-PER"}, + {"end": 27, "word": "New", "index": 8, "score": 0.999372, "start": 24, "entity": "I-LOC"}, + {"end": 32, "word": "York", "index": 9, "score": 0.999355, "start": 28, "entity": "I-LOC"}, + {"end": 37, "word": "City", "index": 10, "score": 0.999431, "start": 33, "entity": "I-LOC"} + ]]); + assert_eq!(got, want); + } + + #[pg_test] + #[ignore = "requires model download"] + fn readme_nlp_token_classification_pos() { + let sql = r#"select pgml.transform( + inputs => array [ + 'I live in Amsterdam.' + ], + task => '{"task": "token-classification", + "model": "vblagoje/bert-english-uncased-finetuned-pos" + }'::JSONB + ) as pos;"#; + let got = Spi::get_one::(sql).unwrap().unwrap().0; + let want = serde_json::json!([[ + {"end": 1, "word": "i", "index": 1, "score": 0.999, "start": 0, "entity": "PRON"}, + {"end": 6, "word": "live", "index": 2, "score": 0.998, "start": 2, "entity": "VERB"}, + {"end": 9, "word": "in", "index": 3, "score": 0.999, "start": 7, "entity": "ADP"}, + {"end": 19, "word": "amsterdam", "index": 4, "score": 0.998, "start": 10, "entity": "PROPN"}, + {"end": 20, "word": ".", "index": 5, "score": 0.999, "start": 19, "entity": "PUNCT"} + ]]); + assert_eq!(got, want); + } + + #[pg_test] + #[ignore = "requires model download"] + fn readme_nlp_translation() { + let sql = r#"select pgml.transform( + inputs => array[ + 'How are you?' + ], + task => '{"task": "translation", + "model": "Helsinki-NLP/opus-mt-en-fr" + }'::JSONB + );"#; + let got = Spi::get_one::(sql).unwrap().unwrap().0; + let want = serde_json::json!([ + {"translation_text": "Comment allez-vous ?"} + ]); + assert_eq!(got, want); + } + + #[pg_test] + #[ignore = "requires model download"] + fn readme_nlp_summarization() { + let sql = r#"select pgml.transform( + task => '{"task": "summarization", + "model": "sshleifer/distilbart-cnn-12-6" + }'::JSONB, + inputs => array[ + 'Paris is the capital and most populous city of France, with an estimated population of 2,175,601 residents as of 2018, in an area of more than 105 square kilometres (41 square miles). The City of Paris is the centre and seat of government of the region and province of Île-de-France, or Paris Region, which has an estimated population of 12,174,880, or about 18 percent of the population of France as of 2017.' + ] + );"#; + let got = Spi::get_one::(sql).unwrap().unwrap().0; + let want = serde_json::json!([ + {"summary_text": " Paris is the capital and most populous city of France, with an estimated population of 2,175,601 residents as of 2018 . The city is the centre and seat of government of the region and province of Île-de-France, or Paris Region . Paris Region has an estimated 18 percent of the population of France as of 2017 ."} + ]); + assert_eq!(got, want); + } + + #[pg_test] + #[ignore = "requires model download"] + fn readme_nlp_summarization_min_max_length() { + let sql = r#"select pgml.transform( + task => '{"task": "summarization", + "model": "sshleifer/distilbart-cnn-12-6" + }'::JSONB, + inputs => array[ + 'Paris is the capital and most populous city of France, with an estimated population of 2,175,601 residents as of 2018, in an area of more than 105 square kilometres (41 square miles). The City of Paris is the centre and seat of government of the region and province of Île-de-France, or Paris Region, which has an estimated population of 12,174,880, or about 18 percent of the population of France as of 2017.' + ], + args => '{ + "min_length" : 20, + "max_length" : 70 + }'::JSONB + );"#; + let got = Spi::get_one::(sql).unwrap().unwrap().0; + let want = serde_json::json!([ + {"summary_text": " Paris is the capital and most populous city of France, with an estimated population of 2,175,601 residents as of 2018 . City of Paris is centre and seat of government of the region and province of Île-de-France, or Paris Region, which has an estimated 12,174,880, or about 18 percent"} + ]); + assert_eq!(got, want); + } + + #[pg_test] + #[ignore = "requires model download"] + fn readme_nlp_question_answering() { + let sql = r#"SELECT pgml.transform( + 'question-answering', + inputs => ARRAY[ + '{ + "question": "Where do I live?", + "context": "My name is Merve and I live in İstanbul." + }' + ] + ) AS answer;"#; + let got = Spi::get_one::(sql).unwrap().unwrap().0; + let want = serde_json::json!({ + "end" : 39, + "score" : 0.9538117051124572, + "start" : 31, + "answer": "İstanbul" + }); + assert_eq!(got, want); + } + + #[pg_test] + #[ignore = "requires model download"] + fn readme_nlp_text_generation() { + let sql = r#"SELECT pgml.transform( + task => 'text-generation', + inputs => ARRAY[ + 'Three Rings for the Elven-kings under the sky, Seven for the Dwarf-lords in their halls of stone' + ] + ) AS answer;"#; + let got = Spi::get_one::(sql).unwrap().unwrap().0; + let want = serde_json::json!([ + [ + {"generated_text": "Three Rings for the Elven-kings under the sky, Seven for the Dwarf-lords in their halls of stone, and eight for the Dragon-lords in their halls of blood.\n\nEach of the guild-building systems is one-man"} + ] + ]); + assert_eq!(got, want); + } + + #[pg_test] + #[ignore = "requires model download"] + fn readme_nlp_text_generation_specific_model() { + let sql = r#"SELECT pgml.transform( + task => '{ + "task" : "text-generation", + "model" : "gpt2-medium" + }'::JSONB, + inputs => ARRAY[ + 'Three Rings for the Elven-kings under the sky, Seven for the Dwarf-lords in their halls of stone' + ] + ) AS answer;"#; + let got = Spi::get_one::(sql).unwrap().unwrap().0; + let want = serde_json::json!([ + [{"generated_text": "Three Rings for the Elven-kings under the sky, Seven for the Dwarf-lords in their halls of stone.\n\nThis place has a deep connection to the lore of ancient Elven civilization. It is home to the most ancient of artifacts,"}] + ]); + assert_eq!(got, want); + } + + #[pg_test] + #[ignore = "requires model download"] + fn readme_nlp_text_generation_max_length() { + let sql = r#"SELECT pgml.transform( + task => '{ + "task" : "text-generation", + "model" : "gpt2-medium" + }'::JSONB, + inputs => ARRAY[ + 'Three Rings for the Elven-kings under the sky, Seven for the Dwarf-lords in their halls of stone' + ], + args => '{ + "max_length" : 200 + }'::JSONB + ) AS answer;"#; + let got = Spi::get_one::(sql).unwrap().unwrap().0; + let want = serde_json::json!([ + [{"generated_text": "Three Rings for the Elven-kings under the sky, Seven for the Dwarf-lords in their halls of stone, Three for the Dwarfs and the Elves, One for the Gnomes of the Mines, and Two for the Elves of Dross.\"\n\nHobbits: The Fellowship is the first book of J.R.R. Tolkien's story-cycle, and began with his second novel - The Two Towers - and ends in The Lord of the Rings.\n\n\nIt is a non-fiction novel, so there is no copyright claim on some parts of the story but the actual text of the book is copyrighted by author J.R.R. Tolkien.\n\n\nThe book has been classified into two types: fantasy novels and children's books\n\nHobbits: The Fellowship is the first book of J.R.R. Tolkien's story-cycle, and began with his second novel - The Two Towers - and ends in The Lord of the Rings.It"}] + ]); + assert_eq!(got, want); + } + + #[pg_test] + #[ignore = "requires model download"] + fn readme_nlp_text_generation_num_return_sequences() { + let sql = r#"SELECT pgml.transform( + task => '{ + "task" : "text-generation", + "model" : "gpt2-medium" + }'::JSONB, + inputs => ARRAY[ + 'Three Rings for the Elven-kings under the sky, Seven for the Dwarf-lords in their halls of stone' + ], + args => '{ + "num_return_sequences" : 3 + }'::JSONB + ) AS answer;"#; + let got = Spi::get_one::(sql).unwrap().unwrap().0; + let want = serde_json::json!([ + [ + {"generated_text": "Three Rings for the Elven-kings under the sky, Seven for the Dwarf-lords in their halls of stone, and Thirteen for the human-men in their hall of fire.\n\nAll of us, our families, and our people"}, + {"generated_text": "Three Rings for the Elven-kings under the sky, Seven for the Dwarf-lords in their halls of stone, and the tenth for a King! As each of these has its own special story, so I have written them into the game."}, + {"generated_text": "Three Rings for the Elven-kings under the sky, Seven for the Dwarf-lords in their halls of stone… What's left in the end is your heart's desire after all!\n\nHans: (Trying to be brave)"} + ] + ]); + assert_eq!(got, want); + } + + #[pg_test] + #[ignore = "requires model download"] + fn readme_nlp_text_generation_beams_stopping() { + let sql = r#"SELECT pgml.transform( + task => '{ + "task" : "text-generation", + "model" : "gpt2-medium" + }'::JSONB, + inputs => ARRAY[ + 'Three Rings for the Elven-kings under the sky, Seven for the Dwarf-lords in their halls of stone' + ], + args => '{ + "num_beams" : 5, + "early_stopping" : true + }'::JSONB + ) AS answer;"#; + let got = Spi::get_one::(sql).unwrap().unwrap().0; + let want = serde_json::json!([[ + {"generated_text": "Three Rings for the Elven-kings under the sky, Seven for the Dwarf-lords in their halls of stone, Nine for the Dwarves in their caverns of ice, Ten for the Elves in their caverns of fire, Eleven for the"} + ]]); + assert_eq!(got, want); + } + + #[pg_test] + #[ignore = "requires model download"] + fn readme_nlp_text_generation_temperature() { + let sql = r#"SELECT pgml.transform( + task => '{ + "task" : "text-generation", + "model" : "gpt2-medium" + }'::JSONB, + inputs => ARRAY[ + 'Three Rings for the Elven-kings under the sky, Seven for the Dwarf-lords in their halls of stone' + ], + args => '{ + "do_sample" : true, + "temperature" : 0.9 + }'::JSONB + ) AS answer;"#; + let got = Spi::get_one::(sql).unwrap().unwrap().0; + let want = serde_json::json!([[{"generated_text": "Three Rings for the Elven-kings under the sky, Seven for the Dwarf-lords in their halls of stone, and Thirteen for the Giants and Men of S.A.\n\nThe First Seven-Year Time-Traveling Trilogy is"}]]); + assert_eq!(got, want); + } + + #[pg_test] + #[ignore = "requires model download"] + fn readme_nlp_text_generation_top_p() { + let sql = r#"SELECT pgml.transform( + task => '{ + "task" : "text-generation", + "model" : "gpt2-medium" + }'::JSONB, + inputs => ARRAY[ + 'Three Rings for the Elven-kings under the sky, Seven for the Dwarf-lords in their halls of stone' + ], + args => '{ + "do_sample" : true, + "top_p" : 0.8 + }'::JSONB + ) AS answer;"#; + let got = Spi::get_one::(sql).unwrap().unwrap().0; + let want = serde_json::json!([[{"generated_text": "Three Rings for the Elven-kings under the sky, Seven for the Dwarf-lords in their halls of stone, Four for the Elves of the forests and fields, and Three for the Dwarfs and their warriors.\" ―Lord Rohan [src"}]]); + assert_eq!(got, want); + } + + #[pg_test] + #[ignore = "requires model download"] + fn readme_nlp_text_text_generation() { + let sql = r#"SELECT pgml.transform( + task => '{ + "task" : "text2text-generation" + }'::JSONB, + inputs => ARRAY[ + 'translate from English to French: I''m very happy' + ] + ) AS answer;"#; + let got = Spi::get_one::(sql).unwrap().unwrap().0; + let want = serde_json::json!([ + {"generated_text": "Je suis très heureux"} + ]); + assert_eq!(got, want); + } + + #[pg_test] + #[ignore = "requires model download"] + fn readme_nlp_fill_mask() { + let sql = r#"SELECT pgml.transform( + task => '{ + "task" : "fill-mask" + }'::JSONB, + inputs => ARRAY[ + 'Paris is the of France.' + + ] + ) AS answer;"#; + let got = Spi::get_one::(sql).unwrap().unwrap().0; + let want = serde_json::json!([ + {"score": 0.679, "token": 812, "sequence": "Paris is the capital of France.", "token_str": " capital"}, + {"score": 0.051, "token": 32357, "sequence": "Paris is the birthplace of France.", "token_str": " birthplace"}, + {"score": 0.038, "token": 1144, "sequence": "Paris is the heart of France.", "token_str": " heart"}, + {"score": 0.024, "token": 29778, "sequence": "Paris is the envy of France.", "token_str": " envy"}, + {"score": 0.022, "token": 1867, "sequence": "Paris is the Capital of France.", "token_str": " Capital"} + ]); + assert_eq!(got, want); + } + + #[pg_test] + #[ignore = "requires model download"] + fn template() { + let sql = r#""#; + let got = Spi::get_one::(sql).unwrap().unwrap().0; + let want = serde_json::json!([]); + assert_eq!(got, want); + } + #[pg_test] fn test_project_lifecycle() { let project = Project::create("test", Task::regression);