Skip to content

Commit cc7fc54

Browse files
[textanalytics] add custom named entities bespoke method (Azure#24995)
* initial work * docs,samples,linting * expose TA poller * doc fix + add poller metadata tests * add missing recordings
1 parent de11ecf commit cc7fc54

File tree

31 files changed

+2898
-285
lines changed

31 files changed

+2898
-285
lines changed

sdk/textanalytics/azure-ai-textanalytics/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
### Features Added
66

7+
- Added `begin_recognize_custom_entities` client method to recognize custom named entities in documents.
8+
79
### Breaking Changes
810

911
- Removed the Extractive Text Summarization feature and related models: `ExtractSummaryAction`, `ExtractSummaryResult`, and `SummarySentence`. To access this beta feature, install the `5.2.0b4` version of the client library.

sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@
5858
AnalyzeHealthcareEntitiesAction,
5959
)
6060

61-
from ._lro import AnalyzeHealthcareEntitiesLROPoller, AnalyzeActionsLROPoller
61+
from ._lro import AnalyzeHealthcareEntitiesLROPoller, AnalyzeActionsLROPoller, TextAnalyticsLROPoller
6262

6363
__all__ = [
6464
"TextAnalyticsApiVersion",
@@ -114,6 +114,7 @@
114114
"ClassifyDocumentResult",
115115
"ClassificationCategory",
116116
"AnalyzeHealthcareEntitiesAction",
117+
"TextAnalyticsLROPoller",
117118
]
118119

119120
__version__ = VERSION

sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_lro.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import functools
88
import json
99
import datetime
10-
from typing import Any, Optional
10+
from typing import Any, Optional, MutableMapping
1111
from urllib.parse import urlencode
1212
from azure.core.polling._poller import PollingReturnType
1313
from azure.core.exceptions import HttpResponseError
@@ -228,6 +228,9 @@ def from_continuation_token( # type: ignore
228228
continuation_token: str,
229229
**kwargs: Any
230230
) -> "AnalyzeHealthcareEntitiesLROPoller": # type: ignore
231+
"""
232+
:meta private:
233+
"""
231234
client, initial_response, deserialization_callback = polling_method.from_continuation_token(
232235
continuation_token, **kwargs
233236
)
@@ -457,6 +460,50 @@ def from_continuation_token( # type: ignore
457460
continuation_token: str,
458461
**kwargs: Any
459462
) -> "AnalyzeActionsLROPoller": # type: ignore
463+
"""
464+
:meta private:
465+
"""
466+
client, initial_response, deserialization_callback = polling_method.from_continuation_token(
467+
continuation_token, **kwargs
468+
)
469+
polling_method._lro_algorithms = [ # pylint: disable=protected-access
470+
TextAnalyticsOperationResourcePolling(
471+
show_stats=initial_response.context.options["show_stats"]
472+
)
473+
]
474+
return cls(
475+
client,
476+
initial_response,
477+
functools.partial(deserialization_callback, initial_response),
478+
polling_method
479+
)
480+
481+
482+
class TextAnalyticsLROPoller(LROPoller[PollingReturnType]):
483+
def polling_method(self) -> AnalyzeActionsLROPollingMethod:
484+
"""Return the polling method associated to this poller."""
485+
return self._polling_method # type: ignore
486+
487+
@property
488+
def details(self) -> MutableMapping[str, Any]:
489+
return {
490+
"id": self.polling_method().id,
491+
"created_on": self.polling_method().created_on,
492+
"expires_on": self.polling_method().expires_on,
493+
"display_name": self.polling_method().display_name,
494+
"last_modified_on": self.polling_method().last_modified_on,
495+
}
496+
497+
@classmethod
498+
def from_continuation_token( # type: ignore
499+
cls,
500+
polling_method: AnalyzeActionsLROPollingMethod,
501+
continuation_token: str,
502+
**kwargs: Any
503+
) -> "TextAnalyticsLROPoller": # type: ignore
504+
"""
505+
:meta private:
506+
"""
460507
client, initial_response, deserialization_callback = polling_method.from_continuation_token(
461508
continuation_token, **kwargs
462509
)

sdk/textanalytics/azure-ai-textanalytics/azure/ai/textanalytics/_response_handlers.py

Lines changed: 47 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,8 @@ def order_lro_results(doc_id_order, combined):
9696
def prepare_result(func):
9797
def choose_wrapper(*args, **kwargs):
9898
def wrapper(
99-
response, obj, response_headers, ordering_function
100-
): # pylint: disable=unused-argument
99+
response, obj, _, ordering_function
100+
):
101101
if hasattr(obj, "results"):
102102
obj = obj.results # language API compat
103103

