1
- # Copyright (c) 2023, 2024 , Oracle and/or its affiliates.
1
+ # Copyright (c) 2023, 2025 , Oracle and/or its affiliates.
2
2
#
3
3
# This program is free software; you can redistribute it and/or modify
4
4
# it under the terms of the GNU General Public License, version 2.0, as
59
59
)
60
60
61
61
from .. import version
62
+ from .._scripting import get_local_infile_filenames
62
63
from ..constants import (
63
64
ClientFlag ,
64
65
FieldType ,
81
82
WriteTimeoutError ,
82
83
get_exception ,
83
84
)
85
+ from ..protocol import EOF_STATUS , ERR_STATUS , LOCAL_INFILE_STATUS , OK_STATUS
84
86
from ..types import (
85
87
BinaryProtocolType ,
86
88
DescriptionType ,
@@ -219,7 +221,7 @@ async def _do_handshake(self) -> None:
219
221
"""Get the handshake from the MySQL server."""
220
222
packet = await self ._socket .read ()
221
223
logger .debug ("Protocol::Handshake packet: %s" , packet )
222
- if packet [4 ] == 255 :
224
+ if packet [4 ] == ERR_STATUS :
223
225
raise get_exception (packet )
224
226
225
227
self ._handshake = self ._protocol .parse_handshake (packet )
@@ -368,11 +370,11 @@ def _handle_ok(self, packet: bytes) -> OkPacketType:
368
370
packet, an error will be raised. If the packet is neither an OK or an Error
369
371
packet, InterfaceError will be raised.
370
372
"""
371
- if packet [4 ] == 0 :
373
+ if packet [4 ] == OK_STATUS :
372
374
ok_pkt = self ._protocol .parse_ok (packet )
373
375
self ._handle_server_status (ok_pkt ["status_flag" ])
374
376
return ok_pkt
375
- if packet [4 ] == 255 :
377
+ if packet [4 ] == ERR_STATUS :
376
378
raise get_exception (packet )
377
379
raise InterfaceError ("Expected OK packet" )
378
380
@@ -393,11 +395,11 @@ def _handle_eof(self, packet: bytes) -> EofPacketType:
393
395
packet, an error will be raised. If the packet is neither and OK or an Error
394
396
packet, InterfaceError will be raised.
395
397
"""
396
- if packet [4 ] == 254 :
398
+ if packet [4 ] == EOF_STATUS :
397
399
eof = self ._protocol .parse_eof (packet )
398
400
self ._handle_server_status (eof ["status_flag" ])
399
401
return eof
400
- if packet [4 ] == 255 :
402
+ if packet [4 ] == ERR_STATUS :
401
403
raise get_exception (packet )
402
404
raise InterfaceError ("Expected EOF packet" )
403
405
@@ -409,7 +411,49 @@ async def _handle_load_data_infile(
409
411
write_timeout : Optional [int ] = None ,
410
412
) -> OkPacketType :
411
413
"""Handle a LOAD DATA INFILE LOCAL request."""
414
+ if self ._local_infile_filenames is None :
415
+ self ._local_infile_filenames = get_local_infile_filenames (self ._query )
416
+ if not self ._local_infile_filenames :
417
+ raise InterfaceError (
418
+ "No `LOCAL INFILE` statements found in the client's request. "
419
+ "Check your request includes valid `LOCAL INFILE` statements."
420
+ )
421
+ elif not self ._local_infile_filenames :
422
+ raise InterfaceError (
423
+ "Got more `LOCAL INFILE` responses than number of `LOCAL INFILE` "
424
+ "statements specified in the client's request. Please, report this "
425
+ "issue to the development team."
426
+ )
427
+
412
428
file_name = os .path .abspath (filename )
429
+ file_name_from_request = os .path .abspath (self ._local_infile_filenames .popleft ())
430
+
431
+ # Verify the file location specified by `filename` from client's request exists
432
+ if not os .path .exists (file_name_from_request ):
433
+ raise InterfaceError (
434
+ f"Location specified by filename { file_name_from_request } "
435
+ "from client's request does not exist."
436
+ )
437
+
438
+ # Verify the file location specified by `filename` from server's response exists
439
+ if not os .path .exists (file_name ):
440
+ raise InterfaceError (
441
+ f"Location specified by filename { file_name } from server's "
442
+ "response does not exist."
443
+ )
444
+
445
+ # Verify the `filename` specified by server's response matches the one from
446
+ # the client's request.
447
+ try :
448
+ if not os .path .samefile (file_name , file_name_from_request ):
449
+ raise InterfaceError (
450
+ f"Filename { file_name } from the server's response is not the same "
451
+ f"as filename { file_name_from_request } from the "
452
+ "client's request."
453
+ )
454
+ except OSError as err :
455
+ raise InterfaceError from err
456
+
413
457
if os .path .islink (file_name ):
414
458
raise OperationalError ("Use of symbolic link is not allowed" )
415
459
if not self ._allow_local_infile and not self ._allow_local_infile_in_path :
@@ -478,16 +522,16 @@ async def _handle_result(
478
522
"""
479
523
if not packet or len (packet ) < 4 :
480
524
raise InterfaceError ("Empty response" )
481
- if packet [4 ] == 0 :
525
+ if packet [4 ] == OK_STATUS :
482
526
return self ._handle_ok (packet )
483
- if packet [4 ] == 251 :
527
+ if packet [4 ] == LOCAL_INFILE_STATUS :
484
528
filename = packet [5 :].decode ()
485
529
return await self ._handle_load_data_infile (
486
530
filename , read_timeout , write_timeout
487
531
)
488
- if packet [4 ] == 254 :
532
+ if packet [4 ] == EOF_STATUS :
489
533
return self ._handle_eof (packet )
490
- if packet [4 ] == 255 :
534
+ if packet [4 ] == ERR_STATUS :
491
535
raise get_exception (packet )
492
536
493
537
# We have a text result set
@@ -520,9 +564,9 @@ def _handle_binary_ok(self, packet: bytes) -> Dict[str, int]:
520
564
521
565
Returns a dict()
522
566
"""
523
- if packet [4 ] == 0 :
567
+ if packet [4 ] == OK_STATUS :
524
568
return self ._protocol .parse_binary_prepare_ok (packet )
525
- if packet [4 ] == 255 :
569
+ if packet [4 ] == ERR_STATUS :
526
570
raise get_exception (packet )
527
571
raise InterfaceError ("Expected Binary OK packet" )
528
572
@@ -545,11 +589,11 @@ async def _handle_binary_result(
545
589
"""
546
590
if not packet or len (packet ) < 4 :
547
591
raise InterfaceError ("Empty response" )
548
- if packet [4 ] == 0 :
592
+ if packet [4 ] == OK_STATUS :
549
593
return self ._handle_ok (packet )
550
- if packet [4 ] == 254 :
594
+ if packet [4 ] == EOF_STATUS :
551
595
return self ._handle_eof (packet )
552
- if packet [4 ] == 255 :
596
+ if packet [4 ] == ERR_STATUS :
553
597
raise get_exception (packet )
554
598
555
599
# We have a binary result set
@@ -965,6 +1009,11 @@ async def cmd_query(
965
1009
if isinstance (query , str ):
966
1010
query = query .encode ()
967
1011
query = bytearray (query )
1012
+
1013
+ # Set/Reset internal state related to query execution
1014
+ self ._query = query
1015
+ self ._local_infile_filenames = None
1016
+
968
1017
# Prepare query attrs
969
1018
charset = self ._charset .name if self ._charset .name != "utf8mb4" else "utf8"
970
1019
packet = bytearray ()
0 commit comments