30
30
import argparse
31
31
import ast
32
32
import base64
33
+ import contextlib
33
34
import json
35
+ import os
34
36
import re
35
37
import shutil
36
38
import subprocess
39
41
import time
40
42
import uuid
41
43
44
+ from cryptography import fernet
42
45
import google .auth
43
46
from google .cloud import storage
44
47
from google .oauth2 import service_account
45
- from googleapiclient import discovery
46
- from googleapiclient import errors
47
-
48
+ from googleapiclient import discovery , errors
49
+ from kubernetes import client , config
50
+ from mysql import connector
51
+ import six
52
+ from six .moves import configparser
48
53
49
54
DEFAULT_SCOPES = ["https://www.googleapis.com/auth/cloud-platform" ]
50
55
@@ -370,8 +375,112 @@ def export_data(sql_client, project, instance, gcs_bucket_name, filename):
370
375
wait_sql_operation (sql_client , project , operation .get ("name" ))
371
376
372
377
378
+ def get_fernet_key (composer_env ):
379
+ print ("Retrieving fernet key for Composer Environment {}." .format (
380
+ composer_env .get ('name' , '' )))
381
+ gke_cluster_resource = composer_env .get ("config" , {}).get ("gkeCluster" )
382
+ project_zone_cluster = re .match (
383
+ "projects/([^/]*)/zones/([^/]*)/clusters/([^/]*)" , gke_cluster_resource
384
+ ).groups ()
385
+ tmp_dir_name = None
386
+ try :
387
+ print ("Getting cluster credentials {} to retrieve fernet key." .format (
388
+ gke_cluster_resource ))
389
+ tmp_dir_name = tempfile .mkdtemp ()
390
+ kubeconfig_file = tmp_dir_name + "/config"
391
+ os .environ ["KUBECONFIG" ] = kubeconfig_file
392
+ if subprocess .call (
393
+ [
394
+ "gcloud" ,
395
+ "container" ,
396
+ "clusters" ,
397
+ "get-credentials" ,
398
+ project_zone_cluster [2 ],
399
+ "--zone" ,
400
+ project_zone_cluster [1 ],
401
+ "--project" ,
402
+ project_zone_cluster [0 ]
403
+ ]
404
+ ):
405
+ print ("Failed to retrieve cluster credentials: {}." .format (
406
+ gke_cluster_resource ))
407
+ sys .exit (1 )
408
+
409
+ kubernetes_client = client .CoreV1Api (
410
+ api_client = config .new_client_from_config (
411
+ config_file = kubeconfig_file ))
412
+ airflow_configmap = kubernetes_client .read_namespaced_config_map (
413
+ "airflow-configmap" , "default" )
414
+ config_str = airflow_configmap .data ['airflow.cfg' ]
415
+ with contextlib .closing (six .StringIO (config_str )) as config_buffer :
416
+ config_parser = configparser .ConfigParser ()
417
+ config_parser .readfp (config_buffer )
418
+ return config_parser .get ("core" , "fernet_key" )
419
+ except Exception as exc :
420
+ print ("Failed to get fernet key for cluster: {}." .format (str (exc )))
421
+ sys .exit (1 )
422
+ finally :
423
+ if tmp_dir_name :
424
+ shutil .rmtree (tmp_dir_name )
425
+
426
+
427
+ def reencrypt_variables_connections (old_fernet_key_str , new_fernet_key_str ):
428
+ old_fernet_key = fernet .Fernet (old_fernet_key_str .encode ("utf-8" ))
429
+ new_fernet_key = fernet .Fernet (new_fernet_key_str .encode ("utf-8" ))
430
+ db = connector .connect (
431
+ host = "127.0.0.1" ,
432
+ user = "root" ,
433
+ database = "airflow-db" ,
434
+ )
435
+ variable_cursor = db .cursor ()
436
+ variable_cursor .execute ("SELECT id, val, is_encrypted FROM variable" )
437
+ rows = variable_cursor .fetchall ()
438
+ for row in rows :
439
+ id = row [0 ]
440
+ val = row [1 ]
441
+ is_encrypted = row [2 ]
442
+ if is_encrypted :
443
+ updated_val = new_fernet_key .encrypt (
444
+ old_fernet_key .decrypt (bytes (val ))).decode ()
445
+ variable_cursor .execute (
446
+ "UPDATE variable SET val=%s WHERE id=%s" , (updated_val , id ))
447
+ db .commit ()
448
+
449
+ conn_cursor = db .cursor ()
450
+ conn_cursor .execute (
451
+ "SELECT id, password, extra, is_encrypted, is_extra_encrypted FROM "
452
+ "connection" )
453
+ rows = conn_cursor .fetchall ()
454
+ for row in rows :
455
+ id = row [0 ]
456
+ password = row [1 ]
457
+ extra = row [2 ]
458
+ is_encrypted = row [3 ]
459
+ is_extra_encrypted = row [4 ]
460
+ if is_encrypted :
461
+ updated_password = new_fernet_key .encrypt (
462
+ old_fernet_key .decrypt (bytes (password ))).decode ()
463
+ conn_cursor .execute (
464
+ "UPDATE connection SET password=%s WHERE id=%s" ,
465
+ (updated_password , id ))
466
+ if is_extra_encrypted :
467
+ updated_extra = new_fernet_key .encrypt (
468
+ old_fernet_key .decrypt (bytes (extra ))).decode ()
469
+ conn_cursor .execute (
470
+ "UPDATE connection SET extra=%s WHERE id=%s" ,
471
+ (updated_extra , id ))
472
+ db .commit ()
473
+
474
+
373
475
def import_data (
374
- sql_client , service_account_key , project , instance , gcs_bucket , filename
476
+ sql_client ,
477
+ service_account_key ,
478
+ project ,
479
+ instance ,
480
+ gcs_bucket ,
481
+ filename ,
482
+ old_fernet_key ,
483
+ new_fernet_key
375
484
):
376
485
tmp_dir_name = None
377
486
fuse_dir = None
@@ -383,7 +492,6 @@ def import_data(
383
492
if subprocess .call (["mkdir" , fuse_dir ]):
384
493
print ("Failed to make temporary subdir {}." .format (fuse_dir ))
385
494
sys .exit (1 )
386
- print (str (["gcsfuse" , gcs_bucket , fuse_dir ]))
387
495
if subprocess .call (["gcsfuse" , gcs_bucket , fuse_dir ]):
388
496
print (
389
497
"Failed to fuse bucket {} with temp local directory {}" .format (
@@ -424,9 +532,11 @@ def import_data(
424
532
):
425
533
print ("Failed to import database." )
426
534
sys .exit (1 )
535
+ print ("Reencrypting variables and connections." )
536
+ reencrypt_variables_connections (old_fernet_key , new_fernet_key )
427
537
print ("Database import succeeded." )
428
- except Exception :
429
- print ("Failed to copy database." )
538
+ except Exception as exc :
539
+ print ("Failed to copy database: {}" . format ( str ( exc )) )
430
540
sys .exit (1 )
431
541
finally :
432
542
if proxy_subprocess :
@@ -522,6 +632,9 @@ def copy_database(project, existing_env, new_env, running_as_service_account):
522
632
gcs_db_dump_bucket .name ,
523
633
"db_dump.sql" ,
524
634
)
635
+ print ("Obtaining fernet keys for Composer Environments." )
636
+ old_fernet_key = get_fernet_key (existing_env )
637
+ new_fernet_key = get_fernet_key (new_env )
525
638
print ("Preparing database import to new Environment." )
526
639
import_data (
527
640
sql_client ,
@@ -530,6 +643,8 @@ def copy_database(project, existing_env, new_env, running_as_service_account):
530
643
new_sql_instance ,
531
644
gcs_db_dump_bucket .name ,
532
645
"db_dump.sql" ,
646
+ old_fernet_key ,
647
+ new_fernet_key ,
533
648
)
534
649
finally :
535
650
if gke_service_account_key :
@@ -542,7 +657,7 @@ def copy_database(project, existing_env, new_env, running_as_service_account):
542
657
)
543
658
if gcs_db_dump_bucket :
544
659
print ("Deleting temporary Cloud Storage bucket." )
545
- # delete_bucket(gcs_db_dump_bucket)
660
+ delete_bucket (gcs_db_dump_bucket )
546
661
547
662
548
663
def copy_gcs_bucket (existing_env , new_env ):
0 commit comments