@@ -280,7 +280,7 @@ def classify_document_result(
280280

281281

282282
def healthcare_extract_page_data(
283-
doc_id_order, obj, response_headers, health_job_state
283+
doc_id_order, obj, health_job_state
284284
): # pylint: disable=unused-argument
285285
return (
286286
health_job_state.next_link,
@@ -289,7 +289,7 @@ def healthcare_extract_page_data(
289289
health_job_state.results
290290
if hasattr(health_job_state, "results")
291291
else health_job_state.tasks.items[0].results,
292-
response_headers,
292+
{},
293293
lro=True
294294
),
295295
)
@@ -382,7 +382,7 @@ def get_ordered_errors(tasks_obj, task_name, doc_id_order):
382382
raise ValueError("Unexpected response from service - no errors for missing action results.")
383383

384384

385-
def _get_doc_results(task, doc_id_order, response_headers, returned_tasks_object):
385+
def _get_doc_results(task, doc_id_order, returned_tasks_object):
386386
returned_tasks = returned_tasks_object.tasks
387387
current_task_type, task_name = task
388388
deserialization_callback = _get_deserialization_callback_from_task_type(
@@ -401,18 +401,25 @@ def _get_doc_results(task, doc_id_order, response_headers, returned_tasks_object
401401
if response_task_to_deserialize.results is None:
402402
return get_ordered_errors(returned_tasks_object, task_name, doc_id_order)
403403
return deserialization_callback(
404-
doc_id_order, response_task_to_deserialize.results, response_headers, lro=True
404+
doc_id_order, response_task_to_deserialize.results, {}, lro=True
405405
)
406406

407407

408-
def get_iter_items(doc_id_order, task_order, response_headers, analyze_job_state):
408+
def get_iter_items(doc_id_order, task_order, bespoke, analyze_job_state):
409409
iter_items = defaultdict(list) # map doc id to action results
410410
returned_tasks_object = analyze_job_state
411+
412+
if bespoke:
413+
return _get_doc_results(
414+
task_order[0],
415+
doc_id_order,
416+
returned_tasks_object,
417+
)
418+
411419
for task in task_order:
412420
results = _get_doc_results(
413421
task,
414422
doc_id_order,
415-
response_headers,
416423
returned_tasks_object,
417424
)
418425
for result in results:
@@ -422,11 +429,11 @@ def get_iter_items(doc_id_order, task_order, response_headers, analyze_job_state
422429

423430

424431
def analyze_extract_page_data(
425-
doc_id_order, task_order, response_headers, analyze_job_state
432+
doc_id_order, task_order, bespoke, analyze_job_state
426433
):
427434
# return next link, list of
428435
iter_items = get_iter_items(
429-
doc_id_order, task_order, response_headers, analyze_job_state
436+
doc_id_order, task_order, bespoke, analyze_job_state
430437
)
431438
return analyze_job_state.next_link, iter_items
432439

@@ -456,14 +463,14 @@ def lro_get_next_page(
456463

457464

458465
def healthcare_paged_result(
459-
doc_id_order, health_status_callback, _, obj, response_headers, show_stats=False
460-
): # pylint: disable=unused-argument
466+
doc_id_order, health_status_callback, _, obj, show_stats=False
467+
):
461468
return ItemPaged(
462469
functools.partial(
463470
lro_get_next_page, health_status_callback, obj, show_stats=show_stats
464471
),
465472
functools.partial(
466-
healthcare_extract_page_data, doc_id_order, obj, response_headers
473+
healthcare_extract_page_data, doc_id_order, obj
467474
),
468475
)
469476

@@ -474,14 +481,38 @@ def analyze_paged_result(
474481
analyze_status_callback,
475482
_,
476483
obj,
477-
response_headers,
478484
show_stats=False,
479-
): # pylint: disable=unused-argument
485+
bespoke=False
486+
):
480487
return ItemPaged(
481488
functools.partial(
482489
lro_get_next_page, analyze_status_callback, obj, show_stats=show_stats
483490
),
484491
functools.partial(
485-
analyze_extract_page_data, doc_id_order, task_order, response_headers
492+
analyze_extract_page_data, doc_id_order, task_order, bespoke
486493
),
487494
)
495+
496+
497+
def _get_result_from_continuation_token(
498+
client, continuation_token, poller_type, polling_method, callback, bespoke=False
499+
):
500+
def result_callback(initial_response, pipeline_response):
501+
doc_id_order = initial_response.context.options["doc_id_order"]
502+
show_stats = initial_response.context.options["show_stats"]
503+
task_id_order = initial_response.context.options.get("task_id_order")
504+
return callback(
505+
pipeline_response,
506+
None,
507+
doc_id_order,
508+
task_id_order=task_id_order,
509+
show_stats=show_stats,
510+
bespoke=bespoke
511+
)
512+
513+
return poller_type.from_continuation_token(
514+
polling_method=polling_method,
515+
client=client,
516+
deserialization_callback=result_callback,
517+
continuation_token=continuation_token
518+
)

0 commit comments

Comments
 (0)