From 7eba1d70b54872a7cc748e2e56243c02be4f1122 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Benjamin=20H=C3=A4ttasch?= Date: Tue, 15 May 2018 12:45:51 +0200 Subject: [PATCH] First draft for simple benchmark NL to SQL Translator Accuracy on Training and Dev Set: ~0.29 --- requirements.txt | 2 ++ translate.py | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+) create mode 100644 translate.py diff --git a/requirements.txt b/requirements.txt index 0838eaf..de25e18 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,5 @@ tqdm records babel tabulate + +jsonlines diff --git a/translate.py b/translate.py new file mode 100644 index 0000000..b065cbc --- /dev/null +++ b/translate.py @@ -0,0 +1,35 @@ +import jsonlines + +TABLES = "data/dev.tables.jsonl" +QUERIES = "data/dev.jsonl" +PREDICTIONS = "test/my.pred.dev.jsonl" + +tables = {} + +# Read tables and data +with jsonlines.open(TABLES) as input_tables: + for table in input_tables: + tables[table["id"]] = table + +# Open queries and output file +with jsonlines.open(QUERIES) as input_queries, jsonlines.open(PREDICTIONS, mode='w') as predictions_file: + for query in input_queries: + question = query["question"].lower() + table = tables[query["table_id"]] + + # Determine selection + sel = -1 + for i, field in enumerate(table["header"]): + if field.lower() in question: + sel = i + + # Determine conditions + conds = [] + for row in table["rows"]: + for i, cell in enumerate(row): + if isinstance(cell, str): + if cell.lower() in question: + conds.append([i, 0, cell]) + + prediction = {"query": {"sel": sel, "conds": conds, "agg": 0}, "error": ""} + predictions_file.write(prediction)