diff --git a/.repo-metadata.json b/.repo-metadata.json index ccd2629f3..00f9bb7dc 100644 --- a/.repo-metadata.json +++ b/.repo-metadata.json @@ -4,7 +4,7 @@ "product_documentation": "https://cloud.google.com/compute/", "client_documentation": "https://googleapis.dev/python/compute/latest", "issue_tracker": "https://issuetracker.google.com/issues/new?component=187134&template=0", - "release_level": "beta", + "release_level": "alpha", "language": "python", "repo": "googleapis/python-compute", "distribution_name": "google-cloud-compute", diff --git a/.trampolinerc b/.trampolinerc index c7d663ae9..383b6ec89 100644 --- a/.trampolinerc +++ b/.trampolinerc @@ -18,7 +18,6 @@ required_envvars+=( "STAGING_BUCKET" "V2_STAGING_BUCKET" - "NOX_SESSION" ) # Add env vars which are passed down into the container here. diff --git a/CHANGELOG.md b/CHANGELOG.md index e309376cb..319353939 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,22 @@ # Changelog +## [0.2.0](https://www.github.com/googleapis/python-compute/compare/v0.1.0...v0.2.0) (2021-02-11) + + +### Features + +* run synthtool to pick up mtls feature ([#6](https://www.github.com/googleapis/python-compute/issues/6)) ([3abec21](https://www.github.com/googleapis/python-compute/commit/3abec21a1d5b1384779c48b899f23ba18ca0ddb3)) + + +### Bug Fixes + +* don't use integers for enums in json encoding ([a3685b5](https://www.github.com/googleapis/python-compute/commit/a3685b5a03a75256d2d00b89dcc8fda34596edde)) +* fix body encoding for rest transport ([#17](https://www.github.com/googleapis/python-compute/issues/17)) ([a3685b5](https://www.github.com/googleapis/python-compute/commit/a3685b5a03a75256d2d00b89dcc8fda34596edde)) +* regenerate the client lib ([#9](https://www.github.com/googleapis/python-compute/issues/9)) ([b9def52](https://www.github.com/googleapis/python-compute/commit/b9def52a47067804d5b79e867fb3ff895f8f4c21)) +* set development status classifier to alpha ([#2](https://www.github.com/googleapis/python-compute/issues/2)) ([54814f8](https://www.github.com/googleapis/python-compute/commit/54814f8ad15b8f8dff051c7c7819bc4a7b8e099f)) +* stabilize order of query_params ([a3685b5](https://www.github.com/googleapis/python-compute/commit/a3685b5a03a75256d2d00b89dcc8fda34596edde)) +* update paging implementation to handle unconventional pagination ([a3685b5](https://www.github.com/googleapis/python-compute/commit/a3685b5a03a75256d2d00b89dcc8fda34596edde)) + ## 0.1.0 (2021-01-08) diff --git a/docs/compute_v1/accelerator_types.rst b/docs/compute_v1/accelerator_types.rst index 8788df128..5eb8673b0 100644 --- a/docs/compute_v1/accelerator_types.rst +++ b/docs/compute_v1/accelerator_types.rst @@ -4,3 +4,8 @@ AcceleratorTypes .. automodule:: google.cloud.compute_v1.services.accelerator_types :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.accelerator_types.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/addresses.rst b/docs/compute_v1/addresses.rst index ff9f390e0..0579da503 100644 --- a/docs/compute_v1/addresses.rst +++ b/docs/compute_v1/addresses.rst @@ -4,3 +4,8 @@ Addresses .. automodule:: google.cloud.compute_v1.services.addresses :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.addresses.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/autoscalers.rst b/docs/compute_v1/autoscalers.rst index b1934613f..7e7acb14c 100644 --- a/docs/compute_v1/autoscalers.rst +++ b/docs/compute_v1/autoscalers.rst @@ -4,3 +4,8 @@ Autoscalers .. automodule:: google.cloud.compute_v1.services.autoscalers :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.autoscalers.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/backend_buckets.rst b/docs/compute_v1/backend_buckets.rst index 5ae49505b..6b71dfbb6 100644 --- a/docs/compute_v1/backend_buckets.rst +++ b/docs/compute_v1/backend_buckets.rst @@ -4,3 +4,8 @@ BackendBuckets .. automodule:: google.cloud.compute_v1.services.backend_buckets :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.backend_buckets.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/backend_services.rst b/docs/compute_v1/backend_services.rst index 4ae802216..9d387953b 100644 --- a/docs/compute_v1/backend_services.rst +++ b/docs/compute_v1/backend_services.rst @@ -4,3 +4,8 @@ BackendServices .. automodule:: google.cloud.compute_v1.services.backend_services :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.backend_services.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/disk_types.rst b/docs/compute_v1/disk_types.rst index 685a6985e..a4853da43 100644 --- a/docs/compute_v1/disk_types.rst +++ b/docs/compute_v1/disk_types.rst @@ -4,3 +4,8 @@ DiskTypes .. automodule:: google.cloud.compute_v1.services.disk_types :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.disk_types.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/disks.rst b/docs/compute_v1/disks.rst index b3979d150..154429e8a 100644 --- a/docs/compute_v1/disks.rst +++ b/docs/compute_v1/disks.rst @@ -4,3 +4,8 @@ Disks .. automodule:: google.cloud.compute_v1.services.disks :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.disks.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/external_vpn_gateways.rst b/docs/compute_v1/external_vpn_gateways.rst index 0945902af..ff28cc2ea 100644 --- a/docs/compute_v1/external_vpn_gateways.rst +++ b/docs/compute_v1/external_vpn_gateways.rst @@ -4,3 +4,8 @@ ExternalVpnGateways .. automodule:: google.cloud.compute_v1.services.external_vpn_gateways :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.external_vpn_gateways.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/firewalls.rst b/docs/compute_v1/firewalls.rst index 81ac2d166..ef6049d7c 100644 --- a/docs/compute_v1/firewalls.rst +++ b/docs/compute_v1/firewalls.rst @@ -4,3 +4,8 @@ Firewalls .. automodule:: google.cloud.compute_v1.services.firewalls :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.firewalls.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/forwarding_rules.rst b/docs/compute_v1/forwarding_rules.rst index 7d16f9de5..08b543649 100644 --- a/docs/compute_v1/forwarding_rules.rst +++ b/docs/compute_v1/forwarding_rules.rst @@ -4,3 +4,8 @@ ForwardingRules .. automodule:: google.cloud.compute_v1.services.forwarding_rules :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.forwarding_rules.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/global_addresses.rst b/docs/compute_v1/global_addresses.rst index 7ca81d5c5..8755a1855 100644 --- a/docs/compute_v1/global_addresses.rst +++ b/docs/compute_v1/global_addresses.rst @@ -4,3 +4,8 @@ GlobalAddresses .. automodule:: google.cloud.compute_v1.services.global_addresses :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.global_addresses.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/global_forwarding_rules.rst b/docs/compute_v1/global_forwarding_rules.rst index b4a2a8381..a54eb4b0a 100644 --- a/docs/compute_v1/global_forwarding_rules.rst +++ b/docs/compute_v1/global_forwarding_rules.rst @@ -4,3 +4,8 @@ GlobalForwardingRules .. automodule:: google.cloud.compute_v1.services.global_forwarding_rules :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.global_forwarding_rules.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/global_network_endpoint_groups.rst b/docs/compute_v1/global_network_endpoint_groups.rst index c4962be83..430a202f2 100644 --- a/docs/compute_v1/global_network_endpoint_groups.rst +++ b/docs/compute_v1/global_network_endpoint_groups.rst @@ -4,3 +4,8 @@ GlobalNetworkEndpointGroups .. automodule:: google.cloud.compute_v1.services.global_network_endpoint_groups :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.global_network_endpoint_groups.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/global_operations.rst b/docs/compute_v1/global_operations.rst index c9280a004..6c4cef346 100644 --- a/docs/compute_v1/global_operations.rst +++ b/docs/compute_v1/global_operations.rst @@ -4,3 +4,8 @@ GlobalOperations .. automodule:: google.cloud.compute_v1.services.global_operations :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.global_operations.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/global_organization_operations.rst b/docs/compute_v1/global_organization_operations.rst index 5a415898b..2fb034517 100644 --- a/docs/compute_v1/global_organization_operations.rst +++ b/docs/compute_v1/global_organization_operations.rst @@ -4,3 +4,8 @@ GlobalOrganizationOperations .. automodule:: google.cloud.compute_v1.services.global_organization_operations :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.global_organization_operations.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/health_checks.rst b/docs/compute_v1/health_checks.rst index cfa091319..68e2841a2 100644 --- a/docs/compute_v1/health_checks.rst +++ b/docs/compute_v1/health_checks.rst @@ -4,3 +4,8 @@ HealthChecks .. automodule:: google.cloud.compute_v1.services.health_checks :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.health_checks.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/images.rst b/docs/compute_v1/images.rst index 3ad9732dd..318fd8073 100644 --- a/docs/compute_v1/images.rst +++ b/docs/compute_v1/images.rst @@ -4,3 +4,8 @@ Images .. automodule:: google.cloud.compute_v1.services.images :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.images.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/instance_group_managers.rst b/docs/compute_v1/instance_group_managers.rst index aed206b3b..0010acecf 100644 --- a/docs/compute_v1/instance_group_managers.rst +++ b/docs/compute_v1/instance_group_managers.rst @@ -4,3 +4,8 @@ InstanceGroupManagers .. automodule:: google.cloud.compute_v1.services.instance_group_managers :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.instance_group_managers.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/instance_groups.rst b/docs/compute_v1/instance_groups.rst index 35a4b178d..094150adb 100644 --- a/docs/compute_v1/instance_groups.rst +++ b/docs/compute_v1/instance_groups.rst @@ -4,3 +4,8 @@ InstanceGroups .. automodule:: google.cloud.compute_v1.services.instance_groups :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.instance_groups.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/instance_templates.rst b/docs/compute_v1/instance_templates.rst index 4f0effd30..35a6f2f09 100644 --- a/docs/compute_v1/instance_templates.rst +++ b/docs/compute_v1/instance_templates.rst @@ -4,3 +4,8 @@ InstanceTemplates .. automodule:: google.cloud.compute_v1.services.instance_templates :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.instance_templates.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/instances.rst b/docs/compute_v1/instances.rst index 2d63480da..f8ba5f34b 100644 --- a/docs/compute_v1/instances.rst +++ b/docs/compute_v1/instances.rst @@ -4,3 +4,8 @@ Instances .. automodule:: google.cloud.compute_v1.services.instances :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.instances.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/interconnect_attachments.rst b/docs/compute_v1/interconnect_attachments.rst index 0aa910125..3356e3459 100644 --- a/docs/compute_v1/interconnect_attachments.rst +++ b/docs/compute_v1/interconnect_attachments.rst @@ -4,3 +4,8 @@ InterconnectAttachments .. automodule:: google.cloud.compute_v1.services.interconnect_attachments :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.interconnect_attachments.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/interconnect_locations.rst b/docs/compute_v1/interconnect_locations.rst index f78790ca7..bfb2d7c87 100644 --- a/docs/compute_v1/interconnect_locations.rst +++ b/docs/compute_v1/interconnect_locations.rst @@ -4,3 +4,8 @@ InterconnectLocations .. automodule:: google.cloud.compute_v1.services.interconnect_locations :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.interconnect_locations.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/interconnects.rst b/docs/compute_v1/interconnects.rst index 7324ff16c..d33ff5d3c 100644 --- a/docs/compute_v1/interconnects.rst +++ b/docs/compute_v1/interconnects.rst @@ -4,3 +4,8 @@ Interconnects .. automodule:: google.cloud.compute_v1.services.interconnects :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.interconnects.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/licenses.rst b/docs/compute_v1/licenses.rst index f3807c747..4684df852 100644 --- a/docs/compute_v1/licenses.rst +++ b/docs/compute_v1/licenses.rst @@ -4,3 +4,8 @@ Licenses .. automodule:: google.cloud.compute_v1.services.licenses :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.licenses.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/machine_types.rst b/docs/compute_v1/machine_types.rst index aad7d7378..768b2d449 100644 --- a/docs/compute_v1/machine_types.rst +++ b/docs/compute_v1/machine_types.rst @@ -4,3 +4,8 @@ MachineTypes .. automodule:: google.cloud.compute_v1.services.machine_types :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.machine_types.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/network_endpoint_groups.rst b/docs/compute_v1/network_endpoint_groups.rst index cc230e4cb..3a1a44dee 100644 --- a/docs/compute_v1/network_endpoint_groups.rst +++ b/docs/compute_v1/network_endpoint_groups.rst @@ -4,3 +4,8 @@ NetworkEndpointGroups .. automodule:: google.cloud.compute_v1.services.network_endpoint_groups :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.network_endpoint_groups.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/networks.rst b/docs/compute_v1/networks.rst index ba328fd42..91273a803 100644 --- a/docs/compute_v1/networks.rst +++ b/docs/compute_v1/networks.rst @@ -4,3 +4,8 @@ Networks .. automodule:: google.cloud.compute_v1.services.networks :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.networks.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/node_groups.rst b/docs/compute_v1/node_groups.rst index ef9e7e5a5..95b5c580b 100644 --- a/docs/compute_v1/node_groups.rst +++ b/docs/compute_v1/node_groups.rst @@ -4,3 +4,8 @@ NodeGroups .. automodule:: google.cloud.compute_v1.services.node_groups :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.node_groups.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/node_templates.rst b/docs/compute_v1/node_templates.rst index 1ef57c0c8..d6b5c39db 100644 --- a/docs/compute_v1/node_templates.rst +++ b/docs/compute_v1/node_templates.rst @@ -4,3 +4,8 @@ NodeTemplates .. automodule:: google.cloud.compute_v1.services.node_templates :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.node_templates.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/node_types.rst b/docs/compute_v1/node_types.rst index 7ee2ecbce..f5723f0fc 100644 --- a/docs/compute_v1/node_types.rst +++ b/docs/compute_v1/node_types.rst @@ -4,3 +4,8 @@ NodeTypes .. automodule:: google.cloud.compute_v1.services.node_types :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.node_types.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/packet_mirrorings.rst b/docs/compute_v1/packet_mirrorings.rst index 4f6a68cb0..77d59a217 100644 --- a/docs/compute_v1/packet_mirrorings.rst +++ b/docs/compute_v1/packet_mirrorings.rst @@ -4,3 +4,8 @@ PacketMirrorings .. automodule:: google.cloud.compute_v1.services.packet_mirrorings :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.packet_mirrorings.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/projects.rst b/docs/compute_v1/projects.rst index 7df931447..deb8641a1 100644 --- a/docs/compute_v1/projects.rst +++ b/docs/compute_v1/projects.rst @@ -4,3 +4,8 @@ Projects .. automodule:: google.cloud.compute_v1.services.projects :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.projects.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/region_autoscalers.rst b/docs/compute_v1/region_autoscalers.rst index df6b30705..b491e22e2 100644 --- a/docs/compute_v1/region_autoscalers.rst +++ b/docs/compute_v1/region_autoscalers.rst @@ -4,3 +4,8 @@ RegionAutoscalers .. automodule:: google.cloud.compute_v1.services.region_autoscalers :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.region_autoscalers.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/region_backend_services.rst b/docs/compute_v1/region_backend_services.rst index 3a19c6104..bbfda5cf4 100644 --- a/docs/compute_v1/region_backend_services.rst +++ b/docs/compute_v1/region_backend_services.rst @@ -4,3 +4,8 @@ RegionBackendServices .. automodule:: google.cloud.compute_v1.services.region_backend_services :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.region_backend_services.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/region_commitments.rst b/docs/compute_v1/region_commitments.rst index c9e747a37..141ad66e7 100644 --- a/docs/compute_v1/region_commitments.rst +++ b/docs/compute_v1/region_commitments.rst @@ -4,3 +4,8 @@ RegionCommitments .. automodule:: google.cloud.compute_v1.services.region_commitments :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.region_commitments.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/region_disk_types.rst b/docs/compute_v1/region_disk_types.rst index 0eb6fc286..f46e6258e 100644 --- a/docs/compute_v1/region_disk_types.rst +++ b/docs/compute_v1/region_disk_types.rst @@ -4,3 +4,8 @@ RegionDiskTypes .. automodule:: google.cloud.compute_v1.services.region_disk_types :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.region_disk_types.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/region_disks.rst b/docs/compute_v1/region_disks.rst index 8812bc142..689b917a2 100644 --- a/docs/compute_v1/region_disks.rst +++ b/docs/compute_v1/region_disks.rst @@ -4,3 +4,8 @@ RegionDisks .. automodule:: google.cloud.compute_v1.services.region_disks :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.region_disks.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/region_health_check_services.rst b/docs/compute_v1/region_health_check_services.rst index 3020a1d36..8c342c0ad 100644 --- a/docs/compute_v1/region_health_check_services.rst +++ b/docs/compute_v1/region_health_check_services.rst @@ -4,3 +4,8 @@ RegionHealthCheckServices .. automodule:: google.cloud.compute_v1.services.region_health_check_services :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.region_health_check_services.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/region_health_checks.rst b/docs/compute_v1/region_health_checks.rst index bc916ce1a..631de0dfc 100644 --- a/docs/compute_v1/region_health_checks.rst +++ b/docs/compute_v1/region_health_checks.rst @@ -4,3 +4,8 @@ RegionHealthChecks .. automodule:: google.cloud.compute_v1.services.region_health_checks :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.region_health_checks.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/region_instance_group_managers.rst b/docs/compute_v1/region_instance_group_managers.rst index 7cc14091d..fd3af5961 100644 --- a/docs/compute_v1/region_instance_group_managers.rst +++ b/docs/compute_v1/region_instance_group_managers.rst @@ -4,3 +4,8 @@ RegionInstanceGroupManagers .. automodule:: google.cloud.compute_v1.services.region_instance_group_managers :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.region_instance_group_managers.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/region_instance_groups.rst b/docs/compute_v1/region_instance_groups.rst index 22161a072..473dc65e9 100644 --- a/docs/compute_v1/region_instance_groups.rst +++ b/docs/compute_v1/region_instance_groups.rst @@ -4,3 +4,8 @@ RegionInstanceGroups .. automodule:: google.cloud.compute_v1.services.region_instance_groups :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.region_instance_groups.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/region_network_endpoint_groups.rst b/docs/compute_v1/region_network_endpoint_groups.rst index 74ccc559b..d16d391cf 100644 --- a/docs/compute_v1/region_network_endpoint_groups.rst +++ b/docs/compute_v1/region_network_endpoint_groups.rst @@ -4,3 +4,8 @@ RegionNetworkEndpointGroups .. automodule:: google.cloud.compute_v1.services.region_network_endpoint_groups :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.region_network_endpoint_groups.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/region_notification_endpoints.rst b/docs/compute_v1/region_notification_endpoints.rst index 6924c1768..4baf0ced5 100644 --- a/docs/compute_v1/region_notification_endpoints.rst +++ b/docs/compute_v1/region_notification_endpoints.rst @@ -4,3 +4,8 @@ RegionNotificationEndpoints .. automodule:: google.cloud.compute_v1.services.region_notification_endpoints :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.region_notification_endpoints.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/region_operations.rst b/docs/compute_v1/region_operations.rst index 9bf0248ad..be2ef7bf9 100644 --- a/docs/compute_v1/region_operations.rst +++ b/docs/compute_v1/region_operations.rst @@ -4,3 +4,8 @@ RegionOperations .. automodule:: google.cloud.compute_v1.services.region_operations :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.region_operations.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/region_ssl_certificates.rst b/docs/compute_v1/region_ssl_certificates.rst index 1b45f988b..6de5c7c4b 100644 --- a/docs/compute_v1/region_ssl_certificates.rst +++ b/docs/compute_v1/region_ssl_certificates.rst @@ -4,3 +4,8 @@ RegionSslCertificates .. automodule:: google.cloud.compute_v1.services.region_ssl_certificates :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.region_ssl_certificates.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/region_target_http_proxies.rst b/docs/compute_v1/region_target_http_proxies.rst index 867748824..f157d3bb7 100644 --- a/docs/compute_v1/region_target_http_proxies.rst +++ b/docs/compute_v1/region_target_http_proxies.rst @@ -4,3 +4,8 @@ RegionTargetHttpProxies .. automodule:: google.cloud.compute_v1.services.region_target_http_proxies :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.region_target_http_proxies.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/region_target_https_proxies.rst b/docs/compute_v1/region_target_https_proxies.rst index 0eb410395..3b5c715b4 100644 --- a/docs/compute_v1/region_target_https_proxies.rst +++ b/docs/compute_v1/region_target_https_proxies.rst @@ -4,3 +4,8 @@ RegionTargetHttpsProxies .. automodule:: google.cloud.compute_v1.services.region_target_https_proxies :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.region_target_https_proxies.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/region_url_maps.rst b/docs/compute_v1/region_url_maps.rst index 2a7d37904..e01a03f0d 100644 --- a/docs/compute_v1/region_url_maps.rst +++ b/docs/compute_v1/region_url_maps.rst @@ -4,3 +4,8 @@ RegionUrlMaps .. automodule:: google.cloud.compute_v1.services.region_url_maps :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.region_url_maps.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/regions.rst b/docs/compute_v1/regions.rst index d547521f6..f1128793b 100644 --- a/docs/compute_v1/regions.rst +++ b/docs/compute_v1/regions.rst @@ -4,3 +4,8 @@ Regions .. automodule:: google.cloud.compute_v1.services.regions :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.regions.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/reservations.rst b/docs/compute_v1/reservations.rst index f9913af5e..14f4dce81 100644 --- a/docs/compute_v1/reservations.rst +++ b/docs/compute_v1/reservations.rst @@ -4,3 +4,8 @@ Reservations .. automodule:: google.cloud.compute_v1.services.reservations :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.reservations.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/resource_policies.rst b/docs/compute_v1/resource_policies.rst index eceb888d6..6ca37afc2 100644 --- a/docs/compute_v1/resource_policies.rst +++ b/docs/compute_v1/resource_policies.rst @@ -4,3 +4,8 @@ ResourcePolicies .. automodule:: google.cloud.compute_v1.services.resource_policies :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.resource_policies.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/routers.rst b/docs/compute_v1/routers.rst index b35977a2c..dac1975e4 100644 --- a/docs/compute_v1/routers.rst +++ b/docs/compute_v1/routers.rst @@ -4,3 +4,8 @@ Routers .. automodule:: google.cloud.compute_v1.services.routers :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.routers.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/routes.rst b/docs/compute_v1/routes.rst index b6c784132..8a14f50e6 100644 --- a/docs/compute_v1/routes.rst +++ b/docs/compute_v1/routes.rst @@ -4,3 +4,8 @@ Routes .. automodule:: google.cloud.compute_v1.services.routes :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.routes.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/security_policies.rst b/docs/compute_v1/security_policies.rst index bf60efa18..c4bc38690 100644 --- a/docs/compute_v1/security_policies.rst +++ b/docs/compute_v1/security_policies.rst @@ -4,3 +4,8 @@ SecurityPolicies .. automodule:: google.cloud.compute_v1.services.security_policies :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.security_policies.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/snapshots.rst b/docs/compute_v1/snapshots.rst index 8265d6d19..1af9eed27 100644 --- a/docs/compute_v1/snapshots.rst +++ b/docs/compute_v1/snapshots.rst @@ -4,3 +4,8 @@ Snapshots .. automodule:: google.cloud.compute_v1.services.snapshots :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.snapshots.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/ssl_certificates.rst b/docs/compute_v1/ssl_certificates.rst index ca48ca54c..c63ade587 100644 --- a/docs/compute_v1/ssl_certificates.rst +++ b/docs/compute_v1/ssl_certificates.rst @@ -4,3 +4,8 @@ SslCertificates .. automodule:: google.cloud.compute_v1.services.ssl_certificates :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.ssl_certificates.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/ssl_policies.rst b/docs/compute_v1/ssl_policies.rst index 094179f70..ddb6a78b6 100644 --- a/docs/compute_v1/ssl_policies.rst +++ b/docs/compute_v1/ssl_policies.rst @@ -4,3 +4,8 @@ SslPolicies .. automodule:: google.cloud.compute_v1.services.ssl_policies :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.ssl_policies.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/subnetworks.rst b/docs/compute_v1/subnetworks.rst index 1469d7ba4..c99a21d05 100644 --- a/docs/compute_v1/subnetworks.rst +++ b/docs/compute_v1/subnetworks.rst @@ -4,3 +4,8 @@ Subnetworks .. automodule:: google.cloud.compute_v1.services.subnetworks :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.subnetworks.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/target_grpc_proxies.rst b/docs/compute_v1/target_grpc_proxies.rst index d36a5ed3d..302c72398 100644 --- a/docs/compute_v1/target_grpc_proxies.rst +++ b/docs/compute_v1/target_grpc_proxies.rst @@ -4,3 +4,8 @@ TargetGrpcProxies .. automodule:: google.cloud.compute_v1.services.target_grpc_proxies :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.target_grpc_proxies.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/target_http_proxies.rst b/docs/compute_v1/target_http_proxies.rst index dd082c59b..ce0cd5a42 100644 --- a/docs/compute_v1/target_http_proxies.rst +++ b/docs/compute_v1/target_http_proxies.rst @@ -4,3 +4,8 @@ TargetHttpProxies .. automodule:: google.cloud.compute_v1.services.target_http_proxies :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.target_http_proxies.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/target_https_proxies.rst b/docs/compute_v1/target_https_proxies.rst index d05206031..0331656fa 100644 --- a/docs/compute_v1/target_https_proxies.rst +++ b/docs/compute_v1/target_https_proxies.rst @@ -4,3 +4,8 @@ TargetHttpsProxies .. automodule:: google.cloud.compute_v1.services.target_https_proxies :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.target_https_proxies.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/target_instances.rst b/docs/compute_v1/target_instances.rst index c9421cf2d..840d1903e 100644 --- a/docs/compute_v1/target_instances.rst +++ b/docs/compute_v1/target_instances.rst @@ -4,3 +4,8 @@ TargetInstances .. automodule:: google.cloud.compute_v1.services.target_instances :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.target_instances.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/target_pools.rst b/docs/compute_v1/target_pools.rst index 719aa3661..3898ef46d 100644 --- a/docs/compute_v1/target_pools.rst +++ b/docs/compute_v1/target_pools.rst @@ -4,3 +4,8 @@ TargetPools .. automodule:: google.cloud.compute_v1.services.target_pools :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.target_pools.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/target_ssl_proxies.rst b/docs/compute_v1/target_ssl_proxies.rst index 21b5d9edb..f7a2d95ce 100644 --- a/docs/compute_v1/target_ssl_proxies.rst +++ b/docs/compute_v1/target_ssl_proxies.rst @@ -4,3 +4,8 @@ TargetSslProxies .. automodule:: google.cloud.compute_v1.services.target_ssl_proxies :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.target_ssl_proxies.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/target_tcp_proxies.rst b/docs/compute_v1/target_tcp_proxies.rst index 2f259b309..66b3d186f 100644 --- a/docs/compute_v1/target_tcp_proxies.rst +++ b/docs/compute_v1/target_tcp_proxies.rst @@ -4,3 +4,8 @@ TargetTcpProxies .. automodule:: google.cloud.compute_v1.services.target_tcp_proxies :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.target_tcp_proxies.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/target_vpn_gateways.rst b/docs/compute_v1/target_vpn_gateways.rst index f9690ebc9..f1f42c8ad 100644 --- a/docs/compute_v1/target_vpn_gateways.rst +++ b/docs/compute_v1/target_vpn_gateways.rst @@ -4,3 +4,8 @@ TargetVpnGateways .. automodule:: google.cloud.compute_v1.services.target_vpn_gateways :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.target_vpn_gateways.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/url_maps.rst b/docs/compute_v1/url_maps.rst index 259390224..4c82c849c 100644 --- a/docs/compute_v1/url_maps.rst +++ b/docs/compute_v1/url_maps.rst @@ -4,3 +4,8 @@ UrlMaps .. automodule:: google.cloud.compute_v1.services.url_maps :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.url_maps.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/vpn_gateways.rst b/docs/compute_v1/vpn_gateways.rst index b5921da6f..d08ca6794 100644 --- a/docs/compute_v1/vpn_gateways.rst +++ b/docs/compute_v1/vpn_gateways.rst @@ -4,3 +4,8 @@ VpnGateways .. automodule:: google.cloud.compute_v1.services.vpn_gateways :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.vpn_gateways.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/vpn_tunnels.rst b/docs/compute_v1/vpn_tunnels.rst index e513996c8..91549354b 100644 --- a/docs/compute_v1/vpn_tunnels.rst +++ b/docs/compute_v1/vpn_tunnels.rst @@ -4,3 +4,8 @@ VpnTunnels .. automodule:: google.cloud.compute_v1.services.vpn_tunnels :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.vpn_tunnels.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/zone_operations.rst b/docs/compute_v1/zone_operations.rst index 83847994f..1b966fd77 100644 --- a/docs/compute_v1/zone_operations.rst +++ b/docs/compute_v1/zone_operations.rst @@ -4,3 +4,8 @@ ZoneOperations .. automodule:: google.cloud.compute_v1.services.zone_operations :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.zone_operations.pagers + :members: + :inherited-members: diff --git a/docs/compute_v1/zones.rst b/docs/compute_v1/zones.rst index d56b0f489..55af2eafc 100644 --- a/docs/compute_v1/zones.rst +++ b/docs/compute_v1/zones.rst @@ -4,3 +4,8 @@ Zones .. automodule:: google.cloud.compute_v1.services.zones :members: :inherited-members: + + +.. automodule:: google.cloud.compute_v1.services.zones.pagers + :members: + :inherited-members: diff --git a/google/cloud/compute_v1/__init__.py b/google/cloud/compute_v1/__init__.py index 9d541fb82..367e6c4ce 100644 --- a/google/cloud/compute_v1/__init__.py +++ b/google/cloud/compute_v1/__init__.py @@ -1152,6 +1152,7 @@ "AcceleratorType", "AcceleratorTypeAggregatedList", "AcceleratorTypeList", + "AcceleratorTypesClient", "AcceleratorTypesScopedList", "Accelerators", "AccessConfig", @@ -1360,7 +1361,6 @@ "DiskTypesClient", "DiskTypesScopedList", "DisksAddResourcePoliciesRequest", - "DisksClient", "DisksRemoveResourcePoliciesRequest", "DisksResizeRequest", "DisksScopedList", @@ -2271,5 +2271,5 @@ "ZoneSetLabelsRequest", "ZoneSetPolicyRequest", "ZonesClient", - "AcceleratorTypesClient", + "DisksClient", ) diff --git a/google/cloud/compute_v1/services/accelerator_types/client.py b/google/cloud/compute_v1/services/accelerator_types/client.py index 4ed3def1d..b55bab926 100644 --- a/google/cloud/compute_v1/services/accelerator_types/client.py +++ b/google/cloud/compute_v1/services/accelerator_types/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.accelerator_types import pagers from google.cloud.compute_v1.types import compute from .transports.base import AcceleratorTypesTransport, DEFAULT_CLIENT_INFO @@ -268,21 +269,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -325,7 +322,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -338,7 +335,7 @@ def aggregated_list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.AcceleratorTypeAggregatedList: + ) -> pagers.AggregatedListPager: r"""Retrieves an aggregated list of accelerator types. Args: @@ -359,7 +356,10 @@ def aggregated_list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.AcceleratorTypeAggregatedList: + google.cloud.compute_v1.services.accelerator_types.pagers.AggregatedListPager: + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -392,6 +392,12 @@ def aggregated_list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.AggregatedListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response @@ -497,7 +503,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.AcceleratorTypeList: + ) -> pagers.ListPager: r"""Retrieves a list of accelerator types that are available to the specified project. @@ -526,8 +532,12 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.AcceleratorTypeList: + google.cloud.compute_v1.services.accelerator_types.pagers.ListPager: Contains a list of accelerator types. + Iterating over this object will yield + results and resolve additional pages + automatically. + """ # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have @@ -561,6 +571,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/accelerator_types/pagers.py b/google/cloud/compute_v1/services/accelerator_types/pagers.py new file mode 100644 index 000000000..284b35519 --- /dev/null +++ b/google/cloud/compute_v1/services/accelerator_types/pagers.py @@ -0,0 +1,156 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class AggregatedListPager: + """A pager for iterating through ``aggregated_list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.AcceleratorTypeAggregatedList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``AggregatedList`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.AcceleratorTypeAggregatedList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.AcceleratorTypeAggregatedList], + request: compute.AggregatedListAcceleratorTypesRequest, + response: compute.AcceleratorTypeAggregatedList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.AggregatedListAcceleratorTypesRequest): + The initial request object. + response (google.cloud.compute_v1.types.AcceleratorTypeAggregatedList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.AggregatedListAcceleratorTypesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.AcceleratorTypeAggregatedList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[Tuple[str, compute.AcceleratorTypesScopedList]]: + for page in self.pages: + yield from page.items.items() + + def get(self, key: str) -> Optional[compute.AcceleratorTypesScopedList]: + return self._response.items.get(key) + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.AcceleratorTypeList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.AcceleratorTypeList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.AcceleratorTypeList], + request: compute.ListAcceleratorTypesRequest, + response: compute.AcceleratorTypeList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListAcceleratorTypesRequest): + The initial request object. + response (google.cloud.compute_v1.types.AcceleratorTypeList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListAcceleratorTypesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.AcceleratorTypeList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.AcceleratorType]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/accelerator_types/transports/rest.py b/google/cloud/compute_v1/services/accelerator_types/transports/rest.py index 942f0cfa0..d60589b40 100644 --- a/google/cloud/compute_v1/services/accelerator_types/transports/rest.py +++ b/google/cloud/compute_v1/services/accelerator_types/transports/rest.py @@ -55,7 +55,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -74,8 +74,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -90,6 +91,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def aggregated_list( self, @@ -123,12 +126,12 @@ def aggregated_list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, - "orderBy": request.order_by, "includeAllScopes": request.include_all_scopes, - "returnPartialSuccess": request.return_partial_success, "maxResults": request.max_results, + "orderBy": request.order_by, + "pageToken": request.page_token, + "returnPartialSuccess": request.return_partial_success, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -141,6 +144,9 @@ def aggregated_list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.AcceleratorTypeAggregatedList.from_json(response.content) @@ -198,6 +204,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.AcceleratorType.from_json(response.content) @@ -233,11 +242,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -250,6 +259,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.AcceleratorTypeList.from_json(response.content) diff --git a/google/cloud/compute_v1/services/addresses/client.py b/google/cloud/compute_v1/services/addresses/client.py index d7d7076e8..fd38a15d3 100644 --- a/google/cloud/compute_v1/services/addresses/client.py +++ b/google/cloud/compute_v1/services/addresses/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.addresses import pagers from google.cloud.compute_v1.types import compute from .transports.base import AddressesTransport, DEFAULT_CLIENT_INFO @@ -264,21 +265,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -321,7 +318,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -334,7 +331,7 @@ def aggregated_list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.AddressAggregatedList: + ) -> pagers.AggregatedListPager: r"""Retrieves an aggregated list of addresses. Args: @@ -355,7 +352,10 @@ def aggregated_list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.AddressAggregatedList: + google.cloud.compute_v1.services.addresses.pagers.AggregatedListPager: + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -388,6 +388,12 @@ def aggregated_list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.AggregatedListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response @@ -726,7 +732,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.AddressList: + ) -> pagers.ListPager: r"""Retrieves a list of addresses contained within the specified region. @@ -752,8 +758,12 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.AddressList: + google.cloud.compute_v1.services.addresses.pagers.ListPager: Contains a list of addresses. + Iterating over this object will yield + results and resolve additional pages + automatically. + """ # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have @@ -787,6 +797,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/addresses/pagers.py b/google/cloud/compute_v1/services/addresses/pagers.py new file mode 100644 index 000000000..53c4a0b49 --- /dev/null +++ b/google/cloud/compute_v1/services/addresses/pagers.py @@ -0,0 +1,156 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class AggregatedListPager: + """A pager for iterating through ``aggregated_list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.AddressAggregatedList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``AggregatedList`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.AddressAggregatedList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.AddressAggregatedList], + request: compute.AggregatedListAddressesRequest, + response: compute.AddressAggregatedList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.AggregatedListAddressesRequest): + The initial request object. + response (google.cloud.compute_v1.types.AddressAggregatedList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.AggregatedListAddressesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.AddressAggregatedList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[Tuple[str, compute.AddressesScopedList]]: + for page in self.pages: + yield from page.items.items() + + def get(self, key: str) -> Optional[compute.AddressesScopedList]: + return self._response.items.get(key) + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.AddressList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.AddressList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.AddressList], + request: compute.ListAddressesRequest, + response: compute.AddressList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListAddressesRequest): + The initial request object. + response (google.cloud.compute_v1.types.AddressList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListAddressesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.AddressList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.Address]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/addresses/transports/rest.py b/google/cloud/compute_v1/services/addresses/transports/rest.py index b840e8821..cd3f80cf6 100644 --- a/google/cloud/compute_v1/services/addresses/transports/rest.py +++ b/google/cloud/compute_v1/services/addresses/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def aggregated_list( self, @@ -122,12 +125,12 @@ def aggregated_list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, - "orderBy": request.order_by, "includeAllScopes": request.include_all_scopes, - "returnPartialSuccess": request.return_partial_success, "maxResults": request.max_results, + "orderBy": request.order_by, + "pageToken": request.page_token, + "returnPartialSuccess": request.return_partial_success, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -140,6 +143,9 @@ def aggregated_list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.AddressAggregatedList.from_json(response.content) @@ -219,6 +225,9 @@ def delete( # Send the request response = self._session.delete(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -295,6 +304,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Address.from_json(response.content) @@ -351,7 +363,9 @@ def insert( # Jsonify the request body body = compute.Address.to_json( - request.address_resource, including_default_value_fields=False + request.address_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -374,7 +388,10 @@ def insert( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -410,11 +427,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -427,6 +444,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.AddressList.from_json(response.content) diff --git a/google/cloud/compute_v1/services/autoscalers/client.py b/google/cloud/compute_v1/services/autoscalers/client.py index ee99c1d62..a110b9dbd 100644 --- a/google/cloud/compute_v1/services/autoscalers/client.py +++ b/google/cloud/compute_v1/services/autoscalers/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.autoscalers import pagers from google.cloud.compute_v1.types import compute from .transports.base import AutoscalersTransport, DEFAULT_CLIENT_INFO @@ -264,21 +265,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -321,7 +318,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -334,7 +331,7 @@ def aggregated_list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.AutoscalerAggregatedList: + ) -> pagers.AggregatedListPager: r"""Retrieves an aggregated list of autoscalers. Args: @@ -355,7 +352,10 @@ def aggregated_list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.AutoscalerAggregatedList: + google.cloud.compute_v1.services.autoscalers.pagers.AggregatedListPager: + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -388,6 +388,12 @@ def aggregated_list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.AggregatedListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response @@ -715,7 +721,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.AutoscalerList: + ) -> pagers.ListPager: r"""Retrieves a list of autoscalers contained within the specified zone. @@ -742,9 +748,12 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.AutoscalerList: + google.cloud.compute_v1.services.autoscalers.pagers.ListPager: Contains a list of Autoscaler resources. + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -779,6 +788,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/autoscalers/pagers.py b/google/cloud/compute_v1/services/autoscalers/pagers.py new file mode 100644 index 000000000..7fab36d45 --- /dev/null +++ b/google/cloud/compute_v1/services/autoscalers/pagers.py @@ -0,0 +1,156 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class AggregatedListPager: + """A pager for iterating through ``aggregated_list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.AutoscalerAggregatedList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``AggregatedList`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.AutoscalerAggregatedList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.AutoscalerAggregatedList], + request: compute.AggregatedListAutoscalersRequest, + response: compute.AutoscalerAggregatedList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.AggregatedListAutoscalersRequest): + The initial request object. + response (google.cloud.compute_v1.types.AutoscalerAggregatedList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.AggregatedListAutoscalersRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.AutoscalerAggregatedList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[Tuple[str, compute.AutoscalersScopedList]]: + for page in self.pages: + yield from page.items.items() + + def get(self, key: str) -> Optional[compute.AutoscalersScopedList]: + return self._response.items.get(key) + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.AutoscalerList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.AutoscalerList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.AutoscalerList], + request: compute.ListAutoscalersRequest, + response: compute.AutoscalerList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListAutoscalersRequest): + The initial request object. + response (google.cloud.compute_v1.types.AutoscalerList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListAutoscalersRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.AutoscalerList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.Autoscaler]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/autoscalers/transports/rest.py b/google/cloud/compute_v1/services/autoscalers/transports/rest.py index 5784ae6a4..e5cc0414a 100644 --- a/google/cloud/compute_v1/services/autoscalers/transports/rest.py +++ b/google/cloud/compute_v1/services/autoscalers/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def aggregated_list( self, @@ -122,12 +125,12 @@ def aggregated_list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, - "orderBy": request.order_by, "includeAllScopes": request.include_all_scopes, - "returnPartialSuccess": request.return_partial_success, "maxResults": request.max_results, + "orderBy": request.order_by, + "pageToken": request.page_token, + "returnPartialSuccess": request.return_partial_success, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -140,6 +143,9 @@ def aggregated_list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.AutoscalerAggregatedList.from_json(response.content) @@ -219,6 +225,9 @@ def delete( # Send the request response = self._session.delete(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -288,6 +297,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Autoscaler.from_json(response.content) @@ -344,7 +356,9 @@ def insert( # Jsonify the request body body = compute.Autoscaler.to_json( - request.autoscaler_resource, including_default_value_fields=False + request.autoscaler_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -367,7 +381,10 @@ def insert( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -406,11 +423,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -423,6 +440,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.AutoscalerList.from_json(response.content) @@ -479,7 +499,9 @@ def patch( # Jsonify the request body body = compute.Autoscaler.to_json( - request.autoscaler_resource, including_default_value_fields=False + request.autoscaler_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -491,8 +513,8 @@ def patch( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "requestId": request.request_id, "autoscaler": request.autoscaler, + "requestId": request.request_id, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -503,7 +525,10 @@ def patch( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.patch(url, json=body,) + response = self._session.patch(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -561,7 +586,9 @@ def update( # Jsonify the request body body = compute.Autoscaler.to_json( - request.autoscaler_resource, including_default_value_fields=False + request.autoscaler_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -573,8 +600,8 @@ def update( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "requestId": request.request_id, "autoscaler": request.autoscaler, + "requestId": request.request_id, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -585,7 +612,10 @@ def update( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.put(url, json=body,) + response = self._session.put(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) diff --git a/google/cloud/compute_v1/services/backend_buckets/client.py b/google/cloud/compute_v1/services/backend_buckets/client.py index 62ba89113..bafee1faf 100644 --- a/google/cloud/compute_v1/services/backend_buckets/client.py +++ b/google/cloud/compute_v1/services/backend_buckets/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.backend_buckets import pagers from google.cloud.compute_v1.types import compute from .transports.base import BackendBucketsTransport, DEFAULT_CLIENT_INFO @@ -266,21 +267,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -323,7 +320,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -842,7 +839,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.BackendBucketList: + ) -> pagers.ListPager: r"""Retrieves the list of BackendBucket resources available to the specified project. @@ -864,9 +861,12 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.BackendBucketList: + google.cloud.compute_v1.services.backend_buckets.pagers.ListPager: Contains a list of BackendBucket resources. + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -899,6 +899,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/backend_buckets/pagers.py b/google/cloud/compute_v1/services/backend_buckets/pagers.py new file mode 100644 index 000000000..a71b74ff8 --- /dev/null +++ b/google/cloud/compute_v1/services/backend_buckets/pagers.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.BackendBucketList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.BackendBucketList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.BackendBucketList], + request: compute.ListBackendBucketsRequest, + response: compute.BackendBucketList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListBackendBucketsRequest): + The initial request object. + response (google.cloud.compute_v1.types.BackendBucketList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListBackendBucketsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.BackendBucketList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.BackendBucket]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/backend_buckets/transports/rest.py b/google/cloud/compute_v1/services/backend_buckets/transports/rest.py index d9a112a2b..c896529b1 100644 --- a/google/cloud/compute_v1/services/backend_buckets/transports/rest.py +++ b/google/cloud/compute_v1/services/backend_buckets/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def add_signed_url_key( self, @@ -143,7 +146,9 @@ def add_signed_url_key( # Jsonify the request body body = compute.SignedUrlKey.to_json( - request.signed_url_key_resource, including_default_value_fields=False + request.signed_url_key_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -168,7 +173,10 @@ def add_signed_url_key( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -248,6 +256,9 @@ def delete( # Send the request response = self._session.delete(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -327,6 +338,9 @@ def delete_signed_url_key( # Send the request response = self._session.post(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -381,6 +395,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.BackendBucket.from_json(response.content) @@ -437,7 +454,9 @@ def insert( # Jsonify the request body body = compute.BackendBucket.to_json( - request.backend_bucket_resource, including_default_value_fields=False + request.backend_bucket_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -460,7 +479,10 @@ def insert( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -499,11 +521,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -516,6 +538,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.BackendBucketList.from_json(response.content) @@ -572,7 +597,9 @@ def patch( # Jsonify the request body body = compute.BackendBucket.to_json( - request.backend_bucket_resource, including_default_value_fields=False + request.backend_bucket_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -597,7 +624,10 @@ def patch( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.patch(url, json=body,) + response = self._session.patch(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -655,7 +685,9 @@ def update( # Jsonify the request body body = compute.BackendBucket.to_json( - request.backend_bucket_resource, including_default_value_fields=False + request.backend_bucket_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -680,7 +712,10 @@ def update( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.put(url, json=body,) + response = self._session.put(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) diff --git a/google/cloud/compute_v1/services/backend_services/client.py b/google/cloud/compute_v1/services/backend_services/client.py index 0c79ed5ef..01c186923 100644 --- a/google/cloud/compute_v1/services/backend_services/client.py +++ b/google/cloud/compute_v1/services/backend_services/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.backend_services import pagers from google.cloud.compute_v1.types import compute from .transports.base import BackendServicesTransport, DEFAULT_CLIENT_INFO @@ -266,21 +267,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -323,7 +320,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -447,7 +444,7 @@ def aggregated_list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.BackendServiceAggregatedList: + ) -> pagers.AggregatedListPager: r"""Retrieves the list of all BackendService resources, regional and global, available to the specified project. @@ -471,9 +468,12 @@ def aggregated_list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.BackendServiceAggregatedList: + google.cloud.compute_v1.services.backend_services.pagers.AggregatedListPager: Contains a list of BackendServicesScopedList. + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -506,6 +506,12 @@ def aggregated_list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.AggregatedListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response @@ -1020,7 +1026,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.BackendServiceList: + ) -> pagers.ListPager: r"""Retrieves the list of BackendService resources available to the specified project. @@ -1042,9 +1048,12 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.BackendServiceList: + google.cloud.compute_v1.services.backend_services.pagers.ListPager: Contains a list of BackendService resources. + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -1077,6 +1086,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/backend_services/pagers.py b/google/cloud/compute_v1/services/backend_services/pagers.py new file mode 100644 index 000000000..5b3c487f8 --- /dev/null +++ b/google/cloud/compute_v1/services/backend_services/pagers.py @@ -0,0 +1,156 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class AggregatedListPager: + """A pager for iterating through ``aggregated_list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.BackendServiceAggregatedList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``AggregatedList`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.BackendServiceAggregatedList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.BackendServiceAggregatedList], + request: compute.AggregatedListBackendServicesRequest, + response: compute.BackendServiceAggregatedList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.AggregatedListBackendServicesRequest): + The initial request object. + response (google.cloud.compute_v1.types.BackendServiceAggregatedList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.AggregatedListBackendServicesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.BackendServiceAggregatedList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[Tuple[str, compute.BackendServicesScopedList]]: + for page in self.pages: + yield from page.items.items() + + def get(self, key: str) -> Optional[compute.BackendServicesScopedList]: + return self._response.items.get(key) + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.BackendServiceList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.BackendServiceList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.BackendServiceList], + request: compute.ListBackendServicesRequest, + response: compute.BackendServiceList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListBackendServicesRequest): + The initial request object. + response (google.cloud.compute_v1.types.BackendServiceList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListBackendServicesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.BackendServiceList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.BackendService]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/backend_services/transports/rest.py b/google/cloud/compute_v1/services/backend_services/transports/rest.py index b67fd2cf1..080c74112 100644 --- a/google/cloud/compute_v1/services/backend_services/transports/rest.py +++ b/google/cloud/compute_v1/services/backend_services/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def add_signed_url_key( self, @@ -143,7 +146,9 @@ def add_signed_url_key( # Jsonify the request body body = compute.SignedUrlKey.to_json( - request.signed_url_key_resource, including_default_value_fields=False + request.signed_url_key_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -168,7 +173,10 @@ def add_signed_url_key( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -207,12 +215,12 @@ def aggregated_list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, - "orderBy": request.order_by, "includeAllScopes": request.include_all_scopes, - "returnPartialSuccess": request.return_partial_success, "maxResults": request.max_results, + "orderBy": request.order_by, + "pageToken": request.page_token, + "returnPartialSuccess": request.return_partial_success, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -225,6 +233,9 @@ def aggregated_list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.BackendServiceAggregatedList.from_json(response.content) @@ -303,6 +314,9 @@ def delete( # Send the request response = self._session.delete(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -382,6 +396,9 @@ def delete_signed_url_key( # Send the request response = self._session.post(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -452,6 +469,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.BackendService.from_json(response.content) @@ -482,6 +502,7 @@ def get_health( body = compute.ResourceGroupReference.to_json( request.resource_group_reference_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -504,7 +525,10 @@ def get_health( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.BackendServiceGroupHealth.from_json(response.content) @@ -562,7 +586,9 @@ def insert( # Jsonify the request body body = compute.BackendService.to_json( - request.backend_service_resource, including_default_value_fields=False + request.backend_service_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -585,7 +611,10 @@ def insert( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -624,11 +653,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -641,6 +670,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.BackendServiceList.from_json(response.content) @@ -697,7 +729,9 @@ def patch( # Jsonify the request body body = compute.BackendService.to_json( - request.backend_service_resource, including_default_value_fields=False + request.backend_service_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -722,7 +756,10 @@ def patch( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.patch(url, json=body,) + response = self._session.patch(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -782,6 +819,7 @@ def set_security_policy( body = compute.SecurityPolicyReference.to_json( request.security_policy_reference_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -806,7 +844,10 @@ def set_security_policy( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -864,7 +905,9 @@ def update( # Jsonify the request body body = compute.BackendService.to_json( - request.backend_service_resource, including_default_value_fields=False + request.backend_service_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -889,7 +932,10 @@ def update( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.put(url, json=body,) + response = self._session.put(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) diff --git a/google/cloud/compute_v1/services/disk_types/client.py b/google/cloud/compute_v1/services/disk_types/client.py index 461af0978..296c597b3 100644 --- a/google/cloud/compute_v1/services/disk_types/client.py +++ b/google/cloud/compute_v1/services/disk_types/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.disk_types import pagers from google.cloud.compute_v1.types import compute from .transports.base import DiskTypesTransport, DEFAULT_CLIENT_INFO @@ -264,21 +265,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -321,7 +318,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -334,7 +331,7 @@ def aggregated_list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.DiskTypeAggregatedList: + ) -> pagers.AggregatedListPager: r"""Retrieves an aggregated list of disk types. Args: @@ -355,7 +352,10 @@ def aggregated_list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.DiskTypeAggregatedList: + google.cloud.compute_v1.services.disk_types.pagers.AggregatedListPager: + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -388,6 +388,12 @@ def aggregated_list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.AggregatedListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response @@ -504,7 +510,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.DiskTypeList: + ) -> pagers.ListPager: r"""Retrieves a list of disk types available to the specified project. @@ -532,8 +538,12 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.DiskTypeList: + google.cloud.compute_v1.services.disk_types.pagers.ListPager: Contains a list of disk types. + Iterating over this object will yield + results and resolve additional pages + automatically. + """ # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have @@ -567,6 +577,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/disk_types/pagers.py b/google/cloud/compute_v1/services/disk_types/pagers.py new file mode 100644 index 000000000..5c0dbb19d --- /dev/null +++ b/google/cloud/compute_v1/services/disk_types/pagers.py @@ -0,0 +1,156 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class AggregatedListPager: + """A pager for iterating through ``aggregated_list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.DiskTypeAggregatedList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``AggregatedList`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.DiskTypeAggregatedList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.DiskTypeAggregatedList], + request: compute.AggregatedListDiskTypesRequest, + response: compute.DiskTypeAggregatedList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.AggregatedListDiskTypesRequest): + The initial request object. + response (google.cloud.compute_v1.types.DiskTypeAggregatedList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.AggregatedListDiskTypesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.DiskTypeAggregatedList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[Tuple[str, compute.DiskTypesScopedList]]: + for page in self.pages: + yield from page.items.items() + + def get(self, key: str) -> Optional[compute.DiskTypesScopedList]: + return self._response.items.get(key) + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.DiskTypeList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.DiskTypeList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.DiskTypeList], + request: compute.ListDiskTypesRequest, + response: compute.DiskTypeList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListDiskTypesRequest): + The initial request object. + response (google.cloud.compute_v1.types.DiskTypeList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListDiskTypesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.DiskTypeList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.DiskType]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/disk_types/transports/rest.py b/google/cloud/compute_v1/services/disk_types/transports/rest.py index 47ddb0815..212eeadbf 100644 --- a/google/cloud/compute_v1/services/disk_types/transports/rest.py +++ b/google/cloud/compute_v1/services/disk_types/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def aggregated_list( self, @@ -122,12 +125,12 @@ def aggregated_list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, - "orderBy": request.order_by, "includeAllScopes": request.include_all_scopes, - "returnPartialSuccess": request.return_partial_success, "maxResults": request.max_results, + "orderBy": request.order_by, + "pageToken": request.page_token, + "returnPartialSuccess": request.return_partial_success, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -140,6 +143,9 @@ def aggregated_list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.DiskTypeAggregatedList.from_json(response.content) @@ -208,6 +214,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.DiskType.from_json(response.content) @@ -242,11 +251,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -259,6 +268,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.DiskTypeList.from_json(response.content) diff --git a/google/cloud/compute_v1/services/disks/client.py b/google/cloud/compute_v1/services/disks/client.py index 17c1f1ee7..c5dad0ae3 100644 --- a/google/cloud/compute_v1/services/disks/client.py +++ b/google/cloud/compute_v1/services/disks/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.disks import pagers from google.cloud.compute_v1.types import compute from .transports.base import DisksTransport, DEFAULT_CLIENT_INFO @@ -264,21 +265,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -321,7 +318,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -456,7 +453,7 @@ def aggregated_list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.DiskAggregatedList: + ) -> pagers.AggregatedListPager: r"""Retrieves an aggregated list of persistent disks. Args: @@ -477,7 +474,10 @@ def aggregated_list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.DiskAggregatedList: + google.cloud.compute_v1.services.disks.pagers.AggregatedListPager: + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -510,6 +510,12 @@ def aggregated_list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.AggregatedListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response @@ -1110,7 +1116,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.DiskList: + ) -> pagers.ListPager: r"""Retrieves a list of persistent disks contained within the specified zone. @@ -1138,8 +1144,12 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.DiskList: + google.cloud.compute_v1.services.disks.pagers.ListPager: A list of Disk resources. + Iterating over this object will yield + results and resolve additional pages + automatically. + """ # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have @@ -1173,6 +1183,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/disks/pagers.py b/google/cloud/compute_v1/services/disks/pagers.py new file mode 100644 index 000000000..665157ce9 --- /dev/null +++ b/google/cloud/compute_v1/services/disks/pagers.py @@ -0,0 +1,156 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class AggregatedListPager: + """A pager for iterating through ``aggregated_list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.DiskAggregatedList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``AggregatedList`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.DiskAggregatedList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.DiskAggregatedList], + request: compute.AggregatedListDisksRequest, + response: compute.DiskAggregatedList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.AggregatedListDisksRequest): + The initial request object. + response (google.cloud.compute_v1.types.DiskAggregatedList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.AggregatedListDisksRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.DiskAggregatedList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[Tuple[str, compute.DisksScopedList]]: + for page in self.pages: + yield from page.items.items() + + def get(self, key: str) -> Optional[compute.DisksScopedList]: + return self._response.items.get(key) + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.DiskList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.DiskList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.DiskList], + request: compute.ListDisksRequest, + response: compute.DiskList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListDisksRequest): + The initial request object. + response (google.cloud.compute_v1.types.DiskList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListDisksRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.DiskList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.Disk]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/disks/transports/rest.py b/google/cloud/compute_v1/services/disks/transports/rest.py index 33a33de36..a45810aef 100644 --- a/google/cloud/compute_v1/services/disks/transports/rest.py +++ b/google/cloud/compute_v1/services/disks/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def add_resource_policies( self, @@ -145,6 +148,7 @@ def add_resource_policies( body = compute.DisksAddResourcePoliciesRequest.to_json( request.disks_add_resource_policies_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -170,7 +174,10 @@ def add_resource_policies( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -207,12 +214,12 @@ def aggregated_list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, - "orderBy": request.order_by, "includeAllScopes": request.include_all_scopes, - "returnPartialSuccess": request.return_partial_success, "maxResults": request.max_results, + "orderBy": request.order_by, + "pageToken": request.page_token, + "returnPartialSuccess": request.return_partial_success, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -225,6 +232,9 @@ def aggregated_list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.DiskAggregatedList.from_json(response.content) @@ -281,7 +291,9 @@ def create_snapshot( # Jsonify the request body body = compute.Snapshot.to_json( - request.snapshot_resource, including_default_value_fields=False + request.snapshot_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -296,8 +308,8 @@ def create_snapshot( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "requestId": request.request_id, "guestFlush": request.guest_flush, + "requestId": request.request_id, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -308,7 +320,10 @@ def create_snapshot( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -388,6 +403,9 @@ def delete( # Send the request response = self._session.delete(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -456,6 +474,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Disk.from_json(response.content) @@ -557,6 +578,9 @@ def get_iam_policy( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Policy.from_json(response.content) @@ -612,7 +636,9 @@ def insert( # Jsonify the request body body = compute.Disk.to_json( - request.disk_resource, including_default_value_fields=False + request.disk_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -636,7 +662,10 @@ def insert( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -672,11 +701,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -689,6 +718,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.DiskList.from_json(response.content) @@ -747,6 +779,7 @@ def remove_resource_policies( body = compute.DisksRemoveResourcePoliciesRequest.to_json( request.disks_remove_resource_policies_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -772,7 +805,10 @@ def remove_resource_policies( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -829,7 +865,9 @@ def resize( # Jsonify the request body body = compute.DisksResizeRequest.to_json( - request.disks_resize_request_resource, including_default_value_fields=False + request.disks_resize_request_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -855,7 +893,10 @@ def resize( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -937,6 +978,7 @@ def set_iam_policy( body = compute.ZoneSetPolicyRequest.to_json( request.zone_set_policy_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -960,7 +1002,10 @@ def set_iam_policy( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Policy.from_json(response.content) @@ -1020,6 +1065,7 @@ def set_labels( body = compute.ZoneSetLabelsRequest.to_json( request.zone_set_labels_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -1045,7 +1091,10 @@ def set_labels( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -1077,6 +1126,7 @@ def test_iam_permissions( body = compute.TestPermissionsRequest.to_json( request.test_permissions_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -1100,7 +1150,10 @@ def test_iam_permissions( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.TestPermissionsResponse.from_json(response.content) diff --git a/google/cloud/compute_v1/services/external_vpn_gateways/client.py b/google/cloud/compute_v1/services/external_vpn_gateways/client.py index 3d33fb8d8..f541a900f 100644 --- a/google/cloud/compute_v1/services/external_vpn_gateways/client.py +++ b/google/cloud/compute_v1/services/external_vpn_gateways/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.external_vpn_gateways import pagers from google.cloud.compute_v1.types import compute from .transports.base import ExternalVpnGatewaysTransport, DEFAULT_CLIENT_INFO @@ -268,21 +269,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -325,7 +322,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -630,7 +627,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.ExternalVpnGatewayList: + ) -> pagers.ListPager: r"""Retrieves the list of ExternalVpnGateway available to the specified project. @@ -652,9 +649,12 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.ExternalVpnGatewayList: + google.cloud.compute_v1.services.external_vpn_gateways.pagers.ListPager: Response to the list request, and contains a list of externalVpnGateways. + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -687,6 +687,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/external_vpn_gateways/pagers.py b/google/cloud/compute_v1/services/external_vpn_gateways/pagers.py new file mode 100644 index 000000000..f60ab239f --- /dev/null +++ b/google/cloud/compute_v1/services/external_vpn_gateways/pagers.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.ExternalVpnGatewayList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.ExternalVpnGatewayList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.ExternalVpnGatewayList], + request: compute.ListExternalVpnGatewaysRequest, + response: compute.ExternalVpnGatewayList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListExternalVpnGatewaysRequest): + The initial request object. + response (google.cloud.compute_v1.types.ExternalVpnGatewayList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListExternalVpnGatewaysRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.ExternalVpnGatewayList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.ExternalVpnGateway]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/external_vpn_gateways/transports/rest.py b/google/cloud/compute_v1/services/external_vpn_gateways/transports/rest.py index ced89adcd..43da3694e 100644 --- a/google/cloud/compute_v1/services/external_vpn_gateways/transports/rest.py +++ b/google/cloud/compute_v1/services/external_vpn_gateways/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def delete( self, @@ -165,6 +168,9 @@ def delete( # Send the request response = self._session.delete(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -228,6 +234,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.ExternalVpnGateway.from_json(response.content) @@ -284,7 +293,9 @@ def insert( # Jsonify the request body body = compute.ExternalVpnGateway.to_json( - request.external_vpn_gateway_resource, including_default_value_fields=False + request.external_vpn_gateway_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -307,7 +318,10 @@ def insert( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -346,11 +360,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -363,6 +377,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.ExternalVpnGatewayList.from_json(response.content) @@ -421,6 +438,7 @@ def set_labels( body = compute.GlobalSetLabelsRequest.to_json( request.global_set_labels_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -441,7 +459,10 @@ def set_labels( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -473,6 +494,7 @@ def test_iam_permissions( body = compute.TestPermissionsRequest.to_json( request.test_permissions_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -493,7 +515,10 @@ def test_iam_permissions( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.TestPermissionsResponse.from_json(response.content) diff --git a/google/cloud/compute_v1/services/firewalls/client.py b/google/cloud/compute_v1/services/firewalls/client.py index 22b098b20..c55222076 100644 --- a/google/cloud/compute_v1/services/firewalls/client.py +++ b/google/cloud/compute_v1/services/firewalls/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.firewalls import pagers from google.cloud.compute_v1.types import compute from .transports.base import FirewallsTransport, DEFAULT_CLIENT_INFO @@ -264,21 +265,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -321,7 +318,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -608,7 +605,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.FirewallList: + ) -> pagers.ListPager: r"""Retrieves the list of firewall rules available to the specified project. @@ -629,8 +626,12 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.FirewallList: + google.cloud.compute_v1.services.firewalls.pagers.ListPager: Contains a list of firewalls. + Iterating over this object will yield + results and resolve additional pages + automatically. + """ # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have @@ -662,6 +663,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/firewalls/pagers.py b/google/cloud/compute_v1/services/firewalls/pagers.py new file mode 100644 index 000000000..ff2937ad2 --- /dev/null +++ b/google/cloud/compute_v1/services/firewalls/pagers.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.FirewallList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.FirewallList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.FirewallList], + request: compute.ListFirewallsRequest, + response: compute.FirewallList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListFirewallsRequest): + The initial request object. + response (google.cloud.compute_v1.types.FirewallList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListFirewallsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.FirewallList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.Firewall]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/firewalls/transports/rest.py b/google/cloud/compute_v1/services/firewalls/transports/rest.py index 9839e3674..2018ae10d 100644 --- a/google/cloud/compute_v1/services/firewalls/transports/rest.py +++ b/google/cloud/compute_v1/services/firewalls/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def delete( self, @@ -163,6 +166,9 @@ def delete( # Send the request response = self._session.delete(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -213,6 +219,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Firewall.from_json(response.content) @@ -269,7 +278,9 @@ def insert( # Jsonify the request body body = compute.Firewall.to_json( - request.firewall_resource, including_default_value_fields=False + request.firewall_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -292,7 +303,10 @@ def insert( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -328,11 +342,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -345,6 +359,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.FirewallList.from_json(response.content) @@ -401,7 +418,9 @@ def patch( # Jsonify the request body body = compute.Firewall.to_json( - request.firewall_resource, including_default_value_fields=False + request.firewall_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -424,7 +443,10 @@ def patch( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.patch(url, json=body,) + response = self._session.patch(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -482,7 +504,9 @@ def update( # Jsonify the request body body = compute.Firewall.to_json( - request.firewall_resource, including_default_value_fields=False + request.firewall_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -505,7 +529,10 @@ def update( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.put(url, json=body,) + response = self._session.put(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) diff --git a/google/cloud/compute_v1/services/forwarding_rules/client.py b/google/cloud/compute_v1/services/forwarding_rules/client.py index bd77765de..f8f13b563 100644 --- a/google/cloud/compute_v1/services/forwarding_rules/client.py +++ b/google/cloud/compute_v1/services/forwarding_rules/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.forwarding_rules import pagers from google.cloud.compute_v1.types import compute from .transports.base import ForwardingRulesTransport, DEFAULT_CLIENT_INFO @@ -266,21 +267,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -323,7 +320,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -336,7 +333,7 @@ def aggregated_list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.ForwardingRuleAggregatedList: + ) -> pagers.AggregatedListPager: r"""Retrieves an aggregated list of forwarding rules. Args: @@ -357,7 +354,10 @@ def aggregated_list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.ForwardingRuleAggregatedList: + google.cloud.compute_v1.services.forwarding_rules.pagers.AggregatedListPager: + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -390,6 +390,12 @@ def aggregated_list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.AggregatedListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response @@ -729,7 +735,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.ForwardingRuleList: + ) -> pagers.ListPager: r"""Retrieves a list of ForwardingRule resources available to the specified project and region. @@ -758,9 +764,12 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.ForwardingRuleList: + google.cloud.compute_v1.services.forwarding_rules.pagers.ListPager: Contains a list of ForwardingRule resources. + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -795,6 +804,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/forwarding_rules/pagers.py b/google/cloud/compute_v1/services/forwarding_rules/pagers.py new file mode 100644 index 000000000..0320a5ac5 --- /dev/null +++ b/google/cloud/compute_v1/services/forwarding_rules/pagers.py @@ -0,0 +1,156 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class AggregatedListPager: + """A pager for iterating through ``aggregated_list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.ForwardingRuleAggregatedList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``AggregatedList`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.ForwardingRuleAggregatedList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.ForwardingRuleAggregatedList], + request: compute.AggregatedListForwardingRulesRequest, + response: compute.ForwardingRuleAggregatedList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.AggregatedListForwardingRulesRequest): + The initial request object. + response (google.cloud.compute_v1.types.ForwardingRuleAggregatedList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.AggregatedListForwardingRulesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.ForwardingRuleAggregatedList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[Tuple[str, compute.ForwardingRulesScopedList]]: + for page in self.pages: + yield from page.items.items() + + def get(self, key: str) -> Optional[compute.ForwardingRulesScopedList]: + return self._response.items.get(key) + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.ForwardingRuleList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.ForwardingRuleList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.ForwardingRuleList], + request: compute.ListForwardingRulesRequest, + response: compute.ForwardingRuleList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListForwardingRulesRequest): + The initial request object. + response (google.cloud.compute_v1.types.ForwardingRuleList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListForwardingRulesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.ForwardingRuleList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.ForwardingRule]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/forwarding_rules/transports/rest.py b/google/cloud/compute_v1/services/forwarding_rules/transports/rest.py index cdc77e8e8..ee40d69f9 100644 --- a/google/cloud/compute_v1/services/forwarding_rules/transports/rest.py +++ b/google/cloud/compute_v1/services/forwarding_rules/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def aggregated_list( self, @@ -122,12 +125,12 @@ def aggregated_list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, - "orderBy": request.order_by, "includeAllScopes": request.include_all_scopes, - "returnPartialSuccess": request.return_partial_success, "maxResults": request.max_results, + "orderBy": request.order_by, + "pageToken": request.page_token, + "returnPartialSuccess": request.return_partial_success, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -140,6 +143,9 @@ def aggregated_list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.ForwardingRuleAggregatedList.from_json(response.content) @@ -219,6 +225,9 @@ def delete( # Send the request response = self._session.delete(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -290,6 +299,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.ForwardingRule.from_json(response.content) @@ -346,7 +358,9 @@ def insert( # Jsonify the request body body = compute.ForwardingRule.to_json( - request.forwarding_rule_resource, including_default_value_fields=False + request.forwarding_rule_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -369,7 +383,10 @@ def insert( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -408,11 +425,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -425,6 +442,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.ForwardingRuleList.from_json(response.content) @@ -481,7 +501,9 @@ def patch( # Jsonify the request body body = compute.ForwardingRule.to_json( - request.forwarding_rule_resource, including_default_value_fields=False + request.forwarding_rule_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -507,7 +529,10 @@ def patch( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.patch(url, json=body,) + response = self._session.patch(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -565,7 +590,9 @@ def set_target( # Jsonify the request body body = compute.TargetReference.to_json( - request.target_reference_resource, including_default_value_fields=False + request.target_reference_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -591,7 +618,10 @@ def set_target( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) diff --git a/google/cloud/compute_v1/services/global_addresses/client.py b/google/cloud/compute_v1/services/global_addresses/client.py index 8ceee74af..618993128 100644 --- a/google/cloud/compute_v1/services/global_addresses/client.py +++ b/google/cloud/compute_v1/services/global_addresses/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.global_addresses import pagers from google.cloud.compute_v1.types import compute from .transports.base import GlobalAddressesTransport, DEFAULT_CLIENT_INFO @@ -266,21 +267,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -323,7 +320,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -640,7 +637,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.AddressList: + ) -> pagers.ListPager: r"""Retrieves a list of global addresses. Args: @@ -661,8 +658,12 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.AddressList: + google.cloud.compute_v1.services.global_addresses.pagers.ListPager: Contains a list of addresses. + Iterating over this object will yield + results and resolve additional pages + automatically. + """ # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have @@ -694,6 +695,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/global_addresses/pagers.py b/google/cloud/compute_v1/services/global_addresses/pagers.py new file mode 100644 index 000000000..2428e6bed --- /dev/null +++ b/google/cloud/compute_v1/services/global_addresses/pagers.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.AddressList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.AddressList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.AddressList], + request: compute.ListGlobalAddressesRequest, + response: compute.AddressList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListGlobalAddressesRequest): + The initial request object. + response (google.cloud.compute_v1.types.AddressList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListGlobalAddressesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.AddressList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.Address]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/global_addresses/transports/rest.py b/google/cloud/compute_v1/services/global_addresses/transports/rest.py index 7d05ded55..b9c2f0796 100644 --- a/google/cloud/compute_v1/services/global_addresses/transports/rest.py +++ b/google/cloud/compute_v1/services/global_addresses/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def delete( self, @@ -163,6 +166,9 @@ def delete( # Send the request response = self._session.delete(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -237,6 +243,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Address.from_json(response.content) @@ -293,7 +302,9 @@ def insert( # Jsonify the request body body = compute.Address.to_json( - request.address_resource, including_default_value_fields=False + request.address_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -316,7 +327,10 @@ def insert( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -353,11 +367,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -370,6 +384,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.AddressList.from_json(response.content) diff --git a/google/cloud/compute_v1/services/global_forwarding_rules/client.py b/google/cloud/compute_v1/services/global_forwarding_rules/client.py index 083b8062f..a9f7c1a2d 100644 --- a/google/cloud/compute_v1/services/global_forwarding_rules/client.py +++ b/google/cloud/compute_v1/services/global_forwarding_rules/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.global_forwarding_rules import pagers from google.cloud.compute_v1.types import compute from .transports.base import GlobalForwardingRulesTransport, DEFAULT_CLIENT_INFO @@ -268,21 +269,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -325,7 +322,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -637,7 +634,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.ForwardingRuleList: + ) -> pagers.ListPager: r"""Retrieves a list of GlobalForwardingRule resources available to the specified project. @@ -659,9 +656,12 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.ForwardingRuleList: + google.cloud.compute_v1.services.global_forwarding_rules.pagers.ListPager: Contains a list of ForwardingRule resources. + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -694,6 +694,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/global_forwarding_rules/pagers.py b/google/cloud/compute_v1/services/global_forwarding_rules/pagers.py new file mode 100644 index 000000000..f606d1c7e --- /dev/null +++ b/google/cloud/compute_v1/services/global_forwarding_rules/pagers.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.ForwardingRuleList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.ForwardingRuleList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.ForwardingRuleList], + request: compute.ListGlobalForwardingRulesRequest, + response: compute.ForwardingRuleList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListGlobalForwardingRulesRequest): + The initial request object. + response (google.cloud.compute_v1.types.ForwardingRuleList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListGlobalForwardingRulesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.ForwardingRuleList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.ForwardingRule]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/global_forwarding_rules/transports/rest.py b/google/cloud/compute_v1/services/global_forwarding_rules/transports/rest.py index 15e5b3076..7d26ef92b 100644 --- a/google/cloud/compute_v1/services/global_forwarding_rules/transports/rest.py +++ b/google/cloud/compute_v1/services/global_forwarding_rules/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def delete( self, @@ -165,6 +168,9 @@ def delete( # Send the request response = self._session.delete(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -235,6 +241,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.ForwardingRule.from_json(response.content) @@ -291,7 +300,9 @@ def insert( # Jsonify the request body body = compute.ForwardingRule.to_json( - request.forwarding_rule_resource, including_default_value_fields=False + request.forwarding_rule_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -314,7 +325,10 @@ def insert( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -353,11 +367,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -370,6 +384,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.ForwardingRuleList.from_json(response.content) @@ -426,7 +443,9 @@ def patch( # Jsonify the request body body = compute.ForwardingRule.to_json( - request.forwarding_rule_resource, including_default_value_fields=False + request.forwarding_rule_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -451,7 +470,10 @@ def patch( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.patch(url, json=body,) + response = self._session.patch(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -509,7 +531,9 @@ def set_target( # Jsonify the request body body = compute.TargetReference.to_json( - request.target_reference_resource, including_default_value_fields=False + request.target_reference_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -534,7 +558,10 @@ def set_target( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) diff --git a/google/cloud/compute_v1/services/global_network_endpoint_groups/client.py b/google/cloud/compute_v1/services/global_network_endpoint_groups/client.py index ced1e3bbd..d4a70db19 100644 --- a/google/cloud/compute_v1/services/global_network_endpoint_groups/client.py +++ b/google/cloud/compute_v1/services/global_network_endpoint_groups/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.global_network_endpoint_groups import pagers from google.cloud.compute_v1.types import compute from .transports.base import GlobalNetworkEndpointGroupsTransport, DEFAULT_CLIENT_INFO @@ -270,21 +271,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -327,7 +324,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -888,7 +885,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.NetworkEndpointGroupList: + ) -> pagers.ListPager: r"""Retrieves the list of network endpoint groups that are located in the specified project. @@ -910,7 +907,10 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.NetworkEndpointGroupList: + google.cloud.compute_v1.services.global_network_endpoint_groups.pagers.ListPager: + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -943,6 +943,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response @@ -955,7 +961,7 @@ def list_network_endpoints( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.NetworkEndpointGroupsListNetworkEndpoints: + ) -> pagers.ListNetworkEndpointsPager: r"""Lists the network endpoints in the specified network endpoint group. @@ -986,7 +992,10 @@ def list_network_endpoints( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.NetworkEndpointGroupsListNetworkEndpoints: + google.cloud.compute_v1.services.global_network_endpoint_groups.pagers.ListNetworkEndpointsPager: + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -1025,6 +1034,12 @@ def list_network_endpoints( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListNetworkEndpointsPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/global_network_endpoint_groups/pagers.py b/google/cloud/compute_v1/services/global_network_endpoint_groups/pagers.py new file mode 100644 index 000000000..551f20668 --- /dev/null +++ b/google/cloud/compute_v1/services/global_network_endpoint_groups/pagers.py @@ -0,0 +1,155 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.NetworkEndpointGroupList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.NetworkEndpointGroupList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.NetworkEndpointGroupList], + request: compute.ListGlobalNetworkEndpointGroupsRequest, + response: compute.NetworkEndpointGroupList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListGlobalNetworkEndpointGroupsRequest): + The initial request object. + response (google.cloud.compute_v1.types.NetworkEndpointGroupList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListGlobalNetworkEndpointGroupsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.NetworkEndpointGroupList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.NetworkEndpointGroup]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListNetworkEndpointsPager: + """A pager for iterating through ``list_network_endpoints`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.NetworkEndpointGroupsListNetworkEndpoints` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ListNetworkEndpoints`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.NetworkEndpointGroupsListNetworkEndpoints` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.NetworkEndpointGroupsListNetworkEndpoints], + request: compute.ListNetworkEndpointsGlobalNetworkEndpointGroupsRequest, + response: compute.NetworkEndpointGroupsListNetworkEndpoints, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListNetworkEndpointsGlobalNetworkEndpointGroupsRequest): + The initial request object. + response (google.cloud.compute_v1.types.NetworkEndpointGroupsListNetworkEndpoints): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListNetworkEndpointsGlobalNetworkEndpointGroupsRequest( + request + ) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.NetworkEndpointGroupsListNetworkEndpoints]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.NetworkEndpointWithHealthStatus]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/global_network_endpoint_groups/transports/rest.py b/google/cloud/compute_v1/services/global_network_endpoint_groups/transports/rest.py index 7189d0352..eedebf997 100644 --- a/google/cloud/compute_v1/services/global_network_endpoint_groups/transports/rest.py +++ b/google/cloud/compute_v1/services/global_network_endpoint_groups/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def attach_network_endpoints( self, @@ -145,6 +148,7 @@ def attach_network_endpoints( body = compute.GlobalNetworkEndpointGroupsAttachEndpointsRequest.to_json( request.global_network_endpoint_groups_attach_endpoints_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -169,7 +173,10 @@ def attach_network_endpoints( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -249,6 +256,9 @@ def delete( # Send the request response = self._session.delete(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -307,6 +317,7 @@ def detach_network_endpoints( body = compute.GlobalNetworkEndpointGroupsDetachEndpointsRequest.to_json( request.global_network_endpoint_groups_detach_endpoints_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -331,7 +342,10 @@ def detach_network_endpoints( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -394,6 +408,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.NetworkEndpointGroup.from_json(response.content) @@ -452,6 +469,7 @@ def insert( body = compute.NetworkEndpointGroup.to_json( request.network_endpoint_group_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -474,7 +492,10 @@ def insert( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -511,11 +532,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -528,6 +549,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.NetworkEndpointGroupList.from_json(response.content) @@ -565,11 +589,11 @@ def list_network_endpoints( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -582,6 +606,9 @@ def list_network_endpoints( # Send the request response = self._session.post(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.NetworkEndpointGroupsListNetworkEndpoints.from_json( response.content diff --git a/google/cloud/compute_v1/services/global_operations/client.py b/google/cloud/compute_v1/services/global_operations/client.py index 43c593892..c434d8215 100644 --- a/google/cloud/compute_v1/services/global_operations/client.py +++ b/google/cloud/compute_v1/services/global_operations/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.global_operations import pagers from google.cloud.compute_v1.types import compute from .transports.base import GlobalOperationsTransport, DEFAULT_CLIENT_INFO @@ -266,21 +267,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -323,7 +320,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -336,7 +333,7 @@ def aggregated_list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.OperationAggregatedList: + ) -> pagers.AggregatedListPager: r"""Retrieves an aggregated list of all operations. Args: @@ -357,7 +354,10 @@ def aggregated_list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.OperationAggregatedList: + google.cloud.compute_v1.services.global_operations.pagers.AggregatedListPager: + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -390,6 +390,12 @@ def aggregated_list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.AggregatedListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response @@ -580,7 +586,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.OperationList: + ) -> pagers.ListPager: r"""Retrieves a list of Operation resources contained within the specified project. @@ -602,9 +608,12 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.OperationList: + google.cloud.compute_v1.services.global_operations.pagers.ListPager: Contains a list of Operation resources. + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -637,6 +646,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/global_operations/pagers.py b/google/cloud/compute_v1/services/global_operations/pagers.py new file mode 100644 index 000000000..3acc037f2 --- /dev/null +++ b/google/cloud/compute_v1/services/global_operations/pagers.py @@ -0,0 +1,156 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class AggregatedListPager: + """A pager for iterating through ``aggregated_list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.OperationAggregatedList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``AggregatedList`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.OperationAggregatedList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.OperationAggregatedList], + request: compute.AggregatedListGlobalOperationsRequest, + response: compute.OperationAggregatedList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.AggregatedListGlobalOperationsRequest): + The initial request object. + response (google.cloud.compute_v1.types.OperationAggregatedList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.AggregatedListGlobalOperationsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.OperationAggregatedList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[Tuple[str, compute.OperationsScopedList]]: + for page in self.pages: + yield from page.items.items() + + def get(self, key: str) -> Optional[compute.OperationsScopedList]: + return self._response.items.get(key) + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.OperationList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.OperationList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.OperationList], + request: compute.ListGlobalOperationsRequest, + response: compute.OperationList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListGlobalOperationsRequest): + The initial request object. + response (google.cloud.compute_v1.types.OperationList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListGlobalOperationsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.OperationList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.Operation]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/global_operations/transports/rest.py b/google/cloud/compute_v1/services/global_operations/transports/rest.py index 81303fbd6..e7afb894a 100644 --- a/google/cloud/compute_v1/services/global_operations/transports/rest.py +++ b/google/cloud/compute_v1/services/global_operations/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def aggregated_list( self, @@ -122,12 +125,12 @@ def aggregated_list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, - "orderBy": request.order_by, "includeAllScopes": request.include_all_scopes, - "returnPartialSuccess": request.return_partial_success, "maxResults": request.max_results, + "orderBy": request.order_by, + "pageToken": request.page_token, + "returnPartialSuccess": request.return_partial_success, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -140,6 +143,9 @@ def aggregated_list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.OperationAggregatedList.from_json(response.content) @@ -189,6 +195,9 @@ def delete( # Send the request response = self._session.delete(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.DeleteGlobalOperationResponse.from_json(response.content) @@ -263,6 +272,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -300,11 +312,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -317,6 +329,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.OperationList.from_json(response.content) @@ -391,6 +406,9 @@ def wait( # Send the request response = self._session.post(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) diff --git a/google/cloud/compute_v1/services/global_organization_operations/client.py b/google/cloud/compute_v1/services/global_organization_operations/client.py index af0eb1b8c..6ca1466bc 100644 --- a/google/cloud/compute_v1/services/global_organization_operations/client.py +++ b/google/cloud/compute_v1/services/global_organization_operations/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.global_organization_operations import pagers from google.cloud.compute_v1.types import compute from .transports.base import GlobalOrganizationOperationsTransport, DEFAULT_CLIENT_INFO @@ -270,21 +271,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -327,7 +324,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -502,7 +499,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.OperationList: + ) -> pagers.ListPager: r"""Retrieves a list of Operation resources contained within the specified organization. @@ -519,9 +516,12 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.OperationList: + google.cloud.compute_v1.services.global_organization_operations.pagers.ListPager: Contains a list of Operation resources. + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -540,6 +540,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/global_organization_operations/pagers.py b/google/cloud/compute_v1/services/global_organization_operations/pagers.py new file mode 100644 index 000000000..30b420d47 --- /dev/null +++ b/google/cloud/compute_v1/services/global_organization_operations/pagers.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.OperationList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.OperationList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.OperationList], + request: compute.ListGlobalOrganizationOperationsRequest, + response: compute.OperationList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListGlobalOrganizationOperationsRequest): + The initial request object. + response (google.cloud.compute_v1.types.OperationList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListGlobalOrganizationOperationsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.OperationList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.Operation]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/global_organization_operations/transports/rest.py b/google/cloud/compute_v1/services/global_organization_operations/transports/rest.py index 0bf658df7..5909c89a5 100644 --- a/google/cloud/compute_v1/services/global_organization_operations/transports/rest.py +++ b/google/cloud/compute_v1/services/global_organization_operations/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def delete( self, @@ -138,6 +141,9 @@ def delete( # Send the request response = self._session.delete(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.DeleteGlobalOrganizationOperationResponse.from_json( response.content @@ -216,6 +222,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -253,12 +262,12 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { + "filter": request.filter, + "maxResults": request.max_results, + "orderBy": request.order_by, "pageToken": request.page_token, "parentId": request.parent_id, "returnPartialSuccess": request.return_partial_success, - "filter": request.filter, - "orderBy": request.order_by, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -271,6 +280,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.OperationList.from_json(response.content) diff --git a/google/cloud/compute_v1/services/health_checks/client.py b/google/cloud/compute_v1/services/health_checks/client.py index 72b5ca4b3..de3f512b5 100644 --- a/google/cloud/compute_v1/services/health_checks/client.py +++ b/google/cloud/compute_v1/services/health_checks/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.health_checks import pagers from google.cloud.compute_v1.types import compute from .transports.base import HealthChecksTransport, DEFAULT_CLIENT_INFO @@ -264,21 +265,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -321,7 +318,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -334,7 +331,7 @@ def aggregated_list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.HealthChecksAggregatedList: + ) -> pagers.AggregatedListPager: r"""Retrieves the list of all HealthCheck resources, regional and global, available to the specified project. @@ -358,7 +355,10 @@ def aggregated_list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.HealthChecksAggregatedList: + google.cloud.compute_v1.services.health_checks.pagers.AggregatedListPager: + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -391,6 +391,12 @@ def aggregated_list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.AggregatedListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response @@ -706,7 +712,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.HealthCheckList: + ) -> pagers.ListPager: r"""Retrieves the list of HealthCheck resources available to the specified project. @@ -728,9 +734,12 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.HealthCheckList: + google.cloud.compute_v1.services.health_checks.pagers.ListPager: Contains a list of HealthCheck resources. + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -763,6 +772,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/health_checks/pagers.py b/google/cloud/compute_v1/services/health_checks/pagers.py new file mode 100644 index 000000000..6a2d05b87 --- /dev/null +++ b/google/cloud/compute_v1/services/health_checks/pagers.py @@ -0,0 +1,156 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class AggregatedListPager: + """A pager for iterating through ``aggregated_list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.HealthChecksAggregatedList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``AggregatedList`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.HealthChecksAggregatedList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.HealthChecksAggregatedList], + request: compute.AggregatedListHealthChecksRequest, + response: compute.HealthChecksAggregatedList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.AggregatedListHealthChecksRequest): + The initial request object. + response (google.cloud.compute_v1.types.HealthChecksAggregatedList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.AggregatedListHealthChecksRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.HealthChecksAggregatedList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[Tuple[str, compute.HealthChecksScopedList]]: + for page in self.pages: + yield from page.items.items() + + def get(self, key: str) -> Optional[compute.HealthChecksScopedList]: + return self._response.items.get(key) + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.HealthCheckList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.HealthCheckList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.HealthCheckList], + request: compute.ListHealthChecksRequest, + response: compute.HealthCheckList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListHealthChecksRequest): + The initial request object. + response (google.cloud.compute_v1.types.HealthCheckList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListHealthChecksRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.HealthCheckList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.HealthCheck]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/health_checks/transports/rest.py b/google/cloud/compute_v1/services/health_checks/transports/rest.py index b899c3e08..1c785f29a 100644 --- a/google/cloud/compute_v1/services/health_checks/transports/rest.py +++ b/google/cloud/compute_v1/services/health_checks/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def aggregated_list( self, @@ -122,12 +125,12 @@ def aggregated_list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, - "orderBy": request.order_by, "includeAllScopes": request.include_all_scopes, - "returnPartialSuccess": request.return_partial_success, "maxResults": request.max_results, + "orderBy": request.order_by, + "pageToken": request.page_token, + "returnPartialSuccess": request.return_partial_success, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -140,6 +143,9 @@ def aggregated_list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.HealthChecksAggregatedList.from_json(response.content) @@ -216,6 +222,9 @@ def delete( # Send the request response = self._session.delete(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -290,6 +299,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.HealthCheck.from_json(response.content) @@ -346,7 +358,9 @@ def insert( # Jsonify the request body body = compute.HealthCheck.to_json( - request.health_check_resource, including_default_value_fields=False + request.health_check_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -369,7 +383,10 @@ def insert( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -408,11 +425,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -425,6 +442,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.HealthCheckList.from_json(response.content) @@ -481,7 +501,9 @@ def patch( # Jsonify the request body body = compute.HealthCheck.to_json( - request.health_check_resource, including_default_value_fields=False + request.health_check_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -504,7 +526,10 @@ def patch( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.patch(url, json=body,) + response = self._session.patch(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -562,7 +587,9 @@ def update( # Jsonify the request body body = compute.HealthCheck.to_json( - request.health_check_resource, including_default_value_fields=False + request.health_check_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -585,7 +612,10 @@ def update( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.put(url, json=body,) + response = self._session.put(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) diff --git a/google/cloud/compute_v1/services/images/client.py b/google/cloud/compute_v1/services/images/client.py index 6ceb426f9..3566dfa73 100644 --- a/google/cloud/compute_v1/services/images/client.py +++ b/google/cloud/compute_v1/services/images/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.images import pagers from google.cloud.compute_v1.types import compute from .transports.base import ImagesTransport, DEFAULT_CLIENT_INFO @@ -264,21 +265,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -321,7 +318,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -923,7 +920,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.ImageList: + ) -> pagers.ListPager: r"""Retrieves the list of custom images available to the specified project. Custom images are images you create that belong to your project. This method does not get @@ -950,8 +947,12 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.ImageList: + google.cloud.compute_v1.services.images.pagers.ListPager: Contains a list of images. + Iterating over this object will yield + results and resolve additional pages + automatically. + """ # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have @@ -983,6 +984,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/images/pagers.py b/google/cloud/compute_v1/services/images/pagers.py new file mode 100644 index 000000000..00ab35682 --- /dev/null +++ b/google/cloud/compute_v1/services/images/pagers.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.ImageList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.ImageList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.ImageList], + request: compute.ListImagesRequest, + response: compute.ImageList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListImagesRequest): + The initial request object. + response (google.cloud.compute_v1.types.ImageList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListImagesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.ImageList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.Image]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/images/transports/rest.py b/google/cloud/compute_v1/services/images/transports/rest.py index a609f5729..79a118003 100644 --- a/google/cloud/compute_v1/services/images/transports/rest.py +++ b/google/cloud/compute_v1/services/images/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def delete( self, @@ -162,6 +165,9 @@ def delete( # Send the request response = self._session.delete(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -218,7 +224,9 @@ def deprecate( # Jsonify the request body body = compute.DeprecationStatus.to_json( - request.deprecation_status_resource, including_default_value_fields=False + request.deprecation_status_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -241,7 +249,10 @@ def deprecate( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -293,6 +304,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Image.from_json(response.content) @@ -344,6 +358,9 @@ def get_from_family( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Image.from_json(response.content) @@ -442,6 +459,9 @@ def get_iam_policy( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Policy.from_json(response.content) @@ -497,7 +517,9 @@ def insert( # Jsonify the request body body = compute.Image.to_json( - request.image_resource, including_default_value_fields=False + request.image_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -509,8 +531,8 @@ def insert( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "requestId": request.request_id, "forceCreate": request.force_create, + "requestId": request.request_id, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -521,7 +543,10 @@ def insert( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -557,11 +582,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -574,6 +599,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.ImageList.from_json(response.content) @@ -629,7 +657,9 @@ def patch( # Jsonify the request body body = compute.Image.to_json( - request.image_resource, including_default_value_fields=False + request.image_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -652,7 +682,10 @@ def patch( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.patch(url, json=body,) + response = self._session.patch(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -734,6 +767,7 @@ def set_iam_policy( body = compute.GlobalSetPolicyRequest.to_json( request.global_set_policy_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -754,7 +788,10 @@ def set_iam_policy( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Policy.from_json(response.content) @@ -814,6 +851,7 @@ def set_labels( body = compute.GlobalSetLabelsRequest.to_json( request.global_set_labels_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -834,7 +872,10 @@ def set_labels( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -866,6 +907,7 @@ def test_iam_permissions( body = compute.TestPermissionsRequest.to_json( request.test_permissions_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -886,7 +928,10 @@ def test_iam_permissions( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.TestPermissionsResponse.from_json(response.content) diff --git a/google/cloud/compute_v1/services/instance_group_managers/client.py b/google/cloud/compute_v1/services/instance_group_managers/client.py index 1523237f9..240ed9f14 100644 --- a/google/cloud/compute_v1/services/instance_group_managers/client.py +++ b/google/cloud/compute_v1/services/instance_group_managers/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.instance_group_managers import pagers from google.cloud.compute_v1.types import compute from .transports.base import InstanceGroupManagersTransport, DEFAULT_CLIENT_INFO @@ -268,21 +269,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -325,7 +322,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -481,7 +478,7 @@ def aggregated_list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.InstanceGroupManagerAggregatedList: + ) -> pagers.AggregatedListPager: r"""Retrieves the list of managed instance groups and groups them by zone. @@ -503,7 +500,10 @@ def aggregated_list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.InstanceGroupManagerAggregatedList: + google.cloud.compute_v1.services.instance_group_managers.pagers.AggregatedListPager: + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -536,6 +536,12 @@ def aggregated_list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.AggregatedListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response @@ -1432,7 +1438,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.InstanceGroupManagerList: + ) -> pagers.ListPager: r"""Retrieves a list of managed instance groups that are contained within the specified project and zone. @@ -1461,8 +1467,12 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.InstanceGroupManagerList: + google.cloud.compute_v1.services.instance_group_managers.pagers.ListPager: [Output Only] A list of managed instance groups. + + Iterating over this object will yield results and + resolve additional pages automatically. + """ # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have @@ -1496,6 +1506,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response @@ -1509,7 +1525,7 @@ def list_errors( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.InstanceGroupManagersListErrorsResponse: + ) -> pagers.ListErrorsPager: r"""Lists all errors thrown by actions on instances for a given managed instance group. The filter and orderBy query parameters are not supported. @@ -1549,7 +1565,10 @@ def list_errors( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.InstanceGroupManagersListErrorsResponse: + google.cloud.compute_v1.services.instance_group_managers.pagers.ListErrorsPager: + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -1586,6 +1605,12 @@ def list_errors( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListErrorsPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response @@ -1599,7 +1624,7 @@ def list_managed_instances( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.InstanceGroupManagersListManagedInstancesResponse: + ) -> pagers.ListManagedInstancesPager: r"""Lists all of the instances in the managed instance group. Each instance in the list has a currentAction, which indicates the action that the managed instance @@ -1641,7 +1666,10 @@ def list_managed_instances( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.InstanceGroupManagersListManagedInstancesResponse: + google.cloud.compute_v1.services.instance_group_managers.pagers.ListManagedInstancesPager: + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -1680,6 +1708,12 @@ def list_managed_instances( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListManagedInstancesPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response @@ -1693,7 +1727,7 @@ def list_per_instance_configs( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.InstanceGroupManagersListPerInstanceConfigsResp: + ) -> pagers.ListPerInstanceConfigsPager: r"""Lists all of the per-instance configs defined for the managed instance group. The orderBy query parameter is not supported. @@ -1731,7 +1765,10 @@ def list_per_instance_configs( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.InstanceGroupManagersListPerInstanceConfigsResp: + google.cloud.compute_v1.services.instance_group_managers.pagers.ListPerInstanceConfigsPager: + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -1774,6 +1811,12 @@ def list_per_instance_configs( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPerInstanceConfigsPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/instance_group_managers/pagers.py b/google/cloud/compute_v1/services/instance_group_managers/pagers.py new file mode 100644 index 000000000..e5200c6f8 --- /dev/null +++ b/google/cloud/compute_v1/services/instance_group_managers/pagers.py @@ -0,0 +1,352 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class AggregatedListPager: + """A pager for iterating through ``aggregated_list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.InstanceGroupManagerAggregatedList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``AggregatedList`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.InstanceGroupManagerAggregatedList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.InstanceGroupManagerAggregatedList], + request: compute.AggregatedListInstanceGroupManagersRequest, + response: compute.InstanceGroupManagerAggregatedList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.AggregatedListInstanceGroupManagersRequest): + The initial request object. + response (google.cloud.compute_v1.types.InstanceGroupManagerAggregatedList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.AggregatedListInstanceGroupManagersRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.InstanceGroupManagerAggregatedList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[Tuple[str, compute.InstanceGroupManagersScopedList]]: + for page in self.pages: + yield from page.items.items() + + def get(self, key: str) -> Optional[compute.InstanceGroupManagersScopedList]: + return self._response.items.get(key) + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.InstanceGroupManagerList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.InstanceGroupManagerList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.InstanceGroupManagerList], + request: compute.ListInstanceGroupManagersRequest, + response: compute.InstanceGroupManagerList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListInstanceGroupManagersRequest): + The initial request object. + response (google.cloud.compute_v1.types.InstanceGroupManagerList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListInstanceGroupManagersRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.InstanceGroupManagerList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.InstanceGroupManager]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListErrorsPager: + """A pager for iterating through ``list_errors`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.InstanceGroupManagersListErrorsResponse` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ListErrors`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.InstanceGroupManagersListErrorsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.InstanceGroupManagersListErrorsResponse], + request: compute.ListErrorsInstanceGroupManagersRequest, + response: compute.InstanceGroupManagersListErrorsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListErrorsInstanceGroupManagersRequest): + The initial request object. + response (google.cloud.compute_v1.types.InstanceGroupManagersListErrorsResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListErrorsInstanceGroupManagersRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.InstanceGroupManagersListErrorsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.InstanceManagedByIgmError]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListManagedInstancesPager: + """A pager for iterating through ``list_managed_instances`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.InstanceGroupManagersListManagedInstancesResponse` object, and + provides an ``__iter__`` method to iterate through its + ``managed_instances`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ListManagedInstances`` requests and continue to iterate + through the ``managed_instances`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.InstanceGroupManagersListManagedInstancesResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[ + ..., compute.InstanceGroupManagersListManagedInstancesResponse + ], + request: compute.ListManagedInstancesInstanceGroupManagersRequest, + response: compute.InstanceGroupManagersListManagedInstancesResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListManagedInstancesInstanceGroupManagersRequest): + The initial request object. + response (google.cloud.compute_v1.types.InstanceGroupManagersListManagedInstancesResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListManagedInstancesInstanceGroupManagersRequest( + request + ) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages( + self, + ) -> Iterable[compute.InstanceGroupManagersListManagedInstancesResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.ManagedInstance]: + for page in self.pages: + yield from page.managed_instances + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListPerInstanceConfigsPager: + """A pager for iterating through ``list_per_instance_configs`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.InstanceGroupManagersListPerInstanceConfigsResp` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ListPerInstanceConfigs`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.InstanceGroupManagersListPerInstanceConfigsResp` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.InstanceGroupManagersListPerInstanceConfigsResp], + request: compute.ListPerInstanceConfigsInstanceGroupManagersRequest, + response: compute.InstanceGroupManagersListPerInstanceConfigsResp, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListPerInstanceConfigsInstanceGroupManagersRequest): + The initial request object. + response (google.cloud.compute_v1.types.InstanceGroupManagersListPerInstanceConfigsResp): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListPerInstanceConfigsInstanceGroupManagersRequest( + request + ) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages( + self, + ) -> Iterable[compute.InstanceGroupManagersListPerInstanceConfigsResp]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.PerInstanceConfig]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/instance_group_managers/transports/rest.py b/google/cloud/compute_v1/services/instance_group_managers/transports/rest.py index 8ef595399..c3b5f1bc3 100644 --- a/google/cloud/compute_v1/services/instance_group_managers/transports/rest.py +++ b/google/cloud/compute_v1/services/instance_group_managers/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def abandon_instances( self, @@ -145,6 +148,7 @@ def abandon_instances( body = compute.InstanceGroupManagersAbandonInstancesRequest.to_json( request.instance_group_managers_abandon_instances_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -170,7 +174,10 @@ def abandon_instances( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -207,12 +214,12 @@ def aggregated_list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, - "orderBy": request.order_by, "includeAllScopes": request.include_all_scopes, - "returnPartialSuccess": request.return_partial_success, "maxResults": request.max_results, + "orderBy": request.order_by, + "pageToken": request.page_token, + "returnPartialSuccess": request.return_partial_success, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -225,6 +232,9 @@ def aggregated_list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.InstanceGroupManagerAggregatedList.from_json(response.content) @@ -284,6 +294,7 @@ def apply_updates_to_instances( body = compute.InstanceGroupManagersApplyUpdatesRequest.to_json( request.instance_group_managers_apply_updates_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -307,7 +318,10 @@ def apply_updates_to_instances( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -367,6 +381,7 @@ def create_instances( body = compute.InstanceGroupManagersCreateInstancesRequest.to_json( request.instance_group_managers_create_instances_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -392,7 +407,10 @@ def create_instances( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -473,6 +491,9 @@ def delete( # Send the request response = self._session.delete(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -531,6 +552,7 @@ def delete_instances( body = compute.InstanceGroupManagersDeleteInstancesRequest.to_json( request.instance_group_managers_delete_instances_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -556,7 +578,10 @@ def delete_instances( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -617,6 +642,7 @@ def delete_per_instance_configs( body = compute.InstanceGroupManagersDeletePerInstanceConfigsReq.to_json( request.instance_group_managers_delete_per_instance_configs_req_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -640,7 +666,10 @@ def delete_per_instance_configs( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -705,6 +734,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.InstanceGroupManager.from_json(response.content) @@ -763,6 +795,7 @@ def insert( body = compute.InstanceGroupManager.to_json( request.instance_group_manager_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -785,7 +818,10 @@ def insert( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -822,11 +858,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -839,6 +875,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.InstanceGroupManagerList.from_json(response.content) @@ -877,11 +916,11 @@ def list_errors( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -894,6 +933,9 @@ def list_errors( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.InstanceGroupManagersListErrorsResponse.from_json( response.content @@ -934,11 +976,11 @@ def list_managed_instances( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -951,6 +993,9 @@ def list_managed_instances( # Send the request response = self._session.post(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.InstanceGroupManagersListManagedInstancesResponse.from_json( response.content @@ -991,11 +1036,11 @@ def list_per_instance_configs( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -1008,6 +1053,9 @@ def list_per_instance_configs( # Send the request response = self._session.post(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.InstanceGroupManagersListPerInstanceConfigsResp.from_json( response.content @@ -1068,6 +1116,7 @@ def patch( body = compute.InstanceGroupManager.to_json( request.instance_group_manager_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -1093,7 +1142,10 @@ def patch( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.patch(url, json=body,) + response = self._session.patch(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -1154,6 +1206,7 @@ def patch_per_instance_configs( body = compute.InstanceGroupManagersPatchPerInstanceConfigsReq.to_json( request.instance_group_managers_patch_per_instance_configs_req_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -1179,7 +1232,10 @@ def patch_per_instance_configs( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -1239,6 +1295,7 @@ def recreate_instances( body = compute.InstanceGroupManagersRecreateInstancesRequest.to_json( request.instance_group_managers_recreate_instances_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -1264,7 +1321,10 @@ def recreate_instances( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -1346,6 +1406,9 @@ def resize( # Send the request response = self._session.post(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -1404,6 +1467,7 @@ def set_instance_template( body = compute.InstanceGroupManagersSetInstanceTemplateRequest.to_json( request.instance_group_managers_set_instance_template_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -1429,7 +1493,10 @@ def set_instance_template( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -1489,6 +1556,7 @@ def set_target_pools( body = compute.InstanceGroupManagersSetTargetPoolsRequest.to_json( request.instance_group_managers_set_target_pools_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -1514,7 +1582,10 @@ def set_target_pools( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -1575,6 +1646,7 @@ def update_per_instance_configs( body = compute.InstanceGroupManagersUpdatePerInstanceConfigsReq.to_json( request.instance_group_managers_update_per_instance_configs_req_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -1600,7 +1672,10 @@ def update_per_instance_configs( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) diff --git a/google/cloud/compute_v1/services/instance_groups/client.py b/google/cloud/compute_v1/services/instance_groups/client.py index 4ba2512fb..b42d42a87 100644 --- a/google/cloud/compute_v1/services/instance_groups/client.py +++ b/google/cloud/compute_v1/services/instance_groups/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.instance_groups import pagers from google.cloud.compute_v1.types import compute from .transports.base import InstanceGroupsTransport, DEFAULT_CLIENT_INFO @@ -266,21 +267,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -323,7 +320,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -466,7 +463,7 @@ def aggregated_list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.InstanceGroupAggregatedList: + ) -> pagers.AggregatedListPager: r"""Retrieves the list of instance groups and sorts them by zone. @@ -488,7 +485,10 @@ def aggregated_list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.InstanceGroupAggregatedList: + google.cloud.compute_v1.services.instance_groups.pagers.AggregatedListPager: + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -521,6 +521,12 @@ def aggregated_list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.AggregatedListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response @@ -864,7 +870,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.InstanceGroupList: + ) -> pagers.ListPager: r"""Retrieves the list of zonal instance group resources contained within the specified zone. For managed instance groups, use the @@ -896,8 +902,12 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.InstanceGroupList: + google.cloud.compute_v1.services.instance_groups.pagers.ListPager: A list of InstanceGroup resources. + Iterating over this object will yield + results and resolve additional pages + automatically. + """ # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have @@ -931,6 +941,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response @@ -945,7 +961,7 @@ def list_instances( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.InstanceGroupsListInstances: + ) -> pagers.ListInstancesPager: r"""Lists the instances in the specified instance group. The orderBy query parameter is not supported. @@ -987,7 +1003,10 @@ def list_instances( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.InstanceGroupsListInstances: + google.cloud.compute_v1.services.instance_groups.pagers.ListInstancesPager: + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -1035,6 +1054,12 @@ def list_instances( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListInstancesPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/instance_groups/pagers.py b/google/cloud/compute_v1/services/instance_groups/pagers.py new file mode 100644 index 000000000..a19c7f07f --- /dev/null +++ b/google/cloud/compute_v1/services/instance_groups/pagers.py @@ -0,0 +1,218 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class AggregatedListPager: + """A pager for iterating through ``aggregated_list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.InstanceGroupAggregatedList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``AggregatedList`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.InstanceGroupAggregatedList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.InstanceGroupAggregatedList], + request: compute.AggregatedListInstanceGroupsRequest, + response: compute.InstanceGroupAggregatedList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.AggregatedListInstanceGroupsRequest): + The initial request object. + response (google.cloud.compute_v1.types.InstanceGroupAggregatedList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.AggregatedListInstanceGroupsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.InstanceGroupAggregatedList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[Tuple[str, compute.InstanceGroupsScopedList]]: + for page in self.pages: + yield from page.items.items() + + def get(self, key: str) -> Optional[compute.InstanceGroupsScopedList]: + return self._response.items.get(key) + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.InstanceGroupList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.InstanceGroupList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.InstanceGroupList], + request: compute.ListInstanceGroupsRequest, + response: compute.InstanceGroupList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListInstanceGroupsRequest): + The initial request object. + response (google.cloud.compute_v1.types.InstanceGroupList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListInstanceGroupsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.InstanceGroupList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.InstanceGroup]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListInstancesPager: + """A pager for iterating through ``list_instances`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.InstanceGroupsListInstances` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ListInstances`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.InstanceGroupsListInstances` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.InstanceGroupsListInstances], + request: compute.ListInstancesInstanceGroupsRequest, + response: compute.InstanceGroupsListInstances, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListInstancesInstanceGroupsRequest): + The initial request object. + response (google.cloud.compute_v1.types.InstanceGroupsListInstances): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListInstancesInstanceGroupsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.InstanceGroupsListInstances]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.InstanceWithNamedPorts]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/instance_groups/transports/rest.py b/google/cloud/compute_v1/services/instance_groups/transports/rest.py index cf7ecfc85..c1dc1f2f4 100644 --- a/google/cloud/compute_v1/services/instance_groups/transports/rest.py +++ b/google/cloud/compute_v1/services/instance_groups/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def add_instances( self, @@ -145,6 +148,7 @@ def add_instances( body = compute.InstanceGroupsAddInstancesRequest.to_json( request.instance_groups_add_instances_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -170,7 +174,10 @@ def add_instances( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -207,12 +214,12 @@ def aggregated_list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, - "orderBy": request.order_by, "includeAllScopes": request.include_all_scopes, - "returnPartialSuccess": request.return_partial_success, "maxResults": request.max_results, + "orderBy": request.order_by, + "pageToken": request.page_token, + "returnPartialSuccess": request.return_partial_success, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -225,6 +232,9 @@ def aggregated_list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.InstanceGroupAggregatedList.from_json(response.content) @@ -304,6 +314,9 @@ def delete( # Send the request response = self._session.delete(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -373,6 +386,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.InstanceGroup.from_json(response.content) @@ -429,7 +445,9 @@ def insert( # Jsonify the request body body = compute.InstanceGroup.to_json( - request.instance_group_resource, including_default_value_fields=False + request.instance_group_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -452,7 +470,10 @@ def insert( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -489,11 +510,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -506,6 +527,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.InstanceGroupList.from_json(response.content) @@ -536,6 +560,7 @@ def list_instances( body = compute.InstanceGroupsListInstancesRequest.to_json( request.instance_groups_list_instances_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -550,11 +575,11 @@ def list_instances( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -565,7 +590,10 @@ def list_instances( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.InstanceGroupsListInstances.from_json(response.content) @@ -625,6 +653,7 @@ def remove_instances( body = compute.InstanceGroupsRemoveInstancesRequest.to_json( request.instance_groups_remove_instances_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -650,7 +679,10 @@ def remove_instances( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -710,6 +742,7 @@ def set_named_ports( body = compute.InstanceGroupsSetNamedPortsRequest.to_json( request.instance_groups_set_named_ports_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -735,7 +768,10 @@ def set_named_ports( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) diff --git a/google/cloud/compute_v1/services/instance_templates/client.py b/google/cloud/compute_v1/services/instance_templates/client.py index 9f4d4cac7..014c7b7e2 100644 --- a/google/cloud/compute_v1/services/instance_templates/client.py +++ b/google/cloud/compute_v1/services/instance_templates/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.instance_templates import pagers from google.cloud.compute_v1.types import compute from .transports.base import InstanceTemplatesTransport, DEFAULT_CLIENT_INFO @@ -268,21 +269,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -325,7 +322,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -752,7 +749,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.InstanceTemplateList: + ) -> pagers.ListPager: r"""Retrieves a list of instance templates that are contained within the specified project. @@ -774,8 +771,12 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.InstanceTemplateList: + google.cloud.compute_v1.services.instance_templates.pagers.ListPager: A list of instance templates. + Iterating over this object will yield + results and resolve additional pages + automatically. + """ # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have @@ -807,6 +808,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/instance_templates/pagers.py b/google/cloud/compute_v1/services/instance_templates/pagers.py new file mode 100644 index 000000000..5f767a6cb --- /dev/null +++ b/google/cloud/compute_v1/services/instance_templates/pagers.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.InstanceTemplateList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.InstanceTemplateList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.InstanceTemplateList], + request: compute.ListInstanceTemplatesRequest, + response: compute.InstanceTemplateList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListInstanceTemplatesRequest): + The initial request object. + response (google.cloud.compute_v1.types.InstanceTemplateList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListInstanceTemplatesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.InstanceTemplateList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.InstanceTemplate]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/instance_templates/transports/rest.py b/google/cloud/compute_v1/services/instance_templates/transports/rest.py index 540f5db1a..97b0a2e8d 100644 --- a/google/cloud/compute_v1/services/instance_templates/transports/rest.py +++ b/google/cloud/compute_v1/services/instance_templates/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def delete( self, @@ -165,6 +168,9 @@ def delete( # Send the request response = self._session.delete(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -219,6 +225,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.InstanceTemplate.from_json(response.content) @@ -317,6 +326,9 @@ def get_iam_policy( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Policy.from_json(response.content) @@ -373,7 +385,9 @@ def insert( # Jsonify the request body body = compute.InstanceTemplate.to_json( - request.instance_template_resource, including_default_value_fields=False + request.instance_template_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -396,7 +410,10 @@ def insert( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -433,11 +450,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -450,6 +467,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.InstanceTemplateList.from_json(response.content) @@ -530,6 +550,7 @@ def set_iam_policy( body = compute.GlobalSetPolicyRequest.to_json( request.global_set_policy_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -550,7 +571,10 @@ def set_iam_policy( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Policy.from_json(response.content) @@ -582,6 +606,7 @@ def test_iam_permissions( body = compute.TestPermissionsRequest.to_json( request.test_permissions_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -602,7 +627,10 @@ def test_iam_permissions( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.TestPermissionsResponse.from_json(response.content) diff --git a/google/cloud/compute_v1/services/instances/client.py b/google/cloud/compute_v1/services/instances/client.py index 19c6890a4..fd5da4d87 100644 --- a/google/cloud/compute_v1/services/instances/client.py +++ b/google/cloud/compute_v1/services/instances/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.instances import pagers from google.cloud.compute_v1.types import compute from .transports.base import InstancesTransport, DEFAULT_CLIENT_INFO @@ -264,21 +265,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -321,7 +318,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -585,7 +582,7 @@ def aggregated_list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.InstanceAggregatedList: + ) -> pagers.AggregatedListPager: r"""Retrieves aggregated list of all of the instances in your project across all regions and zones. @@ -607,7 +604,10 @@ def aggregated_list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.InstanceAggregatedList: + google.cloud.compute_v1.services.instances.pagers.AggregatedListPager: + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -640,6 +640,12 @@ def aggregated_list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.AggregatedListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response @@ -1812,7 +1818,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.InstanceList: + ) -> pagers.ListPager: r"""Retrieves the list of instances contained within the specified zone. @@ -1840,8 +1846,12 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.InstanceList: + google.cloud.compute_v1.services.instances.pagers.ListPager: Contains a list of instances. + Iterating over this object will yield + results and resolve additional pages + automatically. + """ # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have @@ -1875,6 +1885,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response @@ -1888,7 +1904,7 @@ def list_referrers( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.InstanceListReferrers: + ) -> pagers.ListReferrersPager: r"""Retrieves a list of resources that refer to the VM instance specified in the request. For example, if the VM instance is part of a managed or unmanaged instance @@ -1930,9 +1946,12 @@ def list_referrers( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.InstanceListReferrers: + google.cloud.compute_v1.services.instances.pagers.ListReferrersPager: Contains a list of instance referrers. + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -1969,6 +1988,12 @@ def list_referrers( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListReferrersPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/instances/pagers.py b/google/cloud/compute_v1/services/instances/pagers.py new file mode 100644 index 000000000..3907eddb3 --- /dev/null +++ b/google/cloud/compute_v1/services/instances/pagers.py @@ -0,0 +1,218 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class AggregatedListPager: + """A pager for iterating through ``aggregated_list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.InstanceAggregatedList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``AggregatedList`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.InstanceAggregatedList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.InstanceAggregatedList], + request: compute.AggregatedListInstancesRequest, + response: compute.InstanceAggregatedList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.AggregatedListInstancesRequest): + The initial request object. + response (google.cloud.compute_v1.types.InstanceAggregatedList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.AggregatedListInstancesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.InstanceAggregatedList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[Tuple[str, compute.InstancesScopedList]]: + for page in self.pages: + yield from page.items.items() + + def get(self, key: str) -> Optional[compute.InstancesScopedList]: + return self._response.items.get(key) + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.InstanceList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.InstanceList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.InstanceList], + request: compute.ListInstancesRequest, + response: compute.InstanceList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListInstancesRequest): + The initial request object. + response (google.cloud.compute_v1.types.InstanceList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListInstancesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.InstanceList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.Instance]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListReferrersPager: + """A pager for iterating through ``list_referrers`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.InstanceListReferrers` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ListReferrers`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.InstanceListReferrers` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.InstanceListReferrers], + request: compute.ListReferrersInstancesRequest, + response: compute.InstanceListReferrers, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListReferrersInstancesRequest): + The initial request object. + response (google.cloud.compute_v1.types.InstanceListReferrers): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListReferrersInstancesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.InstanceListReferrers]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.Reference]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/instances/transports/rest.py b/google/cloud/compute_v1/services/instances/transports/rest.py index 0f28fa128..745de97d5 100644 --- a/google/cloud/compute_v1/services/instances/transports/rest.py +++ b/google/cloud/compute_v1/services/instances/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def add_access_config( self, @@ -143,7 +146,9 @@ def add_access_config( # Jsonify the request body body = compute.AccessConfig.to_json( - request.access_config_resource, including_default_value_fields=False + request.access_config_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -158,8 +163,8 @@ def add_access_config( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "requestId": request.request_id, "networkInterface": request.network_interface, + "requestId": request.request_id, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -170,7 +175,10 @@ def add_access_config( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -230,6 +238,7 @@ def add_resource_policies( body = compute.InstancesAddResourcePoliciesRequest.to_json( request.instances_add_resource_policies_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -255,7 +264,10 @@ def add_resource_policies( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -292,12 +304,12 @@ def aggregated_list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, - "orderBy": request.order_by, "includeAllScopes": request.include_all_scopes, - "returnPartialSuccess": request.return_partial_success, "maxResults": request.max_results, + "orderBy": request.order_by, + "pageToken": request.page_token, + "returnPartialSuccess": request.return_partial_success, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -310,6 +322,9 @@ def aggregated_list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.InstanceAggregatedList.from_json(response.content) @@ -366,7 +381,9 @@ def attach_disk( # Jsonify the request body body = compute.AttachedDisk.to_json( - request.attached_disk_resource, including_default_value_fields=False + request.attached_disk_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -381,8 +398,8 @@ def attach_disk( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "requestId": request.request_id, "forceAttach": request.force_attach, + "requestId": request.request_id, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -393,7 +410,10 @@ def attach_disk( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -474,6 +494,9 @@ def delete( # Send the request response = self._session.delete(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -540,9 +563,9 @@ def delete_access_config( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "requestId": request.request_id, - "networkInterface": request.network_interface, "accessConfig": request.access_config, + "networkInterface": request.network_interface, + "requestId": request.request_id, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -555,6 +578,9 @@ def delete_access_config( # Send the request response = self._session.post(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -621,8 +647,8 @@ def detach_disk( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "requestId": request.request_id, "deviceName": request.device_name, + "requestId": request.request_id, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -635,6 +661,9 @@ def detach_disk( # Send the request response = self._session.post(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -689,6 +718,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Instance.from_json(response.content) @@ -727,8 +759,8 @@ def get_guest_attributes( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "variableKey": request.variable_key, "queryPath": request.query_path, + "variableKey": request.variable_key, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -741,6 +773,9 @@ def get_guest_attributes( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.GuestAttributes.from_json(response.content) @@ -842,6 +877,9 @@ def get_iam_policy( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Policy.from_json(response.content) @@ -891,6 +929,9 @@ def get_screenshot( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Screenshot.from_json(response.content) @@ -943,6 +984,9 @@ def get_serial_port_output( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.SerialPortOutput.from_json(response.content) @@ -993,6 +1037,9 @@ def get_shielded_instance_identity( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.ShieldedInstanceIdentity.from_json(response.content) @@ -1049,7 +1096,9 @@ def insert( # Jsonify the request body body = compute.Instance.to_json( - request.instance_resource, including_default_value_fields=False + request.instance_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -1061,8 +1110,8 @@ def insert( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "sourceInstanceTemplate": request.source_instance_template, "requestId": request.request_id, + "sourceInstanceTemplate": request.source_instance_template, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -1073,7 +1122,10 @@ def insert( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -1109,11 +1161,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -1126,6 +1178,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.InstanceList.from_json(response.content) @@ -1166,11 +1221,11 @@ def list_referrers( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -1183,6 +1238,9 @@ def list_referrers( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.InstanceListReferrers.from_json(response.content) @@ -1241,6 +1299,7 @@ def remove_resource_policies( body = compute.InstancesRemoveResourcePoliciesRequest.to_json( request.instances_remove_resource_policies_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -1266,7 +1325,10 @@ def remove_resource_policies( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -1347,6 +1409,9 @@ def reset( # Send the request response = self._session.post(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -1427,6 +1492,9 @@ def set_deletion_protection( # Send the request response = self._session.post(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -1494,8 +1562,8 @@ def set_disk_auto_delete( # not required for GCE query_params = { "autoDelete": request.auto_delete, - "requestId": request.request_id, "deviceName": request.device_name, + "requestId": request.request_id, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -1508,6 +1576,9 @@ def set_disk_auto_delete( # Send the request response = self._session.post(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -1588,6 +1659,7 @@ def set_iam_policy( body = compute.ZoneSetPolicyRequest.to_json( request.zone_set_policy_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -1611,7 +1683,10 @@ def set_iam_policy( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Policy.from_json(response.content) @@ -1671,6 +1746,7 @@ def set_labels( body = compute.InstancesSetLabelsRequest.to_json( request.instances_set_labels_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -1696,7 +1772,10 @@ def set_labels( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -1756,6 +1835,7 @@ def set_machine_resources( body = compute.InstancesSetMachineResourcesRequest.to_json( request.instances_set_machine_resources_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -1781,7 +1861,10 @@ def set_machine_resources( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -1841,6 +1924,7 @@ def set_machine_type( body = compute.InstancesSetMachineTypeRequest.to_json( request.instances_set_machine_type_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -1866,7 +1950,10 @@ def set_machine_type( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -1924,7 +2011,9 @@ def set_metadata( # Jsonify the request body body = compute.Metadata.to_json( - request.metadata_resource, including_default_value_fields=False + request.metadata_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -1950,7 +2039,10 @@ def set_metadata( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -2010,6 +2102,7 @@ def set_min_cpu_platform( body = compute.InstancesSetMinCpuPlatformRequest.to_json( request.instances_set_min_cpu_platform_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -2035,7 +2128,10 @@ def set_min_cpu_platform( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -2093,7 +2189,9 @@ def set_scheduling( # Jsonify the request body body = compute.Scheduling.to_json( - request.scheduling_resource, including_default_value_fields=False + request.scheduling_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -2119,7 +2217,10 @@ def set_scheduling( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -2179,6 +2280,7 @@ def set_service_account( body = compute.InstancesSetServiceAccountRequest.to_json( request.instances_set_service_account_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -2204,7 +2306,10 @@ def set_service_account( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -2265,6 +2370,7 @@ def set_shielded_instance_integrity_policy( body = compute.ShieldedInstanceIntegrityPolicy.to_json( request.shielded_instance_integrity_policy_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -2290,7 +2396,10 @@ def set_shielded_instance_integrity_policy( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.patch(url, json=body,) + response = self._session.patch(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -2348,7 +2457,9 @@ def set_tags( # Jsonify the request body body = compute.Tags.to_json( - request.tags_resource, including_default_value_fields=False + request.tags_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -2374,7 +2485,10 @@ def set_tags( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -2454,6 +2568,9 @@ def simulate_maintenance_event( # Send the request response = self._session.post(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -2533,6 +2650,9 @@ def start( # Send the request response = self._session.post(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -2591,6 +2711,7 @@ def start_with_encryption_key( body = compute.InstancesStartWithEncryptionKeyRequest.to_json( request.instances_start_with_encryption_key_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -2616,7 +2737,10 @@ def start_with_encryption_key( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -2696,6 +2820,9 @@ def stop( # Send the request response = self._session.post(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -2726,6 +2853,7 @@ def test_iam_permissions( body = compute.TestPermissionsRequest.to_json( request.test_permissions_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -2749,7 +2877,10 @@ def test_iam_permissions( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.TestPermissionsResponse.from_json(response.content) @@ -2807,7 +2938,9 @@ def update( # Jsonify the request body body = compute.Instance.to_json( - request.instance_resource, including_default_value_fields=False + request.instance_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -2822,9 +2955,9 @@ def update( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "requestId": request.request_id, "minimalAction": request.minimal_action, "mostDisruptiveAllowedAction": request.most_disruptive_allowed_action, + "requestId": request.request_id, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -2835,7 +2968,10 @@ def update( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.put(url, json=body,) + response = self._session.put(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -2893,7 +3029,9 @@ def update_access_config( # Jsonify the request body body = compute.AccessConfig.to_json( - request.access_config_resource, including_default_value_fields=False + request.access_config_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -2908,8 +3046,8 @@ def update_access_config( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "requestId": request.request_id, "networkInterface": request.network_interface, + "requestId": request.request_id, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -2920,7 +3058,10 @@ def update_access_config( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -2978,7 +3119,9 @@ def update_display_device( # Jsonify the request body body = compute.DisplayDevice.to_json( - request.display_device_resource, including_default_value_fields=False + request.display_device_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -3004,7 +3147,10 @@ def update_display_device( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.patch(url, json=body,) + response = self._session.patch(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -3062,7 +3208,9 @@ def update_network_interface( # Jsonify the request body body = compute.NetworkInterface.to_json( - request.network_interface_resource, including_default_value_fields=False + request.network_interface_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -3077,8 +3225,8 @@ def update_network_interface( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "requestId": request.request_id, "networkInterface": request.network_interface, + "requestId": request.request_id, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -3089,7 +3237,10 @@ def update_network_interface( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.patch(url, json=body,) + response = self._session.patch(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -3150,6 +3301,7 @@ def update_shielded_instance_config( body = compute.ShieldedInstanceConfig.to_json( request.shielded_instance_config_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -3175,7 +3327,10 @@ def update_shielded_instance_config( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.patch(url, json=body,) + response = self._session.patch(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) diff --git a/google/cloud/compute_v1/services/interconnect_attachments/client.py b/google/cloud/compute_v1/services/interconnect_attachments/client.py index 61853775d..8a4ed7dfa 100644 --- a/google/cloud/compute_v1/services/interconnect_attachments/client.py +++ b/google/cloud/compute_v1/services/interconnect_attachments/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.interconnect_attachments import pagers from google.cloud.compute_v1.types import compute from .transports.base import InterconnectAttachmentsTransport, DEFAULT_CLIENT_INFO @@ -268,21 +269,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -325,7 +322,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -338,7 +335,7 @@ def aggregated_list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.InterconnectAttachmentAggregatedList: + ) -> pagers.AggregatedListPager: r"""Retrieves an aggregated list of interconnect attachments. @@ -360,7 +357,10 @@ def aggregated_list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.InterconnectAttachmentAggregatedList: + google.cloud.compute_v1.services.interconnect_attachments.pagers.AggregatedListPager: + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -395,6 +395,12 @@ def aggregated_list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.AggregatedListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response @@ -715,7 +721,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.InterconnectAttachmentList: + ) -> pagers.ListPager: r"""Retrieves the list of interconnect attachments contained within the specified region. @@ -742,10 +748,12 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.InterconnectAttachmentList: + google.cloud.compute_v1.services.interconnect_attachments.pagers.ListPager: Response to the list request, and contains a list of interconnect - attachments. + attachments. Iterating over this object + will yield results and resolve + additional pages automatically. """ # Create or coerce a protobuf request object. @@ -780,6 +788,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/interconnect_attachments/pagers.py b/google/cloud/compute_v1/services/interconnect_attachments/pagers.py new file mode 100644 index 000000000..ae4acf937 --- /dev/null +++ b/google/cloud/compute_v1/services/interconnect_attachments/pagers.py @@ -0,0 +1,158 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class AggregatedListPager: + """A pager for iterating through ``aggregated_list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.InterconnectAttachmentAggregatedList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``AggregatedList`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.InterconnectAttachmentAggregatedList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.InterconnectAttachmentAggregatedList], + request: compute.AggregatedListInterconnectAttachmentsRequest, + response: compute.InterconnectAttachmentAggregatedList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.AggregatedListInterconnectAttachmentsRequest): + The initial request object. + response (google.cloud.compute_v1.types.InterconnectAttachmentAggregatedList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.AggregatedListInterconnectAttachmentsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.InterconnectAttachmentAggregatedList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__( + self, + ) -> Iterable[Tuple[str, compute.InterconnectAttachmentsScopedList]]: + for page in self.pages: + yield from page.items.items() + + def get(self, key: str) -> Optional[compute.InterconnectAttachmentsScopedList]: + return self._response.items.get(key) + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.InterconnectAttachmentList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.InterconnectAttachmentList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.InterconnectAttachmentList], + request: compute.ListInterconnectAttachmentsRequest, + response: compute.InterconnectAttachmentList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListInterconnectAttachmentsRequest): + The initial request object. + response (google.cloud.compute_v1.types.InterconnectAttachmentList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListInterconnectAttachmentsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.InterconnectAttachmentList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.InterconnectAttachment]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/interconnect_attachments/transports/rest.py b/google/cloud/compute_v1/services/interconnect_attachments/transports/rest.py index 2676fe670..4e4f1b9a1 100644 --- a/google/cloud/compute_v1/services/interconnect_attachments/transports/rest.py +++ b/google/cloud/compute_v1/services/interconnect_attachments/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def aggregated_list( self, @@ -122,12 +125,12 @@ def aggregated_list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, - "orderBy": request.order_by, "includeAllScopes": request.include_all_scopes, - "returnPartialSuccess": request.return_partial_success, "maxResults": request.max_results, + "orderBy": request.order_by, + "pageToken": request.page_token, + "returnPartialSuccess": request.return_partial_success, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -140,6 +143,9 @@ def aggregated_list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.InterconnectAttachmentAggregatedList.from_json(response.content) @@ -219,6 +225,9 @@ def delete( # Send the request response = self._session.delete(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -275,6 +284,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.InterconnectAttachment.from_json(response.content) @@ -333,6 +345,7 @@ def insert( body = compute.InterconnectAttachment.to_json( request.interconnect_attachment_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -344,8 +357,8 @@ def insert( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "validateOnly": request.validate_only, "requestId": request.request_id, + "validateOnly": request.validate_only, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -356,7 +369,10 @@ def insert( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -396,11 +412,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -413,6 +429,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.InterconnectAttachmentList.from_json(response.content) @@ -471,6 +490,7 @@ def patch( body = compute.InterconnectAttachment.to_json( request.interconnect_attachment_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -496,7 +516,10 @@ def patch( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.patch(url, json=body,) + response = self._session.patch(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) diff --git a/google/cloud/compute_v1/services/interconnect_locations/client.py b/google/cloud/compute_v1/services/interconnect_locations/client.py index c240ae837..64f9a72a2 100644 --- a/google/cloud/compute_v1/services/interconnect_locations/client.py +++ b/google/cloud/compute_v1/services/interconnect_locations/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.interconnect_locations import pagers from google.cloud.compute_v1.types import compute from .transports.base import InterconnectLocationsTransport, DEFAULT_CLIENT_INFO @@ -268,21 +269,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -325,7 +322,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -422,7 +419,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.InterconnectLocationList: + ) -> pagers.ListPager: r"""Retrieves the list of interconnect locations available to the specified project. @@ -444,10 +441,12 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.InterconnectLocationList: + google.cloud.compute_v1.services.interconnect_locations.pagers.ListPager: Response to the list request, and contains a list of interconnect - locations. + locations. Iterating over this object + will yield results and resolve + additional pages automatically. """ # Create or coerce a protobuf request object. @@ -480,6 +479,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/interconnect_locations/pagers.py b/google/cloud/compute_v1/services/interconnect_locations/pagers.py new file mode 100644 index 000000000..176d02915 --- /dev/null +++ b/google/cloud/compute_v1/services/interconnect_locations/pagers.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.InterconnectLocationList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.InterconnectLocationList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.InterconnectLocationList], + request: compute.ListInterconnectLocationsRequest, + response: compute.InterconnectLocationList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListInterconnectLocationsRequest): + The initial request object. + response (google.cloud.compute_v1.types.InterconnectLocationList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListInterconnectLocationsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.InterconnectLocationList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.InterconnectLocation]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/interconnect_locations/transports/rest.py b/google/cloud/compute_v1/services/interconnect_locations/transports/rest.py index 1e79480f7..7a8600898 100644 --- a/google/cloud/compute_v1/services/interconnect_locations/transports/rest.py +++ b/google/cloud/compute_v1/services/interconnect_locations/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def get( self, @@ -142,6 +145,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.InterconnectLocation.from_json(response.content) @@ -180,11 +186,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -197,6 +203,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.InterconnectLocationList.from_json(response.content) diff --git a/google/cloud/compute_v1/services/interconnects/client.py b/google/cloud/compute_v1/services/interconnects/client.py index 7d64c21ea..1a28e42b3 100644 --- a/google/cloud/compute_v1/services/interconnects/client.py +++ b/google/cloud/compute_v1/services/interconnects/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.interconnects import pagers from google.cloud.compute_v1.types import compute from .transports.base import InterconnectsTransport, DEFAULT_CLIENT_INFO @@ -264,21 +265,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -321,7 +318,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -690,7 +687,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.InterconnectList: + ) -> pagers.ListPager: r"""Retrieves the list of interconnect available to the specified project. @@ -712,9 +709,12 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.InterconnectList: + google.cloud.compute_v1.services.interconnects.pagers.ListPager: Response to the list request, and contains a list of interconnects. + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -747,6 +747,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/interconnects/pagers.py b/google/cloud/compute_v1/services/interconnects/pagers.py new file mode 100644 index 000000000..4ce5501ed --- /dev/null +++ b/google/cloud/compute_v1/services/interconnects/pagers.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.InterconnectList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.InterconnectList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.InterconnectList], + request: compute.ListInterconnectsRequest, + response: compute.InterconnectList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListInterconnectsRequest): + The initial request object. + response (google.cloud.compute_v1.types.InterconnectList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListInterconnectsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.InterconnectList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.Interconnect]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/interconnects/transports/rest.py b/google/cloud/compute_v1/services/interconnects/transports/rest.py index e1580a7bc..b56e966f2 100644 --- a/google/cloud/compute_v1/services/interconnects/transports/rest.py +++ b/google/cloud/compute_v1/services/interconnects/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def delete( self, @@ -163,6 +166,9 @@ def delete( # Send the request response = self._session.delete(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -216,6 +222,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Interconnect.from_json(response.content) @@ -264,6 +273,9 @@ def get_diagnostics( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.InterconnectsGetDiagnosticsResponse.from_json(response.content) @@ -320,7 +332,9 @@ def insert( # Jsonify the request body body = compute.Interconnect.to_json( - request.interconnect_resource, including_default_value_fields=False + request.interconnect_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -343,7 +357,10 @@ def insert( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -382,11 +399,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -399,6 +416,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.InterconnectList.from_json(response.content) @@ -455,7 +475,9 @@ def patch( # Jsonify the request body body = compute.Interconnect.to_json( - request.interconnect_resource, including_default_value_fields=False + request.interconnect_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -478,7 +500,10 @@ def patch( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.patch(url, json=body,) + response = self._session.patch(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) diff --git a/google/cloud/compute_v1/services/license_codes/client.py b/google/cloud/compute_v1/services/license_codes/client.py index 106e680de..1f3ebec3c 100644 --- a/google/cloud/compute_v1/services/license_codes/client.py +++ b/google/cloud/compute_v1/services/license_codes/client.py @@ -264,21 +264,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -321,7 +317,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) diff --git a/google/cloud/compute_v1/services/license_codes/transports/rest.py b/google/cloud/compute_v1/services/license_codes/transports/rest.py index fefbbb2d3..bb0746b91 100644 --- a/google/cloud/compute_v1/services/license_codes/transports/rest.py +++ b/google/cloud/compute_v1/services/license_codes/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def get( self, @@ -140,6 +143,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.LicenseCode.from_json(response.content) @@ -170,6 +176,7 @@ def test_iam_permissions( body = compute.TestPermissionsRequest.to_json( request.test_permissions_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -190,7 +197,10 @@ def test_iam_permissions( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.TestPermissionsResponse.from_json(response.content) diff --git a/google/cloud/compute_v1/services/licenses/client.py b/google/cloud/compute_v1/services/licenses/client.py index a9643c4ef..58f6f8177 100644 --- a/google/cloud/compute_v1/services/licenses/client.py +++ b/google/cloud/compute_v1/services/licenses/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.licenses import pagers from google.cloud.compute_v1.types import compute from .transports.base import LicensesTransport, DEFAULT_CLIENT_INFO @@ -264,21 +265,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -321,7 +318,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -747,7 +744,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.LicensesListResponse: + ) -> pagers.ListPager: r"""Retrieves the list of licenses available in the specified project. This method does not get any licenses that belong to other projects, including licenses @@ -776,7 +773,10 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.LicensesListResponse: + google.cloud.compute_v1.services.licenses.pagers.ListPager: + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -809,6 +809,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/licenses/pagers.py b/google/cloud/compute_v1/services/licenses/pagers.py new file mode 100644 index 000000000..b6c5bd63a --- /dev/null +++ b/google/cloud/compute_v1/services/licenses/pagers.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.LicensesListResponse` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.LicensesListResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.LicensesListResponse], + request: compute.ListLicensesRequest, + response: compute.LicensesListResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListLicensesRequest): + The initial request object. + response (google.cloud.compute_v1.types.LicensesListResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListLicensesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.LicensesListResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.License]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/licenses/transports/rest.py b/google/cloud/compute_v1/services/licenses/transports/rest.py index 9c5df4410..cba33ef8d 100644 --- a/google/cloud/compute_v1/services/licenses/transports/rest.py +++ b/google/cloud/compute_v1/services/licenses/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def delete( self, @@ -150,8 +153,8 @@ def delete( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "requestId": request.request_id, "license": request.license_, + "requestId": request.request_id, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -164,6 +167,9 @@ def delete( # Send the request response = self._session.delete(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -218,6 +224,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.License.from_json(response.content) @@ -316,6 +325,9 @@ def get_iam_policy( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Policy.from_json(response.content) @@ -372,7 +384,9 @@ def insert( # Jsonify the request body body = compute.License.to_json( - request.license_resource, including_default_value_fields=False + request.license_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -395,7 +409,10 @@ def insert( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -431,11 +448,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -448,6 +465,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.LicensesListResponse.from_json(response.content) @@ -528,6 +548,7 @@ def set_iam_policy( body = compute.GlobalSetPolicyRequest.to_json( request.global_set_policy_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -548,7 +569,10 @@ def set_iam_policy( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Policy.from_json(response.content) @@ -580,6 +604,7 @@ def test_iam_permissions( body = compute.TestPermissionsRequest.to_json( request.test_permissions_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -600,7 +625,10 @@ def test_iam_permissions( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.TestPermissionsResponse.from_json(response.content) diff --git a/google/cloud/compute_v1/services/machine_types/client.py b/google/cloud/compute_v1/services/machine_types/client.py index b340321ad..b04341968 100644 --- a/google/cloud/compute_v1/services/machine_types/client.py +++ b/google/cloud/compute_v1/services/machine_types/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.machine_types import pagers from google.cloud.compute_v1.types import compute from .transports.base import MachineTypesTransport, DEFAULT_CLIENT_INFO @@ -264,21 +265,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -321,7 +318,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -334,7 +331,7 @@ def aggregated_list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.MachineTypeAggregatedList: + ) -> pagers.AggregatedListPager: r"""Retrieves an aggregated list of machine types. Args: @@ -355,7 +352,10 @@ def aggregated_list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.MachineTypeAggregatedList: + google.cloud.compute_v1.services.machine_types.pagers.AggregatedListPager: + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -388,6 +388,12 @@ def aggregated_list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.AggregatedListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response @@ -491,7 +497,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.MachineTypeList: + ) -> pagers.ListPager: r"""Retrieves a list of machine types available to the specified project. @@ -520,8 +526,12 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.MachineTypeList: + google.cloud.compute_v1.services.machine_types.pagers.ListPager: Contains a list of machine types. + Iterating over this object will yield + results and resolve additional pages + automatically. + """ # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have @@ -555,6 +565,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/machine_types/pagers.py b/google/cloud/compute_v1/services/machine_types/pagers.py new file mode 100644 index 000000000..8221af685 --- /dev/null +++ b/google/cloud/compute_v1/services/machine_types/pagers.py @@ -0,0 +1,156 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class AggregatedListPager: + """A pager for iterating through ``aggregated_list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.MachineTypeAggregatedList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``AggregatedList`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.MachineTypeAggregatedList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.MachineTypeAggregatedList], + request: compute.AggregatedListMachineTypesRequest, + response: compute.MachineTypeAggregatedList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.AggregatedListMachineTypesRequest): + The initial request object. + response (google.cloud.compute_v1.types.MachineTypeAggregatedList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.AggregatedListMachineTypesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.MachineTypeAggregatedList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[Tuple[str, compute.MachineTypesScopedList]]: + for page in self.pages: + yield from page.items.items() + + def get(self, key: str) -> Optional[compute.MachineTypesScopedList]: + return self._response.items.get(key) + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.MachineTypeList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.MachineTypeList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.MachineTypeList], + request: compute.ListMachineTypesRequest, + response: compute.MachineTypeList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListMachineTypesRequest): + The initial request object. + response (google.cloud.compute_v1.types.MachineTypeList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListMachineTypesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.MachineTypeList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.MachineType]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/machine_types/transports/rest.py b/google/cloud/compute_v1/services/machine_types/transports/rest.py index 551b2362f..e20b1f06b 100644 --- a/google/cloud/compute_v1/services/machine_types/transports/rest.py +++ b/google/cloud/compute_v1/services/machine_types/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def aggregated_list( self, @@ -122,12 +125,12 @@ def aggregated_list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, - "orderBy": request.order_by, "includeAllScopes": request.include_all_scopes, - "returnPartialSuccess": request.return_partial_success, "maxResults": request.max_results, + "orderBy": request.order_by, + "pageToken": request.page_token, + "returnPartialSuccess": request.return_partial_success, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -140,6 +143,9 @@ def aggregated_list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.MachineTypeAggregatedList.from_json(response.content) @@ -195,6 +201,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.MachineType.from_json(response.content) @@ -230,11 +239,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -247,6 +256,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.MachineTypeList.from_json(response.content) diff --git a/google/cloud/compute_v1/services/network_endpoint_groups/client.py b/google/cloud/compute_v1/services/network_endpoint_groups/client.py index eb23f540a..67b7509d0 100644 --- a/google/cloud/compute_v1/services/network_endpoint_groups/client.py +++ b/google/cloud/compute_v1/services/network_endpoint_groups/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.network_endpoint_groups import pagers from google.cloud.compute_v1.types import compute from .transports.base import NetworkEndpointGroupsTransport, DEFAULT_CLIENT_INFO @@ -268,21 +269,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -325,7 +322,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -338,7 +335,7 @@ def aggregated_list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.NetworkEndpointGroupAggregatedList: + ) -> pagers.AggregatedListPager: r"""Retrieves the list of network endpoint groups and sorts them by zone. @@ -360,7 +357,10 @@ def aggregated_list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.NetworkEndpointGroupAggregatedList: + google.cloud.compute_v1.services.network_endpoint_groups.pagers.AggregatedListPager: + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -393,6 +393,12 @@ def aggregated_list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.AggregatedListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response @@ -1002,7 +1008,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.NetworkEndpointGroupList: + ) -> pagers.ListPager: r"""Retrieves the list of network endpoint groups that are located in the specified project and zone. @@ -1032,7 +1038,10 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.NetworkEndpointGroupList: + google.cloud.compute_v1.services.network_endpoint_groups.pagers.ListPager: + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -1067,6 +1076,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response @@ -1081,7 +1096,7 @@ def list_network_endpoints( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.NetworkEndpointGroupsListNetworkEndpoints: + ) -> pagers.ListNetworkEndpointsPager: r"""Lists the network endpoints in the specified network endpoint group. @@ -1125,7 +1140,10 @@ def list_network_endpoints( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.NetworkEndpointGroupsListNetworkEndpoints: + google.cloud.compute_v1.services.network_endpoint_groups.pagers.ListNetworkEndpointsPager: + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -1175,6 +1193,12 @@ def list_network_endpoints( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListNetworkEndpointsPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/network_endpoint_groups/pagers.py b/google/cloud/compute_v1/services/network_endpoint_groups/pagers.py new file mode 100644 index 000000000..6a377efbd --- /dev/null +++ b/google/cloud/compute_v1/services/network_endpoint_groups/pagers.py @@ -0,0 +1,220 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class AggregatedListPager: + """A pager for iterating through ``aggregated_list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.NetworkEndpointGroupAggregatedList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``AggregatedList`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.NetworkEndpointGroupAggregatedList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.NetworkEndpointGroupAggregatedList], + request: compute.AggregatedListNetworkEndpointGroupsRequest, + response: compute.NetworkEndpointGroupAggregatedList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.AggregatedListNetworkEndpointGroupsRequest): + The initial request object. + response (google.cloud.compute_v1.types.NetworkEndpointGroupAggregatedList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.AggregatedListNetworkEndpointGroupsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.NetworkEndpointGroupAggregatedList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[Tuple[str, compute.NetworkEndpointGroupsScopedList]]: + for page in self.pages: + yield from page.items.items() + + def get(self, key: str) -> Optional[compute.NetworkEndpointGroupsScopedList]: + return self._response.items.get(key) + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.NetworkEndpointGroupList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.NetworkEndpointGroupList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.NetworkEndpointGroupList], + request: compute.ListNetworkEndpointGroupsRequest, + response: compute.NetworkEndpointGroupList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListNetworkEndpointGroupsRequest): + The initial request object. + response (google.cloud.compute_v1.types.NetworkEndpointGroupList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListNetworkEndpointGroupsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.NetworkEndpointGroupList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.NetworkEndpointGroup]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListNetworkEndpointsPager: + """A pager for iterating through ``list_network_endpoints`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.NetworkEndpointGroupsListNetworkEndpoints` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ListNetworkEndpoints`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.NetworkEndpointGroupsListNetworkEndpoints` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.NetworkEndpointGroupsListNetworkEndpoints], + request: compute.ListNetworkEndpointsNetworkEndpointGroupsRequest, + response: compute.NetworkEndpointGroupsListNetworkEndpoints, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListNetworkEndpointsNetworkEndpointGroupsRequest): + The initial request object. + response (google.cloud.compute_v1.types.NetworkEndpointGroupsListNetworkEndpoints): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListNetworkEndpointsNetworkEndpointGroupsRequest( + request + ) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.NetworkEndpointGroupsListNetworkEndpoints]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.NetworkEndpointWithHealthStatus]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/network_endpoint_groups/transports/rest.py b/google/cloud/compute_v1/services/network_endpoint_groups/transports/rest.py index 78a06724a..105a67afe 100644 --- a/google/cloud/compute_v1/services/network_endpoint_groups/transports/rest.py +++ b/google/cloud/compute_v1/services/network_endpoint_groups/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def aggregated_list( self, @@ -122,12 +125,12 @@ def aggregated_list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, - "orderBy": request.order_by, "includeAllScopes": request.include_all_scopes, - "returnPartialSuccess": request.return_partial_success, "maxResults": request.max_results, + "orderBy": request.order_by, + "pageToken": request.page_token, + "returnPartialSuccess": request.return_partial_success, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -140,6 +143,9 @@ def aggregated_list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.NetworkEndpointGroupAggregatedList.from_json(response.content) @@ -198,6 +204,7 @@ def attach_network_endpoints( body = compute.NetworkEndpointGroupsAttachEndpointsRequest.to_json( request.network_endpoint_groups_attach_endpoints_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -223,7 +230,10 @@ def attach_network_endpoints( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -304,6 +314,9 @@ def delete( # Send the request response = self._session.delete(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -362,6 +375,7 @@ def detach_network_endpoints( body = compute.NetworkEndpointGroupsDetachEndpointsRequest.to_json( request.network_endpoint_groups_detach_endpoints_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -387,7 +401,10 @@ def detach_network_endpoints( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -451,6 +468,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.NetworkEndpointGroup.from_json(response.content) @@ -509,6 +529,7 @@ def insert( body = compute.NetworkEndpointGroup.to_json( request.network_endpoint_group_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -531,7 +552,10 @@ def insert( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -568,11 +592,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -585,6 +609,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.NetworkEndpointGroupList.from_json(response.content) @@ -615,6 +642,7 @@ def list_network_endpoints( body = compute.NetworkEndpointGroupsListEndpointsRequest.to_json( request.network_endpoint_groups_list_endpoints_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -629,11 +657,11 @@ def list_network_endpoints( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -644,7 +672,10 @@ def list_network_endpoints( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.NetworkEndpointGroupsListNetworkEndpoints.from_json( @@ -678,6 +709,7 @@ def test_iam_permissions( body = compute.TestPermissionsRequest.to_json( request.test_permissions_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -701,7 +733,10 @@ def test_iam_permissions( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.TestPermissionsResponse.from_json(response.content) diff --git a/google/cloud/compute_v1/services/networks/client.py b/google/cloud/compute_v1/services/networks/client.py index 8e7a6406a..e569f7f57 100644 --- a/google/cloud/compute_v1/services/networks/client.py +++ b/google/cloud/compute_v1/services/networks/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.networks import pagers from google.cloud.compute_v1.types import compute from .transports.base import NetworksTransport, DEFAULT_CLIENT_INFO @@ -264,21 +265,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -321,7 +318,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -720,7 +717,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.NetworkList: + ) -> pagers.ListPager: r"""Retrieves the list of networks available to the specified project. @@ -741,8 +738,12 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.NetworkList: + google.cloud.compute_v1.services.networks.pagers.ListPager: Contains a list of networks. + Iterating over this object will yield + results and resolve additional pages + automatically. + """ # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have @@ -774,6 +775,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response @@ -786,7 +793,7 @@ def list_peering_routes( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.ExchangedPeeringRoutesList: + ) -> pagers.ListPeeringRoutesPager: r"""Lists the peering routes exchanged over peering connection. @@ -813,7 +820,10 @@ def list_peering_routes( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.ExchangedPeeringRoutesList: + google.cloud.compute_v1.services.networks.pagers.ListPeeringRoutesPager: + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -848,6 +858,12 @@ def list_peering_routes( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPeeringRoutesPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/networks/pagers.py b/google/cloud/compute_v1/services/networks/pagers.py new file mode 100644 index 000000000..3866c6c47 --- /dev/null +++ b/google/cloud/compute_v1/services/networks/pagers.py @@ -0,0 +1,153 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.NetworkList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.NetworkList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.NetworkList], + request: compute.ListNetworksRequest, + response: compute.NetworkList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListNetworksRequest): + The initial request object. + response (google.cloud.compute_v1.types.NetworkList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListNetworksRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.NetworkList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.Network]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListPeeringRoutesPager: + """A pager for iterating through ``list_peering_routes`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.ExchangedPeeringRoutesList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ListPeeringRoutes`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.ExchangedPeeringRoutesList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.ExchangedPeeringRoutesList], + request: compute.ListPeeringRoutesNetworksRequest, + response: compute.ExchangedPeeringRoutesList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListPeeringRoutesNetworksRequest): + The initial request object. + response (google.cloud.compute_v1.types.ExchangedPeeringRoutesList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListPeeringRoutesNetworksRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.ExchangedPeeringRoutesList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.ExchangedPeeringRoute]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/networks/transports/rest.py b/google/cloud/compute_v1/services/networks/transports/rest.py index 8e3d8a9f5..3615fc004 100644 --- a/google/cloud/compute_v1/services/networks/transports/rest.py +++ b/google/cloud/compute_v1/services/networks/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def add_peering( self, @@ -145,6 +148,7 @@ def add_peering( body = compute.NetworksAddPeeringRequest.to_json( request.networks_add_peering_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -167,7 +171,10 @@ def add_peering( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -245,6 +252,9 @@ def delete( # Send the request response = self._session.delete(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -296,6 +306,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Network.from_json(response.content) @@ -352,7 +365,9 @@ def insert( # Jsonify the request body body = compute.Network.to_json( - request.network_resource, including_default_value_fields=False + request.network_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -375,7 +390,10 @@ def insert( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -411,11 +429,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -428,6 +446,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.NetworkList.from_json(response.content) @@ -463,14 +484,14 @@ def list_peering_routes( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, + "direction": request.direction, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, - "region": request.region, + "pageToken": request.page_token, "peeringName": request.peering_name, + "region": request.region, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, - "direction": request.direction, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -483,6 +504,9 @@ def list_peering_routes( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.ExchangedPeeringRoutesList.from_json(response.content) @@ -538,7 +562,9 @@ def patch( # Jsonify the request body body = compute.Network.to_json( - request.network_resource, including_default_value_fields=False + request.network_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -561,7 +587,10 @@ def patch( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.patch(url, json=body,) + response = self._session.patch(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -621,6 +650,7 @@ def remove_peering( body = compute.NetworksRemovePeeringRequest.to_json( request.networks_remove_peering_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -643,7 +673,10 @@ def remove_peering( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -721,6 +754,9 @@ def switch_to_custom_mode( # Send the request response = self._session.post(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -779,6 +815,7 @@ def update_peering( body = compute.NetworksUpdatePeeringRequest.to_json( request.networks_update_peering_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -801,7 +838,10 @@ def update_peering( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.patch(url, json=body,) + response = self._session.patch(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) diff --git a/google/cloud/compute_v1/services/node_groups/client.py b/google/cloud/compute_v1/services/node_groups/client.py index 0cff2a49b..14d60b3fc 100644 --- a/google/cloud/compute_v1/services/node_groups/client.py +++ b/google/cloud/compute_v1/services/node_groups/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.node_groups import pagers from google.cloud.compute_v1.types import compute from .transports.base import NodeGroupsTransport, DEFAULT_CLIENT_INFO @@ -264,21 +265,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -321,7 +318,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -454,7 +451,7 @@ def aggregated_list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.NodeGroupAggregatedList: + ) -> pagers.AggregatedListPager: r"""Retrieves an aggregated list of node groups. Note: use nodeGroups.listNodes for more details about each group. @@ -477,7 +474,10 @@ def aggregated_list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.NodeGroupAggregatedList: + google.cloud.compute_v1.services.node_groups.pagers.AggregatedListPager: + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -510,6 +510,12 @@ def aggregated_list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.AggregatedListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response @@ -1107,7 +1113,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.NodeGroupList: + ) -> pagers.ListPager: r"""Retrieves a list of node groups available to the specified project. Note: use nodeGroups.listNodes for more details about each group. @@ -1136,8 +1142,12 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.NodeGroupList: + google.cloud.compute_v1.services.node_groups.pagers.ListPager: Contains a list of nodeGroups. + Iterating over this object will yield + results and resolve additional pages + automatically. + """ # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have @@ -1171,6 +1181,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response @@ -1184,7 +1200,7 @@ def list_nodes( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.NodeGroupsListNodes: + ) -> pagers.ListNodesPager: r"""Lists nodes in the node group. Args: @@ -1219,7 +1235,10 @@ def list_nodes( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.NodeGroupsListNodes: + google.cloud.compute_v1.services.node_groups.pagers.ListNodesPager: + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -1256,6 +1275,12 @@ def list_nodes( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListNodesPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/node_groups/pagers.py b/google/cloud/compute_v1/services/node_groups/pagers.py new file mode 100644 index 000000000..3bedd0f97 --- /dev/null +++ b/google/cloud/compute_v1/services/node_groups/pagers.py @@ -0,0 +1,218 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class AggregatedListPager: + """A pager for iterating through ``aggregated_list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.NodeGroupAggregatedList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``AggregatedList`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.NodeGroupAggregatedList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.NodeGroupAggregatedList], + request: compute.AggregatedListNodeGroupsRequest, + response: compute.NodeGroupAggregatedList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.AggregatedListNodeGroupsRequest): + The initial request object. + response (google.cloud.compute_v1.types.NodeGroupAggregatedList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.AggregatedListNodeGroupsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.NodeGroupAggregatedList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[Tuple[str, compute.NodeGroupsScopedList]]: + for page in self.pages: + yield from page.items.items() + + def get(self, key: str) -> Optional[compute.NodeGroupsScopedList]: + return self._response.items.get(key) + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.NodeGroupList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.NodeGroupList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.NodeGroupList], + request: compute.ListNodeGroupsRequest, + response: compute.NodeGroupList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListNodeGroupsRequest): + The initial request object. + response (google.cloud.compute_v1.types.NodeGroupList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListNodeGroupsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.NodeGroupList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.NodeGroup]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListNodesPager: + """A pager for iterating through ``list_nodes`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.NodeGroupsListNodes` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ListNodes`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.NodeGroupsListNodes` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.NodeGroupsListNodes], + request: compute.ListNodesNodeGroupsRequest, + response: compute.NodeGroupsListNodes, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListNodesNodeGroupsRequest): + The initial request object. + response (google.cloud.compute_v1.types.NodeGroupsListNodes): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListNodesNodeGroupsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.NodeGroupsListNodes]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.NodeGroupNode]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/node_groups/transports/rest.py b/google/cloud/compute_v1/services/node_groups/transports/rest.py index f85e27fd6..549476201 100644 --- a/google/cloud/compute_v1/services/node_groups/transports/rest.py +++ b/google/cloud/compute_v1/services/node_groups/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def add_nodes( self, @@ -145,6 +148,7 @@ def add_nodes( body = compute.NodeGroupsAddNodesRequest.to_json( request.node_groups_add_nodes_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -170,7 +174,10 @@ def add_nodes( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -207,12 +214,12 @@ def aggregated_list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, - "orderBy": request.order_by, "includeAllScopes": request.include_all_scopes, - "returnPartialSuccess": request.return_partial_success, "maxResults": request.max_results, + "orderBy": request.order_by, + "pageToken": request.page_token, + "returnPartialSuccess": request.return_partial_success, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -225,6 +232,9 @@ def aggregated_list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.NodeGroupAggregatedList.from_json(response.content) @@ -304,6 +314,9 @@ def delete( # Send the request response = self._session.delete(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -362,6 +375,7 @@ def delete_nodes( body = compute.NodeGroupsDeleteNodesRequest.to_json( request.node_groups_delete_nodes_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -387,7 +401,10 @@ def delete_nodes( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -446,6 +463,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.NodeGroup.from_json(response.content) @@ -547,6 +567,9 @@ def get_iam_policy( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Policy.from_json(response.content) @@ -603,7 +626,9 @@ def insert( # Jsonify the request body body = compute.NodeGroup.to_json( - request.node_group_resource, including_default_value_fields=False + request.node_group_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -615,8 +640,8 @@ def insert( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "requestId": request.request_id, "initialNodeCount": request.initial_node_count, + "requestId": request.request_id, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -627,7 +652,10 @@ def insert( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -664,11 +692,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -681,6 +709,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.NodeGroupList.from_json(response.content) @@ -719,11 +750,11 @@ def list_nodes( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -736,6 +767,9 @@ def list_nodes( # Send the request response = self._session.post(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.NodeGroupsListNodes.from_json(response.content) @@ -792,7 +826,9 @@ def patch( # Jsonify the request body body = compute.NodeGroup.to_json( - request.node_group_resource, including_default_value_fields=False + request.node_group_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -818,7 +854,10 @@ def patch( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.patch(url, json=body,) + response = self._session.patch(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -900,6 +939,7 @@ def set_iam_policy( body = compute.ZoneSetPolicyRequest.to_json( request.zone_set_policy_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -923,7 +963,10 @@ def set_iam_policy( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Policy.from_json(response.content) @@ -983,6 +1026,7 @@ def set_node_template( body = compute.NodeGroupsSetNodeTemplateRequest.to_json( request.node_groups_set_node_template_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -1008,7 +1052,10 @@ def set_node_template( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -1040,6 +1087,7 @@ def test_iam_permissions( body = compute.TestPermissionsRequest.to_json( request.test_permissions_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -1063,7 +1111,10 @@ def test_iam_permissions( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.TestPermissionsResponse.from_json(response.content) diff --git a/google/cloud/compute_v1/services/node_templates/client.py b/google/cloud/compute_v1/services/node_templates/client.py index 8b08ce4f3..27bff6713 100644 --- a/google/cloud/compute_v1/services/node_templates/client.py +++ b/google/cloud/compute_v1/services/node_templates/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.node_templates import pagers from google.cloud.compute_v1.types import compute from .transports.base import NodeTemplatesTransport, DEFAULT_CLIENT_INFO @@ -264,21 +265,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -321,7 +318,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -334,7 +331,7 @@ def aggregated_list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.NodeTemplateAggregatedList: + ) -> pagers.AggregatedListPager: r"""Retrieves an aggregated list of node templates. Args: @@ -355,7 +352,10 @@ def aggregated_list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.NodeTemplateAggregatedList: + google.cloud.compute_v1.services.node_templates.pagers.AggregatedListPager: + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -388,6 +388,12 @@ def aggregated_list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.AggregatedListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response @@ -846,7 +852,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.NodeTemplateList: + ) -> pagers.ListPager: r"""Retrieves a list of node templates available to the specified project. @@ -875,8 +881,12 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.NodeTemplateList: + google.cloud.compute_v1.services.node_templates.pagers.ListPager: Contains a list of node templates. + Iterating over this object will yield + results and resolve additional pages + automatically. + """ # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have @@ -910,6 +920,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/node_templates/pagers.py b/google/cloud/compute_v1/services/node_templates/pagers.py new file mode 100644 index 000000000..16570a828 --- /dev/null +++ b/google/cloud/compute_v1/services/node_templates/pagers.py @@ -0,0 +1,156 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class AggregatedListPager: + """A pager for iterating through ``aggregated_list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.NodeTemplateAggregatedList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``AggregatedList`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.NodeTemplateAggregatedList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.NodeTemplateAggregatedList], + request: compute.AggregatedListNodeTemplatesRequest, + response: compute.NodeTemplateAggregatedList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.AggregatedListNodeTemplatesRequest): + The initial request object. + response (google.cloud.compute_v1.types.NodeTemplateAggregatedList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.AggregatedListNodeTemplatesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.NodeTemplateAggregatedList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[Tuple[str, compute.NodeTemplatesScopedList]]: + for page in self.pages: + yield from page.items.items() + + def get(self, key: str) -> Optional[compute.NodeTemplatesScopedList]: + return self._response.items.get(key) + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.NodeTemplateList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.NodeTemplateList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.NodeTemplateList], + request: compute.ListNodeTemplatesRequest, + response: compute.NodeTemplateList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListNodeTemplatesRequest): + The initial request object. + response (google.cloud.compute_v1.types.NodeTemplateList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListNodeTemplatesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.NodeTemplateList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.NodeTemplate]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/node_templates/transports/rest.py b/google/cloud/compute_v1/services/node_templates/transports/rest.py index e1c85b561..840c034dd 100644 --- a/google/cloud/compute_v1/services/node_templates/transports/rest.py +++ b/google/cloud/compute_v1/services/node_templates/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def aggregated_list( self, @@ -122,12 +125,12 @@ def aggregated_list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, - "orderBy": request.order_by, "includeAllScopes": request.include_all_scopes, - "returnPartialSuccess": request.return_partial_success, "maxResults": request.max_results, + "orderBy": request.order_by, + "pageToken": request.page_token, + "returnPartialSuccess": request.return_partial_success, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -140,6 +143,9 @@ def aggregated_list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.NodeTemplateAggregatedList.from_json(response.content) @@ -219,6 +225,9 @@ def delete( # Send the request response = self._session.delete(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -274,6 +283,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.NodeTemplate.from_json(response.content) @@ -375,6 +387,9 @@ def get_iam_policy( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Policy.from_json(response.content) @@ -431,7 +446,9 @@ def insert( # Jsonify the request body body = compute.NodeTemplate.to_json( - request.node_template_resource, including_default_value_fields=False + request.node_template_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -454,7 +471,10 @@ def insert( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -491,11 +511,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -508,6 +528,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.NodeTemplateList.from_json(response.content) @@ -588,6 +611,7 @@ def set_iam_policy( body = compute.RegionSetPolicyRequest.to_json( request.region_set_policy_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -611,7 +635,10 @@ def set_iam_policy( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Policy.from_json(response.content) @@ -643,6 +670,7 @@ def test_iam_permissions( body = compute.TestPermissionsRequest.to_json( request.test_permissions_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -666,7 +694,10 @@ def test_iam_permissions( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.TestPermissionsResponse.from_json(response.content) diff --git a/google/cloud/compute_v1/services/node_types/client.py b/google/cloud/compute_v1/services/node_types/client.py index ffbb26eda..c33817efb 100644 --- a/google/cloud/compute_v1/services/node_types/client.py +++ b/google/cloud/compute_v1/services/node_types/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.node_types import pagers from google.cloud.compute_v1.types import compute from .transports.base import NodeTypesTransport, DEFAULT_CLIENT_INFO @@ -264,21 +265,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -321,7 +318,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -334,7 +331,7 @@ def aggregated_list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.NodeTypeAggregatedList: + ) -> pagers.AggregatedListPager: r"""Retrieves an aggregated list of node types. Args: @@ -355,7 +352,10 @@ def aggregated_list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.NodeTypeAggregatedList: + google.cloud.compute_v1.services.node_types.pagers.AggregatedListPager: + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -388,6 +388,12 @@ def aggregated_list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.AggregatedListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response @@ -492,7 +498,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.NodeTypeList: + ) -> pagers.ListPager: r"""Retrieves a list of node types available to the specified project. @@ -520,8 +526,12 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.NodeTypeList: + google.cloud.compute_v1.services.node_types.pagers.ListPager: Contains a list of node types. + Iterating over this object will yield + results and resolve additional pages + automatically. + """ # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have @@ -555,6 +565,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/node_types/pagers.py b/google/cloud/compute_v1/services/node_types/pagers.py new file mode 100644 index 000000000..1c5849e80 --- /dev/null +++ b/google/cloud/compute_v1/services/node_types/pagers.py @@ -0,0 +1,156 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class AggregatedListPager: + """A pager for iterating through ``aggregated_list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.NodeTypeAggregatedList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``AggregatedList`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.NodeTypeAggregatedList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.NodeTypeAggregatedList], + request: compute.AggregatedListNodeTypesRequest, + response: compute.NodeTypeAggregatedList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.AggregatedListNodeTypesRequest): + The initial request object. + response (google.cloud.compute_v1.types.NodeTypeAggregatedList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.AggregatedListNodeTypesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.NodeTypeAggregatedList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[Tuple[str, compute.NodeTypesScopedList]]: + for page in self.pages: + yield from page.items.items() + + def get(self, key: str) -> Optional[compute.NodeTypesScopedList]: + return self._response.items.get(key) + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.NodeTypeList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.NodeTypeList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.NodeTypeList], + request: compute.ListNodeTypesRequest, + response: compute.NodeTypeList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListNodeTypesRequest): + The initial request object. + response (google.cloud.compute_v1.types.NodeTypeList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListNodeTypesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.NodeTypeList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.NodeType]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/node_types/transports/rest.py b/google/cloud/compute_v1/services/node_types/transports/rest.py index ac3352295..711ddebb0 100644 --- a/google/cloud/compute_v1/services/node_types/transports/rest.py +++ b/google/cloud/compute_v1/services/node_types/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def aggregated_list( self, @@ -122,12 +125,12 @@ def aggregated_list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, - "orderBy": request.order_by, "includeAllScopes": request.include_all_scopes, - "returnPartialSuccess": request.return_partial_success, "maxResults": request.max_results, + "orderBy": request.order_by, + "pageToken": request.page_token, + "returnPartialSuccess": request.return_partial_success, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -140,6 +143,9 @@ def aggregated_list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.NodeTypeAggregatedList.from_json(response.content) @@ -197,6 +203,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.NodeType.from_json(response.content) @@ -231,11 +240,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -248,6 +257,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.NodeTypeList.from_json(response.content) diff --git a/google/cloud/compute_v1/services/packet_mirrorings/client.py b/google/cloud/compute_v1/services/packet_mirrorings/client.py index 0a5a796e3..620631092 100644 --- a/google/cloud/compute_v1/services/packet_mirrorings/client.py +++ b/google/cloud/compute_v1/services/packet_mirrorings/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.packet_mirrorings import pagers from google.cloud.compute_v1.types import compute from .transports.base import PacketMirroringsTransport, DEFAULT_CLIENT_INFO @@ -266,21 +267,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -323,7 +320,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -336,7 +333,7 @@ def aggregated_list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.PacketMirroringAggregatedList: + ) -> pagers.AggregatedListPager: r"""Retrieves an aggregated list of packetMirrorings. Args: @@ -357,8 +354,12 @@ def aggregated_list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.PacketMirroringAggregatedList: + google.cloud.compute_v1.services.packet_mirrorings.pagers.AggregatedListPager: Contains a list of packetMirrorings. + Iterating over this object will yield + results and resolve additional pages + automatically. + """ # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have @@ -390,6 +391,12 @@ def aggregated_list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.AggregatedListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response @@ -711,7 +718,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.PacketMirroringList: + ) -> pagers.ListPager: r"""Retrieves a list of PacketMirroring resources available to the specified project and region. @@ -738,9 +745,12 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.PacketMirroringList: + google.cloud.compute_v1.services.packet_mirrorings.pagers.ListPager: Contains a list of PacketMirroring resources. + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -775,6 +785,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/packet_mirrorings/pagers.py b/google/cloud/compute_v1/services/packet_mirrorings/pagers.py new file mode 100644 index 000000000..40b54f2d3 --- /dev/null +++ b/google/cloud/compute_v1/services/packet_mirrorings/pagers.py @@ -0,0 +1,156 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class AggregatedListPager: + """A pager for iterating through ``aggregated_list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.PacketMirroringAggregatedList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``AggregatedList`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.PacketMirroringAggregatedList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.PacketMirroringAggregatedList], + request: compute.AggregatedListPacketMirroringsRequest, + response: compute.PacketMirroringAggregatedList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.AggregatedListPacketMirroringsRequest): + The initial request object. + response (google.cloud.compute_v1.types.PacketMirroringAggregatedList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.AggregatedListPacketMirroringsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.PacketMirroringAggregatedList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[Tuple[str, compute.PacketMirroringsScopedList]]: + for page in self.pages: + yield from page.items.items() + + def get(self, key: str) -> Optional[compute.PacketMirroringsScopedList]: + return self._response.items.get(key) + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.PacketMirroringList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.PacketMirroringList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.PacketMirroringList], + request: compute.ListPacketMirroringsRequest, + response: compute.PacketMirroringList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListPacketMirroringsRequest): + The initial request object. + response (google.cloud.compute_v1.types.PacketMirroringList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListPacketMirroringsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.PacketMirroringList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.PacketMirroring]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/packet_mirrorings/transports/rest.py b/google/cloud/compute_v1/services/packet_mirrorings/transports/rest.py index 2149c5d4a..a73754a93 100644 --- a/google/cloud/compute_v1/services/packet_mirrorings/transports/rest.py +++ b/google/cloud/compute_v1/services/packet_mirrorings/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def aggregated_list( self, @@ -122,12 +125,12 @@ def aggregated_list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, - "orderBy": request.order_by, "includeAllScopes": request.include_all_scopes, - "returnPartialSuccess": request.return_partial_success, "maxResults": request.max_results, + "orderBy": request.order_by, + "pageToken": request.page_token, + "returnPartialSuccess": request.return_partial_success, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -140,6 +143,9 @@ def aggregated_list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.PacketMirroringAggregatedList.from_json(response.content) @@ -219,6 +225,9 @@ def delete( # Send the request response = self._session.delete(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -277,6 +286,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.PacketMirroring.from_json(response.content) @@ -333,7 +345,9 @@ def insert( # Jsonify the request body body = compute.PacketMirroring.to_json( - request.packet_mirroring_resource, including_default_value_fields=False + request.packet_mirroring_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -356,7 +370,10 @@ def insert( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -395,11 +412,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -412,6 +429,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.PacketMirroringList.from_json(response.content) @@ -468,7 +488,9 @@ def patch( # Jsonify the request body body = compute.PacketMirroring.to_json( - request.packet_mirroring_resource, including_default_value_fields=False + request.packet_mirroring_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -494,7 +516,10 @@ def patch( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.patch(url, json=body,) + response = self._session.patch(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -526,6 +551,7 @@ def test_iam_permissions( body = compute.TestPermissionsRequest.to_json( request.test_permissions_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -549,7 +575,10 @@ def test_iam_permissions( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.TestPermissionsResponse.from_json(response.content) diff --git a/google/cloud/compute_v1/services/projects/client.py b/google/cloud/compute_v1/services/projects/client.py index 35e4e5380..d8105efb1 100644 --- a/google/cloud/compute_v1/services/projects/client.py +++ b/google/cloud/compute_v1/services/projects/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.projects import pagers from google.cloud.compute_v1.types import compute from .transports.base import ProjectsTransport, DEFAULT_CLIENT_INFO @@ -264,21 +265,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -321,7 +318,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -863,7 +860,7 @@ def get_xpn_resources( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.ProjectsGetXpnResources: + ) -> pagers.GetXpnResourcesPager: r"""Gets service resources (a.k.a service project) associated with this host project. @@ -885,7 +882,10 @@ def get_xpn_resources( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.ProjectsGetXpnResources: + google.cloud.compute_v1.services.projects.pagers.GetXpnResourcesPager: + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -918,6 +918,12 @@ def get_xpn_resources( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.GetXpnResourcesPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response @@ -930,7 +936,7 @@ def list_xpn_hosts( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.XpnHostList: + ) -> pagers.ListXpnHostsPager: r"""Lists all shared VPC host projects visible to the user in an organization. @@ -957,7 +963,10 @@ def list_xpn_hosts( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.XpnHostList: + google.cloud.compute_v1.services.projects.pagers.ListXpnHostsPager: + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -994,6 +1003,12 @@ def list_xpn_hosts( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListXpnHostsPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/projects/pagers.py b/google/cloud/compute_v1/services/projects/pagers.py new file mode 100644 index 000000000..f94918d85 --- /dev/null +++ b/google/cloud/compute_v1/services/projects/pagers.py @@ -0,0 +1,153 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class GetXpnResourcesPager: + """A pager for iterating through ``get_xpn_resources`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.ProjectsGetXpnResources` object, and + provides an ``__iter__`` method to iterate through its + ``resources`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``GetXpnResources`` requests and continue to iterate + through the ``resources`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.ProjectsGetXpnResources` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.ProjectsGetXpnResources], + request: compute.GetXpnResourcesProjectsRequest, + response: compute.ProjectsGetXpnResources, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.GetXpnResourcesProjectsRequest): + The initial request object. + response (google.cloud.compute_v1.types.ProjectsGetXpnResources): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.GetXpnResourcesProjectsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.ProjectsGetXpnResources]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.XpnResourceId]: + for page in self.pages: + yield from page.resources + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListXpnHostsPager: + """A pager for iterating through ``list_xpn_hosts`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.XpnHostList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ListXpnHosts`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.XpnHostList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.XpnHostList], + request: compute.ListXpnHostsProjectsRequest, + response: compute.XpnHostList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListXpnHostsProjectsRequest): + The initial request object. + response (google.cloud.compute_v1.types.XpnHostList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListXpnHostsProjectsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.XpnHostList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.Project]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/projects/transports/rest.py b/google/cloud/compute_v1/services/projects/transports/rest.py index 54fb49087..b33374d53 100644 --- a/google/cloud/compute_v1/services/projects/transports/rest.py +++ b/google/cloud/compute_v1/services/projects/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def disable_xpn_host( self, @@ -163,6 +166,9 @@ def disable_xpn_host( # Send the request response = self._session.post(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -221,6 +227,7 @@ def disable_xpn_resource( body = compute.ProjectsDisableXpnResourceRequest.to_json( request.projects_disable_xpn_resource_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -243,7 +250,10 @@ def disable_xpn_resource( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -321,6 +331,9 @@ def enable_xpn_host( # Send the request response = self._session.post(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -379,6 +392,7 @@ def enable_xpn_resource( body = compute.ProjectsEnableXpnResourceRequest.to_json( request.projects_enable_xpn_resource_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -401,7 +415,10 @@ def enable_xpn_resource( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -454,6 +471,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Project.from_json(response.content) @@ -506,6 +526,9 @@ def get_xpn_host( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Project.from_json(response.content) @@ -541,11 +564,11 @@ def get_xpn_resources( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -558,6 +581,9 @@ def get_xpn_resources( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.ProjectsGetXpnResources.from_json(response.content) @@ -588,6 +614,7 @@ def list_xpn_hosts( body = compute.ProjectsListXpnHostsRequest.to_json( request.projects_list_xpn_hosts_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -599,11 +626,11 @@ def list_xpn_hosts( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -614,7 +641,10 @@ def list_xpn_hosts( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.XpnHostList.from_json(response.content) @@ -672,7 +702,9 @@ def move_disk( # Jsonify the request body body = compute.DiskMoveRequest.to_json( - request.disk_move_request_resource, including_default_value_fields=False + request.disk_move_request_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -695,7 +727,10 @@ def move_disk( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -753,7 +788,9 @@ def move_instance( # Jsonify the request body body = compute.InstanceMoveRequest.to_json( - request.instance_move_request_resource, including_default_value_fields=False + request.instance_move_request_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -776,7 +813,10 @@ def move_instance( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -835,7 +875,9 @@ def set_common_instance_metadata( # Jsonify the request body body = compute.Metadata.to_json( - request.metadata_resource, including_default_value_fields=False + request.metadata_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -858,7 +900,10 @@ def set_common_instance_metadata( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -918,6 +963,7 @@ def set_default_network_tier( body = compute.ProjectsSetDefaultNetworkTierRequest.to_json( request.projects_set_default_network_tier_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -940,7 +986,10 @@ def set_default_network_tier( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -998,7 +1047,9 @@ def set_usage_export_bucket( # Jsonify the request body body = compute.UsageExportLocation.to_json( - request.usage_export_location_resource, including_default_value_fields=False + request.usage_export_location_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -1021,7 +1072,10 @@ def set_usage_export_bucket( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) diff --git a/google/cloud/compute_v1/services/region_autoscalers/client.py b/google/cloud/compute_v1/services/region_autoscalers/client.py index d31eed4bb..8f0d64cab 100644 --- a/google/cloud/compute_v1/services/region_autoscalers/client.py +++ b/google/cloud/compute_v1/services/region_autoscalers/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.region_autoscalers import pagers from google.cloud.compute_v1.types import compute from .transports.base import RegionAutoscalersTransport, DEFAULT_CLIENT_INFO @@ -268,21 +269,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -325,7 +322,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -659,7 +656,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.RegionAutoscalerList: + ) -> pagers.ListPager: r"""Retrieves a list of autoscalers contained within the specified region. @@ -688,8 +685,12 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.RegionAutoscalerList: + google.cloud.compute_v1.services.region_autoscalers.pagers.ListPager: Contains a list of autoscalers. + Iterating over this object will yield + results and resolve additional pages + automatically. + """ # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have @@ -723,6 +724,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/region_autoscalers/pagers.py b/google/cloud/compute_v1/services/region_autoscalers/pagers.py new file mode 100644 index 000000000..6b12d3baf --- /dev/null +++ b/google/cloud/compute_v1/services/region_autoscalers/pagers.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.RegionAutoscalerList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.RegionAutoscalerList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.RegionAutoscalerList], + request: compute.ListRegionAutoscalersRequest, + response: compute.RegionAutoscalerList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListRegionAutoscalersRequest): + The initial request object. + response (google.cloud.compute_v1.types.RegionAutoscalerList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListRegionAutoscalersRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.RegionAutoscalerList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.Autoscaler]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/region_autoscalers/transports/rest.py b/google/cloud/compute_v1/services/region_autoscalers/transports/rest.py index 18a6c7397..3a39dd6a2 100644 --- a/google/cloud/compute_v1/services/region_autoscalers/transports/rest.py +++ b/google/cloud/compute_v1/services/region_autoscalers/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def delete( self, @@ -166,6 +169,9 @@ def delete( # Send the request response = self._session.delete(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -235,6 +241,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Autoscaler.from_json(response.content) @@ -291,7 +300,9 @@ def insert( # Jsonify the request body body = compute.Autoscaler.to_json( - request.autoscaler_resource, including_default_value_fields=False + request.autoscaler_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -314,7 +325,10 @@ def insert( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -351,11 +365,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -368,6 +382,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.RegionAutoscalerList.from_json(response.content) @@ -424,7 +441,9 @@ def patch( # Jsonify the request body body = compute.Autoscaler.to_json( - request.autoscaler_resource, including_default_value_fields=False + request.autoscaler_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -436,8 +455,8 @@ def patch( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "requestId": request.request_id, "autoscaler": request.autoscaler, + "requestId": request.request_id, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -448,7 +467,10 @@ def patch( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.patch(url, json=body,) + response = self._session.patch(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -506,7 +528,9 @@ def update( # Jsonify the request body body = compute.Autoscaler.to_json( - request.autoscaler_resource, including_default_value_fields=False + request.autoscaler_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -518,8 +542,8 @@ def update( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "requestId": request.request_id, "autoscaler": request.autoscaler, + "requestId": request.request_id, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -530,7 +554,10 @@ def update( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.put(url, json=body,) + response = self._session.put(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) diff --git a/google/cloud/compute_v1/services/region_backend_services/client.py b/google/cloud/compute_v1/services/region_backend_services/client.py index 4081f7c09..7f0d5ab14 100644 --- a/google/cloud/compute_v1/services/region_backend_services/client.py +++ b/google/cloud/compute_v1/services/region_backend_services/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.region_backend_services import pagers from google.cloud.compute_v1.types import compute from .transports.base import RegionBackendServicesTransport, DEFAULT_CLIENT_INFO @@ -268,21 +269,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -325,7 +322,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -767,7 +764,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.BackendServiceList: + ) -> pagers.ListPager: r"""Retrieves the list of regional BackendService resources available to the specified project in the given region. @@ -797,9 +794,12 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.BackendServiceList: + google.cloud.compute_v1.services.region_backend_services.pagers.ListPager: Contains a list of BackendService resources. + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -834,6 +834,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/region_backend_services/pagers.py b/google/cloud/compute_v1/services/region_backend_services/pagers.py new file mode 100644 index 000000000..df369b0ca --- /dev/null +++ b/google/cloud/compute_v1/services/region_backend_services/pagers.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.BackendServiceList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.BackendServiceList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.BackendServiceList], + request: compute.ListRegionBackendServicesRequest, + response: compute.BackendServiceList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListRegionBackendServicesRequest): + The initial request object. + response (google.cloud.compute_v1.types.BackendServiceList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListRegionBackendServicesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.BackendServiceList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.BackendService]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/region_backend_services/transports/rest.py b/google/cloud/compute_v1/services/region_backend_services/transports/rest.py index ce709187a..0e30895e7 100644 --- a/google/cloud/compute_v1/services/region_backend_services/transports/rest.py +++ b/google/cloud/compute_v1/services/region_backend_services/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def delete( self, @@ -166,6 +169,9 @@ def delete( # Send the request response = self._session.delete(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -237,6 +243,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.BackendService.from_json(response.content) @@ -267,6 +276,7 @@ def get_health( body = compute.ResourceGroupReference.to_json( request.resource_group_reference_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -290,7 +300,10 @@ def get_health( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.BackendServiceGroupHealth.from_json(response.content) @@ -348,7 +361,9 @@ def insert( # Jsonify the request body body = compute.BackendService.to_json( - request.backend_service_resource, including_default_value_fields=False + request.backend_service_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -371,7 +386,10 @@ def insert( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -410,11 +428,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -427,6 +445,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.BackendServiceList.from_json(response.content) @@ -483,7 +504,9 @@ def patch( # Jsonify the request body body = compute.BackendService.to_json( - request.backend_service_resource, including_default_value_fields=False + request.backend_service_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -509,7 +532,10 @@ def patch( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.patch(url, json=body,) + response = self._session.patch(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -567,7 +593,9 @@ def update( # Jsonify the request body body = compute.BackendService.to_json( - request.backend_service_resource, including_default_value_fields=False + request.backend_service_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -593,7 +621,10 @@ def update( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.put(url, json=body,) + response = self._session.put(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) diff --git a/google/cloud/compute_v1/services/region_commitments/client.py b/google/cloud/compute_v1/services/region_commitments/client.py index eaf0b64b7..37cb7eaa2 100644 --- a/google/cloud/compute_v1/services/region_commitments/client.py +++ b/google/cloud/compute_v1/services/region_commitments/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.region_commitments import pagers from google.cloud.compute_v1.types import compute from .transports.base import RegionCommitmentsTransport, DEFAULT_CLIENT_INFO @@ -268,21 +269,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -325,7 +322,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -338,7 +335,7 @@ def aggregated_list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.CommitmentAggregatedList: + ) -> pagers.AggregatedListPager: r"""Retrieves an aggregated list of commitments. Args: @@ -359,7 +356,10 @@ def aggregated_list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.CommitmentAggregatedList: + google.cloud.compute_v1.services.region_commitments.pagers.AggregatedListPager: + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -392,6 +392,12 @@ def aggregated_list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.AggregatedListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response @@ -603,7 +609,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.CommitmentList: + ) -> pagers.ListPager: r"""Retrieves a list of commitments contained within the specified region. @@ -630,9 +636,12 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.CommitmentList: + google.cloud.compute_v1.services.region_commitments.pagers.ListPager: Contains a list of Commitment resources. + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -667,6 +676,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/region_commitments/pagers.py b/google/cloud/compute_v1/services/region_commitments/pagers.py new file mode 100644 index 000000000..54dc0eb1c --- /dev/null +++ b/google/cloud/compute_v1/services/region_commitments/pagers.py @@ -0,0 +1,156 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class AggregatedListPager: + """A pager for iterating through ``aggregated_list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.CommitmentAggregatedList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``AggregatedList`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.CommitmentAggregatedList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.CommitmentAggregatedList], + request: compute.AggregatedListRegionCommitmentsRequest, + response: compute.CommitmentAggregatedList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.AggregatedListRegionCommitmentsRequest): + The initial request object. + response (google.cloud.compute_v1.types.CommitmentAggregatedList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.AggregatedListRegionCommitmentsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.CommitmentAggregatedList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[Tuple[str, compute.CommitmentsScopedList]]: + for page in self.pages: + yield from page.items.items() + + def get(self, key: str) -> Optional[compute.CommitmentsScopedList]: + return self._response.items.get(key) + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.CommitmentList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.CommitmentList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.CommitmentList], + request: compute.ListRegionCommitmentsRequest, + response: compute.CommitmentList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListRegionCommitmentsRequest): + The initial request object. + response (google.cloud.compute_v1.types.CommitmentList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListRegionCommitmentsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.CommitmentList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.Commitment]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/region_commitments/transports/rest.py b/google/cloud/compute_v1/services/region_commitments/transports/rest.py index 83c9b718e..6b2d56cc5 100644 --- a/google/cloud/compute_v1/services/region_commitments/transports/rest.py +++ b/google/cloud/compute_v1/services/region_commitments/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def aggregated_list( self, @@ -122,12 +125,12 @@ def aggregated_list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, - "orderBy": request.order_by, "includeAllScopes": request.include_all_scopes, - "returnPartialSuccess": request.return_partial_success, "maxResults": request.max_results, + "orderBy": request.order_by, + "pageToken": request.page_token, + "returnPartialSuccess": request.return_partial_success, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -140,6 +143,9 @@ def aggregated_list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.CommitmentAggregatedList.from_json(response.content) @@ -198,6 +204,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Commitment.from_json(response.content) @@ -254,7 +263,9 @@ def insert( # Jsonify the request body body = compute.Commitment.to_json( - request.commitment_resource, including_default_value_fields=False + request.commitment_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -277,7 +288,10 @@ def insert( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -316,11 +330,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -333,6 +347,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.CommitmentList.from_json(response.content) diff --git a/google/cloud/compute_v1/services/region_disk_types/client.py b/google/cloud/compute_v1/services/region_disk_types/client.py index 0f5379058..101d8eada 100644 --- a/google/cloud/compute_v1/services/region_disk_types/client.py +++ b/google/cloud/compute_v1/services/region_disk_types/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.region_disk_types import pagers from google.cloud.compute_v1.types import compute from .transports.base import RegionDiskTypesTransport, DEFAULT_CLIENT_INFO @@ -266,21 +267,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -323,7 +320,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -442,7 +439,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.RegionDiskTypeList: + ) -> pagers.ListPager: r"""Retrieves a list of regional disk types available to the specified project. @@ -471,7 +468,10 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.RegionDiskTypeList: + google.cloud.compute_v1.services.region_disk_types.pagers.ListPager: + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -506,6 +506,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/region_disk_types/pagers.py b/google/cloud/compute_v1/services/region_disk_types/pagers.py new file mode 100644 index 000000000..a0ca2f635 --- /dev/null +++ b/google/cloud/compute_v1/services/region_disk_types/pagers.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.RegionDiskTypeList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.RegionDiskTypeList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.RegionDiskTypeList], + request: compute.ListRegionDiskTypesRequest, + response: compute.RegionDiskTypeList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListRegionDiskTypesRequest): + The initial request object. + response (google.cloud.compute_v1.types.RegionDiskTypeList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListRegionDiskTypesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.RegionDiskTypeList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.DiskType]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/region_disk_types/transports/rest.py b/google/cloud/compute_v1/services/region_disk_types/transports/rest.py index 01fe87fa7..77cb03576 100644 --- a/google/cloud/compute_v1/services/region_disk_types/transports/rest.py +++ b/google/cloud/compute_v1/services/region_disk_types/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def get( self, @@ -156,6 +159,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.DiskType.from_json(response.content) @@ -191,11 +197,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -208,6 +214,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.RegionDiskTypeList.from_json(response.content) diff --git a/google/cloud/compute_v1/services/region_disks/client.py b/google/cloud/compute_v1/services/region_disks/client.py index 3fcb0e892..c75fe341e 100644 --- a/google/cloud/compute_v1/services/region_disks/client.py +++ b/google/cloud/compute_v1/services/region_disks/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.region_disks import pagers from google.cloud.compute_v1.types import compute from .transports.base import RegionDisksTransport, DEFAULT_CLIENT_INFO @@ -264,21 +265,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -321,7 +318,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -1033,7 +1030,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.DiskList: + ) -> pagers.ListPager: r"""Retrieves the list of persistent disks contained within the specified region. @@ -1060,8 +1057,12 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.DiskList: + google.cloud.compute_v1.services.region_disks.pagers.ListPager: A list of Disk resources. + Iterating over this object will yield + results and resolve additional pages + automatically. + """ # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have @@ -1095,6 +1096,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/region_disks/pagers.py b/google/cloud/compute_v1/services/region_disks/pagers.py new file mode 100644 index 000000000..e370537fa --- /dev/null +++ b/google/cloud/compute_v1/services/region_disks/pagers.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.DiskList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.DiskList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.DiskList], + request: compute.ListRegionDisksRequest, + response: compute.DiskList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListRegionDisksRequest): + The initial request object. + response (google.cloud.compute_v1.types.DiskList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListRegionDisksRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.DiskList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.Disk]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/region_disks/transports/rest.py b/google/cloud/compute_v1/services/region_disks/transports/rest.py index 0157002f5..693cf83b3 100644 --- a/google/cloud/compute_v1/services/region_disks/transports/rest.py +++ b/google/cloud/compute_v1/services/region_disks/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def add_resource_policies( self, @@ -145,6 +148,7 @@ def add_resource_policies( body = compute.RegionDisksAddResourcePoliciesRequest.to_json( request.region_disks_add_resource_policies_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -170,7 +174,10 @@ def add_resource_policies( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -228,7 +235,9 @@ def create_snapshot( # Jsonify the request body body = compute.Snapshot.to_json( - request.snapshot_resource, including_default_value_fields=False + request.snapshot_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -254,7 +263,10 @@ def create_snapshot( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -335,6 +347,9 @@ def delete( # Send the request response = self._session.delete(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -404,6 +419,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Disk.from_json(response.content) @@ -505,6 +523,9 @@ def get_iam_policy( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Policy.from_json(response.content) @@ -561,7 +582,9 @@ def insert( # Jsonify the request body body = compute.Disk.to_json( - request.disk_resource, including_default_value_fields=False + request.disk_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -585,7 +608,10 @@ def insert( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -622,11 +648,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -639,6 +665,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.DiskList.from_json(response.content) @@ -697,6 +726,7 @@ def remove_resource_policies( body = compute.RegionDisksRemoveResourcePoliciesRequest.to_json( request.region_disks_remove_resource_policies_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -722,7 +752,10 @@ def remove_resource_policies( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -782,6 +815,7 @@ def resize( body = compute.RegionDisksResizeRequest.to_json( request.region_disks_resize_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -807,7 +841,10 @@ def resize( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -889,6 +926,7 @@ def set_iam_policy( body = compute.RegionSetPolicyRequest.to_json( request.region_set_policy_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -912,7 +950,10 @@ def set_iam_policy( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Policy.from_json(response.content) @@ -972,6 +1013,7 @@ def set_labels( body = compute.RegionSetLabelsRequest.to_json( request.region_set_labels_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -997,7 +1039,10 @@ def set_labels( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -1029,6 +1074,7 @@ def test_iam_permissions( body = compute.TestPermissionsRequest.to_json( request.test_permissions_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -1052,7 +1098,10 @@ def test_iam_permissions( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.TestPermissionsResponse.from_json(response.content) diff --git a/google/cloud/compute_v1/services/region_health_check_services/client.py b/google/cloud/compute_v1/services/region_health_check_services/client.py index 9a2d37e74..28a58ad0a 100644 --- a/google/cloud/compute_v1/services/region_health_check_services/client.py +++ b/google/cloud/compute_v1/services/region_health_check_services/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.region_health_check_services import pagers from google.cloud.compute_v1.types import compute from .transports.base import RegionHealthCheckServicesTransport, DEFAULT_CLIENT_INFO @@ -268,21 +269,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -325,7 +322,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -651,7 +648,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.HealthCheckServicesList: + ) -> pagers.ListPager: r"""Lists all the HealthCheckService resources that have been configured for the specified project in the given region. @@ -681,7 +678,10 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.HealthCheckServicesList: + google.cloud.compute_v1.services.region_health_check_services.pagers.ListPager: + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -716,6 +716,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/region_health_check_services/pagers.py b/google/cloud/compute_v1/services/region_health_check_services/pagers.py new file mode 100644 index 000000000..b0fd2a06b --- /dev/null +++ b/google/cloud/compute_v1/services/region_health_check_services/pagers.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.HealthCheckServicesList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.HealthCheckServicesList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.HealthCheckServicesList], + request: compute.ListRegionHealthCheckServicesRequest, + response: compute.HealthCheckServicesList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListRegionHealthCheckServicesRequest): + The initial request object. + response (google.cloud.compute_v1.types.HealthCheckServicesList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListRegionHealthCheckServicesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.HealthCheckServicesList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.HealthCheckService]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/region_health_check_services/transports/rest.py b/google/cloud/compute_v1/services/region_health_check_services/transports/rest.py index 2decab796..28b958da5 100644 --- a/google/cloud/compute_v1/services/region_health_check_services/transports/rest.py +++ b/google/cloud/compute_v1/services/region_health_check_services/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def delete( self, @@ -166,6 +169,9 @@ def delete( # Send the request response = self._session.delete(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -219,6 +225,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.HealthCheckService.from_json(response.content) @@ -275,7 +284,9 @@ def insert( # Jsonify the request body body = compute.HealthCheckService.to_json( - request.health_check_service_resource, including_default_value_fields=False + request.health_check_service_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -298,7 +309,10 @@ def insert( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -335,11 +349,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -352,6 +366,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.HealthCheckServicesList.from_json(response.content) @@ -408,7 +425,9 @@ def patch( # Jsonify the request body body = compute.HealthCheckService.to_json( - request.health_check_service_resource, including_default_value_fields=False + request.health_check_service_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -434,7 +453,10 @@ def patch( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.patch(url, json=body,) + response = self._session.patch(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) diff --git a/google/cloud/compute_v1/services/region_health_checks/client.py b/google/cloud/compute_v1/services/region_health_checks/client.py index 03b5eb66a..d0a49b903 100644 --- a/google/cloud/compute_v1/services/region_health_checks/client.py +++ b/google/cloud/compute_v1/services/region_health_checks/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.region_health_checks import pagers from google.cloud.compute_v1.types import compute from .transports.base import RegionHealthChecksTransport, DEFAULT_CLIENT_INFO @@ -268,21 +269,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -325,7 +322,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -673,7 +670,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.HealthCheckList: + ) -> pagers.ListPager: r"""Retrieves the list of HealthCheck resources available to the specified project. @@ -702,9 +699,12 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.HealthCheckList: + google.cloud.compute_v1.services.region_health_checks.pagers.ListPager: Contains a list of HealthCheck resources. + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -739,6 +739,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/region_health_checks/pagers.py b/google/cloud/compute_v1/services/region_health_checks/pagers.py new file mode 100644 index 000000000..8e65a25b1 --- /dev/null +++ b/google/cloud/compute_v1/services/region_health_checks/pagers.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.HealthCheckList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.HealthCheckList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.HealthCheckList], + request: compute.ListRegionHealthChecksRequest, + response: compute.HealthCheckList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListRegionHealthChecksRequest): + The initial request object. + response (google.cloud.compute_v1.types.HealthCheckList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListRegionHealthChecksRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.HealthCheckList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.HealthCheck]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/region_health_checks/transports/rest.py b/google/cloud/compute_v1/services/region_health_checks/transports/rest.py index 87c4c31c2..05c7c2382 100644 --- a/google/cloud/compute_v1/services/region_health_checks/transports/rest.py +++ b/google/cloud/compute_v1/services/region_health_checks/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def delete( self, @@ -166,6 +169,9 @@ def delete( # Send the request response = self._session.delete(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -243,6 +249,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.HealthCheck.from_json(response.content) @@ -299,7 +308,9 @@ def insert( # Jsonify the request body body = compute.HealthCheck.to_json( - request.health_check_resource, including_default_value_fields=False + request.health_check_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -322,7 +333,10 @@ def insert( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -361,11 +375,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -378,6 +392,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.HealthCheckList.from_json(response.content) @@ -434,7 +451,9 @@ def patch( # Jsonify the request body body = compute.HealthCheck.to_json( - request.health_check_resource, including_default_value_fields=False + request.health_check_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -460,7 +479,10 @@ def patch( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.patch(url, json=body,) + response = self._session.patch(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -518,7 +540,9 @@ def update( # Jsonify the request body body = compute.HealthCheck.to_json( - request.health_check_resource, including_default_value_fields=False + request.health_check_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -544,7 +568,10 @@ def update( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.put(url, json=body,) + response = self._session.put(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) diff --git a/google/cloud/compute_v1/services/region_instance_group_managers/client.py b/google/cloud/compute_v1/services/region_instance_group_managers/client.py index d00ee4e68..6e97236e4 100644 --- a/google/cloud/compute_v1/services/region_instance_group_managers/client.py +++ b/google/cloud/compute_v1/services/region_instance_group_managers/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.region_instance_group_managers import pagers from google.cloud.compute_v1.types import compute from .transports.base import RegionInstanceGroupManagersTransport, DEFAULT_CLIENT_INFO @@ -270,21 +271,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -327,7 +324,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -1376,7 +1373,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.RegionInstanceGroupManagerList: + ) -> pagers.ListPager: r"""Retrieves the list of managed instance groups that are contained within the specified region. @@ -1405,9 +1402,12 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.RegionInstanceGroupManagerList: + google.cloud.compute_v1.services.region_instance_group_managers.pagers.ListPager: Contains a list of managed instance groups. + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -1442,6 +1442,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response @@ -1455,7 +1461,7 @@ def list_errors( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.RegionInstanceGroupManagersListErrorsResponse: + ) -> pagers.ListErrorsPager: r"""Lists all errors thrown by actions on instances for a given regional managed instance group. The filter and orderBy query parameters are not supported. @@ -1494,7 +1500,10 @@ def list_errors( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.RegionInstanceGroupManagersListErrorsResponse: + google.cloud.compute_v1.services.region_instance_group_managers.pagers.ListErrorsPager: + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -1533,6 +1542,12 @@ def list_errors( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListErrorsPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response @@ -1546,7 +1561,7 @@ def list_managed_instances( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.RegionInstanceGroupManagersListInstancesResponse: + ) -> pagers.ListManagedInstancesPager: r"""Lists the instances in the managed instance group and instances that are scheduled to be created. The list includes any current actions that the group has @@ -1585,7 +1600,10 @@ def list_managed_instances( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.RegionInstanceGroupManagersListInstancesResponse: + google.cloud.compute_v1.services.region_instance_group_managers.pagers.ListManagedInstancesPager: + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -1626,6 +1644,12 @@ def list_managed_instances( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListManagedInstancesPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response @@ -1639,7 +1663,7 @@ def list_per_instance_configs( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.RegionInstanceGroupManagersListInstanceConfigsResp: + ) -> pagers.ListPerInstanceConfigsPager: r"""Lists all of the per-instance configs defined for the managed instance group. The orderBy query parameter is not supported. @@ -1676,7 +1700,10 @@ def list_per_instance_configs( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.RegionInstanceGroupManagersListInstanceConfigsResp: + google.cloud.compute_v1.services.region_instance_group_managers.pagers.ListPerInstanceConfigsPager: + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -1719,6 +1746,12 @@ def list_per_instance_configs( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPerInstanceConfigsPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/region_instance_group_managers/pagers.py b/google/cloud/compute_v1/services/region_instance_group_managers/pagers.py new file mode 100644 index 000000000..190a93fd7 --- /dev/null +++ b/google/cloud/compute_v1/services/region_instance_group_managers/pagers.py @@ -0,0 +1,287 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.RegionInstanceGroupManagerList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.RegionInstanceGroupManagerList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.RegionInstanceGroupManagerList], + request: compute.ListRegionInstanceGroupManagersRequest, + response: compute.RegionInstanceGroupManagerList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListRegionInstanceGroupManagersRequest): + The initial request object. + response (google.cloud.compute_v1.types.RegionInstanceGroupManagerList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListRegionInstanceGroupManagersRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.RegionInstanceGroupManagerList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.InstanceGroupManager]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListErrorsPager: + """A pager for iterating through ``list_errors`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.RegionInstanceGroupManagersListErrorsResponse` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ListErrors`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.RegionInstanceGroupManagersListErrorsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.RegionInstanceGroupManagersListErrorsResponse], + request: compute.ListErrorsRegionInstanceGroupManagersRequest, + response: compute.RegionInstanceGroupManagersListErrorsResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListErrorsRegionInstanceGroupManagersRequest): + The initial request object. + response (google.cloud.compute_v1.types.RegionInstanceGroupManagersListErrorsResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListErrorsRegionInstanceGroupManagersRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.RegionInstanceGroupManagersListErrorsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.InstanceManagedByIgmError]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListManagedInstancesPager: + """A pager for iterating through ``list_managed_instances`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.RegionInstanceGroupManagersListInstancesResponse` object, and + provides an ``__iter__`` method to iterate through its + ``managed_instances`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ListManagedInstances`` requests and continue to iterate + through the ``managed_instances`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.RegionInstanceGroupManagersListInstancesResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.RegionInstanceGroupManagersListInstancesResponse], + request: compute.ListManagedInstancesRegionInstanceGroupManagersRequest, + response: compute.RegionInstanceGroupManagersListInstancesResponse, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListManagedInstancesRegionInstanceGroupManagersRequest): + The initial request object. + response (google.cloud.compute_v1.types.RegionInstanceGroupManagersListInstancesResponse): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListManagedInstancesRegionInstanceGroupManagersRequest( + request + ) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages( + self, + ) -> Iterable[compute.RegionInstanceGroupManagersListInstancesResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.ManagedInstance]: + for page in self.pages: + yield from page.managed_instances + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListPerInstanceConfigsPager: + """A pager for iterating through ``list_per_instance_configs`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.RegionInstanceGroupManagersListInstanceConfigsResp` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ListPerInstanceConfigs`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.RegionInstanceGroupManagersListInstanceConfigsResp` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[ + ..., compute.RegionInstanceGroupManagersListInstanceConfigsResp + ], + request: compute.ListPerInstanceConfigsRegionInstanceGroupManagersRequest, + response: compute.RegionInstanceGroupManagersListInstanceConfigsResp, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListPerInstanceConfigsRegionInstanceGroupManagersRequest): + The initial request object. + response (google.cloud.compute_v1.types.RegionInstanceGroupManagersListInstanceConfigsResp): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListPerInstanceConfigsRegionInstanceGroupManagersRequest( + request + ) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages( + self, + ) -> Iterable[compute.RegionInstanceGroupManagersListInstanceConfigsResp]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.PerInstanceConfig]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/region_instance_group_managers/transports/rest.py b/google/cloud/compute_v1/services/region_instance_group_managers/transports/rest.py index 587c128f2..af94f0a48 100644 --- a/google/cloud/compute_v1/services/region_instance_group_managers/transports/rest.py +++ b/google/cloud/compute_v1/services/region_instance_group_managers/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def abandon_instances( self, @@ -145,6 +148,7 @@ def abandon_instances( body = compute.RegionInstanceGroupManagersAbandonInstancesRequest.to_json( request.region_instance_group_managers_abandon_instances_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -170,7 +174,10 @@ def abandon_instances( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -231,6 +238,7 @@ def apply_updates_to_instances( body = compute.RegionInstanceGroupManagersApplyUpdatesRequest.to_json( request.region_instance_group_managers_apply_updates_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -254,7 +262,10 @@ def apply_updates_to_instances( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -314,6 +325,7 @@ def create_instances( body = compute.RegionInstanceGroupManagersCreateInstancesRequest.to_json( request.region_instance_group_managers_create_instances_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -339,7 +351,10 @@ def create_instances( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -420,6 +435,9 @@ def delete( # Send the request response = self._session.delete(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -478,6 +496,7 @@ def delete_instances( body = compute.RegionInstanceGroupManagersDeleteInstancesRequest.to_json( request.region_instance_group_managers_delete_instances_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -503,7 +522,10 @@ def delete_instances( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -564,6 +586,7 @@ def delete_per_instance_configs( body = compute.RegionInstanceGroupManagerDeleteInstanceConfigReq.to_json( request.region_instance_group_manager_delete_instance_config_req_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -587,7 +610,10 @@ def delete_per_instance_configs( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -652,6 +678,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.InstanceGroupManager.from_json(response.content) @@ -710,6 +739,7 @@ def insert( body = compute.InstanceGroupManager.to_json( request.instance_group_manager_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -732,7 +762,10 @@ def insert( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -771,11 +804,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -788,6 +821,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.RegionInstanceGroupManagerList.from_json(response.content) @@ -826,11 +862,11 @@ def list_errors( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -843,6 +879,9 @@ def list_errors( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.RegionInstanceGroupManagersListErrorsResponse.from_json( response.content @@ -883,11 +922,11 @@ def list_managed_instances( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -900,6 +939,9 @@ def list_managed_instances( # Send the request response = self._session.post(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.RegionInstanceGroupManagersListInstancesResponse.from_json( response.content @@ -940,11 +982,11 @@ def list_per_instance_configs( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -957,6 +999,9 @@ def list_per_instance_configs( # Send the request response = self._session.post(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.RegionInstanceGroupManagersListInstanceConfigsResp.from_json( response.content @@ -1017,6 +1062,7 @@ def patch( body = compute.InstanceGroupManager.to_json( request.instance_group_manager_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -1042,7 +1088,10 @@ def patch( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.patch(url, json=body,) + response = self._session.patch(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -1103,6 +1152,7 @@ def patch_per_instance_configs( body = compute.RegionInstanceGroupManagerPatchInstanceConfigReq.to_json( request.region_instance_group_manager_patch_instance_config_req_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -1128,7 +1178,10 @@ def patch_per_instance_configs( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -1188,6 +1241,7 @@ def recreate_instances( body = compute.RegionInstanceGroupManagersRecreateRequest.to_json( request.region_instance_group_managers_recreate_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -1213,7 +1267,10 @@ def recreate_instances( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -1295,6 +1352,9 @@ def resize( # Send the request response = self._session.post(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -1353,6 +1413,7 @@ def set_instance_template( body = compute.RegionInstanceGroupManagersSetTemplateRequest.to_json( request.region_instance_group_managers_set_template_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -1378,7 +1439,10 @@ def set_instance_template( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -1438,6 +1502,7 @@ def set_target_pools( body = compute.RegionInstanceGroupManagersSetTargetPoolsRequest.to_json( request.region_instance_group_managers_set_target_pools_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -1463,7 +1528,10 @@ def set_target_pools( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -1524,6 +1592,7 @@ def update_per_instance_configs( body = compute.RegionInstanceGroupManagerUpdateInstanceConfigReq.to_json( request.region_instance_group_manager_update_instance_config_req_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -1549,7 +1618,10 @@ def update_per_instance_configs( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) diff --git a/google/cloud/compute_v1/services/region_instance_groups/client.py b/google/cloud/compute_v1/services/region_instance_groups/client.py index a1345a600..7e5a877d7 100644 --- a/google/cloud/compute_v1/services/region_instance_groups/client.py +++ b/google/cloud/compute_v1/services/region_instance_groups/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.region_instance_groups import pagers from google.cloud.compute_v1.types import compute from .transports.base import RegionInstanceGroupsTransport, DEFAULT_CLIENT_INFO @@ -268,21 +269,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -325,7 +322,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -446,7 +443,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.RegionInstanceGroupList: + ) -> pagers.ListPager: r"""Retrieves the list of instance group resources contained within the specified region. @@ -475,9 +472,12 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.RegionInstanceGroupList: + google.cloud.compute_v1.services.region_instance_groups.pagers.ListPager: Contains a list of InstanceGroup resources. + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -512,6 +512,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response @@ -526,7 +532,7 @@ def list_instances( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.RegionInstanceGroupsListInstances: + ) -> pagers.ListInstancesPager: r"""Lists the instances in the specified instance group and displays information about the named ports. Depending on the specified options, this method can list @@ -570,7 +576,10 @@ def list_instances( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.RegionInstanceGroupsListInstances: + google.cloud.compute_v1.services.region_instance_groups.pagers.ListInstancesPager: + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -618,6 +627,12 @@ def list_instances( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListInstancesPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/region_instance_groups/pagers.py b/google/cloud/compute_v1/services/region_instance_groups/pagers.py new file mode 100644 index 000000000..99e176e4b --- /dev/null +++ b/google/cloud/compute_v1/services/region_instance_groups/pagers.py @@ -0,0 +1,153 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.RegionInstanceGroupList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.RegionInstanceGroupList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.RegionInstanceGroupList], + request: compute.ListRegionInstanceGroupsRequest, + response: compute.RegionInstanceGroupList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListRegionInstanceGroupsRequest): + The initial request object. + response (google.cloud.compute_v1.types.RegionInstanceGroupList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListRegionInstanceGroupsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.RegionInstanceGroupList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.InstanceGroup]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListInstancesPager: + """A pager for iterating through ``list_instances`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.RegionInstanceGroupsListInstances` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ListInstances`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.RegionInstanceGroupsListInstances` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.RegionInstanceGroupsListInstances], + request: compute.ListInstancesRegionInstanceGroupsRequest, + response: compute.RegionInstanceGroupsListInstances, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListInstancesRegionInstanceGroupsRequest): + The initial request object. + response (google.cloud.compute_v1.types.RegionInstanceGroupsListInstances): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListInstancesRegionInstanceGroupsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.RegionInstanceGroupsListInstances]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.InstanceWithNamedPorts]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/region_instance_groups/transports/rest.py b/google/cloud/compute_v1/services/region_instance_groups/transports/rest.py index 08d2b2314..5a9bbd38a 100644 --- a/google/cloud/compute_v1/services/region_instance_groups/transports/rest.py +++ b/google/cloud/compute_v1/services/region_instance_groups/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def get( self, @@ -156,6 +159,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.InstanceGroup.from_json(response.content) @@ -193,11 +199,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -210,6 +216,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.RegionInstanceGroupList.from_json(response.content) @@ -240,6 +249,7 @@ def list_instances( body = compute.RegionInstanceGroupsListInstancesRequest.to_json( request.region_instance_groups_list_instances_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -254,11 +264,11 @@ def list_instances( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -269,7 +279,10 @@ def list_instances( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.RegionInstanceGroupsListInstances.from_json(response.content) @@ -329,6 +342,7 @@ def set_named_ports( body = compute.RegionInstanceGroupsSetNamedPortsRequest.to_json( request.region_instance_groups_set_named_ports_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -354,7 +368,10 @@ def set_named_ports( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) diff --git a/google/cloud/compute_v1/services/region_network_endpoint_groups/client.py b/google/cloud/compute_v1/services/region_network_endpoint_groups/client.py index e9dd62404..5d97de418 100644 --- a/google/cloud/compute_v1/services/region_network_endpoint_groups/client.py +++ b/google/cloud/compute_v1/services/region_network_endpoint_groups/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.region_network_endpoint_groups import pagers from google.cloud.compute_v1.types import compute from .transports.base import RegionNetworkEndpointGroupsTransport, DEFAULT_CLIENT_INFO @@ -270,21 +271,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -327,7 +324,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -670,7 +667,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.NetworkEndpointGroupList: + ) -> pagers.ListPager: r"""Retrieves the list of regional network endpoint groups available to the specified project in the given region. @@ -701,7 +698,10 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.NetworkEndpointGroupList: + google.cloud.compute_v1.services.region_network_endpoint_groups.pagers.ListPager: + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -736,6 +736,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/region_network_endpoint_groups/pagers.py b/google/cloud/compute_v1/services/region_network_endpoint_groups/pagers.py new file mode 100644 index 000000000..66fb6aa05 --- /dev/null +++ b/google/cloud/compute_v1/services/region_network_endpoint_groups/pagers.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.NetworkEndpointGroupList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.NetworkEndpointGroupList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.NetworkEndpointGroupList], + request: compute.ListRegionNetworkEndpointGroupsRequest, + response: compute.NetworkEndpointGroupList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListRegionNetworkEndpointGroupsRequest): + The initial request object. + response (google.cloud.compute_v1.types.NetworkEndpointGroupList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListRegionNetworkEndpointGroupsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.NetworkEndpointGroupList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.NetworkEndpointGroup]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/region_network_endpoint_groups/transports/rest.py b/google/cloud/compute_v1/services/region_network_endpoint_groups/transports/rest.py index d3ff067af..a395f4264 100644 --- a/google/cloud/compute_v1/services/region_network_endpoint_groups/transports/rest.py +++ b/google/cloud/compute_v1/services/region_network_endpoint_groups/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def delete( self, @@ -166,6 +169,9 @@ def delete( # Send the request response = self._session.delete(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -228,6 +234,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.NetworkEndpointGroup.from_json(response.content) @@ -286,6 +295,7 @@ def insert( body = compute.NetworkEndpointGroup.to_json( request.network_endpoint_group_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -308,7 +318,10 @@ def insert( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -345,11 +358,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -362,6 +375,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.NetworkEndpointGroupList.from_json(response.content) diff --git a/google/cloud/compute_v1/services/region_notification_endpoints/client.py b/google/cloud/compute_v1/services/region_notification_endpoints/client.py index 4d57bd080..45bb0d59a 100644 --- a/google/cloud/compute_v1/services/region_notification_endpoints/client.py +++ b/google/cloud/compute_v1/services/region_notification_endpoints/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.region_notification_endpoints import pagers from google.cloud.compute_v1.types import compute from .transports.base import RegionNotificationEndpointsTransport, DEFAULT_CLIENT_INFO @@ -270,21 +271,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -327,7 +324,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -659,7 +656,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.NotificationEndpointList: + ) -> pagers.ListPager: r"""Lists the NotificationEndpoints for a project in the given region. @@ -688,7 +685,10 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.NotificationEndpointList: + google.cloud.compute_v1.services.region_notification_endpoints.pagers.ListPager: + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -723,6 +723,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/region_notification_endpoints/pagers.py b/google/cloud/compute_v1/services/region_notification_endpoints/pagers.py new file mode 100644 index 000000000..dc3331ed7 --- /dev/null +++ b/google/cloud/compute_v1/services/region_notification_endpoints/pagers.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.NotificationEndpointList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.NotificationEndpointList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.NotificationEndpointList], + request: compute.ListRegionNotificationEndpointsRequest, + response: compute.NotificationEndpointList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListRegionNotificationEndpointsRequest): + The initial request object. + response (google.cloud.compute_v1.types.NotificationEndpointList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListRegionNotificationEndpointsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.NotificationEndpointList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.NotificationEndpoint]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/region_notification_endpoints/transports/rest.py b/google/cloud/compute_v1/services/region_notification_endpoints/transports/rest.py index 73ab1c1c5..b1af05fbe 100644 --- a/google/cloud/compute_v1/services/region_notification_endpoints/transports/rest.py +++ b/google/cloud/compute_v1/services/region_notification_endpoints/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def delete( self, @@ -166,6 +169,9 @@ def delete( # Send the request response = self._session.delete(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -225,6 +231,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.NotificationEndpoint.from_json(response.content) @@ -281,7 +290,9 @@ def insert( # Jsonify the request body body = compute.NotificationEndpoint.to_json( - request.notification_endpoint_resource, including_default_value_fields=False + request.notification_endpoint_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -304,7 +315,10 @@ def insert( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -341,11 +355,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -358,6 +372,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.NotificationEndpointList.from_json(response.content) diff --git a/google/cloud/compute_v1/services/region_operations/client.py b/google/cloud/compute_v1/services/region_operations/client.py index dbbee0053..98fa6d571 100644 --- a/google/cloud/compute_v1/services/region_operations/client.py +++ b/google/cloud/compute_v1/services/region_operations/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.region_operations import pagers from google.cloud.compute_v1.types import compute from .transports.base import RegionOperationsTransport, DEFAULT_CLIENT_INFO @@ -266,21 +267,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -323,7 +320,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -533,7 +530,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.OperationList: + ) -> pagers.ListPager: r"""Retrieves a list of Operation resources contained within the specified region. @@ -560,9 +557,12 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.OperationList: + google.cloud.compute_v1.services.region_operations.pagers.ListPager: Contains a list of Operation resources. + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -597,6 +597,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/region_operations/pagers.py b/google/cloud/compute_v1/services/region_operations/pagers.py new file mode 100644 index 000000000..fc7a72dc0 --- /dev/null +++ b/google/cloud/compute_v1/services/region_operations/pagers.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.OperationList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.OperationList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.OperationList], + request: compute.ListRegionOperationsRequest, + response: compute.OperationList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListRegionOperationsRequest): + The initial request object. + response (google.cloud.compute_v1.types.OperationList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListRegionOperationsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.OperationList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.Operation]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/region_operations/transports/rest.py b/google/cloud/compute_v1/services/region_operations/transports/rest.py index c6a5b91a8..037c2c31f 100644 --- a/google/cloud/compute_v1/services/region_operations/transports/rest.py +++ b/google/cloud/compute_v1/services/region_operations/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def delete( self, @@ -139,6 +142,9 @@ def delete( # Send the request response = self._session.delete(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.DeleteRegionOperationResponse.from_json(response.content) @@ -216,6 +222,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -253,11 +262,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -270,6 +279,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.OperationList.from_json(response.content) @@ -347,6 +359,9 @@ def wait( # Send the request response = self._session.post(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) diff --git a/google/cloud/compute_v1/services/region_ssl_certificates/client.py b/google/cloud/compute_v1/services/region_ssl_certificates/client.py index 1b0804435..3355d6a5a 100644 --- a/google/cloud/compute_v1/services/region_ssl_certificates/client.py +++ b/google/cloud/compute_v1/services/region_ssl_certificates/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.region_ssl_certificates import pagers from google.cloud.compute_v1.types import compute from .transports.base import RegionSslCertificatesTransport, DEFAULT_CLIENT_INFO @@ -268,21 +269,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -325,7 +322,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -672,7 +669,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.SslCertificateList: + ) -> pagers.ListPager: r"""Retrieves the list of SslCertificate resources available to the specified project in the specified region. @@ -702,9 +699,12 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.SslCertificateList: + google.cloud.compute_v1.services.region_ssl_certificates.pagers.ListPager: Contains a list of SslCertificate resources. + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -739,6 +739,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/region_ssl_certificates/pagers.py b/google/cloud/compute_v1/services/region_ssl_certificates/pagers.py new file mode 100644 index 000000000..bbe705086 --- /dev/null +++ b/google/cloud/compute_v1/services/region_ssl_certificates/pagers.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.SslCertificateList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.SslCertificateList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.SslCertificateList], + request: compute.ListRegionSslCertificatesRequest, + response: compute.SslCertificateList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListRegionSslCertificatesRequest): + The initial request object. + response (google.cloud.compute_v1.types.SslCertificateList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListRegionSslCertificatesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.SslCertificateList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.SslCertificate]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/region_ssl_certificates/transports/rest.py b/google/cloud/compute_v1/services/region_ssl_certificates/transports/rest.py index 1c9f3a51e..6878e48c3 100644 --- a/google/cloud/compute_v1/services/region_ssl_certificates/transports/rest.py +++ b/google/cloud/compute_v1/services/region_ssl_certificates/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def delete( self, @@ -166,6 +169,9 @@ def delete( # Send the request response = self._session.delete(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -241,6 +247,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.SslCertificate.from_json(response.content) @@ -297,7 +306,9 @@ def insert( # Jsonify the request body body = compute.SslCertificate.to_json( - request.ssl_certificate_resource, including_default_value_fields=False + request.ssl_certificate_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -320,7 +331,10 @@ def insert( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -359,11 +373,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -376,6 +390,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.SslCertificateList.from_json(response.content) diff --git a/google/cloud/compute_v1/services/region_target_http_proxies/client.py b/google/cloud/compute_v1/services/region_target_http_proxies/client.py index 13021aaee..fcd3f671c 100644 --- a/google/cloud/compute_v1/services/region_target_http_proxies/client.py +++ b/google/cloud/compute_v1/services/region_target_http_proxies/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.region_target_http_proxies import pagers from google.cloud.compute_v1.types import compute from .transports.base import RegionTargetHttpProxiesTransport, DEFAULT_CLIENT_INFO @@ -268,21 +269,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -325,7 +322,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -669,7 +666,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.TargetHttpProxyList: + ) -> pagers.ListPager: r"""Retrieves the list of TargetHttpProxy resources available to the specified project in the specified region. @@ -699,8 +696,12 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.TargetHttpProxyList: + google.cloud.compute_v1.services.region_target_http_proxies.pagers.ListPager: A list of TargetHttpProxy resources. + Iterating over this object will yield + results and resolve additional pages + automatically. + """ # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have @@ -734,6 +735,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/region_target_http_proxies/pagers.py b/google/cloud/compute_v1/services/region_target_http_proxies/pagers.py new file mode 100644 index 000000000..25a545b2c --- /dev/null +++ b/google/cloud/compute_v1/services/region_target_http_proxies/pagers.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.TargetHttpProxyList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.TargetHttpProxyList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.TargetHttpProxyList], + request: compute.ListRegionTargetHttpProxiesRequest, + response: compute.TargetHttpProxyList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListRegionTargetHttpProxiesRequest): + The initial request object. + response (google.cloud.compute_v1.types.TargetHttpProxyList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListRegionTargetHttpProxiesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.TargetHttpProxyList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.TargetHttpProxy]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/region_target_http_proxies/transports/rest.py b/google/cloud/compute_v1/services/region_target_http_proxies/transports/rest.py index 9390fab53..c7d1662dd 100644 --- a/google/cloud/compute_v1/services/region_target_http_proxies/transports/rest.py +++ b/google/cloud/compute_v1/services/region_target_http_proxies/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def delete( self, @@ -166,6 +169,9 @@ def delete( # Send the request response = self._session.delete(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -238,6 +244,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.TargetHttpProxy.from_json(response.content) @@ -294,7 +303,9 @@ def insert( # Jsonify the request body body = compute.TargetHttpProxy.to_json( - request.target_http_proxy_resource, including_default_value_fields=False + request.target_http_proxy_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -317,7 +328,10 @@ def insert( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -354,11 +368,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -371,6 +385,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.TargetHttpProxyList.from_json(response.content) @@ -427,7 +444,9 @@ def set_url_map( # Jsonify the request body body = compute.UrlMapReference.to_json( - request.url_map_reference_resource, including_default_value_fields=False + request.url_map_reference_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -453,7 +472,10 @@ def set_url_map( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) diff --git a/google/cloud/compute_v1/services/region_target_https_proxies/client.py b/google/cloud/compute_v1/services/region_target_https_proxies/client.py index b7c368780..71935d61f 100644 --- a/google/cloud/compute_v1/services/region_target_https_proxies/client.py +++ b/google/cloud/compute_v1/services/region_target_https_proxies/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.region_target_https_proxies import pagers from google.cloud.compute_v1.types import compute from .transports.base import RegionTargetHttpsProxiesTransport, DEFAULT_CLIENT_INFO @@ -268,21 +269,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -325,7 +322,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -669,7 +666,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.TargetHttpsProxyList: + ) -> pagers.ListPager: r"""Retrieves the list of TargetHttpsProxy resources available to the specified project in the specified region. @@ -699,9 +696,12 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.TargetHttpsProxyList: + google.cloud.compute_v1.services.region_target_https_proxies.pagers.ListPager: Contains a list of TargetHttpsProxy resources. + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -736,6 +736,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/region_target_https_proxies/pagers.py b/google/cloud/compute_v1/services/region_target_https_proxies/pagers.py new file mode 100644 index 000000000..2e38cef7d --- /dev/null +++ b/google/cloud/compute_v1/services/region_target_https_proxies/pagers.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.TargetHttpsProxyList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.TargetHttpsProxyList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.TargetHttpsProxyList], + request: compute.ListRegionTargetHttpsProxiesRequest, + response: compute.TargetHttpsProxyList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListRegionTargetHttpsProxiesRequest): + The initial request object. + response (google.cloud.compute_v1.types.TargetHttpsProxyList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListRegionTargetHttpsProxiesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.TargetHttpsProxyList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.TargetHttpsProxy]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/region_target_https_proxies/transports/rest.py b/google/cloud/compute_v1/services/region_target_https_proxies/transports/rest.py index c4182ce2f..861a72150 100644 --- a/google/cloud/compute_v1/services/region_target_https_proxies/transports/rest.py +++ b/google/cloud/compute_v1/services/region_target_https_proxies/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def delete( self, @@ -166,6 +169,9 @@ def delete( # Send the request response = self._session.delete(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -237,6 +243,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.TargetHttpsProxy.from_json(response.content) @@ -293,7 +302,9 @@ def insert( # Jsonify the request body body = compute.TargetHttpsProxy.to_json( - request.target_https_proxy_resource, including_default_value_fields=False + request.target_https_proxy_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -316,7 +327,10 @@ def insert( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -355,11 +369,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -372,6 +386,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.TargetHttpsProxyList.from_json(response.content) @@ -430,6 +447,7 @@ def set_ssl_certificates( body = compute.RegionTargetHttpsProxiesSetSslCertificatesRequest.to_json( request.region_target_https_proxies_set_ssl_certificates_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -455,7 +473,10 @@ def set_ssl_certificates( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -513,7 +534,9 @@ def set_url_map( # Jsonify the request body body = compute.UrlMapReference.to_json( - request.url_map_reference_resource, including_default_value_fields=False + request.url_map_reference_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -539,7 +562,10 @@ def set_url_map( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) diff --git a/google/cloud/compute_v1/services/region_url_maps/client.py b/google/cloud/compute_v1/services/region_url_maps/client.py index 417882da1..45b175910 100644 --- a/google/cloud/compute_v1/services/region_url_maps/client.py +++ b/google/cloud/compute_v1/services/region_url_maps/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.region_url_maps import pagers from google.cloud.compute_v1.types import compute from .transports.base import RegionUrlMapsTransport, DEFAULT_CLIENT_INFO @@ -264,21 +265,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -321,7 +318,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -673,7 +670,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.UrlMapList: + ) -> pagers.ListPager: r"""Retrieves the list of UrlMap resources available to the specified project in the specified region. @@ -702,8 +699,12 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.UrlMapList: + google.cloud.compute_v1.services.region_url_maps.pagers.ListPager: Contains a list of UrlMap resources. + Iterating over this object will yield + results and resolve additional pages + automatically. + """ # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have @@ -737,6 +738,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/region_url_maps/pagers.py b/google/cloud/compute_v1/services/region_url_maps/pagers.py new file mode 100644 index 000000000..132a85b4a --- /dev/null +++ b/google/cloud/compute_v1/services/region_url_maps/pagers.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.UrlMapList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.UrlMapList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.UrlMapList], + request: compute.ListRegionUrlMapsRequest, + response: compute.UrlMapList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListRegionUrlMapsRequest): + The initial request object. + response (google.cloud.compute_v1.types.UrlMapList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListRegionUrlMapsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.UrlMapList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.UrlMap]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/region_url_maps/transports/rest.py b/google/cloud/compute_v1/services/region_url_maps/transports/rest.py index 48a91299a..a01cc8d55 100644 --- a/google/cloud/compute_v1/services/region_url_maps/transports/rest.py +++ b/google/cloud/compute_v1/services/region_url_maps/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def delete( self, @@ -166,6 +169,9 @@ def delete( # Send the request response = self._session.delete(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -247,6 +253,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.UrlMap.from_json(response.content) @@ -303,7 +312,9 @@ def insert( # Jsonify the request body body = compute.UrlMap.to_json( - request.url_map_resource, including_default_value_fields=False + request.url_map_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -326,7 +337,10 @@ def insert( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -363,11 +377,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -380,6 +394,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.UrlMapList.from_json(response.content) @@ -436,7 +453,9 @@ def patch( # Jsonify the request body body = compute.UrlMap.to_json( - request.url_map_resource, including_default_value_fields=False + request.url_map_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -462,7 +481,10 @@ def patch( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.patch(url, json=body,) + response = self._session.patch(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -520,7 +542,9 @@ def update( # Jsonify the request body body = compute.UrlMap.to_json( - request.url_map_resource, including_default_value_fields=False + request.url_map_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -546,7 +570,10 @@ def update( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.put(url, json=body,) + response = self._session.put(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -578,6 +605,7 @@ def validate( body = compute.RegionUrlMapsValidateRequest.to_json( request.region_url_maps_validate_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -601,7 +629,10 @@ def validate( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.UrlMapsValidateResponse.from_json(response.content) diff --git a/google/cloud/compute_v1/services/regions/client.py b/google/cloud/compute_v1/services/regions/client.py index e5d83a9a1..9213cbb0b 100644 --- a/google/cloud/compute_v1/services/regions/client.py +++ b/google/cloud/compute_v1/services/regions/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.regions import pagers from google.cloud.compute_v1.types import compute from .transports.base import RegionsTransport, DEFAULT_CLIENT_INFO @@ -264,21 +265,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -321,7 +318,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -414,7 +411,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.RegionList: + ) -> pagers.ListPager: r"""Retrieves the list of region resources available to the specified project. @@ -435,8 +432,12 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.RegionList: + google.cloud.compute_v1.services.regions.pagers.ListPager: Contains a list of region resources. + Iterating over this object will yield + results and resolve additional pages + automatically. + """ # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have @@ -468,6 +469,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/regions/pagers.py b/google/cloud/compute_v1/services/regions/pagers.py new file mode 100644 index 000000000..6e0a595aa --- /dev/null +++ b/google/cloud/compute_v1/services/regions/pagers.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.RegionList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.RegionList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.RegionList], + request: compute.ListRegionsRequest, + response: compute.RegionList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListRegionsRequest): + The initial request object. + response (google.cloud.compute_v1.types.RegionList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListRegionsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.RegionList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.Region]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/regions/transports/rest.py b/google/cloud/compute_v1/services/regions/transports/rest.py index ec4e99454..5902801b3 100644 --- a/google/cloud/compute_v1/services/regions/transports/rest.py +++ b/google/cloud/compute_v1/services/regions/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def get( self, @@ -137,6 +140,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Region.from_json(response.content) @@ -171,11 +177,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -188,6 +194,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.RegionList.from_json(response.content) diff --git a/google/cloud/compute_v1/services/reservations/client.py b/google/cloud/compute_v1/services/reservations/client.py index f770e9a2f..0db7fb897 100644 --- a/google/cloud/compute_v1/services/reservations/client.py +++ b/google/cloud/compute_v1/services/reservations/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.reservations import pagers from google.cloud.compute_v1.types import compute from .transports.base import ReservationsTransport, DEFAULT_CLIENT_INFO @@ -264,21 +265,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -321,7 +318,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -334,7 +331,7 @@ def aggregated_list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.ReservationAggregatedList: + ) -> pagers.AggregatedListPager: r"""Retrieves an aggregated list of reservations. Args: @@ -355,8 +352,12 @@ def aggregated_list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.ReservationAggregatedList: + google.cloud.compute_v1.services.reservations.pagers.AggregatedListPager: Contains a list of reservations. + Iterating over this object will yield + results and resolve additional pages + automatically. + """ # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have @@ -388,6 +389,12 @@ def aggregated_list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.AggregatedListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response @@ -837,7 +844,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.ReservationList: + ) -> pagers.ListPager: r"""A list of all the reservations that have been configured for the specified project in specified zone. @@ -864,7 +871,10 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.ReservationList: + google.cloud.compute_v1.services.reservations.pagers.ListPager: + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -899,6 +909,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/reservations/pagers.py b/google/cloud/compute_v1/services/reservations/pagers.py new file mode 100644 index 000000000..bf1275a2a --- /dev/null +++ b/google/cloud/compute_v1/services/reservations/pagers.py @@ -0,0 +1,156 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class AggregatedListPager: + """A pager for iterating through ``aggregated_list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.ReservationAggregatedList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``AggregatedList`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.ReservationAggregatedList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.ReservationAggregatedList], + request: compute.AggregatedListReservationsRequest, + response: compute.ReservationAggregatedList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.AggregatedListReservationsRequest): + The initial request object. + response (google.cloud.compute_v1.types.ReservationAggregatedList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.AggregatedListReservationsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.ReservationAggregatedList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[Tuple[str, compute.ReservationsScopedList]]: + for page in self.pages: + yield from page.items.items() + + def get(self, key: str) -> Optional[compute.ReservationsScopedList]: + return self._response.items.get(key) + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.ReservationList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.ReservationList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.ReservationList], + request: compute.ListReservationsRequest, + response: compute.ReservationList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListReservationsRequest): + The initial request object. + response (google.cloud.compute_v1.types.ReservationList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListReservationsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.ReservationList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.Reservation]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/reservations/transports/rest.py b/google/cloud/compute_v1/services/reservations/transports/rest.py index 6fa471e50..96a8aa664 100644 --- a/google/cloud/compute_v1/services/reservations/transports/rest.py +++ b/google/cloud/compute_v1/services/reservations/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def aggregated_list( self, @@ -122,12 +125,12 @@ def aggregated_list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, - "orderBy": request.order_by, "includeAllScopes": request.include_all_scopes, - "returnPartialSuccess": request.return_partial_success, "maxResults": request.max_results, + "orderBy": request.order_by, + "pageToken": request.page_token, + "returnPartialSuccess": request.return_partial_success, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -140,6 +143,9 @@ def aggregated_list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.ReservationAggregatedList.from_json(response.content) @@ -219,6 +225,9 @@ def delete( # Send the request response = self._session.delete(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -273,6 +282,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Reservation.from_json(response.content) @@ -374,6 +386,9 @@ def get_iam_policy( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Policy.from_json(response.content) @@ -430,7 +445,9 @@ def insert( # Jsonify the request body body = compute.Reservation.to_json( - request.reservation_resource, including_default_value_fields=False + request.reservation_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -453,7 +470,10 @@ def insert( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -490,11 +510,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -507,6 +527,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.ReservationList.from_json(response.content) @@ -565,6 +588,7 @@ def resize( body = compute.ReservationsResizeRequest.to_json( request.reservations_resize_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -590,7 +614,10 @@ def resize( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -672,6 +699,7 @@ def set_iam_policy( body = compute.ZoneSetPolicyRequest.to_json( request.zone_set_policy_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -695,7 +723,10 @@ def set_iam_policy( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Policy.from_json(response.content) @@ -727,6 +758,7 @@ def test_iam_permissions( body = compute.TestPermissionsRequest.to_json( request.test_permissions_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -750,7 +782,10 @@ def test_iam_permissions( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.TestPermissionsResponse.from_json(response.content) diff --git a/google/cloud/compute_v1/services/resource_policies/client.py b/google/cloud/compute_v1/services/resource_policies/client.py index 7b4e1e466..e9e02fcce 100644 --- a/google/cloud/compute_v1/services/resource_policies/client.py +++ b/google/cloud/compute_v1/services/resource_policies/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.resource_policies import pagers from google.cloud.compute_v1.types import compute from .transports.base import ResourcePoliciesTransport, DEFAULT_CLIENT_INFO @@ -266,21 +267,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -323,7 +320,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -336,7 +333,7 @@ def aggregated_list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.ResourcePolicyAggregatedList: + ) -> pagers.AggregatedListPager: r"""Retrieves an aggregated list of resource policies. Args: @@ -357,8 +354,12 @@ def aggregated_list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.ResourcePolicyAggregatedList: + google.cloud.compute_v1.services.resource_policies.pagers.AggregatedListPager: Contains a list of resourcePolicies. + Iterating over this object will yield + results and resolve additional pages + automatically. + """ # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have @@ -390,6 +391,12 @@ def aggregated_list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.AggregatedListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response @@ -843,7 +850,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.ResourcePolicyList: + ) -> pagers.ListPager: r"""A list all the resource policies that have been configured for the specified project in specified region. @@ -871,7 +878,10 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.ResourcePolicyList: + google.cloud.compute_v1.services.resource_policies.pagers.ListPager: + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -906,6 +916,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/resource_policies/pagers.py b/google/cloud/compute_v1/services/resource_policies/pagers.py new file mode 100644 index 000000000..79376d997 --- /dev/null +++ b/google/cloud/compute_v1/services/resource_policies/pagers.py @@ -0,0 +1,156 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class AggregatedListPager: + """A pager for iterating through ``aggregated_list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.ResourcePolicyAggregatedList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``AggregatedList`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.ResourcePolicyAggregatedList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.ResourcePolicyAggregatedList], + request: compute.AggregatedListResourcePoliciesRequest, + response: compute.ResourcePolicyAggregatedList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.AggregatedListResourcePoliciesRequest): + The initial request object. + response (google.cloud.compute_v1.types.ResourcePolicyAggregatedList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.AggregatedListResourcePoliciesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.ResourcePolicyAggregatedList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[Tuple[str, compute.ResourcePoliciesScopedList]]: + for page in self.pages: + yield from page.items.items() + + def get(self, key: str) -> Optional[compute.ResourcePoliciesScopedList]: + return self._response.items.get(key) + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.ResourcePolicyList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.ResourcePolicyList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.ResourcePolicyList], + request: compute.ListResourcePoliciesRequest, + response: compute.ResourcePolicyList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListResourcePoliciesRequest): + The initial request object. + response (google.cloud.compute_v1.types.ResourcePolicyList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListResourcePoliciesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.ResourcePolicyList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.ResourcePolicy]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/resource_policies/transports/rest.py b/google/cloud/compute_v1/services/resource_policies/transports/rest.py index 7392146ad..3336db8de 100644 --- a/google/cloud/compute_v1/services/resource_policies/transports/rest.py +++ b/google/cloud/compute_v1/services/resource_policies/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def aggregated_list( self, @@ -122,12 +125,12 @@ def aggregated_list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, - "orderBy": request.order_by, "includeAllScopes": request.include_all_scopes, - "returnPartialSuccess": request.return_partial_success, "maxResults": request.max_results, + "orderBy": request.order_by, + "pageToken": request.page_token, + "returnPartialSuccess": request.return_partial_success, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -140,6 +143,9 @@ def aggregated_list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.ResourcePolicyAggregatedList.from_json(response.content) @@ -219,6 +225,9 @@ def delete( # Send the request response = self._session.delete(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -274,6 +283,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.ResourcePolicy.from_json(response.content) @@ -375,6 +387,9 @@ def get_iam_policy( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Policy.from_json(response.content) @@ -431,7 +446,9 @@ def insert( # Jsonify the request body body = compute.ResourcePolicy.to_json( - request.resource_policy_resource, including_default_value_fields=False + request.resource_policy_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -454,7 +471,10 @@ def insert( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -491,11 +511,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -508,6 +528,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.ResourcePolicyList.from_json(response.content) @@ -588,6 +611,7 @@ def set_iam_policy( body = compute.RegionSetPolicyRequest.to_json( request.region_set_policy_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -611,7 +635,10 @@ def set_iam_policy( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Policy.from_json(response.content) @@ -643,6 +670,7 @@ def test_iam_permissions( body = compute.TestPermissionsRequest.to_json( request.test_permissions_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -666,7 +694,10 @@ def test_iam_permissions( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.TestPermissionsResponse.from_json(response.content) diff --git a/google/cloud/compute_v1/services/routers/client.py b/google/cloud/compute_v1/services/routers/client.py index ce0c0632e..e7f60dc33 100644 --- a/google/cloud/compute_v1/services/routers/client.py +++ b/google/cloud/compute_v1/services/routers/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.routers import pagers from google.cloud.compute_v1.types import compute from .transports.base import RoutersTransport, DEFAULT_CLIENT_INFO @@ -264,21 +265,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -321,7 +318,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -334,7 +331,7 @@ def aggregated_list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.RouterAggregatedList: + ) -> pagers.AggregatedListPager: r"""Retrieves an aggregated list of routers. Args: @@ -355,8 +352,12 @@ def aggregated_list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.RouterAggregatedList: + google.cloud.compute_v1.services.routers.pagers.AggregatedListPager: Contains a list of routers. + Iterating over this object will yield + results and resolve additional pages + automatically. + """ # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have @@ -388,6 +389,12 @@ def aggregated_list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.AggregatedListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response @@ -594,7 +601,7 @@ def get_nat_mapping_info( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.VmEndpointNatMappingsList: + ) -> pagers.GetNatMappingInfoPager: r"""Retrieves runtime Nat mapping information of VM endpoints. @@ -629,9 +636,12 @@ def get_nat_mapping_info( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.VmEndpointNatMappingsList: + google.cloud.compute_v1.services.routers.pagers.GetNatMappingInfoPager: Contains a list of VmEndpointNatMappings. + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -668,6 +678,12 @@ def get_nat_mapping_info( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.GetNatMappingInfoPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response @@ -868,7 +884,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.RouterList: + ) -> pagers.ListPager: r"""Retrieves a list of Router resources available to the specified project. @@ -894,8 +910,12 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.RouterList: + google.cloud.compute_v1.services.routers.pagers.ListPager: Contains a list of Router resources. + Iterating over this object will yield + results and resolve additional pages + automatically. + """ # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have @@ -929,6 +949,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/routers/pagers.py b/google/cloud/compute_v1/services/routers/pagers.py new file mode 100644 index 000000000..d342a7bc2 --- /dev/null +++ b/google/cloud/compute_v1/services/routers/pagers.py @@ -0,0 +1,218 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class AggregatedListPager: + """A pager for iterating through ``aggregated_list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.RouterAggregatedList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``AggregatedList`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.RouterAggregatedList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.RouterAggregatedList], + request: compute.AggregatedListRoutersRequest, + response: compute.RouterAggregatedList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.AggregatedListRoutersRequest): + The initial request object. + response (google.cloud.compute_v1.types.RouterAggregatedList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.AggregatedListRoutersRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.RouterAggregatedList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[Tuple[str, compute.RoutersScopedList]]: + for page in self.pages: + yield from page.items.items() + + def get(self, key: str) -> Optional[compute.RoutersScopedList]: + return self._response.items.get(key) + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class GetNatMappingInfoPager: + """A pager for iterating through ``get_nat_mapping_info`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.VmEndpointNatMappingsList` object, and + provides an ``__iter__`` method to iterate through its + ``result`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``GetNatMappingInfo`` requests and continue to iterate + through the ``result`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.VmEndpointNatMappingsList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.VmEndpointNatMappingsList], + request: compute.GetNatMappingInfoRoutersRequest, + response: compute.VmEndpointNatMappingsList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.GetNatMappingInfoRoutersRequest): + The initial request object. + response (google.cloud.compute_v1.types.VmEndpointNatMappingsList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.GetNatMappingInfoRoutersRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.VmEndpointNatMappingsList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.VmEndpointNatMappings]: + for page in self.pages: + yield from page.result + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.RouterList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.RouterList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.RouterList], + request: compute.ListRoutersRequest, + response: compute.RouterList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListRoutersRequest): + The initial request object. + response (google.cloud.compute_v1.types.RouterList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListRoutersRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.RouterList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.Router]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/routers/transports/rest.py b/google/cloud/compute_v1/services/routers/transports/rest.py index 6317f21a7..16e5806e4 100644 --- a/google/cloud/compute_v1/services/routers/transports/rest.py +++ b/google/cloud/compute_v1/services/routers/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def aggregated_list( self, @@ -122,12 +125,12 @@ def aggregated_list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, - "orderBy": request.order_by, "includeAllScopes": request.include_all_scopes, - "returnPartialSuccess": request.return_partial_success, "maxResults": request.max_results, + "orderBy": request.order_by, + "pageToken": request.page_token, + "returnPartialSuccess": request.return_partial_success, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -140,6 +143,9 @@ def aggregated_list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.RouterAggregatedList.from_json(response.content) @@ -218,6 +224,9 @@ def delete( # Send the request response = self._session.delete(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -269,6 +278,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Router.from_json(response.content) @@ -309,11 +321,11 @@ def get_nat_mapping_info( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -326,6 +338,9 @@ def get_nat_mapping_info( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.VmEndpointNatMappingsList.from_json(response.content) @@ -375,6 +390,9 @@ def get_router_status( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.RouterStatusResponse.from_json(response.content) @@ -430,7 +448,9 @@ def insert( # Jsonify the request body body = compute.Router.to_json( - request.router_resource, including_default_value_fields=False + request.router_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -453,7 +473,10 @@ def insert( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -489,11 +512,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -506,6 +529,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.RouterList.from_json(response.content) @@ -561,7 +587,9 @@ def patch( # Jsonify the request body body = compute.Router.to_json( - request.router_resource, including_default_value_fields=False + request.router_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -587,7 +615,10 @@ def patch( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.patch(url, json=body,) + response = self._session.patch(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -617,7 +648,9 @@ def preview( # Jsonify the request body body = compute.Router.to_json( - request.router_resource, including_default_value_fields=False + request.router_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -641,7 +674,10 @@ def preview( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.RoutersPreviewResponse.from_json(response.content) @@ -698,7 +734,9 @@ def update( # Jsonify the request body body = compute.Router.to_json( - request.router_resource, including_default_value_fields=False + request.router_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -724,7 +762,10 @@ def update( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.put(url, json=body,) + response = self._session.put(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) diff --git a/google/cloud/compute_v1/services/routes/client.py b/google/cloud/compute_v1/services/routes/client.py index 796bc65f7..fa662b7a5 100644 --- a/google/cloud/compute_v1/services/routes/client.py +++ b/google/cloud/compute_v1/services/routes/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.routes import pagers from google.cloud.compute_v1.types import compute from .transports.base import RoutesTransport, DEFAULT_CLIENT_INFO @@ -264,21 +265,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -321,7 +318,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -609,7 +606,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.RouteList: + ) -> pagers.ListPager: r"""Retrieves the list of Route resources available to the specified project. @@ -630,8 +627,12 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.RouteList: + google.cloud.compute_v1.services.routes.pagers.ListPager: Contains a list of Route resources. + Iterating over this object will yield + results and resolve additional pages + automatically. + """ # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have @@ -663,6 +664,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/routes/pagers.py b/google/cloud/compute_v1/services/routes/pagers.py new file mode 100644 index 000000000..4975a5332 --- /dev/null +++ b/google/cloud/compute_v1/services/routes/pagers.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.RouteList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.RouteList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.RouteList], + request: compute.ListRoutesRequest, + response: compute.RouteList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListRoutesRequest): + The initial request object. + response (google.cloud.compute_v1.types.RouteList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListRoutesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.RouteList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.Route]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/routes/transports/rest.py b/google/cloud/compute_v1/services/routes/transports/rest.py index 01e628718..e2d2e6332 100644 --- a/google/cloud/compute_v1/services/routes/transports/rest.py +++ b/google/cloud/compute_v1/services/routes/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def delete( self, @@ -162,6 +165,9 @@ def delete( # Send the request response = self._session.delete(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -214,6 +220,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Route.from_json(response.content) @@ -269,7 +278,9 @@ def insert( # Jsonify the request body body = compute.Route.to_json( - request.route_resource, including_default_value_fields=False + request.route_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -292,7 +303,10 @@ def insert( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -328,11 +342,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -345,6 +359,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.RouteList.from_json(response.content) diff --git a/google/cloud/compute_v1/services/security_policies/client.py b/google/cloud/compute_v1/services/security_policies/client.py index e52499e78..17356c0e9 100644 --- a/google/cloud/compute_v1/services/security_policies/client.py +++ b/google/cloud/compute_v1/services/security_policies/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.security_policies import pagers from google.cloud.compute_v1.types import compute from .transports.base import SecurityPoliciesTransport, DEFAULT_CLIENT_INFO @@ -266,21 +267,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -323,7 +320,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -806,7 +803,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.SecurityPolicyList: + ) -> pagers.ListPager: r"""List all the policies that have been configured for the specified project. @@ -828,7 +825,10 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.SecurityPolicyList: + google.cloud.compute_v1.services.security_policies.pagers.ListPager: + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -861,6 +861,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/security_policies/pagers.py b/google/cloud/compute_v1/services/security_policies/pagers.py new file mode 100644 index 000000000..2d5191d8f --- /dev/null +++ b/google/cloud/compute_v1/services/security_policies/pagers.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.SecurityPolicyList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.SecurityPolicyList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.SecurityPolicyList], + request: compute.ListSecurityPoliciesRequest, + response: compute.SecurityPolicyList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListSecurityPoliciesRequest): + The initial request object. + response (google.cloud.compute_v1.types.SecurityPolicyList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListSecurityPoliciesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.SecurityPolicyList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.SecurityPolicy]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/security_policies/transports/rest.py b/google/cloud/compute_v1/services/security_policies/transports/rest.py index 6b3600251..fc361f3aa 100644 --- a/google/cloud/compute_v1/services/security_policies/transports/rest.py +++ b/google/cloud/compute_v1/services/security_policies/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def add_rule( self, @@ -143,7 +146,9 @@ def add_rule( # Jsonify the request body body = compute.SecurityPolicyRule.to_json( - request.security_policy_rule_resource, including_default_value_fields=False + request.security_policy_rule_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -166,7 +171,10 @@ def add_rule( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -246,6 +254,9 @@ def delete( # Send the request response = self._session.delete(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -301,6 +312,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.SecurityPolicy.from_json(response.content) @@ -355,6 +369,9 @@ def get_rule( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.SecurityPolicyRule.from_json(response.content) @@ -411,7 +428,9 @@ def insert( # Jsonify the request body body = compute.SecurityPolicy.to_json( - request.security_policy_resource, including_default_value_fields=False + request.security_policy_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -434,7 +453,10 @@ def insert( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -471,11 +493,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -488,6 +510,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.SecurityPolicyList.from_json(response.content) @@ -524,11 +549,11 @@ def list_preconfigured_expression_sets( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -541,6 +566,9 @@ def list_preconfigured_expression_sets( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.SecurityPoliciesListPreconfiguredExpressionSetsResponse.from_json( response.content @@ -599,7 +627,9 @@ def patch( # Jsonify the request body body = compute.SecurityPolicy.to_json( - request.security_policy_resource, including_default_value_fields=False + request.security_policy_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -624,7 +654,10 @@ def patch( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.patch(url, json=body,) + response = self._session.patch(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -682,7 +715,9 @@ def patch_rule( # Jsonify the request body body = compute.SecurityPolicyRule.to_json( - request.security_policy_rule_resource, including_default_value_fields=False + request.security_policy_rule_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -707,7 +742,10 @@ def patch_rule( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -787,6 +825,9 @@ def remove_rule( # Send the request response = self._session.post(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) diff --git a/google/cloud/compute_v1/services/snapshots/client.py b/google/cloud/compute_v1/services/snapshots/client.py index 98bfd8c97..215b9002a 100644 --- a/google/cloud/compute_v1/services/snapshots/client.py +++ b/google/cloud/compute_v1/services/snapshots/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.snapshots import pagers from google.cloud.compute_v1.types import compute from .transports.base import SnapshotsTransport, DEFAULT_CLIENT_INFO @@ -264,21 +265,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -321,7 +318,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -648,7 +645,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.SnapshotList: + ) -> pagers.ListPager: r"""Retrieves the list of Snapshot resources contained within the specified project. @@ -669,9 +666,12 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.SnapshotList: + google.cloud.compute_v1.services.snapshots.pagers.ListPager: Contains a list of Snapshot resources. + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -704,6 +704,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/snapshots/pagers.py b/google/cloud/compute_v1/services/snapshots/pagers.py new file mode 100644 index 000000000..6a26b39f9 --- /dev/null +++ b/google/cloud/compute_v1/services/snapshots/pagers.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.SnapshotList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.SnapshotList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.SnapshotList], + request: compute.ListSnapshotsRequest, + response: compute.SnapshotList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListSnapshotsRequest): + The initial request object. + response (google.cloud.compute_v1.types.SnapshotList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListSnapshotsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.SnapshotList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.Snapshot]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/snapshots/transports/rest.py b/google/cloud/compute_v1/services/snapshots/transports/rest.py index 2c3f5bf34..e87786135 100644 --- a/google/cloud/compute_v1/services/snapshots/transports/rest.py +++ b/google/cloud/compute_v1/services/snapshots/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def delete( self, @@ -163,6 +166,9 @@ def delete( # Send the request response = self._session.delete(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -214,6 +220,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Snapshot.from_json(response.content) @@ -312,6 +321,9 @@ def get_iam_policy( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Policy.from_json(response.content) @@ -348,11 +360,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -365,6 +377,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.SnapshotList.from_json(response.content) @@ -445,6 +460,7 @@ def set_iam_policy( body = compute.GlobalSetPolicyRequest.to_json( request.global_set_policy_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -465,7 +481,10 @@ def set_iam_policy( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Policy.from_json(response.content) @@ -525,6 +544,7 @@ def set_labels( body = compute.GlobalSetLabelsRequest.to_json( request.global_set_labels_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -545,7 +565,10 @@ def set_labels( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -577,6 +600,7 @@ def test_iam_permissions( body = compute.TestPermissionsRequest.to_json( request.test_permissions_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -597,7 +621,10 @@ def test_iam_permissions( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.TestPermissionsResponse.from_json(response.content) diff --git a/google/cloud/compute_v1/services/ssl_certificates/client.py b/google/cloud/compute_v1/services/ssl_certificates/client.py index 6f1240fda..b65c8d47e 100644 --- a/google/cloud/compute_v1/services/ssl_certificates/client.py +++ b/google/cloud/compute_v1/services/ssl_certificates/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.ssl_certificates import pagers from google.cloud.compute_v1.types import compute from .transports.base import SslCertificatesTransport, DEFAULT_CLIENT_INFO @@ -266,21 +267,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -323,7 +320,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -336,7 +333,7 @@ def aggregated_list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.SslCertificateAggregatedList: + ) -> pagers.AggregatedListPager: r"""Retrieves the list of all SslCertificate resources, regional and global, available to the specified project. @@ -360,7 +357,10 @@ def aggregated_list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.SslCertificateAggregatedList: + google.cloud.compute_v1.services.ssl_certificates.pagers.AggregatedListPager: + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -393,6 +393,12 @@ def aggregated_list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.AggregatedListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response @@ -705,7 +711,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.SslCertificateList: + ) -> pagers.ListPager: r"""Retrieves the list of SslCertificate resources available to the specified project. @@ -727,9 +733,12 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.SslCertificateList: + google.cloud.compute_v1.services.ssl_certificates.pagers.ListPager: Contains a list of SslCertificate resources. + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -762,6 +771,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/ssl_certificates/pagers.py b/google/cloud/compute_v1/services/ssl_certificates/pagers.py new file mode 100644 index 000000000..7aab7e108 --- /dev/null +++ b/google/cloud/compute_v1/services/ssl_certificates/pagers.py @@ -0,0 +1,156 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class AggregatedListPager: + """A pager for iterating through ``aggregated_list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.SslCertificateAggregatedList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``AggregatedList`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.SslCertificateAggregatedList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.SslCertificateAggregatedList], + request: compute.AggregatedListSslCertificatesRequest, + response: compute.SslCertificateAggregatedList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.AggregatedListSslCertificatesRequest): + The initial request object. + response (google.cloud.compute_v1.types.SslCertificateAggregatedList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.AggregatedListSslCertificatesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.SslCertificateAggregatedList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[Tuple[str, compute.SslCertificatesScopedList]]: + for page in self.pages: + yield from page.items.items() + + def get(self, key: str) -> Optional[compute.SslCertificatesScopedList]: + return self._response.items.get(key) + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.SslCertificateList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.SslCertificateList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.SslCertificateList], + request: compute.ListSslCertificatesRequest, + response: compute.SslCertificateList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListSslCertificatesRequest): + The initial request object. + response (google.cloud.compute_v1.types.SslCertificateList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListSslCertificatesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.SslCertificateList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.SslCertificate]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/ssl_certificates/transports/rest.py b/google/cloud/compute_v1/services/ssl_certificates/transports/rest.py index a1032200c..c5278af9c 100644 --- a/google/cloud/compute_v1/services/ssl_certificates/transports/rest.py +++ b/google/cloud/compute_v1/services/ssl_certificates/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def aggregated_list( self, @@ -122,12 +125,12 @@ def aggregated_list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, - "orderBy": request.order_by, "includeAllScopes": request.include_all_scopes, - "returnPartialSuccess": request.return_partial_success, "maxResults": request.max_results, + "orderBy": request.order_by, + "pageToken": request.page_token, + "returnPartialSuccess": request.return_partial_success, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -140,6 +143,9 @@ def aggregated_list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.SslCertificateAggregatedList.from_json(response.content) @@ -218,6 +224,9 @@ def delete( # Send the request response = self._session.delete(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -292,6 +301,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.SslCertificate.from_json(response.content) @@ -348,7 +360,9 @@ def insert( # Jsonify the request body body = compute.SslCertificate.to_json( - request.ssl_certificate_resource, including_default_value_fields=False + request.ssl_certificate_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -371,7 +385,10 @@ def insert( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -410,11 +427,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -427,6 +444,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.SslCertificateList.from_json(response.content) diff --git a/google/cloud/compute_v1/services/ssl_policies/client.py b/google/cloud/compute_v1/services/ssl_policies/client.py index c8f4d847a..8baa448c4 100644 --- a/google/cloud/compute_v1/services/ssl_policies/client.py +++ b/google/cloud/compute_v1/services/ssl_policies/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.ssl_policies import pagers from google.cloud.compute_v1.types import compute from .transports.base import SslPoliciesTransport, DEFAULT_CLIENT_INFO @@ -264,21 +265,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -321,7 +318,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -620,7 +617,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.SslPoliciesList: + ) -> pagers.ListPager: r"""Lists all the SSL policies that have been configured for the specified project. @@ -642,7 +639,10 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.SslPoliciesList: + google.cloud.compute_v1.services.ssl_policies.pagers.ListPager: + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -675,6 +675,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/ssl_policies/pagers.py b/google/cloud/compute_v1/services/ssl_policies/pagers.py new file mode 100644 index 000000000..ac8def99e --- /dev/null +++ b/google/cloud/compute_v1/services/ssl_policies/pagers.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.SslPoliciesList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.SslPoliciesList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.SslPoliciesList], + request: compute.ListSslPoliciesRequest, + response: compute.SslPoliciesList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListSslPoliciesRequest): + The initial request object. + response (google.cloud.compute_v1.types.SslPoliciesList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListSslPoliciesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.SslPoliciesList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.SslPolicy]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/ssl_policies/transports/rest.py b/google/cloud/compute_v1/services/ssl_policies/transports/rest.py index d3fcaacf7..8a86a8290 100644 --- a/google/cloud/compute_v1/services/ssl_policies/transports/rest.py +++ b/google/cloud/compute_v1/services/ssl_policies/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def delete( self, @@ -163,6 +166,9 @@ def delete( # Send the request response = self._session.delete(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -216,6 +222,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.SslPolicy.from_json(response.content) @@ -272,7 +281,9 @@ def insert( # Jsonify the request body body = compute.SslPolicy.to_json( - request.ssl_policy_resource, including_default_value_fields=False + request.ssl_policy_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -295,7 +306,10 @@ def insert( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -332,11 +346,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -349,6 +363,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.SslPoliciesList.from_json(response.content) @@ -384,11 +401,11 @@ def list_available_features( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -401,6 +418,9 @@ def list_available_features( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.SslPoliciesListAvailableFeaturesResponse.from_json( response.content @@ -459,7 +479,9 @@ def patch( # Jsonify the request body body = compute.SslPolicy.to_json( - request.ssl_policy_resource, including_default_value_fields=False + request.ssl_policy_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -482,7 +504,10 @@ def patch( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.patch(url, json=body,) + response = self._session.patch(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) diff --git a/google/cloud/compute_v1/services/subnetworks/client.py b/google/cloud/compute_v1/services/subnetworks/client.py index e445ec7be..de8b447e7 100644 --- a/google/cloud/compute_v1/services/subnetworks/client.py +++ b/google/cloud/compute_v1/services/subnetworks/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.subnetworks import pagers from google.cloud.compute_v1.types import compute from .transports.base import SubnetworksTransport, DEFAULT_CLIENT_INFO @@ -264,21 +265,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -321,7 +318,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -334,7 +331,7 @@ def aggregated_list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.SubnetworkAggregatedList: + ) -> pagers.AggregatedListPager: r"""Retrieves an aggregated list of subnetworks. Args: @@ -355,7 +352,10 @@ def aggregated_list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.SubnetworkAggregatedList: + google.cloud.compute_v1.services.subnetworks.pagers.AggregatedListPager: + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -388,6 +388,12 @@ def aggregated_list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.AggregatedListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response @@ -977,7 +983,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.SubnetworkList: + ) -> pagers.ListPager: r"""Retrieves a list of subnetworks available to the specified project. @@ -1006,9 +1012,12 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.SubnetworkList: + google.cloud.compute_v1.services.subnetworks.pagers.ListPager: Contains a list of Subnetwork resources. + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -1043,6 +1052,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response @@ -1054,7 +1069,7 @@ def list_usable( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.UsableSubnetworksAggregatedList: + ) -> pagers.ListUsablePager: r"""Retrieves an aggregated list of all usable subnetworks in the project. @@ -1076,7 +1091,10 @@ def list_usable( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.UsableSubnetworksAggregatedList: + google.cloud.compute_v1.services.subnetworks.pagers.ListUsablePager: + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -1109,6 +1127,12 @@ def list_usable( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListUsablePager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/subnetworks/pagers.py b/google/cloud/compute_v1/services/subnetworks/pagers.py new file mode 100644 index 000000000..294a77a9e --- /dev/null +++ b/google/cloud/compute_v1/services/subnetworks/pagers.py @@ -0,0 +1,218 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class AggregatedListPager: + """A pager for iterating through ``aggregated_list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.SubnetworkAggregatedList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``AggregatedList`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.SubnetworkAggregatedList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.SubnetworkAggregatedList], + request: compute.AggregatedListSubnetworksRequest, + response: compute.SubnetworkAggregatedList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.AggregatedListSubnetworksRequest): + The initial request object. + response (google.cloud.compute_v1.types.SubnetworkAggregatedList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.AggregatedListSubnetworksRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.SubnetworkAggregatedList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[Tuple[str, compute.SubnetworksScopedList]]: + for page in self.pages: + yield from page.items.items() + + def get(self, key: str) -> Optional[compute.SubnetworksScopedList]: + return self._response.items.get(key) + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.SubnetworkList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.SubnetworkList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.SubnetworkList], + request: compute.ListSubnetworksRequest, + response: compute.SubnetworkList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListSubnetworksRequest): + The initial request object. + response (google.cloud.compute_v1.types.SubnetworkList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListSubnetworksRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.SubnetworkList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.Subnetwork]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListUsablePager: + """A pager for iterating through ``list_usable`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.UsableSubnetworksAggregatedList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ListUsable`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.UsableSubnetworksAggregatedList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.UsableSubnetworksAggregatedList], + request: compute.ListUsableSubnetworksRequest, + response: compute.UsableSubnetworksAggregatedList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListUsableSubnetworksRequest): + The initial request object. + response (google.cloud.compute_v1.types.UsableSubnetworksAggregatedList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListUsableSubnetworksRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.UsableSubnetworksAggregatedList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.UsableSubnetwork]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/subnetworks/transports/rest.py b/google/cloud/compute_v1/services/subnetworks/transports/rest.py index ccb2b4b2b..4310a2c62 100644 --- a/google/cloud/compute_v1/services/subnetworks/transports/rest.py +++ b/google/cloud/compute_v1/services/subnetworks/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def aggregated_list( self, @@ -122,12 +125,12 @@ def aggregated_list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, - "orderBy": request.order_by, "includeAllScopes": request.include_all_scopes, - "returnPartialSuccess": request.return_partial_success, "maxResults": request.max_results, + "orderBy": request.order_by, + "pageToken": request.page_token, + "returnPartialSuccess": request.return_partial_success, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -140,6 +143,9 @@ def aggregated_list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.SubnetworkAggregatedList.from_json(response.content) @@ -219,6 +225,9 @@ def delete( # Send the request response = self._session.delete(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -277,6 +286,7 @@ def expand_ip_cidr_range( body = compute.SubnetworksExpandIpCidrRangeRequest.to_json( request.subnetworks_expand_ip_cidr_range_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -302,7 +312,10 @@ def expand_ip_cidr_range( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -360,6 +373,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Subnetwork.from_json(response.content) @@ -461,6 +477,9 @@ def get_iam_policy( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Policy.from_json(response.content) @@ -517,7 +536,9 @@ def insert( # Jsonify the request body body = compute.Subnetwork.to_json( - request.subnetwork_resource, including_default_value_fields=False + request.subnetwork_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -540,7 +561,10 @@ def insert( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -579,11 +603,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -596,6 +620,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.SubnetworkList.from_json(response.content) @@ -631,11 +658,11 @@ def list_usable( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -648,6 +675,9 @@ def list_usable( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.UsableSubnetworksAggregatedList.from_json(response.content) @@ -704,7 +734,9 @@ def patch( # Jsonify the request body body = compute.Subnetwork.to_json( - request.subnetwork_resource, including_default_value_fields=False + request.subnetwork_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -719,8 +751,8 @@ def patch( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "requestId": request.request_id, "drainTimeoutSeconds": request.drain_timeout_seconds, + "requestId": request.request_id, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -731,7 +763,10 @@ def patch( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.patch(url, json=body,) + response = self._session.patch(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -813,6 +848,7 @@ def set_iam_policy( body = compute.RegionSetPolicyRequest.to_json( request.region_set_policy_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -836,7 +872,10 @@ def set_iam_policy( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Policy.from_json(response.content) @@ -897,6 +936,7 @@ def set_private_ip_google_access( body = compute.SubnetworksSetPrivateIpGoogleAccessRequest.to_json( request.subnetworks_set_private_ip_google_access_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -922,7 +962,10 @@ def set_private_ip_google_access( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -954,6 +997,7 @@ def test_iam_permissions( body = compute.TestPermissionsRequest.to_json( request.test_permissions_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -977,7 +1021,10 @@ def test_iam_permissions( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.TestPermissionsResponse.from_json(response.content) diff --git a/google/cloud/compute_v1/services/target_grpc_proxies/client.py b/google/cloud/compute_v1/services/target_grpc_proxies/client.py index 48f5c8561..c13568178 100644 --- a/google/cloud/compute_v1/services/target_grpc_proxies/client.py +++ b/google/cloud/compute_v1/services/target_grpc_proxies/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.target_grpc_proxies import pagers from google.cloud.compute_v1.types import compute from .transports.base import TargetGrpcProxiesTransport, DEFAULT_CLIENT_INFO @@ -268,21 +269,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -325,7 +322,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -624,7 +621,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.TargetGrpcProxyList: + ) -> pagers.ListPager: r"""Lists the TargetGrpcProxies for a project in the given scope. @@ -646,7 +643,10 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.TargetGrpcProxyList: + google.cloud.compute_v1.services.target_grpc_proxies.pagers.ListPager: + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -679,6 +679,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/target_grpc_proxies/pagers.py b/google/cloud/compute_v1/services/target_grpc_proxies/pagers.py new file mode 100644 index 000000000..4602e3f62 --- /dev/null +++ b/google/cloud/compute_v1/services/target_grpc_proxies/pagers.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.TargetGrpcProxyList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.TargetGrpcProxyList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.TargetGrpcProxyList], + request: compute.ListTargetGrpcProxiesRequest, + response: compute.TargetGrpcProxyList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListTargetGrpcProxiesRequest): + The initial request object. + response (google.cloud.compute_v1.types.TargetGrpcProxyList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListTargetGrpcProxiesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.TargetGrpcProxyList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.TargetGrpcProxy]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/target_grpc_proxies/transports/rest.py b/google/cloud/compute_v1/services/target_grpc_proxies/transports/rest.py index d8080e10a..29f236e8d 100644 --- a/google/cloud/compute_v1/services/target_grpc_proxies/transports/rest.py +++ b/google/cloud/compute_v1/services/target_grpc_proxies/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def delete( self, @@ -165,6 +168,9 @@ def delete( # Send the request response = self._session.delete(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -221,6 +227,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.TargetGrpcProxy.from_json(response.content) @@ -277,7 +286,9 @@ def insert( # Jsonify the request body body = compute.TargetGrpcProxy.to_json( - request.target_grpc_proxy_resource, including_default_value_fields=False + request.target_grpc_proxy_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -300,7 +311,10 @@ def insert( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -337,11 +351,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -354,6 +368,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.TargetGrpcProxyList.from_json(response.content) @@ -410,7 +427,9 @@ def patch( # Jsonify the request body body = compute.TargetGrpcProxy.to_json( - request.target_grpc_proxy_resource, including_default_value_fields=False + request.target_grpc_proxy_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -435,7 +454,10 @@ def patch( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.patch(url, json=body,) + response = self._session.patch(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) diff --git a/google/cloud/compute_v1/services/target_http_proxies/client.py b/google/cloud/compute_v1/services/target_http_proxies/client.py index 084737566..5065177ec 100644 --- a/google/cloud/compute_v1/services/target_http_proxies/client.py +++ b/google/cloud/compute_v1/services/target_http_proxies/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.target_http_proxies import pagers from google.cloud.compute_v1.types import compute from .transports.base import TargetHttpProxiesTransport, DEFAULT_CLIENT_INFO @@ -268,21 +269,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -325,7 +322,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -338,7 +335,7 @@ def aggregated_list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.TargetHttpProxyAggregatedList: + ) -> pagers.AggregatedListPager: r"""Retrieves the list of all TargetHttpProxy resources, regional and global, available to the specified project. @@ -362,7 +359,10 @@ def aggregated_list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.TargetHttpProxyAggregatedList: + google.cloud.compute_v1.services.target_http_proxies.pagers.AggregatedListPager: + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -395,6 +395,12 @@ def aggregated_list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.AggregatedListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response @@ -705,7 +711,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.TargetHttpProxyList: + ) -> pagers.ListPager: r"""Retrieves the list of TargetHttpProxy resources available to the specified project. @@ -727,8 +733,12 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.TargetHttpProxyList: + google.cloud.compute_v1.services.target_http_proxies.pagers.ListPager: A list of TargetHttpProxy resources. + Iterating over this object will yield + results and resolve additional pages + automatically. + """ # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have @@ -760,6 +770,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/target_http_proxies/pagers.py b/google/cloud/compute_v1/services/target_http_proxies/pagers.py new file mode 100644 index 000000000..8d5848858 --- /dev/null +++ b/google/cloud/compute_v1/services/target_http_proxies/pagers.py @@ -0,0 +1,156 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class AggregatedListPager: + """A pager for iterating through ``aggregated_list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.TargetHttpProxyAggregatedList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``AggregatedList`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.TargetHttpProxyAggregatedList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.TargetHttpProxyAggregatedList], + request: compute.AggregatedListTargetHttpProxiesRequest, + response: compute.TargetHttpProxyAggregatedList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.AggregatedListTargetHttpProxiesRequest): + The initial request object. + response (google.cloud.compute_v1.types.TargetHttpProxyAggregatedList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.AggregatedListTargetHttpProxiesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.TargetHttpProxyAggregatedList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[Tuple[str, compute.TargetHttpProxiesScopedList]]: + for page in self.pages: + yield from page.items.items() + + def get(self, key: str) -> Optional[compute.TargetHttpProxiesScopedList]: + return self._response.items.get(key) + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.TargetHttpProxyList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.TargetHttpProxyList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.TargetHttpProxyList], + request: compute.ListTargetHttpProxiesRequest, + response: compute.TargetHttpProxyList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListTargetHttpProxiesRequest): + The initial request object. + response (google.cloud.compute_v1.types.TargetHttpProxyList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListTargetHttpProxiesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.TargetHttpProxyList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.TargetHttpProxy]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/target_http_proxies/transports/rest.py b/google/cloud/compute_v1/services/target_http_proxies/transports/rest.py index 6ccef2b6b..5b9cd80cc 100644 --- a/google/cloud/compute_v1/services/target_http_proxies/transports/rest.py +++ b/google/cloud/compute_v1/services/target_http_proxies/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def aggregated_list( self, @@ -122,12 +125,12 @@ def aggregated_list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, - "orderBy": request.order_by, "includeAllScopes": request.include_all_scopes, - "returnPartialSuccess": request.return_partial_success, "maxResults": request.max_results, + "orderBy": request.order_by, + "pageToken": request.page_token, + "returnPartialSuccess": request.return_partial_success, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -140,6 +143,9 @@ def aggregated_list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.TargetHttpProxyAggregatedList.from_json(response.content) @@ -218,6 +224,9 @@ def delete( # Send the request response = self._session.delete(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -289,6 +298,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.TargetHttpProxy.from_json(response.content) @@ -345,7 +357,9 @@ def insert( # Jsonify the request body body = compute.TargetHttpProxy.to_json( - request.target_http_proxy_resource, including_default_value_fields=False + request.target_http_proxy_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -368,7 +382,10 @@ def insert( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -405,11 +422,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -422,6 +439,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.TargetHttpProxyList.from_json(response.content) @@ -478,7 +498,9 @@ def patch( # Jsonify the request body body = compute.TargetHttpProxy.to_json( - request.target_http_proxy_resource, including_default_value_fields=False + request.target_http_proxy_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -503,7 +525,10 @@ def patch( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.patch(url, json=body,) + response = self._session.patch(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -561,7 +586,9 @@ def set_url_map( # Jsonify the request body body = compute.UrlMapReference.to_json( - request.url_map_reference_resource, including_default_value_fields=False + request.url_map_reference_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -586,7 +613,10 @@ def set_url_map( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) diff --git a/google/cloud/compute_v1/services/target_https_proxies/client.py b/google/cloud/compute_v1/services/target_https_proxies/client.py index c5ea7622e..254494a86 100644 --- a/google/cloud/compute_v1/services/target_https_proxies/client.py +++ b/google/cloud/compute_v1/services/target_https_proxies/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.target_https_proxies import pagers from google.cloud.compute_v1.types import compute from .transports.base import TargetHttpsProxiesTransport, DEFAULT_CLIENT_INFO @@ -268,21 +269,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -325,7 +322,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -338,7 +335,7 @@ def aggregated_list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.TargetHttpsProxyAggregatedList: + ) -> pagers.AggregatedListPager: r"""Retrieves the list of all TargetHttpsProxy resources, regional and global, available to the specified project. @@ -362,7 +359,10 @@ def aggregated_list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.TargetHttpsProxyAggregatedList: + google.cloud.compute_v1.services.target_https_proxies.pagers.AggregatedListPager: + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -395,6 +395,12 @@ def aggregated_list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.AggregatedListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response @@ -705,7 +711,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.TargetHttpsProxyList: + ) -> pagers.ListPager: r"""Retrieves the list of TargetHttpsProxy resources available to the specified project. @@ -727,9 +733,12 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.TargetHttpsProxyList: + google.cloud.compute_v1.services.target_https_proxies.pagers.ListPager: Contains a list of TargetHttpsProxy resources. + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -762,6 +771,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/target_https_proxies/pagers.py b/google/cloud/compute_v1/services/target_https_proxies/pagers.py new file mode 100644 index 000000000..d6817b0e5 --- /dev/null +++ b/google/cloud/compute_v1/services/target_https_proxies/pagers.py @@ -0,0 +1,156 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class AggregatedListPager: + """A pager for iterating through ``aggregated_list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.TargetHttpsProxyAggregatedList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``AggregatedList`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.TargetHttpsProxyAggregatedList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.TargetHttpsProxyAggregatedList], + request: compute.AggregatedListTargetHttpsProxiesRequest, + response: compute.TargetHttpsProxyAggregatedList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.AggregatedListTargetHttpsProxiesRequest): + The initial request object. + response (google.cloud.compute_v1.types.TargetHttpsProxyAggregatedList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.AggregatedListTargetHttpsProxiesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.TargetHttpsProxyAggregatedList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[Tuple[str, compute.TargetHttpsProxiesScopedList]]: + for page in self.pages: + yield from page.items.items() + + def get(self, key: str) -> Optional[compute.TargetHttpsProxiesScopedList]: + return self._response.items.get(key) + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.TargetHttpsProxyList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.TargetHttpsProxyList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.TargetHttpsProxyList], + request: compute.ListTargetHttpsProxiesRequest, + response: compute.TargetHttpsProxyList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListTargetHttpsProxiesRequest): + The initial request object. + response (google.cloud.compute_v1.types.TargetHttpsProxyList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListTargetHttpsProxiesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.TargetHttpsProxyList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.TargetHttpsProxy]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/target_https_proxies/transports/rest.py b/google/cloud/compute_v1/services/target_https_proxies/transports/rest.py index e07fae2c7..13e0bedf3 100644 --- a/google/cloud/compute_v1/services/target_https_proxies/transports/rest.py +++ b/google/cloud/compute_v1/services/target_https_proxies/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def aggregated_list( self, @@ -122,12 +125,12 @@ def aggregated_list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, - "orderBy": request.order_by, "includeAllScopes": request.include_all_scopes, - "returnPartialSuccess": request.return_partial_success, "maxResults": request.max_results, + "orderBy": request.order_by, + "pageToken": request.page_token, + "returnPartialSuccess": request.return_partial_success, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -140,6 +143,9 @@ def aggregated_list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.TargetHttpsProxyAggregatedList.from_json(response.content) @@ -218,6 +224,9 @@ def delete( # Send the request response = self._session.delete(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -288,6 +297,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.TargetHttpsProxy.from_json(response.content) @@ -344,7 +356,9 @@ def insert( # Jsonify the request body body = compute.TargetHttpsProxy.to_json( - request.target_https_proxy_resource, including_default_value_fields=False + request.target_https_proxy_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -367,7 +381,10 @@ def insert( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -406,11 +423,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -423,6 +440,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.TargetHttpsProxyList.from_json(response.content) @@ -481,6 +501,7 @@ def set_quic_override( body = compute.TargetHttpsProxiesSetQuicOverrideRequest.to_json( request.target_https_proxies_set_quic_override_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -505,7 +526,10 @@ def set_quic_override( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -565,6 +589,7 @@ def set_ssl_certificates( body = compute.TargetHttpsProxiesSetSslCertificatesRequest.to_json( request.target_https_proxies_set_ssl_certificates_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -589,7 +614,10 @@ def set_ssl_certificates( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -647,7 +675,9 @@ def set_ssl_policy( # Jsonify the request body body = compute.SslPolicyReference.to_json( - request.ssl_policy_reference_resource, including_default_value_fields=False + request.ssl_policy_reference_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -672,7 +702,10 @@ def set_ssl_policy( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -730,7 +763,9 @@ def set_url_map( # Jsonify the request body body = compute.UrlMapReference.to_json( - request.url_map_reference_resource, including_default_value_fields=False + request.url_map_reference_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -755,7 +790,10 @@ def set_url_map( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) diff --git a/google/cloud/compute_v1/services/target_instances/client.py b/google/cloud/compute_v1/services/target_instances/client.py index 7d7d021e4..65f231da2 100644 --- a/google/cloud/compute_v1/services/target_instances/client.py +++ b/google/cloud/compute_v1/services/target_instances/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.target_instances import pagers from google.cloud.compute_v1.types import compute from .transports.base import TargetInstancesTransport, DEFAULT_CLIENT_INFO @@ -266,21 +267,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -323,7 +320,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -336,7 +333,7 @@ def aggregated_list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.TargetInstanceAggregatedList: + ) -> pagers.AggregatedListPager: r"""Retrieves an aggregated list of target instances. Args: @@ -357,7 +354,10 @@ def aggregated_list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.TargetInstanceAggregatedList: + google.cloud.compute_v1.services.target_instances.pagers.AggregatedListPager: + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -390,6 +390,12 @@ def aggregated_list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.AggregatedListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response @@ -716,7 +722,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.TargetInstanceList: + ) -> pagers.ListPager: r"""Retrieves a list of TargetInstance resources available to the specified project and zone. @@ -745,9 +751,12 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.TargetInstanceList: + google.cloud.compute_v1.services.target_instances.pagers.ListPager: Contains a list of TargetInstance resources. + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -782,6 +791,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/target_instances/pagers.py b/google/cloud/compute_v1/services/target_instances/pagers.py new file mode 100644 index 000000000..f31990bbd --- /dev/null +++ b/google/cloud/compute_v1/services/target_instances/pagers.py @@ -0,0 +1,156 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class AggregatedListPager: + """A pager for iterating through ``aggregated_list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.TargetInstanceAggregatedList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``AggregatedList`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.TargetInstanceAggregatedList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.TargetInstanceAggregatedList], + request: compute.AggregatedListTargetInstancesRequest, + response: compute.TargetInstanceAggregatedList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.AggregatedListTargetInstancesRequest): + The initial request object. + response (google.cloud.compute_v1.types.TargetInstanceAggregatedList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.AggregatedListTargetInstancesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.TargetInstanceAggregatedList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[Tuple[str, compute.TargetInstancesScopedList]]: + for page in self.pages: + yield from page.items.items() + + def get(self, key: str) -> Optional[compute.TargetInstancesScopedList]: + return self._response.items.get(key) + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.TargetInstanceList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.TargetInstanceList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.TargetInstanceList], + request: compute.ListTargetInstancesRequest, + response: compute.TargetInstanceList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListTargetInstancesRequest): + The initial request object. + response (google.cloud.compute_v1.types.TargetInstanceList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListTargetInstancesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.TargetInstanceList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.TargetInstance]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/target_instances/transports/rest.py b/google/cloud/compute_v1/services/target_instances/transports/rest.py index b612eab80..965535ef9 100644 --- a/google/cloud/compute_v1/services/target_instances/transports/rest.py +++ b/google/cloud/compute_v1/services/target_instances/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def aggregated_list( self, @@ -122,12 +125,12 @@ def aggregated_list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, - "orderBy": request.order_by, "includeAllScopes": request.include_all_scopes, - "returnPartialSuccess": request.return_partial_success, "maxResults": request.max_results, + "orderBy": request.order_by, + "pageToken": request.page_token, + "returnPartialSuccess": request.return_partial_success, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -140,6 +143,9 @@ def aggregated_list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.TargetInstanceAggregatedList.from_json(response.content) @@ -219,6 +225,9 @@ def delete( # Send the request response = self._session.delete(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -276,6 +285,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.TargetInstance.from_json(response.content) @@ -332,7 +344,9 @@ def insert( # Jsonify the request body body = compute.TargetInstance.to_json( - request.target_instance_resource, including_default_value_fields=False + request.target_instance_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -355,7 +369,10 @@ def insert( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -394,11 +411,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -411,6 +428,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.TargetInstanceList.from_json(response.content) diff --git a/google/cloud/compute_v1/services/target_pools/client.py b/google/cloud/compute_v1/services/target_pools/client.py index fff1c6759..a98fe47d6 100644 --- a/google/cloud/compute_v1/services/target_pools/client.py +++ b/google/cloud/compute_v1/services/target_pools/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.target_pools import pagers from google.cloud.compute_v1.types import compute from .transports.base import TargetPoolsTransport, DEFAULT_CLIENT_INFO @@ -264,21 +265,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -321,7 +318,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -583,7 +580,7 @@ def aggregated_list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.TargetPoolAggregatedList: + ) -> pagers.AggregatedListPager: r"""Retrieves an aggregated list of target pools. Args: @@ -604,7 +601,10 @@ def aggregated_list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.TargetPoolAggregatedList: + google.cloud.compute_v1.services.target_pools.pagers.AggregatedListPager: + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -637,6 +637,12 @@ def aggregated_list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.AggregatedListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response @@ -1058,7 +1064,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.TargetPoolList: + ) -> pagers.ListPager: r"""Retrieves a list of target pools available to the specified project and region. @@ -1087,9 +1093,12 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.TargetPoolList: + google.cloud.compute_v1.services.target_pools.pagers.ListPager: Contains a list of TargetPool resources. + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -1124,6 +1133,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/target_pools/pagers.py b/google/cloud/compute_v1/services/target_pools/pagers.py new file mode 100644 index 000000000..30ade2320 --- /dev/null +++ b/google/cloud/compute_v1/services/target_pools/pagers.py @@ -0,0 +1,156 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class AggregatedListPager: + """A pager for iterating through ``aggregated_list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.TargetPoolAggregatedList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``AggregatedList`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.TargetPoolAggregatedList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.TargetPoolAggregatedList], + request: compute.AggregatedListTargetPoolsRequest, + response: compute.TargetPoolAggregatedList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.AggregatedListTargetPoolsRequest): + The initial request object. + response (google.cloud.compute_v1.types.TargetPoolAggregatedList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.AggregatedListTargetPoolsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.TargetPoolAggregatedList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[Tuple[str, compute.TargetPoolsScopedList]]: + for page in self.pages: + yield from page.items.items() + + def get(self, key: str) -> Optional[compute.TargetPoolsScopedList]: + return self._response.items.get(key) + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.TargetPoolList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.TargetPoolList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.TargetPoolList], + request: compute.ListTargetPoolsRequest, + response: compute.TargetPoolList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListTargetPoolsRequest): + The initial request object. + response (google.cloud.compute_v1.types.TargetPoolList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListTargetPoolsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.TargetPoolList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.TargetPool]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/target_pools/transports/rest.py b/google/cloud/compute_v1/services/target_pools/transports/rest.py index 2fe20daea..557e16548 100644 --- a/google/cloud/compute_v1/services/target_pools/transports/rest.py +++ b/google/cloud/compute_v1/services/target_pools/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def add_health_check( self, @@ -145,6 +148,7 @@ def add_health_check( body = compute.TargetPoolsAddHealthCheckRequest.to_json( request.target_pools_add_health_check_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -170,7 +174,10 @@ def add_health_check( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -230,6 +237,7 @@ def add_instance( body = compute.TargetPoolsAddInstanceRequest.to_json( request.target_pools_add_instance_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -255,7 +263,10 @@ def add_instance( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -292,12 +303,12 @@ def aggregated_list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, - "orderBy": request.order_by, "includeAllScopes": request.include_all_scopes, - "returnPartialSuccess": request.return_partial_success, "maxResults": request.max_results, + "orderBy": request.order_by, + "pageToken": request.page_token, + "returnPartialSuccess": request.return_partial_success, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -310,6 +321,9 @@ def aggregated_list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.TargetPoolAggregatedList.from_json(response.content) @@ -389,6 +403,9 @@ def delete( # Send the request response = self._session.delete(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -446,6 +463,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.TargetPool.from_json(response.content) @@ -474,7 +494,9 @@ def get_health( # Jsonify the request body body = compute.InstanceReference.to_json( - request.instance_reference_resource, including_default_value_fields=False + request.instance_reference_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -498,7 +520,10 @@ def get_health( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.TargetPoolInstanceHealth.from_json(response.content) @@ -556,7 +581,9 @@ def insert( # Jsonify the request body body = compute.TargetPool.to_json( - request.target_pool_resource, including_default_value_fields=False + request.target_pool_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -579,7 +606,10 @@ def insert( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -618,11 +648,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -635,6 +665,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.TargetPoolList.from_json(response.content) @@ -693,6 +726,7 @@ def remove_health_check( body = compute.TargetPoolsRemoveHealthCheckRequest.to_json( request.target_pools_remove_health_check_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -718,7 +752,10 @@ def remove_health_check( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -778,6 +815,7 @@ def remove_instance( body = compute.TargetPoolsRemoveInstanceRequest.to_json( request.target_pools_remove_instance_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -803,7 +841,10 @@ def remove_instance( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -861,7 +902,9 @@ def set_backup( # Jsonify the request body body = compute.TargetReference.to_json( - request.target_reference_resource, including_default_value_fields=False + request.target_reference_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -876,8 +919,8 @@ def set_backup( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "requestId": request.request_id, "failoverRatio": request.failover_ratio, + "requestId": request.request_id, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -888,7 +931,10 @@ def set_backup( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) diff --git a/google/cloud/compute_v1/services/target_ssl_proxies/client.py b/google/cloud/compute_v1/services/target_ssl_proxies/client.py index 01ffde7ec..bac26316e 100644 --- a/google/cloud/compute_v1/services/target_ssl_proxies/client.py +++ b/google/cloud/compute_v1/services/target_ssl_proxies/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.target_ssl_proxies import pagers from google.cloud.compute_v1.types import compute from .transports.base import TargetSslProxiesTransport, DEFAULT_CLIENT_INFO @@ -266,21 +267,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -323,7 +320,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -620,7 +617,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.TargetSslProxyList: + ) -> pagers.ListPager: r"""Retrieves the list of TargetSslProxy resources available to the specified project. @@ -642,9 +639,12 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.TargetSslProxyList: + google.cloud.compute_v1.services.target_ssl_proxies.pagers.ListPager: Contains a list of TargetSslProxy resources. + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -677,6 +677,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/target_ssl_proxies/pagers.py b/google/cloud/compute_v1/services/target_ssl_proxies/pagers.py new file mode 100644 index 000000000..d995c620f --- /dev/null +++ b/google/cloud/compute_v1/services/target_ssl_proxies/pagers.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.TargetSslProxyList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.TargetSslProxyList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.TargetSslProxyList], + request: compute.ListTargetSslProxiesRequest, + response: compute.TargetSslProxyList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListTargetSslProxiesRequest): + The initial request object. + response (google.cloud.compute_v1.types.TargetSslProxyList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListTargetSslProxiesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.TargetSslProxyList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.TargetSslProxy]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/target_ssl_proxies/transports/rest.py b/google/cloud/compute_v1/services/target_ssl_proxies/transports/rest.py index d22fde181..8e0e23db3 100644 --- a/google/cloud/compute_v1/services/target_ssl_proxies/transports/rest.py +++ b/google/cloud/compute_v1/services/target_ssl_proxies/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def delete( self, @@ -165,6 +168,9 @@ def delete( # Send the request response = self._session.delete(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -221,6 +227,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.TargetSslProxy.from_json(response.content) @@ -277,7 +286,9 @@ def insert( # Jsonify the request body body = compute.TargetSslProxy.to_json( - request.target_ssl_proxy_resource, including_default_value_fields=False + request.target_ssl_proxy_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -300,7 +311,10 @@ def insert( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -339,11 +353,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -356,6 +370,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.TargetSslProxyList.from_json(response.content) @@ -414,6 +431,7 @@ def set_backend_service( body = compute.TargetSslProxiesSetBackendServiceRequest.to_json( request.target_ssl_proxies_set_backend_service_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -438,7 +456,10 @@ def set_backend_service( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -498,6 +519,7 @@ def set_proxy_header( body = compute.TargetSslProxiesSetProxyHeaderRequest.to_json( request.target_ssl_proxies_set_proxy_header_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -522,7 +544,10 @@ def set_proxy_header( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -582,6 +607,7 @@ def set_ssl_certificates( body = compute.TargetSslProxiesSetSslCertificatesRequest.to_json( request.target_ssl_proxies_set_ssl_certificates_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -606,7 +632,10 @@ def set_ssl_certificates( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -664,7 +693,9 @@ def set_ssl_policy( # Jsonify the request body body = compute.SslPolicyReference.to_json( - request.ssl_policy_reference_resource, including_default_value_fields=False + request.ssl_policy_reference_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -689,7 +720,10 @@ def set_ssl_policy( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) diff --git a/google/cloud/compute_v1/services/target_tcp_proxies/client.py b/google/cloud/compute_v1/services/target_tcp_proxies/client.py index 4778fc3ec..9afbacfd9 100644 --- a/google/cloud/compute_v1/services/target_tcp_proxies/client.py +++ b/google/cloud/compute_v1/services/target_tcp_proxies/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.target_tcp_proxies import pagers from google.cloud.compute_v1.types import compute from .transports.base import TargetTcpProxiesTransport, DEFAULT_CLIENT_INFO @@ -266,21 +267,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -323,7 +320,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -620,7 +617,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.TargetTcpProxyList: + ) -> pagers.ListPager: r"""Retrieves the list of TargetTcpProxy resources available to the specified project. @@ -642,9 +639,12 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.TargetTcpProxyList: + google.cloud.compute_v1.services.target_tcp_proxies.pagers.ListPager: Contains a list of TargetTcpProxy resources. + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -677,6 +677,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/target_tcp_proxies/pagers.py b/google/cloud/compute_v1/services/target_tcp_proxies/pagers.py new file mode 100644 index 000000000..b9fcb6753 --- /dev/null +++ b/google/cloud/compute_v1/services/target_tcp_proxies/pagers.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.TargetTcpProxyList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.TargetTcpProxyList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.TargetTcpProxyList], + request: compute.ListTargetTcpProxiesRequest, + response: compute.TargetTcpProxyList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListTargetTcpProxiesRequest): + The initial request object. + response (google.cloud.compute_v1.types.TargetTcpProxyList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListTargetTcpProxiesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.TargetTcpProxyList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.TargetTcpProxy]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/target_tcp_proxies/transports/rest.py b/google/cloud/compute_v1/services/target_tcp_proxies/transports/rest.py index f5e2e4da7..51ebec033 100644 --- a/google/cloud/compute_v1/services/target_tcp_proxies/transports/rest.py +++ b/google/cloud/compute_v1/services/target_tcp_proxies/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def delete( self, @@ -165,6 +168,9 @@ def delete( # Send the request response = self._session.delete(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -221,6 +227,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.TargetTcpProxy.from_json(response.content) @@ -277,7 +286,9 @@ def insert( # Jsonify the request body body = compute.TargetTcpProxy.to_json( - request.target_tcp_proxy_resource, including_default_value_fields=False + request.target_tcp_proxy_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -300,7 +311,10 @@ def insert( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -339,11 +353,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -356,6 +370,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.TargetTcpProxyList.from_json(response.content) @@ -414,6 +431,7 @@ def set_backend_service( body = compute.TargetTcpProxiesSetBackendServiceRequest.to_json( request.target_tcp_proxies_set_backend_service_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -438,7 +456,10 @@ def set_backend_service( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -498,6 +519,7 @@ def set_proxy_header( body = compute.TargetTcpProxiesSetProxyHeaderRequest.to_json( request.target_tcp_proxies_set_proxy_header_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -522,7 +544,10 @@ def set_proxy_header( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) diff --git a/google/cloud/compute_v1/services/target_vpn_gateways/client.py b/google/cloud/compute_v1/services/target_vpn_gateways/client.py index c1230be1f..d12a7a27b 100644 --- a/google/cloud/compute_v1/services/target_vpn_gateways/client.py +++ b/google/cloud/compute_v1/services/target_vpn_gateways/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.target_vpn_gateways import pagers from google.cloud.compute_v1.types import compute from .transports.base import TargetVpnGatewaysTransport, DEFAULT_CLIENT_INFO @@ -268,21 +269,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -325,7 +322,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -338,7 +335,7 @@ def aggregated_list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.TargetVpnGatewayAggregatedList: + ) -> pagers.AggregatedListPager: r"""Retrieves an aggregated list of target VPN gateways. Args: @@ -359,7 +356,10 @@ def aggregated_list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.TargetVpnGatewayAggregatedList: + google.cloud.compute_v1.services.target_vpn_gateways.pagers.AggregatedListPager: + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -392,6 +392,12 @@ def aggregated_list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.AggregatedListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response @@ -710,7 +716,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.TargetVpnGatewayList: + ) -> pagers.ListPager: r"""Retrieves a list of target VPN gateways available to the specified project and region. @@ -737,9 +743,12 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.TargetVpnGatewayList: + google.cloud.compute_v1.services.target_vpn_gateways.pagers.ListPager: Contains a list of TargetVpnGateway resources. + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -774,6 +783,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/target_vpn_gateways/pagers.py b/google/cloud/compute_v1/services/target_vpn_gateways/pagers.py new file mode 100644 index 000000000..e01973f79 --- /dev/null +++ b/google/cloud/compute_v1/services/target_vpn_gateways/pagers.py @@ -0,0 +1,156 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class AggregatedListPager: + """A pager for iterating through ``aggregated_list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.TargetVpnGatewayAggregatedList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``AggregatedList`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.TargetVpnGatewayAggregatedList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.TargetVpnGatewayAggregatedList], + request: compute.AggregatedListTargetVpnGatewaysRequest, + response: compute.TargetVpnGatewayAggregatedList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.AggregatedListTargetVpnGatewaysRequest): + The initial request object. + response (google.cloud.compute_v1.types.TargetVpnGatewayAggregatedList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.AggregatedListTargetVpnGatewaysRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.TargetVpnGatewayAggregatedList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[Tuple[str, compute.TargetVpnGatewaysScopedList]]: + for page in self.pages: + yield from page.items.items() + + def get(self, key: str) -> Optional[compute.TargetVpnGatewaysScopedList]: + return self._response.items.get(key) + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.TargetVpnGatewayList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.TargetVpnGatewayList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.TargetVpnGatewayList], + request: compute.ListTargetVpnGatewaysRequest, + response: compute.TargetVpnGatewayList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListTargetVpnGatewaysRequest): + The initial request object. + response (google.cloud.compute_v1.types.TargetVpnGatewayList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListTargetVpnGatewaysRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.TargetVpnGatewayList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.TargetVpnGateway]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/target_vpn_gateways/transports/rest.py b/google/cloud/compute_v1/services/target_vpn_gateways/transports/rest.py index fcd07d110..89f504e61 100644 --- a/google/cloud/compute_v1/services/target_vpn_gateways/transports/rest.py +++ b/google/cloud/compute_v1/services/target_vpn_gateways/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def aggregated_list( self, @@ -122,12 +125,12 @@ def aggregated_list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, - "orderBy": request.order_by, "includeAllScopes": request.include_all_scopes, - "returnPartialSuccess": request.return_partial_success, "maxResults": request.max_results, + "orderBy": request.order_by, + "pageToken": request.page_token, + "returnPartialSuccess": request.return_partial_success, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -140,6 +143,9 @@ def aggregated_list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.TargetVpnGatewayAggregatedList.from_json(response.content) @@ -219,6 +225,9 @@ def delete( # Send the request response = self._session.delete(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -274,6 +283,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.TargetVpnGateway.from_json(response.content) @@ -330,7 +342,9 @@ def insert( # Jsonify the request body body = compute.TargetVpnGateway.to_json( - request.target_vpn_gateway_resource, including_default_value_fields=False + request.target_vpn_gateway_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -353,7 +367,10 @@ def insert( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -392,11 +409,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -409,6 +426,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.TargetVpnGatewayList.from_json(response.content) diff --git a/google/cloud/compute_v1/services/url_maps/client.py b/google/cloud/compute_v1/services/url_maps/client.py index 60012beb4..c42f3ccd0 100644 --- a/google/cloud/compute_v1/services/url_maps/client.py +++ b/google/cloud/compute_v1/services/url_maps/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.url_maps import pagers from google.cloud.compute_v1.types import compute from .transports.base import UrlMapsTransport, DEFAULT_CLIENT_INFO @@ -264,21 +265,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -321,7 +318,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -334,7 +331,7 @@ def aggregated_list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.UrlMapsAggregatedList: + ) -> pagers.AggregatedListPager: r"""Retrieves the list of all UrlMap resources, regional and global, available to the specified project. @@ -358,7 +355,10 @@ def aggregated_list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.UrlMapsAggregatedList: + google.cloud.compute_v1.services.url_maps.pagers.AggregatedListPager: + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -391,6 +391,12 @@ def aggregated_list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.AggregatedListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response @@ -821,7 +827,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.UrlMapList: + ) -> pagers.ListPager: r"""Retrieves the list of UrlMap resources available to the specified project. @@ -842,8 +848,12 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.UrlMapList: + google.cloud.compute_v1.services.url_maps.pagers.ListPager: Contains a list of UrlMap resources. + Iterating over this object will yield + results and resolve additional pages + automatically. + """ # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have @@ -875,6 +885,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/url_maps/pagers.py b/google/cloud/compute_v1/services/url_maps/pagers.py new file mode 100644 index 000000000..1ceea1492 --- /dev/null +++ b/google/cloud/compute_v1/services/url_maps/pagers.py @@ -0,0 +1,156 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class AggregatedListPager: + """A pager for iterating through ``aggregated_list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.UrlMapsAggregatedList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``AggregatedList`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.UrlMapsAggregatedList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.UrlMapsAggregatedList], + request: compute.AggregatedListUrlMapsRequest, + response: compute.UrlMapsAggregatedList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.AggregatedListUrlMapsRequest): + The initial request object. + response (google.cloud.compute_v1.types.UrlMapsAggregatedList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.AggregatedListUrlMapsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.UrlMapsAggregatedList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[Tuple[str, compute.UrlMapsScopedList]]: + for page in self.pages: + yield from page.items.items() + + def get(self, key: str) -> Optional[compute.UrlMapsScopedList]: + return self._response.items.get(key) + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.UrlMapList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.UrlMapList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.UrlMapList], + request: compute.ListUrlMapsRequest, + response: compute.UrlMapList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListUrlMapsRequest): + The initial request object. + response (google.cloud.compute_v1.types.UrlMapList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListUrlMapsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.UrlMapList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.UrlMap]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/url_maps/transports/rest.py b/google/cloud/compute_v1/services/url_maps/transports/rest.py index 664c2610c..49143112f 100644 --- a/google/cloud/compute_v1/services/url_maps/transports/rest.py +++ b/google/cloud/compute_v1/services/url_maps/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def aggregated_list( self, @@ -122,12 +125,12 @@ def aggregated_list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, - "orderBy": request.order_by, "includeAllScopes": request.include_all_scopes, - "returnPartialSuccess": request.return_partial_success, "maxResults": request.max_results, + "orderBy": request.order_by, + "pageToken": request.page_token, + "returnPartialSuccess": request.return_partial_success, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -140,6 +143,9 @@ def aggregated_list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.UrlMapsAggregatedList.from_json(response.content) @@ -215,6 +221,9 @@ def delete( # Send the request response = self._session.delete(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -292,6 +301,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.UrlMap.from_json(response.content) @@ -347,7 +359,9 @@ def insert( # Jsonify the request body body = compute.UrlMap.to_json( - request.url_map_resource, including_default_value_fields=False + request.url_map_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -370,7 +384,10 @@ def insert( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -430,6 +447,7 @@ def invalidate_cache( body = compute.CacheInvalidationRule.to_json( request.cache_invalidation_rule_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -452,7 +470,10 @@ def invalidate_cache( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -488,11 +509,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -505,6 +526,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.UrlMapList.from_json(response.content) @@ -560,7 +584,9 @@ def patch( # Jsonify the request body body = compute.UrlMap.to_json( - request.url_map_resource, including_default_value_fields=False + request.url_map_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -583,7 +609,10 @@ def patch( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.patch(url, json=body,) + response = self._session.patch(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -640,7 +669,9 @@ def update( # Jsonify the request body body = compute.UrlMap.to_json( - request.url_map_resource, including_default_value_fields=False + request.url_map_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -663,7 +694,10 @@ def update( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.put(url, json=body,) + response = self._session.put(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -695,6 +729,7 @@ def validate( body = compute.UrlMapsValidateRequest.to_json( request.url_maps_validate_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -715,7 +750,10 @@ def validate( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.UrlMapsValidateResponse.from_json(response.content) diff --git a/google/cloud/compute_v1/services/vpn_gateways/client.py b/google/cloud/compute_v1/services/vpn_gateways/client.py index a926aab90..84d7fcfa4 100644 --- a/google/cloud/compute_v1/services/vpn_gateways/client.py +++ b/google/cloud/compute_v1/services/vpn_gateways/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.vpn_gateways import pagers from google.cloud.compute_v1.types import compute from .transports.base import VpnGatewaysTransport, DEFAULT_CLIENT_INFO @@ -264,21 +265,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -321,7 +318,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -334,7 +331,7 @@ def aggregated_list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.VpnGatewayAggregatedList: + ) -> pagers.AggregatedListPager: r"""Retrieves an aggregated list of VPN gateways. Args: @@ -355,7 +352,10 @@ def aggregated_list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.VpnGatewayAggregatedList: + google.cloud.compute_v1.services.vpn_gateways.pagers.AggregatedListPager: + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -388,6 +388,12 @@ def aggregated_list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.AggregatedListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response @@ -784,7 +790,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.VpnGatewayList: + ) -> pagers.ListPager: r"""Retrieves a list of VPN gateways available to the specified project and region. @@ -811,9 +817,12 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.VpnGatewayList: + google.cloud.compute_v1.services.vpn_gateways.pagers.ListPager: Contains a list of VpnGateway resources. + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -848,6 +857,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/vpn_gateways/pagers.py b/google/cloud/compute_v1/services/vpn_gateways/pagers.py new file mode 100644 index 000000000..8e414a067 --- /dev/null +++ b/google/cloud/compute_v1/services/vpn_gateways/pagers.py @@ -0,0 +1,156 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class AggregatedListPager: + """A pager for iterating through ``aggregated_list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.VpnGatewayAggregatedList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``AggregatedList`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.VpnGatewayAggregatedList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.VpnGatewayAggregatedList], + request: compute.AggregatedListVpnGatewaysRequest, + response: compute.VpnGatewayAggregatedList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.AggregatedListVpnGatewaysRequest): + The initial request object. + response (google.cloud.compute_v1.types.VpnGatewayAggregatedList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.AggregatedListVpnGatewaysRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.VpnGatewayAggregatedList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[Tuple[str, compute.VpnGatewaysScopedList]]: + for page in self.pages: + yield from page.items.items() + + def get(self, key: str) -> Optional[compute.VpnGatewaysScopedList]: + return self._response.items.get(key) + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.VpnGatewayList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.VpnGatewayList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.VpnGatewayList], + request: compute.ListVpnGatewaysRequest, + response: compute.VpnGatewayList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListVpnGatewaysRequest): + The initial request object. + response (google.cloud.compute_v1.types.VpnGatewayList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListVpnGatewaysRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.VpnGatewayList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.VpnGateway]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/vpn_gateways/transports/rest.py b/google/cloud/compute_v1/services/vpn_gateways/transports/rest.py index fb0900233..73327dfe4 100644 --- a/google/cloud/compute_v1/services/vpn_gateways/transports/rest.py +++ b/google/cloud/compute_v1/services/vpn_gateways/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def aggregated_list( self, @@ -122,12 +125,12 @@ def aggregated_list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, - "orderBy": request.order_by, "includeAllScopes": request.include_all_scopes, - "returnPartialSuccess": request.return_partial_success, "maxResults": request.max_results, + "orderBy": request.order_by, + "pageToken": request.page_token, + "returnPartialSuccess": request.return_partial_success, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -140,6 +143,9 @@ def aggregated_list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.VpnGatewayAggregatedList.from_json(response.content) @@ -219,6 +225,9 @@ def delete( # Send the request response = self._session.delete(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -277,6 +286,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.VpnGateway.from_json(response.content) @@ -326,6 +338,9 @@ def get_status( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.VpnGatewaysGetStatusResponse.from_json(response.content) @@ -382,7 +397,9 @@ def insert( # Jsonify the request body body = compute.VpnGateway.to_json( - request.vpn_gateway_resource, including_default_value_fields=False + request.vpn_gateway_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -405,7 +422,10 @@ def insert( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -444,11 +464,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -461,6 +481,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.VpnGatewayList.from_json(response.content) @@ -519,6 +542,7 @@ def set_labels( body = compute.RegionSetLabelsRequest.to_json( request.region_set_labels_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -544,7 +568,10 @@ def set_labels( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -576,6 +603,7 @@ def test_iam_permissions( body = compute.TestPermissionsRequest.to_json( request.test_permissions_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -599,7 +627,10 @@ def test_iam_permissions( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.TestPermissionsResponse.from_json(response.content) diff --git a/google/cloud/compute_v1/services/vpn_tunnels/client.py b/google/cloud/compute_v1/services/vpn_tunnels/client.py index 53fa3d100..af6863525 100644 --- a/google/cloud/compute_v1/services/vpn_tunnels/client.py +++ b/google/cloud/compute_v1/services/vpn_tunnels/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.vpn_tunnels import pagers from google.cloud.compute_v1.types import compute from .transports.base import VpnTunnelsTransport, DEFAULT_CLIENT_INFO @@ -264,21 +265,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -321,7 +318,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -334,7 +331,7 @@ def aggregated_list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.VpnTunnelAggregatedList: + ) -> pagers.AggregatedListPager: r"""Retrieves an aggregated list of VPN tunnels. Args: @@ -355,7 +352,10 @@ def aggregated_list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.VpnTunnelAggregatedList: + google.cloud.compute_v1.services.vpn_tunnels.pagers.AggregatedListPager: + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -388,6 +388,12 @@ def aggregated_list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.AggregatedListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response @@ -703,7 +709,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.VpnTunnelList: + ) -> pagers.ListPager: r"""Retrieves a list of VpnTunnel resources contained in the specified project and region. @@ -729,9 +735,12 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.VpnTunnelList: + google.cloud.compute_v1.services.vpn_tunnels.pagers.ListPager: Contains a list of VpnTunnel resources. + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -766,6 +775,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/vpn_tunnels/pagers.py b/google/cloud/compute_v1/services/vpn_tunnels/pagers.py new file mode 100644 index 000000000..13780f8d6 --- /dev/null +++ b/google/cloud/compute_v1/services/vpn_tunnels/pagers.py @@ -0,0 +1,156 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class AggregatedListPager: + """A pager for iterating through ``aggregated_list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.VpnTunnelAggregatedList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``AggregatedList`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.VpnTunnelAggregatedList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.VpnTunnelAggregatedList], + request: compute.AggregatedListVpnTunnelsRequest, + response: compute.VpnTunnelAggregatedList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.AggregatedListVpnTunnelsRequest): + The initial request object. + response (google.cloud.compute_v1.types.VpnTunnelAggregatedList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.AggregatedListVpnTunnelsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.VpnTunnelAggregatedList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[Tuple[str, compute.VpnTunnelsScopedList]]: + for page in self.pages: + yield from page.items.items() + + def get(self, key: str) -> Optional[compute.VpnTunnelsScopedList]: + return self._response.items.get(key) + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.VpnTunnelList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.VpnTunnelList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.VpnTunnelList], + request: compute.ListVpnTunnelsRequest, + response: compute.VpnTunnelList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListVpnTunnelsRequest): + The initial request object. + response (google.cloud.compute_v1.types.VpnTunnelList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListVpnTunnelsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.VpnTunnelList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.VpnTunnel]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/vpn_tunnels/transports/rest.py b/google/cloud/compute_v1/services/vpn_tunnels/transports/rest.py index ada70a89f..b68a8729a 100644 --- a/google/cloud/compute_v1/services/vpn_tunnels/transports/rest.py +++ b/google/cloud/compute_v1/services/vpn_tunnels/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def aggregated_list( self, @@ -122,12 +125,12 @@ def aggregated_list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, - "orderBy": request.order_by, "includeAllScopes": request.include_all_scopes, - "returnPartialSuccess": request.return_partial_success, "maxResults": request.max_results, + "orderBy": request.order_by, + "pageToken": request.page_token, + "returnPartialSuccess": request.return_partial_success, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -140,6 +143,9 @@ def aggregated_list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.VpnTunnelAggregatedList.from_json(response.content) @@ -219,6 +225,9 @@ def delete( # Send the request response = self._session.delete(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -271,6 +280,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.VpnTunnel.from_json(response.content) @@ -327,7 +339,9 @@ def insert( # Jsonify the request body body = compute.VpnTunnel.to_json( - request.vpn_tunnel_resource, including_default_value_fields=False + request.vpn_tunnel_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) # TODO(yon-mg): need to handle grpc transcoding and parse url correctly @@ -350,7 +364,10 @@ def insert( url += "?{}".format("&".join(query_params)).replace(" ", "+") # Send the request - response = self._session.post(url, json=body,) + response = self._session.post(url, data=body,) + + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() # Return the response return compute.Operation.from_json(response.content) @@ -389,11 +406,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -406,6 +423,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.VpnTunnelList.from_json(response.content) diff --git a/google/cloud/compute_v1/services/zone_operations/client.py b/google/cloud/compute_v1/services/zone_operations/client.py index 557939bf2..38e20d474 100644 --- a/google/cloud/compute_v1/services/zone_operations/client.py +++ b/google/cloud/compute_v1/services/zone_operations/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.zone_operations import pagers from google.cloud.compute_v1.types import compute from .transports.base import ZoneOperationsTransport, DEFAULT_CLIENT_INFO @@ -266,21 +267,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -323,7 +320,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -533,7 +530,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.OperationList: + ) -> pagers.ListPager: r"""Retrieves a list of Operation resources contained within the specified zone. @@ -560,9 +557,12 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.OperationList: + google.cloud.compute_v1.services.zone_operations.pagers.ListPager: Contains a list of Operation resources. + Iterating over this object will yield + results and resolve additional pages + automatically. """ # Create or coerce a protobuf request object. @@ -597,6 +597,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/zone_operations/pagers.py b/google/cloud/compute_v1/services/zone_operations/pagers.py new file mode 100644 index 000000000..62db29079 --- /dev/null +++ b/google/cloud/compute_v1/services/zone_operations/pagers.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.OperationList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.OperationList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.OperationList], + request: compute.ListZoneOperationsRequest, + response: compute.OperationList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListZoneOperationsRequest): + The initial request object. + response (google.cloud.compute_v1.types.OperationList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListZoneOperationsRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.OperationList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.Operation]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/zone_operations/transports/rest.py b/google/cloud/compute_v1/services/zone_operations/transports/rest.py index 38f2aed09..444772246 100644 --- a/google/cloud/compute_v1/services/zone_operations/transports/rest.py +++ b/google/cloud/compute_v1/services/zone_operations/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def delete( self, @@ -139,6 +142,9 @@ def delete( # Send the request response = self._session.delete(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.DeleteZoneOperationResponse.from_json(response.content) @@ -216,6 +222,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) @@ -253,11 +262,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -270,6 +279,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.OperationList.from_json(response.content) @@ -347,6 +359,9 @@ def wait( # Send the request response = self._session.post(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Operation.from_json(response.content) diff --git a/google/cloud/compute_v1/services/zones/client.py b/google/cloud/compute_v1/services/zones/client.py index fbebb7e3a..fec06f719 100644 --- a/google/cloud/compute_v1/services/zones/client.py +++ b/google/cloud/compute_v1/services/zones/client.py @@ -32,6 +32,7 @@ from google.auth.exceptions import MutualTLSChannelError # type: ignore from google.oauth2 import service_account # type: ignore +from google.cloud.compute_v1.services.zones import pagers from google.cloud.compute_v1.types import compute from .transports.base import ZonesTransport, DEFAULT_CLIENT_INFO @@ -264,21 +265,17 @@ def __init__( util.strtobool(os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false")) ) - ssl_credentials = None + client_cert_source_func = None is_mtls = False if use_client_cert: if client_options.client_cert_source: - import grpc # type: ignore - - cert, key = client_options.client_cert_source() - ssl_credentials = grpc.ssl_channel_credentials( - certificate_chain=cert, private_key=key - ) is_mtls = True + client_cert_source_func = client_options.client_cert_source else: - creds = SslCredentials() - is_mtls = creds.is_mtls - ssl_credentials = creds.ssl_credentials if is_mtls else None + is_mtls = mtls.has_default_client_cert_source() + client_cert_source_func = ( + mtls.default_client_cert_source() if is_mtls else None + ) # Figure out which api endpoint to use. if client_options.api_endpoint is not None: @@ -321,7 +318,7 @@ def __init__( credentials_file=client_options.credentials_file, host=api_endpoint, scopes=client_options.scopes, - ssl_channel_credentials=ssl_credentials, + client_cert_source_for_mtls=client_cert_source_func, quota_project_id=client_options.quota_project_id, client_info=client_info, ) @@ -414,7 +411,7 @@ def list( retry: retries.Retry = gapic_v1.method.DEFAULT, timeout: float = None, metadata: Sequence[Tuple[str, str]] = (), - ) -> compute.ZoneList: + ) -> pagers.ListPager: r"""Retrieves the list of Zone resources available to the specified project. @@ -435,8 +432,12 @@ def list( sent along with the request as metadata. Returns: - google.cloud.compute_v1.types.ZoneList: + google.cloud.compute_v1.services.zones.pagers.ListPager: Contains a list of zone resources. + Iterating over this object will yield + results and resolve additional pages + automatically. + """ # Create or coerce a protobuf request object. # Sanity check: If we got a request object, we should *not* have @@ -468,6 +469,12 @@ def list( # Send the request. response = rpc(request, retry=retry, timeout=timeout, metadata=metadata,) + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPager( + method=rpc, request=request, response=response, metadata=metadata, + ) + # Done; return the response. return response diff --git a/google/cloud/compute_v1/services/zones/pagers.py b/google/cloud/compute_v1/services/zones/pagers.py new file mode 100644 index 000000000..8ed009ba3 --- /dev/null +++ b/google/cloud/compute_v1/services/zones/pagers.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import ( + Any, + AsyncIterable, + Awaitable, + Callable, + Iterable, + Sequence, + Tuple, + Optional, +) + +from google.cloud.compute_v1.types import compute + + +class ListPager: + """A pager for iterating through ``list`` requests. + + This class thinly wraps an initial + :class:`google.cloud.compute_v1.types.ZoneList` object, and + provides an ``__iter__`` method to iterate through its + ``items`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``List`` requests and continue to iterate + through the ``items`` field on the + corresponding responses. + + All the usual :class:`google.cloud.compute_v1.types.ZoneList` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., compute.ZoneList], + request: compute.ListZonesRequest, + response: compute.ZoneList, + *, + metadata: Sequence[Tuple[str, str]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.cloud.compute_v1.types.ListZonesRequest): + The initial request object. + response (google.cloud.compute_v1.types.ZoneList): + The initial response object. + metadata (Sequence[Tuple[str, str]]): Strings which should be + sent along with the request as metadata. + """ + self._method = method + self._request = compute.ListZonesRequest(request) + self._response = response + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterable[compute.ZoneList]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method(self._request, metadata=self._metadata) + yield self._response + + def __iter__(self) -> Iterable[compute.Zone]: + for page in self.pages: + yield from page.items + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/google/cloud/compute_v1/services/zones/transports/rest.py b/google/cloud/compute_v1/services/zones/transports/rest.py index a73fb2136..a7ed4ab98 100644 --- a/google/cloud/compute_v1/services/zones/transports/rest.py +++ b/google/cloud/compute_v1/services/zones/transports/rest.py @@ -54,7 +54,7 @@ def __init__( credentials: credentials.Credentials = None, credentials_file: str = None, scopes: Sequence[str] = None, - ssl_channel_credentials: grpc.ChannelCredentials = None, + client_cert_source_for_mtls: Callable[[], Tuple[bytes, bytes]] = None, quota_project_id: Optional[str] = None, client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, ) -> None: @@ -73,8 +73,9 @@ def __init__( This argument is ignored if ``channel`` is provided. scopes (Optional(Sequence[str])): A list of scopes. This argument is ignored if ``channel`` is provided. - ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials - for grpc channel. It is ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. quota_project_id (Optional[str]): An optional project to use for billing and quota. client_info (google.api_core.gapic_v1.client_info.ClientInfo): @@ -89,6 +90,8 @@ def __init__( host=host, credentials=credentials, client_info=client_info, ) self._session = AuthorizedSession(self._credentials) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) def get( self, @@ -139,6 +142,9 @@ def get( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.Zone.from_json(response.content) @@ -173,11 +179,11 @@ def list( # TODO(yon-mg): handle nested fields corerctly rather than using only top level fields # not required for GCE query_params = { - "pageToken": request.page_token, "filter": request.filter, + "maxResults": request.max_results, "orderBy": request.order_by, + "pageToken": request.page_token, "returnPartialSuccess": request.return_partial_success, - "maxResults": request.max_results, } # TODO(yon-mg): further discussion needed whether 'python truthiness' is appropriate here # discards default values @@ -190,6 +196,9 @@ def list( # Send the request response = self._session.get(url) + # Raise requests.exceptions.HTTPError if the status code is >= 400 + response.raise_for_status() + # Return the response return compute.ZoneList.from_json(response.content) diff --git a/scripts/fixup_compute_v1_keywords.py b/scripts/fixup_compute_v1_keywords.py index 98b137933..c15bf5d9c 100644 --- a/scripts/fixup_compute_v1_keywords.py +++ b/scripts/fixup_compute_v1_keywords.py @@ -50,32 +50,32 @@ class computeCallTransformer(cst.CSTTransformer): 'add_peering': ('network', 'networks_add_peering_request_resource', 'project', 'request_id', ), 'add_resource_policies': ('disk', 'disks_add_resource_policies_request_resource', 'project', 'zone', 'request_id', ), 'add_rule': ('project', 'security_policy', 'security_policy_rule_resource', ), - 'add_signed_url_key': ('backend_bucket', 'project', 'signed_url_key_resource', 'request_id', ), + 'add_signed_url_key': ('backend_service', 'project', 'signed_url_key_resource', 'request_id', ), 'aggregated_list': ('project', 'filter', 'include_all_scopes', 'max_results', 'order_by', 'page_token', 'return_partial_success', ), 'apply_updates_to_instances': ('instance_group_manager', 'instance_group_managers_apply_updates_request_resource', 'project', 'zone', ), 'attach_disk': ('attached_disk_resource', 'instance', 'project', 'zone', 'force_attach', 'request_id', ), - 'attach_network_endpoints': ('global_network_endpoint_groups_attach_endpoints_request_resource', 'network_endpoint_group', 'project', 'request_id', ), + 'attach_network_endpoints': ('network_endpoint_group', 'network_endpoint_groups_attach_endpoints_request_resource', 'project', 'zone', 'request_id', ), 'create_instances': ('instance_group_manager', 'instance_group_managers_create_instances_request_resource', 'project', 'zone', 'request_id', ), 'create_snapshot': ('disk', 'project', 'snapshot_resource', 'zone', 'guest_flush', 'request_id', ), - 'delete': ('address', 'project', 'region', 'request_id', ), + 'delete': ('disk', 'project', 'zone', 'request_id', ), 'delete_access_config': ('access_config', 'instance', 'network_interface', 'project', 'zone', 'request_id', ), 'delete_instances': ('instance_group_manager', 'instance_group_managers_delete_instances_request_resource', 'project', 'zone', 'request_id', ), 'delete_nodes': ('node_group', 'node_groups_delete_nodes_request_resource', 'project', 'zone', 'request_id', ), 'delete_per_instance_configs': ('instance_group_manager', 'instance_group_managers_delete_per_instance_configs_req_resource', 'project', 'zone', ), - 'delete_signed_url_key': ('backend_bucket', 'key_name', 'project', 'request_id', ), + 'delete_signed_url_key': ('backend_service', 'key_name', 'project', 'request_id', ), 'deprecate': ('deprecation_status_resource', 'image', 'project', 'request_id', ), 'detach_disk': ('device_name', 'instance', 'project', 'zone', 'request_id', ), - 'detach_network_endpoints': ('global_network_endpoint_groups_detach_endpoints_request_resource', 'network_endpoint_group', 'project', 'request_id', ), + 'detach_network_endpoints': ('network_endpoint_group', 'network_endpoint_groups_detach_endpoints_request_resource', 'project', 'zone', 'request_id', ), 'disable_xpn_host': ('project', 'request_id', ), 'disable_xpn_resource': ('project', 'projects_disable_xpn_resource_request_resource', 'request_id', ), 'enable_xpn_host': ('project', 'request_id', ), 'enable_xpn_resource': ('project', 'projects_enable_xpn_resource_request_resource', 'request_id', ), 'expand_ip_cidr_range': ('project', 'region', 'subnetwork', 'subnetworks_expand_ip_cidr_range_request_resource', 'request_id', ), - 'get': ('accelerator_type', 'project', 'zone', ), + 'get': ('disk', 'project', 'zone', ), 'get_diagnostics': ('interconnect', 'project', ), 'get_from_family': ('family', 'project', ), 'get_guest_attributes': ('instance', 'project', 'zone', 'query_path', 'variable_key', ), - 'get_health': ('backend_service', 'project', 'resource_group_reference_resource', ), + 'get_health': ('backend_service', 'project', 'region', 'resource_group_reference_resource', ), 'get_iam_policy': ('project', 'resource', 'zone', 'options_requested_policy_version', ), 'get_nat_mapping_info': ('project', 'region', 'router', 'filter', 'max_results', 'order_by', 'page_token', 'return_partial_success', ), 'get_router_status': ('project', 'region', 'router', ), @@ -86,14 +86,14 @@ class computeCallTransformer(cst.CSTTransformer): 'get_status': ('project', 'region', 'vpn_gateway', ), 'get_xpn_host': ('project', ), 'get_xpn_resources': ('project', 'filter', 'max_results', 'order_by', 'page_token', 'return_partial_success', ), - 'insert': ('address_resource', 'project', 'region', 'request_id', ), + 'insert': ('disk_resource', 'project', 'zone', 'request_id', 'source_image', ), 'invalidate_cache': ('cache_invalidation_rule_resource', 'project', 'url_map', 'request_id', ), 'list': ('project', 'zone', 'filter', 'max_results', 'order_by', 'page_token', 'return_partial_success', ), 'list_available_features': ('project', 'filter', 'max_results', 'order_by', 'page_token', 'return_partial_success', ), 'list_errors': ('instance_group_manager', 'project', 'zone', 'filter', 'max_results', 'order_by', 'page_token', 'return_partial_success', ), 'list_instances': ('instance_group', 'instance_groups_list_instances_request_resource', 'project', 'zone', 'filter', 'max_results', 'order_by', 'page_token', 'return_partial_success', ), 'list_managed_instances': ('instance_group_manager', 'project', 'zone', 'filter', 'max_results', 'order_by', 'page_token', 'return_partial_success', ), - 'list_network_endpoints': ('network_endpoint_group', 'project', 'filter', 'max_results', 'order_by', 'page_token', 'return_partial_success', ), + 'list_network_endpoints': ('network_endpoint_group', 'network_endpoint_groups_list_endpoints_request_resource', 'project', 'zone', 'filter', 'max_results', 'order_by', 'page_token', 'return_partial_success', ), 'list_nodes': ('node_group', 'project', 'zone', 'filter', 'max_results', 'order_by', 'page_token', 'return_partial_success', ), 'list_peering_routes': ('network', 'project', 'direction', 'filter', 'max_results', 'order_by', 'page_token', 'peering_name', 'region', 'return_partial_success', ), 'list_per_instance_configs': ('instance_group_manager', 'project', 'zone', 'filter', 'max_results', 'order_by', 'page_token', 'return_partial_success', ), @@ -138,12 +138,12 @@ class computeCallTransformer(cst.CSTTransformer): 'set_security_policy': ('backend_service', 'project', 'security_policy_reference_resource', 'request_id', ), 'set_service_account': ('instance', 'instances_set_service_account_request_resource', 'project', 'zone', 'request_id', ), 'set_shielded_instance_integrity_policy': ('instance', 'project', 'shielded_instance_integrity_policy_resource', 'zone', 'request_id', ), - 'set_ssl_certificates': ('project', 'region', 'region_target_https_proxies_set_ssl_certificates_request_resource', 'target_https_proxy', 'request_id', ), + 'set_ssl_certificates': ('project', 'target_https_proxies_set_ssl_certificates_request_resource', 'target_https_proxy', 'request_id', ), 'set_ssl_policy': ('project', 'ssl_policy_reference_resource', 'target_https_proxy', 'request_id', ), 'set_tags': ('instance', 'project', 'tags_resource', 'zone', 'request_id', ), - 'set_target': ('forwarding_rule', 'project', 'region', 'target_reference_resource', 'request_id', ), + 'set_target': ('forwarding_rule', 'project', 'target_reference_resource', 'request_id', ), 'set_target_pools': ('instance_group_manager', 'instance_group_managers_set_target_pools_request_resource', 'project', 'zone', 'request_id', ), - 'set_url_map': ('project', 'region', 'target_http_proxy', 'url_map_reference_resource', 'request_id', ), + 'set_url_map': ('project', 'target_https_proxy', 'url_map_reference_resource', 'request_id', ), 'set_usage_export_bucket': ('project', 'usage_export_location_resource', 'request_id', ), 'simulate_maintenance_event': ('instance', 'project', 'zone', ), 'start': ('instance', 'project', 'zone', 'request_id', ), @@ -158,8 +158,8 @@ class computeCallTransformer(cst.CSTTransformer): 'update_peering': ('network', 'networks_update_peering_request_resource', 'project', 'request_id', ), 'update_per_instance_configs': ('instance_group_manager', 'instance_group_managers_update_per_instance_configs_req_resource', 'project', 'zone', 'request_id', ), 'update_shielded_instance_config': ('instance', 'project', 'shielded_instance_config_resource', 'zone', 'request_id', ), - 'validate': ('project', 'region', 'region_url_maps_validate_request_resource', 'url_map', ), - 'wait': ('operation', 'project', ), + 'validate': ('project', 'url_map', 'url_maps_validate_request_resource', ), + 'wait': ('operation', 'project', 'region', ), } diff --git a/setup.py b/setup.py index ed51a6fbc..fed95e5ba 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,7 @@ import os import setuptools # type: ignore -version = "0.1.0" +version = "0.2.0" package_root = os.path.abspath(os.path.dirname(__file__)) @@ -46,7 +46,7 @@ ), python_requires=">=3.6", classifiers=[ - "Development Status :: 4 - Beta", + "Development Status :: 3 - Alpha", "Intended Audience :: Developers", "Operating System :: OS Independent", "Programming Language :: Python :: 3.6", diff --git a/synth.metadata b/synth.metadata index 9eb7f0f48..100147148 100644 --- a/synth.metadata +++ b/synth.metadata @@ -3,29 +3,29 @@ { "git": { "name": ".", - "remote": "sso://devrel/cloud/libraries/python/python-compute", - "sha": "432621401ddf808a5fef2d6d565b0e3a75bd2df0" + "remote": "https://github.com/googleapis/python-compute.git", + "sha": "16e7294cd536af9a8bf8e4e99219a883339aa955" } }, { "git": { "name": "googleapis-discovery", "remote": "https://github.com/googleapis/googleapis-discovery.git", - "sha": "ed8fe29ede4a8c5117c955bf0972e6b751800e41" + "sha": "8bde5a64f149bd8f6a3b71efa20d232cb07a55de" } }, { "git": { "name": "synthtool", "remote": "https://github.com/googleapis/synthtool.git", - "sha": "fb53b6fb373b7c3edf4e55f3e8036bc6d73fa483" + "sha": "b259489b06b25f399768b74b8baa943991f38ea7" } }, { "git": { "name": "synthtool", "remote": "https://github.com/googleapis/synthtool.git", - "sha": "fb53b6fb373b7c3edf4e55f3e8036bc6d73fa483" + "sha": "b259489b06b25f399768b74b8baa943991f38ea7" } } ], @@ -39,5 +39,667 @@ "generator": "bazel" } } + ], + "generatedFiles": [ + ".coveragerc", + ".flake8", + ".github/CONTRIBUTING.md", + ".github/ISSUE_TEMPLATE/bug_report.md", + ".github/ISSUE_TEMPLATE/feature_request.md", + ".github/ISSUE_TEMPLATE/support_request.md", + ".github/PULL_REQUEST_TEMPLATE.md", + ".github/release-please.yml", + ".github/snippet-bot.yml", + ".gitignore", + ".kokoro/build.sh", + ".kokoro/continuous/common.cfg", + ".kokoro/continuous/continuous.cfg", + ".kokoro/docker/docs/Dockerfile", + ".kokoro/docker/docs/fetch_gpg_keys.sh", + ".kokoro/docs/common.cfg", + ".kokoro/docs/docs-presubmit.cfg", + ".kokoro/docs/docs.cfg", + ".kokoro/populate-secrets.sh", + ".kokoro/presubmit/common.cfg", + ".kokoro/presubmit/presubmit.cfg", + ".kokoro/publish-docs.sh", + ".kokoro/release.sh", + ".kokoro/release/common.cfg", + ".kokoro/release/release.cfg", + ".kokoro/samples/lint/common.cfg", + ".kokoro/samples/lint/continuous.cfg", + ".kokoro/samples/lint/periodic.cfg", + ".kokoro/samples/lint/presubmit.cfg", + ".kokoro/samples/python3.6/common.cfg", + ".kokoro/samples/python3.6/continuous.cfg", + ".kokoro/samples/python3.6/periodic.cfg", + ".kokoro/samples/python3.6/presubmit.cfg", + ".kokoro/samples/python3.7/common.cfg", + ".kokoro/samples/python3.7/continuous.cfg", + ".kokoro/samples/python3.7/periodic.cfg", + ".kokoro/samples/python3.7/presubmit.cfg", + ".kokoro/samples/python3.8/common.cfg", + ".kokoro/samples/python3.8/continuous.cfg", + ".kokoro/samples/python3.8/periodic.cfg", + ".kokoro/samples/python3.8/presubmit.cfg", + ".kokoro/test-samples.sh", + ".kokoro/trampoline.sh", + ".kokoro/trampoline_v2.sh", + ".pre-commit-config.yaml", + ".trampolinerc", + "CODE_OF_CONDUCT.md", + "CONTRIBUTING.rst", + "LICENSE", + "MANIFEST.in", + "docs/_static/custom.css", + "docs/_templates/layout.html", + "docs/compute_v1/accelerator_types.rst", + "docs/compute_v1/addresses.rst", + "docs/compute_v1/autoscalers.rst", + "docs/compute_v1/backend_buckets.rst", + "docs/compute_v1/backend_services.rst", + "docs/compute_v1/disk_types.rst", + "docs/compute_v1/disks.rst", + "docs/compute_v1/external_vpn_gateways.rst", + "docs/compute_v1/firewalls.rst", + "docs/compute_v1/forwarding_rules.rst", + "docs/compute_v1/global_addresses.rst", + "docs/compute_v1/global_forwarding_rules.rst", + "docs/compute_v1/global_network_endpoint_groups.rst", + "docs/compute_v1/global_operations.rst", + "docs/compute_v1/global_organization_operations.rst", + "docs/compute_v1/health_checks.rst", + "docs/compute_v1/images.rst", + "docs/compute_v1/instance_group_managers.rst", + "docs/compute_v1/instance_groups.rst", + "docs/compute_v1/instance_templates.rst", + "docs/compute_v1/instances.rst", + "docs/compute_v1/interconnect_attachments.rst", + "docs/compute_v1/interconnect_locations.rst", + "docs/compute_v1/interconnects.rst", + "docs/compute_v1/license_codes.rst", + "docs/compute_v1/licenses.rst", + "docs/compute_v1/machine_types.rst", + "docs/compute_v1/network_endpoint_groups.rst", + "docs/compute_v1/networks.rst", + "docs/compute_v1/node_groups.rst", + "docs/compute_v1/node_templates.rst", + "docs/compute_v1/node_types.rst", + "docs/compute_v1/packet_mirrorings.rst", + "docs/compute_v1/projects.rst", + "docs/compute_v1/region_autoscalers.rst", + "docs/compute_v1/region_backend_services.rst", + "docs/compute_v1/region_commitments.rst", + "docs/compute_v1/region_disk_types.rst", + "docs/compute_v1/region_disks.rst", + "docs/compute_v1/region_health_check_services.rst", + "docs/compute_v1/region_health_checks.rst", + "docs/compute_v1/region_instance_group_managers.rst", + "docs/compute_v1/region_instance_groups.rst", + "docs/compute_v1/region_network_endpoint_groups.rst", + "docs/compute_v1/region_notification_endpoints.rst", + "docs/compute_v1/region_operations.rst", + "docs/compute_v1/region_ssl_certificates.rst", + "docs/compute_v1/region_target_http_proxies.rst", + "docs/compute_v1/region_target_https_proxies.rst", + "docs/compute_v1/region_url_maps.rst", + "docs/compute_v1/regions.rst", + "docs/compute_v1/reservations.rst", + "docs/compute_v1/resource_policies.rst", + "docs/compute_v1/routers.rst", + "docs/compute_v1/routes.rst", + "docs/compute_v1/security_policies.rst", + "docs/compute_v1/services.rst", + "docs/compute_v1/snapshots.rst", + "docs/compute_v1/ssl_certificates.rst", + "docs/compute_v1/ssl_policies.rst", + "docs/compute_v1/subnetworks.rst", + "docs/compute_v1/target_grpc_proxies.rst", + "docs/compute_v1/target_http_proxies.rst", + "docs/compute_v1/target_https_proxies.rst", + "docs/compute_v1/target_instances.rst", + "docs/compute_v1/target_pools.rst", + "docs/compute_v1/target_ssl_proxies.rst", + "docs/compute_v1/target_tcp_proxies.rst", + "docs/compute_v1/target_vpn_gateways.rst", + "docs/compute_v1/types.rst", + "docs/compute_v1/url_maps.rst", + "docs/compute_v1/vpn_gateways.rst", + "docs/compute_v1/vpn_tunnels.rst", + "docs/compute_v1/zone_operations.rst", + "docs/compute_v1/zones.rst", + "docs/conf.py", + "docs/multiprocessing.rst", + "google/cloud/compute/__init__.py", + "google/cloud/compute/py.typed", + "google/cloud/compute_v1/__init__.py", + "google/cloud/compute_v1/py.typed", + "google/cloud/compute_v1/services/__init__.py", + "google/cloud/compute_v1/services/accelerator_types/__init__.py", + "google/cloud/compute_v1/services/accelerator_types/client.py", + "google/cloud/compute_v1/services/accelerator_types/pagers.py", + "google/cloud/compute_v1/services/accelerator_types/transports/__init__.py", + "google/cloud/compute_v1/services/accelerator_types/transports/base.py", + "google/cloud/compute_v1/services/accelerator_types/transports/rest.py", + "google/cloud/compute_v1/services/addresses/__init__.py", + "google/cloud/compute_v1/services/addresses/client.py", + "google/cloud/compute_v1/services/addresses/pagers.py", + "google/cloud/compute_v1/services/addresses/transports/__init__.py", + "google/cloud/compute_v1/services/addresses/transports/base.py", + "google/cloud/compute_v1/services/addresses/transports/rest.py", + "google/cloud/compute_v1/services/autoscalers/__init__.py", + "google/cloud/compute_v1/services/autoscalers/client.py", + "google/cloud/compute_v1/services/autoscalers/pagers.py", + "google/cloud/compute_v1/services/autoscalers/transports/__init__.py", + "google/cloud/compute_v1/services/autoscalers/transports/base.py", + "google/cloud/compute_v1/services/autoscalers/transports/rest.py", + "google/cloud/compute_v1/services/backend_buckets/__init__.py", + "google/cloud/compute_v1/services/backend_buckets/client.py", + "google/cloud/compute_v1/services/backend_buckets/pagers.py", + "google/cloud/compute_v1/services/backend_buckets/transports/__init__.py", + "google/cloud/compute_v1/services/backend_buckets/transports/base.py", + "google/cloud/compute_v1/services/backend_buckets/transports/rest.py", + "google/cloud/compute_v1/services/backend_services/__init__.py", + "google/cloud/compute_v1/services/backend_services/client.py", + "google/cloud/compute_v1/services/backend_services/pagers.py", + "google/cloud/compute_v1/services/backend_services/transports/__init__.py", + "google/cloud/compute_v1/services/backend_services/transports/base.py", + "google/cloud/compute_v1/services/backend_services/transports/rest.py", + "google/cloud/compute_v1/services/disk_types/__init__.py", + "google/cloud/compute_v1/services/disk_types/client.py", + "google/cloud/compute_v1/services/disk_types/pagers.py", + "google/cloud/compute_v1/services/disk_types/transports/__init__.py", + "google/cloud/compute_v1/services/disk_types/transports/base.py", + "google/cloud/compute_v1/services/disk_types/transports/rest.py", + "google/cloud/compute_v1/services/disks/__init__.py", + "google/cloud/compute_v1/services/disks/client.py", + "google/cloud/compute_v1/services/disks/pagers.py", + "google/cloud/compute_v1/services/disks/transports/__init__.py", + "google/cloud/compute_v1/services/disks/transports/base.py", + "google/cloud/compute_v1/services/disks/transports/rest.py", + "google/cloud/compute_v1/services/external_vpn_gateways/__init__.py", + "google/cloud/compute_v1/services/external_vpn_gateways/client.py", + "google/cloud/compute_v1/services/external_vpn_gateways/pagers.py", + "google/cloud/compute_v1/services/external_vpn_gateways/transports/__init__.py", + "google/cloud/compute_v1/services/external_vpn_gateways/transports/base.py", + "google/cloud/compute_v1/services/external_vpn_gateways/transports/rest.py", + "google/cloud/compute_v1/services/firewalls/__init__.py", + "google/cloud/compute_v1/services/firewalls/client.py", + "google/cloud/compute_v1/services/firewalls/pagers.py", + "google/cloud/compute_v1/services/firewalls/transports/__init__.py", + "google/cloud/compute_v1/services/firewalls/transports/base.py", + "google/cloud/compute_v1/services/firewalls/transports/rest.py", + "google/cloud/compute_v1/services/forwarding_rules/__init__.py", + "google/cloud/compute_v1/services/forwarding_rules/client.py", + "google/cloud/compute_v1/services/forwarding_rules/pagers.py", + "google/cloud/compute_v1/services/forwarding_rules/transports/__init__.py", + "google/cloud/compute_v1/services/forwarding_rules/transports/base.py", + "google/cloud/compute_v1/services/forwarding_rules/transports/rest.py", + "google/cloud/compute_v1/services/global_addresses/__init__.py", + "google/cloud/compute_v1/services/global_addresses/client.py", + "google/cloud/compute_v1/services/global_addresses/pagers.py", + "google/cloud/compute_v1/services/global_addresses/transports/__init__.py", + "google/cloud/compute_v1/services/global_addresses/transports/base.py", + "google/cloud/compute_v1/services/global_addresses/transports/rest.py", + "google/cloud/compute_v1/services/global_forwarding_rules/__init__.py", + "google/cloud/compute_v1/services/global_forwarding_rules/client.py", + "google/cloud/compute_v1/services/global_forwarding_rules/pagers.py", + "google/cloud/compute_v1/services/global_forwarding_rules/transports/__init__.py", + "google/cloud/compute_v1/services/global_forwarding_rules/transports/base.py", + "google/cloud/compute_v1/services/global_forwarding_rules/transports/rest.py", + "google/cloud/compute_v1/services/global_network_endpoint_groups/__init__.py", + "google/cloud/compute_v1/services/global_network_endpoint_groups/client.py", + "google/cloud/compute_v1/services/global_network_endpoint_groups/pagers.py", + "google/cloud/compute_v1/services/global_network_endpoint_groups/transports/__init__.py", + "google/cloud/compute_v1/services/global_network_endpoint_groups/transports/base.py", + "google/cloud/compute_v1/services/global_network_endpoint_groups/transports/rest.py", + "google/cloud/compute_v1/services/global_operations/__init__.py", + "google/cloud/compute_v1/services/global_operations/client.py", + "google/cloud/compute_v1/services/global_operations/pagers.py", + "google/cloud/compute_v1/services/global_operations/transports/__init__.py", + "google/cloud/compute_v1/services/global_operations/transports/base.py", + "google/cloud/compute_v1/services/global_operations/transports/rest.py", + "google/cloud/compute_v1/services/global_organization_operations/__init__.py", + "google/cloud/compute_v1/services/global_organization_operations/client.py", + "google/cloud/compute_v1/services/global_organization_operations/pagers.py", + "google/cloud/compute_v1/services/global_organization_operations/transports/__init__.py", + "google/cloud/compute_v1/services/global_organization_operations/transports/base.py", + "google/cloud/compute_v1/services/global_organization_operations/transports/rest.py", + "google/cloud/compute_v1/services/health_checks/__init__.py", + "google/cloud/compute_v1/services/health_checks/client.py", + "google/cloud/compute_v1/services/health_checks/pagers.py", + "google/cloud/compute_v1/services/health_checks/transports/__init__.py", + "google/cloud/compute_v1/services/health_checks/transports/base.py", + "google/cloud/compute_v1/services/health_checks/transports/rest.py", + "google/cloud/compute_v1/services/images/__init__.py", + "google/cloud/compute_v1/services/images/client.py", + "google/cloud/compute_v1/services/images/pagers.py", + "google/cloud/compute_v1/services/images/transports/__init__.py", + "google/cloud/compute_v1/services/images/transports/base.py", + "google/cloud/compute_v1/services/images/transports/rest.py", + "google/cloud/compute_v1/services/instance_group_managers/__init__.py", + "google/cloud/compute_v1/services/instance_group_managers/client.py", + "google/cloud/compute_v1/services/instance_group_managers/pagers.py", + "google/cloud/compute_v1/services/instance_group_managers/transports/__init__.py", + "google/cloud/compute_v1/services/instance_group_managers/transports/base.py", + "google/cloud/compute_v1/services/instance_group_managers/transports/rest.py", + "google/cloud/compute_v1/services/instance_groups/__init__.py", + "google/cloud/compute_v1/services/instance_groups/client.py", + "google/cloud/compute_v1/services/instance_groups/pagers.py", + "google/cloud/compute_v1/services/instance_groups/transports/__init__.py", + "google/cloud/compute_v1/services/instance_groups/transports/base.py", + "google/cloud/compute_v1/services/instance_groups/transports/rest.py", + "google/cloud/compute_v1/services/instance_templates/__init__.py", + "google/cloud/compute_v1/services/instance_templates/client.py", + "google/cloud/compute_v1/services/instance_templates/pagers.py", + "google/cloud/compute_v1/services/instance_templates/transports/__init__.py", + "google/cloud/compute_v1/services/instance_templates/transports/base.py", + "google/cloud/compute_v1/services/instance_templates/transports/rest.py", + "google/cloud/compute_v1/services/instances/__init__.py", + "google/cloud/compute_v1/services/instances/client.py", + "google/cloud/compute_v1/services/instances/pagers.py", + "google/cloud/compute_v1/services/instances/transports/__init__.py", + "google/cloud/compute_v1/services/instances/transports/base.py", + "google/cloud/compute_v1/services/instances/transports/rest.py", + "google/cloud/compute_v1/services/interconnect_attachments/__init__.py", + "google/cloud/compute_v1/services/interconnect_attachments/client.py", + "google/cloud/compute_v1/services/interconnect_attachments/pagers.py", + "google/cloud/compute_v1/services/interconnect_attachments/transports/__init__.py", + "google/cloud/compute_v1/services/interconnect_attachments/transports/base.py", + "google/cloud/compute_v1/services/interconnect_attachments/transports/rest.py", + "google/cloud/compute_v1/services/interconnect_locations/__init__.py", + "google/cloud/compute_v1/services/interconnect_locations/client.py", + "google/cloud/compute_v1/services/interconnect_locations/pagers.py", + "google/cloud/compute_v1/services/interconnect_locations/transports/__init__.py", + "google/cloud/compute_v1/services/interconnect_locations/transports/base.py", + "google/cloud/compute_v1/services/interconnect_locations/transports/rest.py", + "google/cloud/compute_v1/services/interconnects/__init__.py", + "google/cloud/compute_v1/services/interconnects/client.py", + "google/cloud/compute_v1/services/interconnects/pagers.py", + "google/cloud/compute_v1/services/interconnects/transports/__init__.py", + "google/cloud/compute_v1/services/interconnects/transports/base.py", + "google/cloud/compute_v1/services/interconnects/transports/rest.py", + "google/cloud/compute_v1/services/license_codes/__init__.py", + "google/cloud/compute_v1/services/license_codes/client.py", + "google/cloud/compute_v1/services/license_codes/transports/__init__.py", + "google/cloud/compute_v1/services/license_codes/transports/base.py", + "google/cloud/compute_v1/services/license_codes/transports/rest.py", + "google/cloud/compute_v1/services/licenses/__init__.py", + "google/cloud/compute_v1/services/licenses/client.py", + "google/cloud/compute_v1/services/licenses/pagers.py", + "google/cloud/compute_v1/services/licenses/transports/__init__.py", + "google/cloud/compute_v1/services/licenses/transports/base.py", + "google/cloud/compute_v1/services/licenses/transports/rest.py", + "google/cloud/compute_v1/services/machine_types/__init__.py", + "google/cloud/compute_v1/services/machine_types/client.py", + "google/cloud/compute_v1/services/machine_types/pagers.py", + "google/cloud/compute_v1/services/machine_types/transports/__init__.py", + "google/cloud/compute_v1/services/machine_types/transports/base.py", + "google/cloud/compute_v1/services/machine_types/transports/rest.py", + "google/cloud/compute_v1/services/network_endpoint_groups/__init__.py", + "google/cloud/compute_v1/services/network_endpoint_groups/client.py", + "google/cloud/compute_v1/services/network_endpoint_groups/pagers.py", + "google/cloud/compute_v1/services/network_endpoint_groups/transports/__init__.py", + "google/cloud/compute_v1/services/network_endpoint_groups/transports/base.py", + "google/cloud/compute_v1/services/network_endpoint_groups/transports/rest.py", + "google/cloud/compute_v1/services/networks/__init__.py", + "google/cloud/compute_v1/services/networks/client.py", + "google/cloud/compute_v1/services/networks/pagers.py", + "google/cloud/compute_v1/services/networks/transports/__init__.py", + "google/cloud/compute_v1/services/networks/transports/base.py", + "google/cloud/compute_v1/services/networks/transports/rest.py", + "google/cloud/compute_v1/services/node_groups/__init__.py", + "google/cloud/compute_v1/services/node_groups/client.py", + "google/cloud/compute_v1/services/node_groups/pagers.py", + "google/cloud/compute_v1/services/node_groups/transports/__init__.py", + "google/cloud/compute_v1/services/node_groups/transports/base.py", + "google/cloud/compute_v1/services/node_groups/transports/rest.py", + "google/cloud/compute_v1/services/node_templates/__init__.py", + "google/cloud/compute_v1/services/node_templates/client.py", + "google/cloud/compute_v1/services/node_templates/pagers.py", + "google/cloud/compute_v1/services/node_templates/transports/__init__.py", + "google/cloud/compute_v1/services/node_templates/transports/base.py", + "google/cloud/compute_v1/services/node_templates/transports/rest.py", + "google/cloud/compute_v1/services/node_types/__init__.py", + "google/cloud/compute_v1/services/node_types/client.py", + "google/cloud/compute_v1/services/node_types/pagers.py", + "google/cloud/compute_v1/services/node_types/transports/__init__.py", + "google/cloud/compute_v1/services/node_types/transports/base.py", + "google/cloud/compute_v1/services/node_types/transports/rest.py", + "google/cloud/compute_v1/services/packet_mirrorings/__init__.py", + "google/cloud/compute_v1/services/packet_mirrorings/client.py", + "google/cloud/compute_v1/services/packet_mirrorings/pagers.py", + "google/cloud/compute_v1/services/packet_mirrorings/transports/__init__.py", + "google/cloud/compute_v1/services/packet_mirrorings/transports/base.py", + "google/cloud/compute_v1/services/packet_mirrorings/transports/rest.py", + "google/cloud/compute_v1/services/projects/__init__.py", + "google/cloud/compute_v1/services/projects/client.py", + "google/cloud/compute_v1/services/projects/pagers.py", + "google/cloud/compute_v1/services/projects/transports/__init__.py", + "google/cloud/compute_v1/services/projects/transports/base.py", + "google/cloud/compute_v1/services/projects/transports/rest.py", + "google/cloud/compute_v1/services/region_autoscalers/__init__.py", + "google/cloud/compute_v1/services/region_autoscalers/client.py", + "google/cloud/compute_v1/services/region_autoscalers/pagers.py", + "google/cloud/compute_v1/services/region_autoscalers/transports/__init__.py", + "google/cloud/compute_v1/services/region_autoscalers/transports/base.py", + "google/cloud/compute_v1/services/region_autoscalers/transports/rest.py", + "google/cloud/compute_v1/services/region_backend_services/__init__.py", + "google/cloud/compute_v1/services/region_backend_services/client.py", + "google/cloud/compute_v1/services/region_backend_services/pagers.py", + "google/cloud/compute_v1/services/region_backend_services/transports/__init__.py", + "google/cloud/compute_v1/services/region_backend_services/transports/base.py", + "google/cloud/compute_v1/services/region_backend_services/transports/rest.py", + "google/cloud/compute_v1/services/region_commitments/__init__.py", + "google/cloud/compute_v1/services/region_commitments/client.py", + "google/cloud/compute_v1/services/region_commitments/pagers.py", + "google/cloud/compute_v1/services/region_commitments/transports/__init__.py", + "google/cloud/compute_v1/services/region_commitments/transports/base.py", + "google/cloud/compute_v1/services/region_commitments/transports/rest.py", + "google/cloud/compute_v1/services/region_disk_types/__init__.py", + "google/cloud/compute_v1/services/region_disk_types/client.py", + "google/cloud/compute_v1/services/region_disk_types/pagers.py", + "google/cloud/compute_v1/services/region_disk_types/transports/__init__.py", + "google/cloud/compute_v1/services/region_disk_types/transports/base.py", + "google/cloud/compute_v1/services/region_disk_types/transports/rest.py", + "google/cloud/compute_v1/services/region_disks/__init__.py", + "google/cloud/compute_v1/services/region_disks/client.py", + "google/cloud/compute_v1/services/region_disks/pagers.py", + "google/cloud/compute_v1/services/region_disks/transports/__init__.py", + "google/cloud/compute_v1/services/region_disks/transports/base.py", + "google/cloud/compute_v1/services/region_disks/transports/rest.py", + "google/cloud/compute_v1/services/region_health_check_services/__init__.py", + "google/cloud/compute_v1/services/region_health_check_services/client.py", + "google/cloud/compute_v1/services/region_health_check_services/pagers.py", + "google/cloud/compute_v1/services/region_health_check_services/transports/__init__.py", + "google/cloud/compute_v1/services/region_health_check_services/transports/base.py", + "google/cloud/compute_v1/services/region_health_check_services/transports/rest.py", + "google/cloud/compute_v1/services/region_health_checks/__init__.py", + "google/cloud/compute_v1/services/region_health_checks/client.py", + "google/cloud/compute_v1/services/region_health_checks/pagers.py", + "google/cloud/compute_v1/services/region_health_checks/transports/__init__.py", + "google/cloud/compute_v1/services/region_health_checks/transports/base.py", + "google/cloud/compute_v1/services/region_health_checks/transports/rest.py", + "google/cloud/compute_v1/services/region_instance_group_managers/__init__.py", + "google/cloud/compute_v1/services/region_instance_group_managers/client.py", + "google/cloud/compute_v1/services/region_instance_group_managers/pagers.py", + "google/cloud/compute_v1/services/region_instance_group_managers/transports/__init__.py", + "google/cloud/compute_v1/services/region_instance_group_managers/transports/base.py", + "google/cloud/compute_v1/services/region_instance_group_managers/transports/rest.py", + "google/cloud/compute_v1/services/region_instance_groups/__init__.py", + "google/cloud/compute_v1/services/region_instance_groups/client.py", + "google/cloud/compute_v1/services/region_instance_groups/pagers.py", + "google/cloud/compute_v1/services/region_instance_groups/transports/__init__.py", + "google/cloud/compute_v1/services/region_instance_groups/transports/base.py", + "google/cloud/compute_v1/services/region_instance_groups/transports/rest.py", + "google/cloud/compute_v1/services/region_network_endpoint_groups/__init__.py", + "google/cloud/compute_v1/services/region_network_endpoint_groups/client.py", + "google/cloud/compute_v1/services/region_network_endpoint_groups/pagers.py", + "google/cloud/compute_v1/services/region_network_endpoint_groups/transports/__init__.py", + "google/cloud/compute_v1/services/region_network_endpoint_groups/transports/base.py", + "google/cloud/compute_v1/services/region_network_endpoint_groups/transports/rest.py", + "google/cloud/compute_v1/services/region_notification_endpoints/__init__.py", + "google/cloud/compute_v1/services/region_notification_endpoints/client.py", + "google/cloud/compute_v1/services/region_notification_endpoints/pagers.py", + "google/cloud/compute_v1/services/region_notification_endpoints/transports/__init__.py", + "google/cloud/compute_v1/services/region_notification_endpoints/transports/base.py", + "google/cloud/compute_v1/services/region_notification_endpoints/transports/rest.py", + "google/cloud/compute_v1/services/region_operations/__init__.py", + "google/cloud/compute_v1/services/region_operations/client.py", + "google/cloud/compute_v1/services/region_operations/pagers.py", + "google/cloud/compute_v1/services/region_operations/transports/__init__.py", + "google/cloud/compute_v1/services/region_operations/transports/base.py", + "google/cloud/compute_v1/services/region_operations/transports/rest.py", + "google/cloud/compute_v1/services/region_ssl_certificates/__init__.py", + "google/cloud/compute_v1/services/region_ssl_certificates/client.py", + "google/cloud/compute_v1/services/region_ssl_certificates/pagers.py", + "google/cloud/compute_v1/services/region_ssl_certificates/transports/__init__.py", + "google/cloud/compute_v1/services/region_ssl_certificates/transports/base.py", + "google/cloud/compute_v1/services/region_ssl_certificates/transports/rest.py", + "google/cloud/compute_v1/services/region_target_http_proxies/__init__.py", + "google/cloud/compute_v1/services/region_target_http_proxies/client.py", + "google/cloud/compute_v1/services/region_target_http_proxies/pagers.py", + "google/cloud/compute_v1/services/region_target_http_proxies/transports/__init__.py", + "google/cloud/compute_v1/services/region_target_http_proxies/transports/base.py", + "google/cloud/compute_v1/services/region_target_http_proxies/transports/rest.py", + "google/cloud/compute_v1/services/region_target_https_proxies/__init__.py", + "google/cloud/compute_v1/services/region_target_https_proxies/client.py", + "google/cloud/compute_v1/services/region_target_https_proxies/pagers.py", + "google/cloud/compute_v1/services/region_target_https_proxies/transports/__init__.py", + "google/cloud/compute_v1/services/region_target_https_proxies/transports/base.py", + "google/cloud/compute_v1/services/region_target_https_proxies/transports/rest.py", + "google/cloud/compute_v1/services/region_url_maps/__init__.py", + "google/cloud/compute_v1/services/region_url_maps/client.py", + "google/cloud/compute_v1/services/region_url_maps/pagers.py", + "google/cloud/compute_v1/services/region_url_maps/transports/__init__.py", + "google/cloud/compute_v1/services/region_url_maps/transports/base.py", + "google/cloud/compute_v1/services/region_url_maps/transports/rest.py", + "google/cloud/compute_v1/services/regions/__init__.py", + "google/cloud/compute_v1/services/regions/client.py", + "google/cloud/compute_v1/services/regions/pagers.py", + "google/cloud/compute_v1/services/regions/transports/__init__.py", + "google/cloud/compute_v1/services/regions/transports/base.py", + "google/cloud/compute_v1/services/regions/transports/rest.py", + "google/cloud/compute_v1/services/reservations/__init__.py", + "google/cloud/compute_v1/services/reservations/client.py", + "google/cloud/compute_v1/services/reservations/pagers.py", + "google/cloud/compute_v1/services/reservations/transports/__init__.py", + "google/cloud/compute_v1/services/reservations/transports/base.py", + "google/cloud/compute_v1/services/reservations/transports/rest.py", + "google/cloud/compute_v1/services/resource_policies/__init__.py", + "google/cloud/compute_v1/services/resource_policies/client.py", + "google/cloud/compute_v1/services/resource_policies/pagers.py", + "google/cloud/compute_v1/services/resource_policies/transports/__init__.py", + "google/cloud/compute_v1/services/resource_policies/transports/base.py", + "google/cloud/compute_v1/services/resource_policies/transports/rest.py", + "google/cloud/compute_v1/services/routers/__init__.py", + "google/cloud/compute_v1/services/routers/client.py", + "google/cloud/compute_v1/services/routers/pagers.py", + "google/cloud/compute_v1/services/routers/transports/__init__.py", + "google/cloud/compute_v1/services/routers/transports/base.py", + "google/cloud/compute_v1/services/routers/transports/rest.py", + "google/cloud/compute_v1/services/routes/__init__.py", + "google/cloud/compute_v1/services/routes/client.py", + "google/cloud/compute_v1/services/routes/pagers.py", + "google/cloud/compute_v1/services/routes/transports/__init__.py", + "google/cloud/compute_v1/services/routes/transports/base.py", + "google/cloud/compute_v1/services/routes/transports/rest.py", + "google/cloud/compute_v1/services/security_policies/__init__.py", + "google/cloud/compute_v1/services/security_policies/client.py", + "google/cloud/compute_v1/services/security_policies/pagers.py", + "google/cloud/compute_v1/services/security_policies/transports/__init__.py", + "google/cloud/compute_v1/services/security_policies/transports/base.py", + "google/cloud/compute_v1/services/security_policies/transports/rest.py", + "google/cloud/compute_v1/services/snapshots/__init__.py", + "google/cloud/compute_v1/services/snapshots/client.py", + "google/cloud/compute_v1/services/snapshots/pagers.py", + "google/cloud/compute_v1/services/snapshots/transports/__init__.py", + "google/cloud/compute_v1/services/snapshots/transports/base.py", + "google/cloud/compute_v1/services/snapshots/transports/rest.py", + "google/cloud/compute_v1/services/ssl_certificates/__init__.py", + "google/cloud/compute_v1/services/ssl_certificates/client.py", + "google/cloud/compute_v1/services/ssl_certificates/pagers.py", + "google/cloud/compute_v1/services/ssl_certificates/transports/__init__.py", + "google/cloud/compute_v1/services/ssl_certificates/transports/base.py", + "google/cloud/compute_v1/services/ssl_certificates/transports/rest.py", + "google/cloud/compute_v1/services/ssl_policies/__init__.py", + "google/cloud/compute_v1/services/ssl_policies/client.py", + "google/cloud/compute_v1/services/ssl_policies/pagers.py", + "google/cloud/compute_v1/services/ssl_policies/transports/__init__.py", + "google/cloud/compute_v1/services/ssl_policies/transports/base.py", + "google/cloud/compute_v1/services/ssl_policies/transports/rest.py", + "google/cloud/compute_v1/services/subnetworks/__init__.py", + "google/cloud/compute_v1/services/subnetworks/client.py", + "google/cloud/compute_v1/services/subnetworks/pagers.py", + "google/cloud/compute_v1/services/subnetworks/transports/__init__.py", + "google/cloud/compute_v1/services/subnetworks/transports/base.py", + "google/cloud/compute_v1/services/subnetworks/transports/rest.py", + "google/cloud/compute_v1/services/target_grpc_proxies/__init__.py", + "google/cloud/compute_v1/services/target_grpc_proxies/client.py", + "google/cloud/compute_v1/services/target_grpc_proxies/pagers.py", + "google/cloud/compute_v1/services/target_grpc_proxies/transports/__init__.py", + "google/cloud/compute_v1/services/target_grpc_proxies/transports/base.py", + "google/cloud/compute_v1/services/target_grpc_proxies/transports/rest.py", + "google/cloud/compute_v1/services/target_http_proxies/__init__.py", + "google/cloud/compute_v1/services/target_http_proxies/client.py", + "google/cloud/compute_v1/services/target_http_proxies/pagers.py", + "google/cloud/compute_v1/services/target_http_proxies/transports/__init__.py", + "google/cloud/compute_v1/services/target_http_proxies/transports/base.py", + "google/cloud/compute_v1/services/target_http_proxies/transports/rest.py", + "google/cloud/compute_v1/services/target_https_proxies/__init__.py", + "google/cloud/compute_v1/services/target_https_proxies/client.py", + "google/cloud/compute_v1/services/target_https_proxies/pagers.py", + "google/cloud/compute_v1/services/target_https_proxies/transports/__init__.py", + "google/cloud/compute_v1/services/target_https_proxies/transports/base.py", + "google/cloud/compute_v1/services/target_https_proxies/transports/rest.py", + "google/cloud/compute_v1/services/target_instances/__init__.py", + "google/cloud/compute_v1/services/target_instances/client.py", + "google/cloud/compute_v1/services/target_instances/pagers.py", + "google/cloud/compute_v1/services/target_instances/transports/__init__.py", + "google/cloud/compute_v1/services/target_instances/transports/base.py", + "google/cloud/compute_v1/services/target_instances/transports/rest.py", + "google/cloud/compute_v1/services/target_pools/__init__.py", + "google/cloud/compute_v1/services/target_pools/client.py", + "google/cloud/compute_v1/services/target_pools/pagers.py", + "google/cloud/compute_v1/services/target_pools/transports/__init__.py", + "google/cloud/compute_v1/services/target_pools/transports/base.py", + "google/cloud/compute_v1/services/target_pools/transports/rest.py", + "google/cloud/compute_v1/services/target_ssl_proxies/__init__.py", + "google/cloud/compute_v1/services/target_ssl_proxies/client.py", + "google/cloud/compute_v1/services/target_ssl_proxies/pagers.py", + "google/cloud/compute_v1/services/target_ssl_proxies/transports/__init__.py", + "google/cloud/compute_v1/services/target_ssl_proxies/transports/base.py", + "google/cloud/compute_v1/services/target_ssl_proxies/transports/rest.py", + "google/cloud/compute_v1/services/target_tcp_proxies/__init__.py", + "google/cloud/compute_v1/services/target_tcp_proxies/client.py", + "google/cloud/compute_v1/services/target_tcp_proxies/pagers.py", + "google/cloud/compute_v1/services/target_tcp_proxies/transports/__init__.py", + "google/cloud/compute_v1/services/target_tcp_proxies/transports/base.py", + "google/cloud/compute_v1/services/target_tcp_proxies/transports/rest.py", + "google/cloud/compute_v1/services/target_vpn_gateways/__init__.py", + "google/cloud/compute_v1/services/target_vpn_gateways/client.py", + "google/cloud/compute_v1/services/target_vpn_gateways/pagers.py", + "google/cloud/compute_v1/services/target_vpn_gateways/transports/__init__.py", + "google/cloud/compute_v1/services/target_vpn_gateways/transports/base.py", + "google/cloud/compute_v1/services/target_vpn_gateways/transports/rest.py", + "google/cloud/compute_v1/services/url_maps/__init__.py", + "google/cloud/compute_v1/services/url_maps/client.py", + "google/cloud/compute_v1/services/url_maps/pagers.py", + "google/cloud/compute_v1/services/url_maps/transports/__init__.py", + "google/cloud/compute_v1/services/url_maps/transports/base.py", + "google/cloud/compute_v1/services/url_maps/transports/rest.py", + "google/cloud/compute_v1/services/vpn_gateways/__init__.py", + "google/cloud/compute_v1/services/vpn_gateways/client.py", + "google/cloud/compute_v1/services/vpn_gateways/pagers.py", + "google/cloud/compute_v1/services/vpn_gateways/transports/__init__.py", + "google/cloud/compute_v1/services/vpn_gateways/transports/base.py", + "google/cloud/compute_v1/services/vpn_gateways/transports/rest.py", + "google/cloud/compute_v1/services/vpn_tunnels/__init__.py", + "google/cloud/compute_v1/services/vpn_tunnels/client.py", + "google/cloud/compute_v1/services/vpn_tunnels/pagers.py", + "google/cloud/compute_v1/services/vpn_tunnels/transports/__init__.py", + "google/cloud/compute_v1/services/vpn_tunnels/transports/base.py", + "google/cloud/compute_v1/services/vpn_tunnels/transports/rest.py", + "google/cloud/compute_v1/services/zone_operations/__init__.py", + "google/cloud/compute_v1/services/zone_operations/client.py", + "google/cloud/compute_v1/services/zone_operations/pagers.py", + "google/cloud/compute_v1/services/zone_operations/transports/__init__.py", + "google/cloud/compute_v1/services/zone_operations/transports/base.py", + "google/cloud/compute_v1/services/zone_operations/transports/rest.py", + "google/cloud/compute_v1/services/zones/__init__.py", + "google/cloud/compute_v1/services/zones/client.py", + "google/cloud/compute_v1/services/zones/pagers.py", + "google/cloud/compute_v1/services/zones/transports/__init__.py", + "google/cloud/compute_v1/services/zones/transports/base.py", + "google/cloud/compute_v1/services/zones/transports/rest.py", + "google/cloud/compute_v1/types/__init__.py", + "google/cloud/compute_v1/types/compute.py", + "mypy.ini", + "noxfile.py", + "renovate.json", + "scripts/decrypt-secrets.sh", + "scripts/fixup_compute_v1_keywords.py", + "scripts/readme-gen/readme_gen.py", + "scripts/readme-gen/templates/README.tmpl.rst", + "scripts/readme-gen/templates/auth.tmpl.rst", + "scripts/readme-gen/templates/auth_api_key.tmpl.rst", + "scripts/readme-gen/templates/install_deps.tmpl.rst", + "scripts/readme-gen/templates/install_portaudio.tmpl.rst", + "setup.cfg", + "testing/.gitignore", + "tests/unit/gapic/compute_v1/__init__.py", + "tests/unit/gapic/compute_v1/test_accelerator_types.py", + "tests/unit/gapic/compute_v1/test_addresses.py", + "tests/unit/gapic/compute_v1/test_autoscalers.py", + "tests/unit/gapic/compute_v1/test_backend_buckets.py", + "tests/unit/gapic/compute_v1/test_backend_services.py", + "tests/unit/gapic/compute_v1/test_disk_types.py", + "tests/unit/gapic/compute_v1/test_disks.py", + "tests/unit/gapic/compute_v1/test_external_vpn_gateways.py", + "tests/unit/gapic/compute_v1/test_firewalls.py", + "tests/unit/gapic/compute_v1/test_forwarding_rules.py", + "tests/unit/gapic/compute_v1/test_global_addresses.py", + "tests/unit/gapic/compute_v1/test_global_forwarding_rules.py", + "tests/unit/gapic/compute_v1/test_global_network_endpoint_groups.py", + "tests/unit/gapic/compute_v1/test_global_operations.py", + "tests/unit/gapic/compute_v1/test_global_organization_operations.py", + "tests/unit/gapic/compute_v1/test_health_checks.py", + "tests/unit/gapic/compute_v1/test_images.py", + "tests/unit/gapic/compute_v1/test_instance_group_managers.py", + "tests/unit/gapic/compute_v1/test_instance_groups.py", + "tests/unit/gapic/compute_v1/test_instance_templates.py", + "tests/unit/gapic/compute_v1/test_instances.py", + "tests/unit/gapic/compute_v1/test_interconnect_attachments.py", + "tests/unit/gapic/compute_v1/test_interconnect_locations.py", + "tests/unit/gapic/compute_v1/test_interconnects.py", + "tests/unit/gapic/compute_v1/test_license_codes.py", + "tests/unit/gapic/compute_v1/test_licenses.py", + "tests/unit/gapic/compute_v1/test_machine_types.py", + "tests/unit/gapic/compute_v1/test_network_endpoint_groups.py", + "tests/unit/gapic/compute_v1/test_networks.py", + "tests/unit/gapic/compute_v1/test_node_groups.py", + "tests/unit/gapic/compute_v1/test_node_templates.py", + "tests/unit/gapic/compute_v1/test_node_types.py", + "tests/unit/gapic/compute_v1/test_packet_mirrorings.py", + "tests/unit/gapic/compute_v1/test_projects.py", + "tests/unit/gapic/compute_v1/test_region_autoscalers.py", + "tests/unit/gapic/compute_v1/test_region_backend_services.py", + "tests/unit/gapic/compute_v1/test_region_commitments.py", + "tests/unit/gapic/compute_v1/test_region_disk_types.py", + "tests/unit/gapic/compute_v1/test_region_disks.py", + "tests/unit/gapic/compute_v1/test_region_health_check_services.py", + "tests/unit/gapic/compute_v1/test_region_health_checks.py", + "tests/unit/gapic/compute_v1/test_region_instance_group_managers.py", + "tests/unit/gapic/compute_v1/test_region_instance_groups.py", + "tests/unit/gapic/compute_v1/test_region_network_endpoint_groups.py", + "tests/unit/gapic/compute_v1/test_region_notification_endpoints.py", + "tests/unit/gapic/compute_v1/test_region_operations.py", + "tests/unit/gapic/compute_v1/test_region_ssl_certificates.py", + "tests/unit/gapic/compute_v1/test_region_target_http_proxies.py", + "tests/unit/gapic/compute_v1/test_region_target_https_proxies.py", + "tests/unit/gapic/compute_v1/test_region_url_maps.py", + "tests/unit/gapic/compute_v1/test_regions.py", + "tests/unit/gapic/compute_v1/test_reservations.py", + "tests/unit/gapic/compute_v1/test_resource_policies.py", + "tests/unit/gapic/compute_v1/test_routers.py", + "tests/unit/gapic/compute_v1/test_routes.py", + "tests/unit/gapic/compute_v1/test_security_policies.py", + "tests/unit/gapic/compute_v1/test_snapshots.py", + "tests/unit/gapic/compute_v1/test_ssl_certificates.py", + "tests/unit/gapic/compute_v1/test_ssl_policies.py", + "tests/unit/gapic/compute_v1/test_subnetworks.py", + "tests/unit/gapic/compute_v1/test_target_grpc_proxies.py", + "tests/unit/gapic/compute_v1/test_target_http_proxies.py", + "tests/unit/gapic/compute_v1/test_target_https_proxies.py", + "tests/unit/gapic/compute_v1/test_target_instances.py", + "tests/unit/gapic/compute_v1/test_target_pools.py", + "tests/unit/gapic/compute_v1/test_target_ssl_proxies.py", + "tests/unit/gapic/compute_v1/test_target_tcp_proxies.py", + "tests/unit/gapic/compute_v1/test_target_vpn_gateways.py", + "tests/unit/gapic/compute_v1/test_url_maps.py", + "tests/unit/gapic/compute_v1/test_vpn_gateways.py", + "tests/unit/gapic/compute_v1/test_vpn_tunnels.py", + "tests/unit/gapic/compute_v1/test_zone_operations.py", + "tests/unit/gapic/compute_v1/test_zones.py" ] } \ No newline at end of file diff --git a/tests/system.py b/tests/system.py new file mode 100644 index 000000000..2a4191546 --- /dev/null +++ b/tests/system.py @@ -0,0 +1,33 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import pytest + +from google.cloud import compute + + +@pytest.fixture(scope="session") +def project_id(): + return os.environ["PROJECT_ID"] + + +def test_list_instances_not_throw_mtls(project_id): + client = compute.InstancesClient() + + # For regular system testing, the following call should never throw. + # For mTLS testing, the call throws if mTLS is not properly configured. + client.list(project=project_id, zone="us-west1-a") diff --git a/tests/unit/gapic/compute_v1/__init__.py b/tests/unit/gapic/compute_v1/__init__.py index 8b1378917..42ffdf2bc 100644 --- a/tests/unit/gapic/compute_v1/__init__.py +++ b/tests/unit/gapic/compute_v1/__init__.py @@ -1 +1,16 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/tests/unit/gapic/compute_v1/test_accelerator_types.py b/tests/unit/gapic/compute_v1/test_accelerator_types.py index 8edf752e8..fc731eff7 100644 --- a/tests/unit/gapic/compute_v1/test_accelerator_types.py +++ b/tests/unit/gapic/compute_v1/test_accelerator_types.py @@ -35,6 +35,7 @@ from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.compute_v1.services.accelerator_types import AcceleratorTypesClient +from google.cloud.compute_v1.services.accelerator_types import pagers from google.cloud.compute_v1.services.accelerator_types import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -158,7 +159,7 @@ def test_accelerator_types_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -174,7 +175,7 @@ def test_accelerator_types_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -190,7 +191,7 @@ def test_accelerator_types_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -218,7 +219,7 @@ def test_accelerator_types_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -262,29 +263,25 @@ def test_accelerator_types_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -293,66 +290,53 @@ def test_accelerator_types_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -371,7 +355,7 @@ def test_accelerator_types_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -394,7 +378,7 @@ def test_accelerator_types_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -434,16 +418,15 @@ def test_aggregated_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.AcceleratorTypeAggregatedList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.aggregated_list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.AcceleratorTypeAggregatedList) + assert isinstance(response, pagers.AggregatedListPager) assert response.id == "id_value" assert response.items == { "key_value": compute.AcceleratorTypesScopedList( @@ -474,6 +457,7 @@ def test_aggregated_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.AcceleratorTypeAggregatedList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -485,7 +469,7 @@ def test_aggregated_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -501,6 +485,75 @@ def test_aggregated_list_rest_flattened_error(): ) +def test_aggregated_list_pager(): + client = AcceleratorTypesClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.AcceleratorTypeAggregatedList( + items={ + "a": compute.AcceleratorTypesScopedList(), + "b": compute.AcceleratorTypesScopedList(), + "c": compute.AcceleratorTypesScopedList(), + }, + next_page_token="abc", + ), + compute.AcceleratorTypeAggregatedList(items={}, next_page_token="def",), + compute.AcceleratorTypeAggregatedList( + items={"g": compute.AcceleratorTypesScopedList(),}, + next_page_token="ghi", + ), + compute.AcceleratorTypeAggregatedList( + items={ + "h": compute.AcceleratorTypesScopedList(), + "i": compute.AcceleratorTypesScopedList(), + }, + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple( + compute.AcceleratorTypeAggregatedList.to_json(x) for x in response + ) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.aggregated_list(request={}) + + assert pager._metadata == metadata + + assert isinstance(pager.get("a"), compute.AcceleratorTypesScopedList) + assert pager.get("h") is None + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, tuple) for i in results) + for result in results: + assert isinstance(result, tuple) + assert tuple(type(t) for t in result) == ( + str, + compute.AcceleratorTypesScopedList, + ) + + assert pager.get("a") is None + assert isinstance(pager.get("h"), compute.AcceleratorTypesScopedList) + + pages = list(client.aggregated_list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_get_rest( transport: str = "rest", request_type=compute.GetAcceleratorTypeRequest ): @@ -529,6 +582,7 @@ def test_get_rest( # Wrap the value into a proper Response obj json_return_value = compute.AcceleratorType.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -563,6 +617,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.AcceleratorType.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -578,7 +633,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -628,16 +683,15 @@ def test_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.AcceleratorTypeList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.AcceleratorTypeList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [ compute.AcceleratorType(creation_timestamp="creation_timestamp_value") @@ -663,6 +717,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.AcceleratorTypeList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -676,7 +731,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -696,6 +751,57 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = AcceleratorTypesClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.AcceleratorTypeList( + items=[ + compute.AcceleratorType(), + compute.AcceleratorType(), + compute.AcceleratorType(), + ], + next_page_token="abc", + ), + compute.AcceleratorTypeList(items=[], next_page_token="def",), + compute.AcceleratorTypeList( + items=[compute.AcceleratorType(),], next_page_token="ghi", + ), + compute.AcceleratorTypeList( + items=[compute.AcceleratorType(), compute.AcceleratorType(),], + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.AcceleratorTypeList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.AcceleratorType) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_credentials_transport_error(): # It is an error to provide credentials and a transport instance. transport = transports.AcceleratorTypesRestTransport( @@ -824,6 +930,17 @@ def test_accelerator_types_auth_adc(): ) +def test_accelerator_types_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.AcceleratorTypesRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_accelerator_types_host_no_port(): client = AcceleratorTypesClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_addresses.py b/tests/unit/gapic/compute_v1/test_addresses.py index bec08f55f..3e2cf6611 100644 --- a/tests/unit/gapic/compute_v1/test_addresses.py +++ b/tests/unit/gapic/compute_v1/test_addresses.py @@ -35,6 +35,7 @@ from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.compute_v1.services.addresses import AddressesClient +from google.cloud.compute_v1.services.addresses import pagers from google.cloud.compute_v1.services.addresses import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -148,7 +149,7 @@ def test_addresses_client_client_options(client_class, transport_class, transpor credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -164,7 +165,7 @@ def test_addresses_client_client_options(client_class, transport_class, transpor credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -180,7 +181,7 @@ def test_addresses_client_client_options(client_class, transport_class, transpor credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -208,7 +209,7 @@ def test_addresses_client_client_options(client_class, transport_class, transpor credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -240,29 +241,25 @@ def test_addresses_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -271,66 +268,53 @@ def test_addresses_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -349,7 +333,7 @@ def test_addresses_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -372,7 +356,7 @@ def test_addresses_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -408,16 +392,15 @@ def test_aggregated_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.AddressAggregatedList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.aggregated_list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.AddressAggregatedList) + assert isinstance(response, pagers.AggregatedListPager) assert response.id == "id_value" assert response.items == { "key_value": compute.AddressesScopedList( @@ -446,6 +429,7 @@ def test_aggregated_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.AddressAggregatedList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -457,7 +441,7 @@ def test_aggregated_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -473,6 +457,69 @@ def test_aggregated_list_rest_flattened_error(): ) +def test_aggregated_list_pager(): + client = AddressesClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.AddressAggregatedList( + items={ + "a": compute.AddressesScopedList(), + "b": compute.AddressesScopedList(), + "c": compute.AddressesScopedList(), + }, + next_page_token="abc", + ), + compute.AddressAggregatedList(items={}, next_page_token="def",), + compute.AddressAggregatedList( + items={"g": compute.AddressesScopedList(),}, next_page_token="ghi", + ), + compute.AddressAggregatedList( + items={ + "h": compute.AddressesScopedList(), + "i": compute.AddressesScopedList(), + }, + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.AddressAggregatedList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.aggregated_list(request={}) + + assert pager._metadata == metadata + + assert isinstance(pager.get("a"), compute.AddressesScopedList) + assert pager.get("h") is None + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, tuple) for i in results) + for result in results: + assert isinstance(result, tuple) + assert tuple(type(t) for t in result) == (str, compute.AddressesScopedList) + + assert pager.get("a") is None + assert isinstance(pager.get("h"), compute.AddressesScopedList) + + pages = list(client.aggregated_list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_delete_rest( transport: str = "rest", request_type=compute.DeleteAddressRequest ): @@ -515,6 +562,7 @@ def test_delete_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -565,6 +613,7 @@ def test_delete_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -578,7 +627,7 @@ def test_delete_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -635,6 +684,7 @@ def test_get_rest(transport: str = "rest", request_type=compute.GetAddressReques # Wrap the value into a proper Response obj json_return_value = compute.Address.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -677,6 +727,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Address.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -690,7 +741,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -755,6 +806,7 @@ def test_insert_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -805,6 +857,7 @@ def test_insert_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -822,14 +875,16 @@ def test_insert_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "region_value" in http_call[1] + str(body) assert compute.Address.to_json( - address_resource, including_default_value_fields=False + address_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -870,16 +925,15 @@ def test_list_rest(transport: str = "rest", request_type=compute.ListAddressesRe # Wrap the value into a proper Response obj json_return_value = compute.AddressList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.AddressList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [compute.Address(address="address_value")] assert response.kind == "kind_value" @@ -903,6 +957,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.AddressList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -916,7 +971,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -936,6 +991,49 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = AddressesClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.AddressList( + items=[compute.Address(), compute.Address(), compute.Address(),], + next_page_token="abc", + ), + compute.AddressList(items=[], next_page_token="def",), + compute.AddressList(items=[compute.Address(),], next_page_token="ghi",), + compute.AddressList(items=[compute.Address(), compute.Address(),],), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.AddressList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.Address) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_credentials_transport_error(): # It is an error to provide credentials and a transport instance. transport = transports.AddressesRestTransport( @@ -1064,6 +1162,17 @@ def test_addresses_auth_adc(): ) +def test_addresses_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.AddressesRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_addresses_host_no_port(): client = AddressesClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_autoscalers.py b/tests/unit/gapic/compute_v1/test_autoscalers.py index 2ecc7a297..b636e53eb 100644 --- a/tests/unit/gapic/compute_v1/test_autoscalers.py +++ b/tests/unit/gapic/compute_v1/test_autoscalers.py @@ -35,6 +35,7 @@ from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.compute_v1.services.autoscalers import AutoscalersClient +from google.cloud.compute_v1.services.autoscalers import pagers from google.cloud.compute_v1.services.autoscalers import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -152,7 +153,7 @@ def test_autoscalers_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -168,7 +169,7 @@ def test_autoscalers_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -184,7 +185,7 @@ def test_autoscalers_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -212,7 +213,7 @@ def test_autoscalers_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -244,29 +245,25 @@ def test_autoscalers_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -275,66 +272,53 @@ def test_autoscalers_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -353,7 +337,7 @@ def test_autoscalers_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -376,7 +360,7 @@ def test_autoscalers_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -418,16 +402,15 @@ def test_aggregated_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.AutoscalerAggregatedList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.aggregated_list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.AutoscalerAggregatedList) + assert isinstance(response, pagers.AggregatedListPager) assert response.id == "id_value" assert response.items == { "key_value": compute.AutoscalersScopedList( @@ -462,6 +445,7 @@ def test_aggregated_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.AutoscalerAggregatedList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -473,7 +457,7 @@ def test_aggregated_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -489,6 +473,72 @@ def test_aggregated_list_rest_flattened_error(): ) +def test_aggregated_list_pager(): + client = AutoscalersClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.AutoscalerAggregatedList( + items={ + "a": compute.AutoscalersScopedList(), + "b": compute.AutoscalersScopedList(), + "c": compute.AutoscalersScopedList(), + }, + next_page_token="abc", + ), + compute.AutoscalerAggregatedList(items={}, next_page_token="def",), + compute.AutoscalerAggregatedList( + items={"g": compute.AutoscalersScopedList(),}, next_page_token="ghi", + ), + compute.AutoscalerAggregatedList( + items={ + "h": compute.AutoscalersScopedList(), + "i": compute.AutoscalersScopedList(), + }, + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.AutoscalerAggregatedList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.aggregated_list(request={}) + + assert pager._metadata == metadata + + assert isinstance(pager.get("a"), compute.AutoscalersScopedList) + assert pager.get("h") is None + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, tuple) for i in results) + for result in results: + assert isinstance(result, tuple) + assert tuple(type(t) for t in result) == ( + str, + compute.AutoscalersScopedList, + ) + + assert pager.get("a") is None + assert isinstance(pager.get("h"), compute.AutoscalersScopedList) + + pages = list(client.aggregated_list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_delete_rest( transport: str = "rest", request_type=compute.DeleteAutoscalerRequest ): @@ -531,6 +581,7 @@ def test_delete_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -581,6 +632,7 @@ def test_delete_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -594,7 +646,7 @@ def test_delete_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -647,6 +699,7 @@ def test_get_rest(transport: str = "rest", request_type=compute.GetAutoscalerReq # Wrap the value into a proper Response obj json_return_value = compute.Autoscaler.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -689,6 +742,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Autoscaler.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -702,7 +756,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -767,6 +821,7 @@ def test_insert_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -817,6 +872,7 @@ def test_insert_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -836,14 +892,16 @@ def test_insert_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "zone_value" in http_call[1] + str(body) assert compute.Autoscaler.to_json( - autoscaler_resource, including_default_value_fields=False + autoscaler_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -894,16 +952,15 @@ def test_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.AutoscalerList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.AutoscalerList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [ compute.Autoscaler( @@ -931,6 +988,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.AutoscalerList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -944,7 +1002,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -964,6 +1022,57 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = AutoscalersClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.AutoscalerList( + items=[ + compute.Autoscaler(), + compute.Autoscaler(), + compute.Autoscaler(), + ], + next_page_token="abc", + ), + compute.AutoscalerList(items=[], next_page_token="def",), + compute.AutoscalerList( + items=[compute.Autoscaler(),], next_page_token="ghi", + ), + compute.AutoscalerList( + items=[compute.Autoscaler(), compute.Autoscaler(),], + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.AutoscalerList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.Autoscaler) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_patch_rest( transport: str = "rest", request_type=compute.PatchAutoscalerRequest ): @@ -1006,6 +1115,7 @@ def test_patch_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1056,6 +1166,7 @@ def test_patch_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1075,14 +1186,16 @@ def test_patch_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "zone_value" in http_call[1] + str(body) assert compute.Autoscaler.to_json( - autoscaler_resource, including_default_value_fields=False + autoscaler_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1144,6 +1257,7 @@ def test_update_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1194,6 +1308,7 @@ def test_update_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1213,14 +1328,16 @@ def test_update_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "zone_value" in http_call[1] + str(body) assert compute.Autoscaler.to_json( - autoscaler_resource, including_default_value_fields=False + autoscaler_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1370,6 +1487,17 @@ def test_autoscalers_auth_adc(): ) +def test_autoscalers_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.AutoscalersRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_autoscalers_host_no_port(): client = AutoscalersClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_backend_buckets.py b/tests/unit/gapic/compute_v1/test_backend_buckets.py index 305858075..35e3019a0 100644 --- a/tests/unit/gapic/compute_v1/test_backend_buckets.py +++ b/tests/unit/gapic/compute_v1/test_backend_buckets.py @@ -35,6 +35,7 @@ from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.compute_v1.services.backend_buckets import BackendBucketsClient +from google.cloud.compute_v1.services.backend_buckets import pagers from google.cloud.compute_v1.services.backend_buckets import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -157,7 +158,7 @@ def test_backend_buckets_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -173,7 +174,7 @@ def test_backend_buckets_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -189,7 +190,7 @@ def test_backend_buckets_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -217,7 +218,7 @@ def test_backend_buckets_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -251,29 +252,25 @@ def test_backend_buckets_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -282,66 +279,53 @@ def test_backend_buckets_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -360,7 +344,7 @@ def test_backend_buckets_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -383,7 +367,7 @@ def test_backend_buckets_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -431,6 +415,7 @@ def test_add_signed_url_key_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -481,6 +466,7 @@ def test_add_signed_url_key_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -498,14 +484,16 @@ def test_add_signed_url_key_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "backend_bucket_value" in http_call[1] + str(body) assert compute.SignedUrlKey.to_json( - signed_url_key_resource, including_default_value_fields=False + signed_url_key_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -565,6 +553,7 @@ def test_delete_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -615,6 +604,7 @@ def test_delete_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -628,7 +618,7 @@ def test_delete_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -690,6 +680,7 @@ def test_delete_signed_url_key_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -740,6 +731,7 @@ def test_delete_signed_url_key_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -755,7 +747,7 @@ def test_delete_signed_url_key_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -809,6 +801,7 @@ def test_get_rest( # Wrap the value into a proper Response obj json_return_value = compute.BackendBucket.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -847,6 +840,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.BackendBucket.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -860,7 +854,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -922,6 +916,7 @@ def test_insert_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -972,6 +967,7 @@ def test_insert_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -987,12 +983,14 @@ def test_insert_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert compute.BackendBucket.to_json( - backend_bucket_resource, including_default_value_fields=False + backend_bucket_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1036,16 +1034,15 @@ def test_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.BackendBucketList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.BackendBucketList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [compute.BackendBucket(bucket_name="bucket_name_value")] assert response.kind == "kind_value" @@ -1069,6 +1066,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.BackendBucketList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1080,7 +1078,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1096,6 +1094,57 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = BackendBucketsClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.BackendBucketList( + items=[ + compute.BackendBucket(), + compute.BackendBucket(), + compute.BackendBucket(), + ], + next_page_token="abc", + ), + compute.BackendBucketList(items=[], next_page_token="def",), + compute.BackendBucketList( + items=[compute.BackendBucket(),], next_page_token="ghi", + ), + compute.BackendBucketList( + items=[compute.BackendBucket(), compute.BackendBucket(),], + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.BackendBucketList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.BackendBucket) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_patch_rest( transport: str = "rest", request_type=compute.PatchBackendBucketRequest ): @@ -1138,6 +1187,7 @@ def test_patch_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1188,6 +1238,7 @@ def test_patch_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1205,14 +1256,16 @@ def test_patch_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "backend_bucket_value" in http_call[1] + str(body) assert compute.BackendBucket.to_json( - backend_bucket_resource, including_default_value_fields=False + backend_bucket_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1274,6 +1327,7 @@ def test_update_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1324,6 +1378,7 @@ def test_update_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1341,14 +1396,16 @@ def test_update_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "backend_bucket_value" in http_call[1] + str(body) assert compute.BackendBucket.to_json( - backend_bucket_resource, including_default_value_fields=False + backend_bucket_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1499,6 +1556,17 @@ def test_backend_buckets_auth_adc(): ) +def test_backend_buckets_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.BackendBucketsRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_backend_buckets_host_no_port(): client = BackendBucketsClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_backend_services.py b/tests/unit/gapic/compute_v1/test_backend_services.py index c53fd388e..5e180766b 100644 --- a/tests/unit/gapic/compute_v1/test_backend_services.py +++ b/tests/unit/gapic/compute_v1/test_backend_services.py @@ -35,6 +35,7 @@ from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.compute_v1.services.backend_services import BackendServicesClient +from google.cloud.compute_v1.services.backend_services import pagers from google.cloud.compute_v1.services.backend_services import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -157,7 +158,7 @@ def test_backend_services_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -173,7 +174,7 @@ def test_backend_services_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -189,7 +190,7 @@ def test_backend_services_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -217,7 +218,7 @@ def test_backend_services_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -261,29 +262,25 @@ def test_backend_services_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -292,66 +289,53 @@ def test_backend_services_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -370,7 +354,7 @@ def test_backend_services_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -393,7 +377,7 @@ def test_backend_services_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -441,6 +425,7 @@ def test_add_signed_url_key_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -491,6 +476,7 @@ def test_add_signed_url_key_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -508,14 +494,16 @@ def test_add_signed_url_key_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "backend_service_value" in http_call[1] + str(body) assert compute.SignedUrlKey.to_json( - signed_url_key_resource, including_default_value_fields=False + signed_url_key_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -565,16 +553,15 @@ def test_aggregated_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.BackendServiceAggregatedList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.aggregated_list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.BackendServiceAggregatedList) + assert isinstance(response, pagers.AggregatedListPager) assert response.id == "id_value" assert response.items == { "key_value": compute.BackendServicesScopedList( @@ -603,6 +590,7 @@ def test_aggregated_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.BackendServiceAggregatedList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -614,7 +602,7 @@ def test_aggregated_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -630,6 +618,75 @@ def test_aggregated_list_rest_flattened_error(): ) +def test_aggregated_list_pager(): + client = BackendServicesClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.BackendServiceAggregatedList( + items={ + "a": compute.BackendServicesScopedList(), + "b": compute.BackendServicesScopedList(), + "c": compute.BackendServicesScopedList(), + }, + next_page_token="abc", + ), + compute.BackendServiceAggregatedList(items={}, next_page_token="def",), + compute.BackendServiceAggregatedList( + items={"g": compute.BackendServicesScopedList(),}, + next_page_token="ghi", + ), + compute.BackendServiceAggregatedList( + items={ + "h": compute.BackendServicesScopedList(), + "i": compute.BackendServicesScopedList(), + }, + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple( + compute.BackendServiceAggregatedList.to_json(x) for x in response + ) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.aggregated_list(request={}) + + assert pager._metadata == metadata + + assert isinstance(pager.get("a"), compute.BackendServicesScopedList) + assert pager.get("h") is None + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, tuple) for i in results) + for result in results: + assert isinstance(result, tuple) + assert tuple(type(t) for t in result) == ( + str, + compute.BackendServicesScopedList, + ) + + assert pager.get("a") is None + assert isinstance(pager.get("h"), compute.BackendServicesScopedList) + + pages = list(client.aggregated_list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_delete_rest( transport: str = "rest", request_type=compute.DeleteBackendServiceRequest ): @@ -672,6 +729,7 @@ def test_delete_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -722,6 +780,7 @@ def test_delete_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -735,7 +794,7 @@ def test_delete_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -798,6 +857,7 @@ def test_delete_signed_url_key_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -848,6 +908,7 @@ def test_delete_signed_url_key_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -863,7 +924,7 @@ def test_delete_signed_url_key_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -951,6 +1012,7 @@ def test_get_rest( # Wrap the value into a proper Response obj json_return_value = compute.BackendService.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1031,6 +1093,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.BackendService.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1044,7 +1107,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1088,6 +1151,7 @@ def test_get_health_rest( # Wrap the value into a proper Response obj json_return_value = compute.BackendServiceGroupHealth.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1118,6 +1182,7 @@ def test_get_health_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.BackendServiceGroupHealth.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1137,14 +1202,16 @@ def test_get_health_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "backend_service_value" in http_call[1] + str(body) assert compute.ResourceGroupReference.to_json( - resource_group_reference_resource, including_default_value_fields=False + resource_group_reference_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1206,6 +1273,7 @@ def test_insert_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1256,6 +1324,7 @@ def test_insert_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1271,12 +1340,14 @@ def test_insert_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert compute.BackendService.to_json( - backend_service_resource, including_default_value_fields=False + backend_service_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1320,16 +1391,15 @@ def test_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.BackendServiceList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.BackendServiceList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [compute.BackendService(affinity_cookie_ttl_sec=2432)] assert response.kind == "kind_value" @@ -1353,6 +1423,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.BackendServiceList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1364,7 +1435,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1380,6 +1451,57 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = BackendServicesClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.BackendServiceList( + items=[ + compute.BackendService(), + compute.BackendService(), + compute.BackendService(), + ], + next_page_token="abc", + ), + compute.BackendServiceList(items=[], next_page_token="def",), + compute.BackendServiceList( + items=[compute.BackendService(),], next_page_token="ghi", + ), + compute.BackendServiceList( + items=[compute.BackendService(), compute.BackendService(),], + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.BackendServiceList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.BackendService) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_patch_rest( transport: str = "rest", request_type=compute.PatchBackendServiceRequest ): @@ -1422,6 +1544,7 @@ def test_patch_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1472,6 +1595,7 @@ def test_patch_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1489,14 +1613,16 @@ def test_patch_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "backend_service_value" in http_call[1] + str(body) assert compute.BackendService.to_json( - backend_service_resource, including_default_value_fields=False + backend_service_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1558,6 +1684,7 @@ def test_set_security_policy_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1608,6 +1735,7 @@ def test_set_security_policy_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1627,14 +1755,16 @@ def test_set_security_policy_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "backend_service_value" in http_call[1] + str(body) assert compute.SecurityPolicyReference.to_json( - security_policy_reference_resource, including_default_value_fields=False + security_policy_reference_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1696,6 +1826,7 @@ def test_update_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1746,6 +1877,7 @@ def test_update_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1763,14 +1895,16 @@ def test_update_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "backend_service_value" in http_call[1] + str(body) assert compute.BackendService.to_json( - backend_service_resource, including_default_value_fields=False + backend_service_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1924,6 +2058,17 @@ def test_backend_services_auth_adc(): ) +def test_backend_services_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.BackendServicesRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_backend_services_host_no_port(): client = BackendServicesClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_disk_types.py b/tests/unit/gapic/compute_v1/test_disk_types.py index 0f2713895..e46c41bc7 100644 --- a/tests/unit/gapic/compute_v1/test_disk_types.py +++ b/tests/unit/gapic/compute_v1/test_disk_types.py @@ -35,6 +35,7 @@ from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.compute_v1.services.disk_types import DiskTypesClient +from google.cloud.compute_v1.services.disk_types import pagers from google.cloud.compute_v1.services.disk_types import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -150,7 +151,7 @@ def test_disk_types_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -166,7 +167,7 @@ def test_disk_types_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -182,7 +183,7 @@ def test_disk_types_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -210,7 +211,7 @@ def test_disk_types_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -242,29 +243,25 @@ def test_disk_types_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -273,66 +270,53 @@ def test_disk_types_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -351,7 +335,7 @@ def test_disk_types_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -374,7 +358,7 @@ def test_disk_types_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -412,16 +396,15 @@ def test_aggregated_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.DiskTypeAggregatedList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.aggregated_list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.DiskTypeAggregatedList) + assert isinstance(response, pagers.AggregatedListPager) assert response.id == "id_value" assert response.items == { "key_value": compute.DiskTypesScopedList( @@ -450,6 +433,7 @@ def test_aggregated_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.DiskTypeAggregatedList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -461,7 +445,7 @@ def test_aggregated_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -477,6 +461,69 @@ def test_aggregated_list_rest_flattened_error(): ) +def test_aggregated_list_pager(): + client = DiskTypesClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.DiskTypeAggregatedList( + items={ + "a": compute.DiskTypesScopedList(), + "b": compute.DiskTypesScopedList(), + "c": compute.DiskTypesScopedList(), + }, + next_page_token="abc", + ), + compute.DiskTypeAggregatedList(items={}, next_page_token="def",), + compute.DiskTypeAggregatedList( + items={"g": compute.DiskTypesScopedList(),}, next_page_token="ghi", + ), + compute.DiskTypeAggregatedList( + items={ + "h": compute.DiskTypesScopedList(), + "i": compute.DiskTypesScopedList(), + }, + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.DiskTypeAggregatedList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.aggregated_list(request={}) + + assert pager._metadata == metadata + + assert isinstance(pager.get("a"), compute.DiskTypesScopedList) + assert pager.get("h") is None + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, tuple) for i in results) + for result in results: + assert isinstance(result, tuple) + assert tuple(type(t) for t in result) == (str, compute.DiskTypesScopedList) + + assert pager.get("a") is None + assert isinstance(pager.get("h"), compute.DiskTypesScopedList) + + pages = list(client.aggregated_list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_get_rest(transport: str = "rest", request_type=compute.GetDiskTypeRequest): client = DiskTypesClient( credentials=credentials.AnonymousCredentials(), transport=transport, @@ -505,6 +552,7 @@ def test_get_rest(transport: str = "rest", request_type=compute.GetDiskTypeReque # Wrap the value into a proper Response obj json_return_value = compute.DiskType.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -541,6 +589,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.DiskType.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -554,7 +603,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -600,16 +649,15 @@ def test_list_rest(transport: str = "rest", request_type=compute.ListDiskTypesRe # Wrap the value into a proper Response obj json_return_value = compute.DiskTypeList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.DiskTypeList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [ compute.DiskType(creation_timestamp="creation_timestamp_value") @@ -635,6 +683,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.DiskTypeList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -648,7 +697,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -666,6 +715,49 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = DiskTypesClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.DiskTypeList( + items=[compute.DiskType(), compute.DiskType(), compute.DiskType(),], + next_page_token="abc", + ), + compute.DiskTypeList(items=[], next_page_token="def",), + compute.DiskTypeList(items=[compute.DiskType(),], next_page_token="ghi",), + compute.DiskTypeList(items=[compute.DiskType(), compute.DiskType(),],), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.DiskTypeList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.DiskType) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_credentials_transport_error(): # It is an error to provide credentials and a transport instance. transport = transports.DiskTypesRestTransport( @@ -794,6 +886,17 @@ def test_disk_types_auth_adc(): ) +def test_disk_types_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.DiskTypesRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_disk_types_host_no_port(): client = DiskTypesClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_disks.py b/tests/unit/gapic/compute_v1/test_disks.py index b9d4af3b0..304dc6c8b 100644 --- a/tests/unit/gapic/compute_v1/test_disks.py +++ b/tests/unit/gapic/compute_v1/test_disks.py @@ -35,6 +35,7 @@ from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.compute_v1.services.disks import DisksClient +from google.cloud.compute_v1.services.disks import pagers from google.cloud.compute_v1.services.disks import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -147,7 +148,7 @@ def test_disks_client_client_options(client_class, transport_class, transport_na credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -163,7 +164,7 @@ def test_disks_client_client_options(client_class, transport_class, transport_na credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -179,7 +180,7 @@ def test_disks_client_client_options(client_class, transport_class, transport_na credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -207,7 +208,7 @@ def test_disks_client_client_options(client_class, transport_class, transport_na credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -239,29 +240,25 @@ def test_disks_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -270,66 +267,53 @@ def test_disks_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -348,7 +332,7 @@ def test_disks_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -371,7 +355,7 @@ def test_disks_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -419,6 +403,7 @@ def test_add_resource_policies_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -469,6 +454,7 @@ def test_add_resource_policies_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -489,7 +475,7 @@ def test_add_resource_policies_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -500,6 +486,7 @@ def test_add_resource_policies_rest_flattened(): assert compute.DisksAddResourcePoliciesRequest.to_json( disks_add_resource_policies_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -550,16 +537,15 @@ def test_aggregated_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.DiskAggregatedList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.aggregated_list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.DiskAggregatedList) + assert isinstance(response, pagers.AggregatedListPager) assert response.id == "id_value" assert response.items == { "key_value": compute.DisksScopedList( @@ -588,6 +574,7 @@ def test_aggregated_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.DiskAggregatedList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -599,7 +586,7 @@ def test_aggregated_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -615,6 +602,66 @@ def test_aggregated_list_rest_flattened_error(): ) +def test_aggregated_list_pager(): + client = DisksClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.DiskAggregatedList( + items={ + "a": compute.DisksScopedList(), + "b": compute.DisksScopedList(), + "c": compute.DisksScopedList(), + }, + next_page_token="abc", + ), + compute.DiskAggregatedList(items={}, next_page_token="def",), + compute.DiskAggregatedList( + items={"g": compute.DisksScopedList(),}, next_page_token="ghi", + ), + compute.DiskAggregatedList( + items={"h": compute.DisksScopedList(), "i": compute.DisksScopedList(),}, + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.DiskAggregatedList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.aggregated_list(request={}) + + assert pager._metadata == metadata + + assert isinstance(pager.get("a"), compute.DisksScopedList) + assert pager.get("h") is None + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, tuple) for i in results) + for result in results: + assert isinstance(result, tuple) + assert tuple(type(t) for t in result) == (str, compute.DisksScopedList) + + assert pager.get("a") is None + assert isinstance(pager.get("h"), compute.DisksScopedList) + + pages = list(client.aggregated_list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_create_snapshot_rest( transport: str = "rest", request_type=compute.CreateSnapshotDiskRequest ): @@ -657,6 +704,7 @@ def test_create_snapshot_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -707,6 +755,7 @@ def test_create_snapshot_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -725,7 +774,7 @@ def test_create_snapshot_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -734,7 +783,9 @@ def test_create_snapshot_rest_flattened(): assert "disk_value" in http_call[1] + str(body) assert compute.Snapshot.to_json( - snapshot_resource, including_default_value_fields=False + snapshot_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -793,6 +844,7 @@ def test_delete_rest(transport: str = "rest", request_type=compute.DeleteDiskReq # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -843,6 +895,7 @@ def test_delete_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -856,7 +909,7 @@ def test_delete_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -938,6 +991,7 @@ def test_get_rest(transport: str = "rest", request_type=compute.GetDiskRequest): # Wrap the value into a proper Response obj json_return_value = compute.Disk.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1005,6 +1059,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Disk.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1018,7 +1073,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1074,6 +1129,7 @@ def test_get_iam_policy_rest( # Wrap the value into a proper Response obj json_return_value = compute.Policy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1112,6 +1168,7 @@ def test_get_iam_policy_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Policy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1125,7 +1182,7 @@ def test_get_iam_policy_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1188,6 +1245,7 @@ def test_insert_rest(transport: str = "rest", request_type=compute.InsertDiskReq # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1238,6 +1296,7 @@ def test_insert_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1253,14 +1312,16 @@ def test_insert_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "zone_value" in http_call[1] + str(body) assert compute.Disk.to_json( - disk_resource, including_default_value_fields=False + disk_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1301,16 +1362,15 @@ def test_list_rest(transport: str = "rest", request_type=compute.ListDisksReques # Wrap the value into a proper Response obj json_return_value = compute.DiskList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.DiskList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [ compute.Disk(creation_timestamp="creation_timestamp_value") @@ -1336,6 +1396,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.DiskList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1349,7 +1410,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1367,6 +1428,49 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = DisksClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.DiskList( + items=[compute.Disk(), compute.Disk(), compute.Disk(),], + next_page_token="abc", + ), + compute.DiskList(items=[], next_page_token="def",), + compute.DiskList(items=[compute.Disk(),], next_page_token="ghi",), + compute.DiskList(items=[compute.Disk(), compute.Disk(),],), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.DiskList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.Disk) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_remove_resource_policies_rest( transport: str = "rest", request_type=compute.RemoveResourcePoliciesDiskRequest ): @@ -1409,6 +1513,7 @@ def test_remove_resource_policies_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1459,6 +1564,7 @@ def test_remove_resource_policies_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1479,7 +1585,7 @@ def test_remove_resource_policies_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1490,6 +1596,7 @@ def test_remove_resource_policies_rest_flattened(): assert compute.DisksRemoveResourcePoliciesRequest.to_json( disks_remove_resource_policies_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1550,6 +1657,7 @@ def test_resize_rest(transport: str = "rest", request_type=compute.ResizeDiskReq # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1600,6 +1708,7 @@ def test_resize_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1620,7 +1729,7 @@ def test_resize_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1629,7 +1738,9 @@ def test_resize_rest_flattened(): assert "disk_value" in http_call[1] + str(body) assert compute.DisksResizeRequest.to_json( - disks_resize_request_resource, including_default_value_fields=False + disks_resize_request_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1683,6 +1794,7 @@ def test_set_iam_policy_rest( # Wrap the value into a proper Response obj json_return_value = compute.Policy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1721,6 +1833,7 @@ def test_set_iam_policy_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Policy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1741,7 +1854,7 @@ def test_set_iam_policy_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1750,7 +1863,9 @@ def test_set_iam_policy_rest_flattened(): assert "resource_value" in http_call[1] + str(body) assert compute.ZoneSetPolicyRequest.to_json( - zone_set_policy_request_resource, including_default_value_fields=False + zone_set_policy_request_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1813,6 +1928,7 @@ def test_set_labels_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1863,6 +1979,7 @@ def test_set_labels_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1883,7 +2000,7 @@ def test_set_labels_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1892,7 +2009,9 @@ def test_set_labels_rest_flattened(): assert "resource_value" in http_call[1] + str(body) assert compute.ZoneSetLabelsRequest.to_json( - zone_set_labels_request_resource, including_default_value_fields=False + zone_set_labels_request_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1933,6 +2052,7 @@ def test_test_iam_permissions_rest( # Wrap the value into a proper Response obj json_return_value = compute.TestPermissionsResponse.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1959,6 +2079,7 @@ def test_test_iam_permissions_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.TestPermissionsResponse.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1979,7 +2100,7 @@ def test_test_iam_permissions_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1988,7 +2109,9 @@ def test_test_iam_permissions_rest_flattened(): assert "resource_value" in http_call[1] + str(body) assert compute.TestPermissionsRequest.to_json( - test_permissions_request_resource, including_default_value_fields=False + test_permissions_request_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -2145,6 +2268,17 @@ def test_disks_auth_adc(): ) +def test_disks_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.DisksRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_disks_host_no_port(): client = DisksClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_external_vpn_gateways.py b/tests/unit/gapic/compute_v1/test_external_vpn_gateways.py index 5b5e851ab..212b049e2 100644 --- a/tests/unit/gapic/compute_v1/test_external_vpn_gateways.py +++ b/tests/unit/gapic/compute_v1/test_external_vpn_gateways.py @@ -37,6 +37,7 @@ from google.cloud.compute_v1.services.external_vpn_gateways import ( ExternalVpnGatewaysClient, ) +from google.cloud.compute_v1.services.external_vpn_gateways import pagers from google.cloud.compute_v1.services.external_vpn_gateways import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -160,7 +161,7 @@ def test_external_vpn_gateways_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -176,7 +177,7 @@ def test_external_vpn_gateways_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -192,7 +193,7 @@ def test_external_vpn_gateways_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -220,7 +221,7 @@ def test_external_vpn_gateways_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -264,29 +265,25 @@ def test_external_vpn_gateways_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -295,66 +292,53 @@ def test_external_vpn_gateways_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -373,7 +357,7 @@ def test_external_vpn_gateways_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -396,7 +380,7 @@ def test_external_vpn_gateways_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -444,6 +428,7 @@ def test_delete_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -494,6 +479,7 @@ def test_delete_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -507,7 +493,7 @@ def test_delete_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -556,6 +542,7 @@ def test_get_rest( # Wrap the value into a proper Response obj json_return_value = compute.ExternalVpnGateway.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -594,6 +581,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.ExternalVpnGateway.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -607,7 +595,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -669,6 +657,7 @@ def test_insert_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -719,6 +708,7 @@ def test_insert_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -737,12 +727,14 @@ def test_insert_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert compute.ExternalVpnGateway.to_json( - external_vpn_gateway_resource, including_default_value_fields=False + external_vpn_gateway_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -791,16 +783,15 @@ def test_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.ExternalVpnGatewayList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.ExternalVpnGatewayList) + assert isinstance(response, pagers.ListPager) assert response.etag == "etag_value" assert response.id == "id_value" assert response.items == [ @@ -827,6 +818,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.ExternalVpnGatewayList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -838,7 +830,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -854,6 +846,57 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = ExternalVpnGatewaysClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.ExternalVpnGatewayList( + items=[ + compute.ExternalVpnGateway(), + compute.ExternalVpnGateway(), + compute.ExternalVpnGateway(), + ], + next_page_token="abc", + ), + compute.ExternalVpnGatewayList(items=[], next_page_token="def",), + compute.ExternalVpnGatewayList( + items=[compute.ExternalVpnGateway(),], next_page_token="ghi", + ), + compute.ExternalVpnGatewayList( + items=[compute.ExternalVpnGateway(), compute.ExternalVpnGateway(),], + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.ExternalVpnGatewayList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.ExternalVpnGateway) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_set_labels_rest( transport: str = "rest", request_type=compute.SetLabelsExternalVpnGatewayRequest ): @@ -896,6 +939,7 @@ def test_set_labels_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -946,6 +990,7 @@ def test_set_labels_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -965,14 +1010,16 @@ def test_set_labels_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "resource_value" in http_call[1] + str(body) assert compute.GlobalSetLabelsRequest.to_json( - global_set_labels_request_resource, including_default_value_fields=False + global_set_labels_request_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1013,6 +1060,7 @@ def test_test_iam_permissions_rest( # Wrap the value into a proper Response obj json_return_value = compute.TestPermissionsResponse.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1039,6 +1087,7 @@ def test_test_iam_permissions_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.TestPermissionsResponse.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1058,14 +1107,16 @@ def test_test_iam_permissions_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "resource_value" in http_call[1] + str(body) assert compute.TestPermissionsRequest.to_json( - test_permissions_request_resource, including_default_value_fields=False + test_permissions_request_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1216,6 +1267,17 @@ def test_external_vpn_gateways_auth_adc(): ) +def test_external_vpn_gateways_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.ExternalVpnGatewaysRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_external_vpn_gateways_host_no_port(): client = ExternalVpnGatewaysClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_firewalls.py b/tests/unit/gapic/compute_v1/test_firewalls.py index d88ea1200..636f12c6b 100644 --- a/tests/unit/gapic/compute_v1/test_firewalls.py +++ b/tests/unit/gapic/compute_v1/test_firewalls.py @@ -35,6 +35,7 @@ from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.compute_v1.services.firewalls import FirewallsClient +from google.cloud.compute_v1.services.firewalls import pagers from google.cloud.compute_v1.services.firewalls import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -148,7 +149,7 @@ def test_firewalls_client_client_options(client_class, transport_class, transpor credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -164,7 +165,7 @@ def test_firewalls_client_client_options(client_class, transport_class, transpor credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -180,7 +181,7 @@ def test_firewalls_client_client_options(client_class, transport_class, transpor credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -208,7 +209,7 @@ def test_firewalls_client_client_options(client_class, transport_class, transpor credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -240,29 +241,25 @@ def test_firewalls_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -271,66 +268,53 @@ def test_firewalls_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -349,7 +333,7 @@ def test_firewalls_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -372,7 +356,7 @@ def test_firewalls_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -420,6 +404,7 @@ def test_delete_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -470,6 +455,7 @@ def test_delete_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -483,7 +469,7 @@ def test_delete_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -539,6 +525,7 @@ def test_get_rest(transport: str = "rest", request_type=compute.GetFirewallReque # Wrap the value into a proper Response obj json_return_value = compute.Firewall.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -584,6 +571,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Firewall.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -597,7 +585,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -659,6 +647,7 @@ def test_insert_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -709,6 +698,7 @@ def test_insert_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -726,12 +716,14 @@ def test_insert_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert compute.Firewall.to_json( - firewall_resource, including_default_value_fields=False + firewall_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -777,16 +769,15 @@ def test_list_rest(transport: str = "rest", request_type=compute.ListFirewallsRe # Wrap the value into a proper Response obj json_return_value = compute.FirewallList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.FirewallList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [ compute.Firewall(allowed=[compute.Allowed(i_p_protocol="i_p_protocol_value")]) @@ -812,6 +803,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.FirewallList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -823,7 +815,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -839,6 +831,49 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = FirewallsClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.FirewallList( + items=[compute.Firewall(), compute.Firewall(), compute.Firewall(),], + next_page_token="abc", + ), + compute.FirewallList(items=[], next_page_token="def",), + compute.FirewallList(items=[compute.Firewall(),], next_page_token="ghi",), + compute.FirewallList(items=[compute.Firewall(), compute.Firewall(),],), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.FirewallList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.Firewall) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_patch_rest(transport: str = "rest", request_type=compute.PatchFirewallRequest): client = FirewallsClient( credentials=credentials.AnonymousCredentials(), transport=transport, @@ -879,6 +914,7 @@ def test_patch_rest(transport: str = "rest", request_type=compute.PatchFirewallR # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -929,6 +965,7 @@ def test_patch_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -948,14 +985,16 @@ def test_patch_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "firewall_value" in http_call[1] + str(body) assert compute.Firewall.to_json( - firewall_resource, including_default_value_fields=False + firewall_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1017,6 +1056,7 @@ def test_update_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1067,6 +1107,7 @@ def test_update_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1086,14 +1127,16 @@ def test_update_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "firewall_value" in http_call[1] + str(body) assert compute.Firewall.to_json( - firewall_resource, including_default_value_fields=False + firewall_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1242,6 +1285,17 @@ def test_firewalls_auth_adc(): ) +def test_firewalls_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.FirewallsRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_firewalls_host_no_port(): client = FirewallsClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_forwarding_rules.py b/tests/unit/gapic/compute_v1/test_forwarding_rules.py index 5f9783c09..af98e35ba 100644 --- a/tests/unit/gapic/compute_v1/test_forwarding_rules.py +++ b/tests/unit/gapic/compute_v1/test_forwarding_rules.py @@ -35,6 +35,7 @@ from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.compute_v1.services.forwarding_rules import ForwardingRulesClient +from google.cloud.compute_v1.services.forwarding_rules import pagers from google.cloud.compute_v1.services.forwarding_rules import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -157,7 +158,7 @@ def test_forwarding_rules_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -173,7 +174,7 @@ def test_forwarding_rules_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -189,7 +190,7 @@ def test_forwarding_rules_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -217,7 +218,7 @@ def test_forwarding_rules_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -261,29 +262,25 @@ def test_forwarding_rules_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -292,66 +289,53 @@ def test_forwarding_rules_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -370,7 +354,7 @@ def test_forwarding_rules_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -393,7 +377,7 @@ def test_forwarding_rules_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -429,16 +413,15 @@ def test_aggregated_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.ForwardingRuleAggregatedList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.aggregated_list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.ForwardingRuleAggregatedList) + assert isinstance(response, pagers.AggregatedListPager) assert response.id == "id_value" assert response.items == { "key_value": compute.ForwardingRulesScopedList( @@ -467,6 +450,7 @@ def test_aggregated_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.ForwardingRuleAggregatedList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -478,7 +462,7 @@ def test_aggregated_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -494,6 +478,75 @@ def test_aggregated_list_rest_flattened_error(): ) +def test_aggregated_list_pager(): + client = ForwardingRulesClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.ForwardingRuleAggregatedList( + items={ + "a": compute.ForwardingRulesScopedList(), + "b": compute.ForwardingRulesScopedList(), + "c": compute.ForwardingRulesScopedList(), + }, + next_page_token="abc", + ), + compute.ForwardingRuleAggregatedList(items={}, next_page_token="def",), + compute.ForwardingRuleAggregatedList( + items={"g": compute.ForwardingRulesScopedList(),}, + next_page_token="ghi", + ), + compute.ForwardingRuleAggregatedList( + items={ + "h": compute.ForwardingRulesScopedList(), + "i": compute.ForwardingRulesScopedList(), + }, + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple( + compute.ForwardingRuleAggregatedList.to_json(x) for x in response + ) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.aggregated_list(request={}) + + assert pager._metadata == metadata + + assert isinstance(pager.get("a"), compute.ForwardingRulesScopedList) + assert pager.get("h") is None + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, tuple) for i in results) + for result in results: + assert isinstance(result, tuple) + assert tuple(type(t) for t in result) == ( + str, + compute.ForwardingRulesScopedList, + ) + + assert pager.get("a") is None + assert isinstance(pager.get("h"), compute.ForwardingRulesScopedList) + + pages = list(client.aggregated_list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_delete_rest( transport: str = "rest", request_type=compute.DeleteForwardingRuleRequest ): @@ -536,6 +589,7 @@ def test_delete_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -586,6 +640,7 @@ def test_delete_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -601,7 +656,7 @@ def test_delete_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -672,6 +727,7 @@ def test_get_rest( # Wrap the value into a proper Response obj json_return_value = compute.ForwardingRule.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -732,6 +788,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.ForwardingRule.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -747,7 +804,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -812,6 +869,7 @@ def test_insert_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -862,6 +920,7 @@ def test_insert_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -879,14 +938,16 @@ def test_insert_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "region_value" in http_call[1] + str(body) assert compute.ForwardingRule.to_json( - forwarding_rule_resource, including_default_value_fields=False + forwarding_rule_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -929,16 +990,15 @@ def test_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.ForwardingRuleList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.ForwardingRuleList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [compute.ForwardingRule(all_ports=True)] assert response.kind == "kind_value" @@ -962,6 +1022,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.ForwardingRuleList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -975,7 +1036,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -995,6 +1056,57 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = ForwardingRulesClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.ForwardingRuleList( + items=[ + compute.ForwardingRule(), + compute.ForwardingRule(), + compute.ForwardingRule(), + ], + next_page_token="abc", + ), + compute.ForwardingRuleList(items=[], next_page_token="def",), + compute.ForwardingRuleList( + items=[compute.ForwardingRule(),], next_page_token="ghi", + ), + compute.ForwardingRuleList( + items=[compute.ForwardingRule(), compute.ForwardingRule(),], + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.ForwardingRuleList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.ForwardingRule) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_patch_rest( transport: str = "rest", request_type=compute.PatchForwardingRuleRequest ): @@ -1037,6 +1149,7 @@ def test_patch_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1087,6 +1200,7 @@ def test_patch_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1105,7 +1219,7 @@ def test_patch_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1114,7 +1228,9 @@ def test_patch_rest_flattened(): assert "forwarding_rule_value" in http_call[1] + str(body) assert compute.ForwardingRule.to_json( - forwarding_rule_resource, including_default_value_fields=False + forwarding_rule_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1175,6 +1291,7 @@ def test_set_target_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1225,6 +1342,7 @@ def test_set_target_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1243,7 +1361,7 @@ def test_set_target_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1252,7 +1370,9 @@ def test_set_target_rest_flattened(): assert "forwarding_rule_value" in http_call[1] + str(body) assert compute.TargetReference.to_json( - target_reference_resource, including_default_value_fields=False + target_reference_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1401,6 +1521,17 @@ def test_forwarding_rules_auth_adc(): ) +def test_forwarding_rules_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.ForwardingRulesRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_forwarding_rules_host_no_port(): client = ForwardingRulesClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_global_addresses.py b/tests/unit/gapic/compute_v1/test_global_addresses.py index 898a239e3..90a94e255 100644 --- a/tests/unit/gapic/compute_v1/test_global_addresses.py +++ b/tests/unit/gapic/compute_v1/test_global_addresses.py @@ -35,6 +35,7 @@ from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.compute_v1.services.global_addresses import GlobalAddressesClient +from google.cloud.compute_v1.services.global_addresses import pagers from google.cloud.compute_v1.services.global_addresses import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -157,7 +158,7 @@ def test_global_addresses_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -173,7 +174,7 @@ def test_global_addresses_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -189,7 +190,7 @@ def test_global_addresses_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -217,7 +218,7 @@ def test_global_addresses_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -261,29 +262,25 @@ def test_global_addresses_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -292,66 +289,53 @@ def test_global_addresses_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -370,7 +354,7 @@ def test_global_addresses_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -393,7 +377,7 @@ def test_global_addresses_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -441,6 +425,7 @@ def test_delete_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -491,6 +476,7 @@ def test_delete_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -504,7 +490,7 @@ def test_delete_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -560,6 +546,7 @@ def test_get_rest( # Wrap the value into a proper Response obj json_return_value = compute.Address.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -602,6 +589,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Address.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -615,7 +603,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -677,6 +665,7 @@ def test_insert_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -727,6 +716,7 @@ def test_insert_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -742,12 +732,14 @@ def test_insert_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert compute.Address.to_json( - address_resource, including_default_value_fields=False + address_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -789,16 +781,15 @@ def test_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.AddressList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.AddressList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [compute.Address(address="address_value")] assert response.kind == "kind_value" @@ -822,6 +813,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.AddressList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -833,7 +825,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -849,6 +841,49 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = GlobalAddressesClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.AddressList( + items=[compute.Address(), compute.Address(), compute.Address(),], + next_page_token="abc", + ), + compute.AddressList(items=[], next_page_token="def",), + compute.AddressList(items=[compute.Address(),], next_page_token="ghi",), + compute.AddressList(items=[compute.Address(), compute.Address(),],), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.AddressList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.Address) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_credentials_transport_error(): # It is an error to provide credentials and a transport instance. transport = transports.GlobalAddressesRestTransport( @@ -976,6 +1011,17 @@ def test_global_addresses_auth_adc(): ) +def test_global_addresses_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.GlobalAddressesRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_global_addresses_host_no_port(): client = GlobalAddressesClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_global_forwarding_rules.py b/tests/unit/gapic/compute_v1/test_global_forwarding_rules.py index e15724c32..562428ea7 100644 --- a/tests/unit/gapic/compute_v1/test_global_forwarding_rules.py +++ b/tests/unit/gapic/compute_v1/test_global_forwarding_rules.py @@ -37,6 +37,7 @@ from google.cloud.compute_v1.services.global_forwarding_rules import ( GlobalForwardingRulesClient, ) +from google.cloud.compute_v1.services.global_forwarding_rules import pagers from google.cloud.compute_v1.services.global_forwarding_rules import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -166,7 +167,7 @@ def test_global_forwarding_rules_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -182,7 +183,7 @@ def test_global_forwarding_rules_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -198,7 +199,7 @@ def test_global_forwarding_rules_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -226,7 +227,7 @@ def test_global_forwarding_rules_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -270,29 +271,25 @@ def test_global_forwarding_rules_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -301,66 +298,53 @@ def test_global_forwarding_rules_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -385,7 +369,7 @@ def test_global_forwarding_rules_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -414,7 +398,7 @@ def test_global_forwarding_rules_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -462,6 +446,7 @@ def test_delete_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -514,6 +499,7 @@ def test_delete_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -527,7 +513,7 @@ def test_delete_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -597,6 +583,7 @@ def test_get_rest( # Wrap the value into a proper Response obj json_return_value = compute.ForwardingRule.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -659,6 +646,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.ForwardingRule.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -672,7 +660,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -736,6 +724,7 @@ def test_insert_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -788,6 +777,7 @@ def test_insert_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -803,12 +793,14 @@ def test_insert_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert compute.ForwardingRule.to_json( - forwarding_rule_resource, including_default_value_fields=False + forwarding_rule_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -852,16 +844,15 @@ def test_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.ForwardingRuleList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.ForwardingRuleList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [compute.ForwardingRule(all_ports=True)] assert response.kind == "kind_value" @@ -887,6 +878,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.ForwardingRuleList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -898,7 +890,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -916,6 +908,59 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = GlobalForwardingRulesClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.ForwardingRuleList( + items=[ + compute.ForwardingRule(), + compute.ForwardingRule(), + compute.ForwardingRule(), + ], + next_page_token="abc", + ), + compute.ForwardingRuleList(items=[], next_page_token="def",), + compute.ForwardingRuleList( + items=[compute.ForwardingRule(),], next_page_token="ghi", + ), + compute.ForwardingRuleList( + items=[compute.ForwardingRule(), compute.ForwardingRule(),], + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.ForwardingRuleList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.ForwardingRule) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_patch_rest( transport: str = "rest", request_type=compute.PatchGlobalForwardingRuleRequest ): @@ -958,6 +1003,7 @@ def test_patch_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1010,6 +1056,7 @@ def test_patch_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1027,14 +1074,16 @@ def test_patch_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "forwarding_rule_value" in http_call[1] + str(body) assert compute.ForwardingRule.to_json( - forwarding_rule_resource, including_default_value_fields=False + forwarding_rule_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1096,6 +1145,7 @@ def test_set_target_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1148,6 +1198,7 @@ def test_set_target_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1165,14 +1216,16 @@ def test_set_target_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "forwarding_rule_value" in http_call[1] + str(body) assert compute.TargetReference.to_json( - target_reference_resource, including_default_value_fields=False + target_reference_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1323,6 +1376,17 @@ def test_global_forwarding_rules_auth_adc(): ) +def test_global_forwarding_rules_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.GlobalForwardingRulesRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_global_forwarding_rules_host_no_port(): client = GlobalForwardingRulesClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_global_network_endpoint_groups.py b/tests/unit/gapic/compute_v1/test_global_network_endpoint_groups.py index 8e50b11eb..dac41f62b 100644 --- a/tests/unit/gapic/compute_v1/test_global_network_endpoint_groups.py +++ b/tests/unit/gapic/compute_v1/test_global_network_endpoint_groups.py @@ -37,6 +37,7 @@ from google.cloud.compute_v1.services.global_network_endpoint_groups import ( GlobalNetworkEndpointGroupsClient, ) +from google.cloud.compute_v1.services.global_network_endpoint_groups import pagers from google.cloud.compute_v1.services.global_network_endpoint_groups import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -172,7 +173,7 @@ def test_global_network_endpoint_groups_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -188,7 +189,7 @@ def test_global_network_endpoint_groups_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -204,7 +205,7 @@ def test_global_network_endpoint_groups_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -232,7 +233,7 @@ def test_global_network_endpoint_groups_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -276,29 +277,25 @@ def test_global_network_endpoint_groups_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -307,66 +304,53 @@ def test_global_network_endpoint_groups_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -391,7 +375,7 @@ def test_global_network_endpoint_groups_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -420,7 +404,7 @@ def test_global_network_endpoint_groups_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -469,6 +453,7 @@ def test_attach_network_endpoints_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -521,6 +506,7 @@ def test_attach_network_endpoints_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -542,7 +528,7 @@ def test_attach_network_endpoints_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -551,6 +537,7 @@ def test_attach_network_endpoints_rest_flattened(): assert compute.GlobalNetworkEndpointGroupsAttachEndpointsRequest.to_json( global_network_endpoint_groups_attach_endpoints_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -617,6 +604,7 @@ def test_delete_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -669,6 +657,7 @@ def test_delete_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -683,7 +672,7 @@ def test_delete_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -748,6 +737,7 @@ def test_detach_network_endpoints_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -800,6 +790,7 @@ def test_detach_network_endpoints_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -821,7 +812,7 @@ def test_detach_network_endpoints_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -830,6 +821,7 @@ def test_detach_network_endpoints_rest_flattened(): assert compute.GlobalNetworkEndpointGroupsDetachEndpointsRequest.to_json( global_network_endpoint_groups_detach_endpoints_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -891,6 +883,7 @@ def test_get_rest( # Wrap the value into a proper Response obj json_return_value = compute.NetworkEndpointGroup.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -944,6 +937,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.NetworkEndpointGroup.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -958,7 +952,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1023,6 +1017,7 @@ def test_insert_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1075,6 +1070,7 @@ def test_insert_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1093,12 +1089,14 @@ def test_insert_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert compute.NetworkEndpointGroup.to_json( - network_endpoint_group_resource, including_default_value_fields=False + network_endpoint_group_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1146,16 +1144,15 @@ def test_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.NetworkEndpointGroupList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.NetworkEndpointGroupList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [ compute.NetworkEndpointGroup(annotations={"key_value": "value_value"}) @@ -1183,6 +1180,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.NetworkEndpointGroupList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1194,7 +1192,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1212,6 +1210,59 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = GlobalNetworkEndpointGroupsClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.NetworkEndpointGroupList( + items=[ + compute.NetworkEndpointGroup(), + compute.NetworkEndpointGroup(), + compute.NetworkEndpointGroup(), + ], + next_page_token="abc", + ), + compute.NetworkEndpointGroupList(items=[], next_page_token="def",), + compute.NetworkEndpointGroupList( + items=[compute.NetworkEndpointGroup(),], next_page_token="ghi", + ), + compute.NetworkEndpointGroupList( + items=[compute.NetworkEndpointGroup(), compute.NetworkEndpointGroup(),], + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.NetworkEndpointGroupList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.NetworkEndpointGroup) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_list_network_endpoints_rest( transport: str = "rest", request_type=compute.ListNetworkEndpointsGlobalNetworkEndpointGroupsRequest, @@ -1249,16 +1300,15 @@ def test_list_network_endpoints_rest( return_value ) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list_network_endpoints(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.NetworkEndpointGroupsListNetworkEndpoints) + assert isinstance(response, pagers.ListNetworkEndpointsPager) assert response.id == "id_value" assert response.items == [ compute.NetworkEndpointWithHealthStatus( @@ -1295,6 +1345,7 @@ def test_list_network_endpoints_rest_flattened(): return_value ) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1309,7 +1360,7 @@ def test_list_network_endpoints_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1331,6 +1382,70 @@ def test_list_network_endpoints_rest_flattened_error(): ) +def test_list_network_endpoints_pager(): + client = GlobalNetworkEndpointGroupsClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.NetworkEndpointGroupsListNetworkEndpoints( + items=[ + compute.NetworkEndpointWithHealthStatus(), + compute.NetworkEndpointWithHealthStatus(), + compute.NetworkEndpointWithHealthStatus(), + ], + next_page_token="abc", + ), + compute.NetworkEndpointGroupsListNetworkEndpoints( + items=[], next_page_token="def", + ), + compute.NetworkEndpointGroupsListNetworkEndpoints( + items=[compute.NetworkEndpointWithHealthStatus(),], + next_page_token="ghi", + ), + compute.NetworkEndpointGroupsListNetworkEndpoints( + items=[ + compute.NetworkEndpointWithHealthStatus(), + compute.NetworkEndpointWithHealthStatus(), + ], + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple( + compute.NetworkEndpointGroupsListNetworkEndpoints.to_json(x) + for x in response + ) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list_network_endpoints(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all( + isinstance(i, compute.NetworkEndpointWithHealthStatus) for i in results + ) + + pages = list(client.list_network_endpoints(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_credentials_transport_error(): # It is an error to provide credentials and a transport instance. transport = transports.GlobalNetworkEndpointGroupsRestTransport( @@ -1463,6 +1578,17 @@ def test_global_network_endpoint_groups_auth_adc(): ) +def test_global_network_endpoint_groups_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.GlobalNetworkEndpointGroupsRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_global_network_endpoint_groups_host_no_port(): client = GlobalNetworkEndpointGroupsClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_global_operations.py b/tests/unit/gapic/compute_v1/test_global_operations.py index d3f5805c1..6678196c0 100644 --- a/tests/unit/gapic/compute_v1/test_global_operations.py +++ b/tests/unit/gapic/compute_v1/test_global_operations.py @@ -35,6 +35,7 @@ from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.compute_v1.services.global_operations import GlobalOperationsClient +from google.cloud.compute_v1.services.global_operations import pagers from google.cloud.compute_v1.services.global_operations import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -158,7 +159,7 @@ def test_global_operations_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -174,7 +175,7 @@ def test_global_operations_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -190,7 +191,7 @@ def test_global_operations_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -218,7 +219,7 @@ def test_global_operations_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -262,29 +263,25 @@ def test_global_operations_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -293,66 +290,53 @@ def test_global_operations_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -371,7 +355,7 @@ def test_global_operations_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -394,7 +378,7 @@ def test_global_operations_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -434,16 +418,15 @@ def test_aggregated_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.OperationAggregatedList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.aggregated_list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.OperationAggregatedList) + assert isinstance(response, pagers.AggregatedListPager) assert response.id == "id_value" assert response.items == { "key_value": compute.OperationsScopedList( @@ -474,6 +457,7 @@ def test_aggregated_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.OperationAggregatedList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -485,7 +469,7 @@ def test_aggregated_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -501,6 +485,69 @@ def test_aggregated_list_rest_flattened_error(): ) +def test_aggregated_list_pager(): + client = GlobalOperationsClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.OperationAggregatedList( + items={ + "a": compute.OperationsScopedList(), + "b": compute.OperationsScopedList(), + "c": compute.OperationsScopedList(), + }, + next_page_token="abc", + ), + compute.OperationAggregatedList(items={}, next_page_token="def",), + compute.OperationAggregatedList( + items={"g": compute.OperationsScopedList(),}, next_page_token="ghi", + ), + compute.OperationAggregatedList( + items={ + "h": compute.OperationsScopedList(), + "i": compute.OperationsScopedList(), + }, + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.OperationAggregatedList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.aggregated_list(request={}) + + assert pager._metadata == metadata + + assert isinstance(pager.get("a"), compute.OperationsScopedList) + assert pager.get("h") is None + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, tuple) for i in results) + for result in results: + assert isinstance(result, tuple) + assert tuple(type(t) for t in result) == (str, compute.OperationsScopedList) + + assert pager.get("a") is None + assert isinstance(pager.get("h"), compute.OperationsScopedList) + + pages = list(client.aggregated_list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_delete_rest( transport: str = "rest", request_type=compute.DeleteGlobalOperationRequest ): @@ -519,6 +566,7 @@ def test_delete_rest( # Wrap the value into a proper Response obj json_return_value = compute.DeleteGlobalOperationResponse.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -544,6 +592,7 @@ def test_delete_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.DeleteGlobalOperationResponse.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -557,7 +606,7 @@ def test_delete_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -619,6 +668,7 @@ def test_get_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -669,6 +719,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -682,7 +733,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -727,16 +778,15 @@ def test_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.OperationList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.OperationList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [ compute.Operation(client_operation_id="client_operation_id_value") @@ -762,6 +812,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.OperationList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -773,7 +824,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -789,6 +840,49 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = GlobalOperationsClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.OperationList( + items=[compute.Operation(), compute.Operation(), compute.Operation(),], + next_page_token="abc", + ), + compute.OperationList(items=[], next_page_token="def",), + compute.OperationList(items=[compute.Operation(),], next_page_token="ghi",), + compute.OperationList(items=[compute.Operation(), compute.Operation(),],), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.OperationList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.Operation) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_wait_rest( transport: str = "rest", request_type=compute.WaitGlobalOperationRequest ): @@ -831,6 +925,7 @@ def test_wait_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -881,6 +976,7 @@ def test_wait_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -894,7 +990,7 @@ def test_wait_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1042,6 +1138,17 @@ def test_global_operations_auth_adc(): ) +def test_global_operations_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.GlobalOperationsRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_global_operations_host_no_port(): client = GlobalOperationsClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_global_organization_operations.py b/tests/unit/gapic/compute_v1/test_global_organization_operations.py index 717247804..8527031d6 100644 --- a/tests/unit/gapic/compute_v1/test_global_organization_operations.py +++ b/tests/unit/gapic/compute_v1/test_global_organization_operations.py @@ -37,6 +37,7 @@ from google.cloud.compute_v1.services.global_organization_operations import ( GlobalOrganizationOperationsClient, ) +from google.cloud.compute_v1.services.global_organization_operations import pagers from google.cloud.compute_v1.services.global_organization_operations import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -172,7 +173,7 @@ def test_global_organization_operations_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -188,7 +189,7 @@ def test_global_organization_operations_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -204,7 +205,7 @@ def test_global_organization_operations_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -232,7 +233,7 @@ def test_global_organization_operations_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -276,29 +277,25 @@ def test_global_organization_operations_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -307,66 +304,53 @@ def test_global_organization_operations_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -391,7 +375,7 @@ def test_global_organization_operations_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -420,7 +404,7 @@ def test_global_organization_operations_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -447,6 +431,7 @@ def test_delete_rest( return_value ) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -476,6 +461,7 @@ def test_delete_rest_flattened(): return_value ) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -487,7 +473,7 @@ def test_delete_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "operation_value" in http_call[1] + str(body) @@ -548,6 +534,7 @@ def test_get_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -600,6 +587,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -611,7 +599,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "operation_value" in http_call[1] + str(body) @@ -656,16 +644,15 @@ def test_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.OperationList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.OperationList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [ compute.Operation(client_operation_id="client_operation_id_value") @@ -693,6 +680,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.OperationList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -704,7 +692,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") def test_list_rest_flattened_error(): @@ -718,6 +706,51 @@ def test_list_rest_flattened_error(): client.list(compute.ListGlobalOrganizationOperationsRequest(),) +def test_list_pager(): + client = GlobalOrganizationOperationsClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.OperationList( + items=[compute.Operation(), compute.Operation(), compute.Operation(),], + next_page_token="abc", + ), + compute.OperationList(items=[], next_page_token="def",), + compute.OperationList(items=[compute.Operation(),], next_page_token="ghi",), + compute.OperationList(items=[compute.Operation(), compute.Operation(),],), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.OperationList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.Operation) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_credentials_transport_error(): # It is an error to provide credentials and a transport instance. transport = transports.GlobalOrganizationOperationsRestTransport( @@ -846,6 +879,17 @@ def test_global_organization_operations_auth_adc(): ) +def test_global_organization_operations_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.GlobalOrganizationOperationsRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_global_organization_operations_host_no_port(): client = GlobalOrganizationOperationsClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_health_checks.py b/tests/unit/gapic/compute_v1/test_health_checks.py index 0bfbb83eb..112b5f214 100644 --- a/tests/unit/gapic/compute_v1/test_health_checks.py +++ b/tests/unit/gapic/compute_v1/test_health_checks.py @@ -35,6 +35,7 @@ from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.compute_v1.services.health_checks import HealthChecksClient +from google.cloud.compute_v1.services.health_checks import pagers from google.cloud.compute_v1.services.health_checks import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -152,7 +153,7 @@ def test_health_checks_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -168,7 +169,7 @@ def test_health_checks_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -184,7 +185,7 @@ def test_health_checks_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -212,7 +213,7 @@ def test_health_checks_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -244,29 +245,25 @@ def test_health_checks_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -275,66 +272,53 @@ def test_health_checks_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -353,7 +337,7 @@ def test_health_checks_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -376,7 +360,7 @@ def test_health_checks_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -412,16 +396,15 @@ def test_aggregated_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.HealthChecksAggregatedList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.aggregated_list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.HealthChecksAggregatedList) + assert isinstance(response, pagers.AggregatedListPager) assert response.id == "id_value" assert response.items == { "key_value": compute.HealthChecksScopedList( @@ -450,6 +433,7 @@ def test_aggregated_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.HealthChecksAggregatedList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -461,7 +445,7 @@ def test_aggregated_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -477,6 +461,74 @@ def test_aggregated_list_rest_flattened_error(): ) +def test_aggregated_list_pager(): + client = HealthChecksClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.HealthChecksAggregatedList( + items={ + "a": compute.HealthChecksScopedList(), + "b": compute.HealthChecksScopedList(), + "c": compute.HealthChecksScopedList(), + }, + next_page_token="abc", + ), + compute.HealthChecksAggregatedList(items={}, next_page_token="def",), + compute.HealthChecksAggregatedList( + items={"g": compute.HealthChecksScopedList(),}, next_page_token="ghi", + ), + compute.HealthChecksAggregatedList( + items={ + "h": compute.HealthChecksScopedList(), + "i": compute.HealthChecksScopedList(), + }, + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple( + compute.HealthChecksAggregatedList.to_json(x) for x in response + ) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.aggregated_list(request={}) + + assert pager._metadata == metadata + + assert isinstance(pager.get("a"), compute.HealthChecksScopedList) + assert pager.get("h") is None + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, tuple) for i in results) + for result in results: + assert isinstance(result, tuple) + assert tuple(type(t) for t in result) == ( + str, + compute.HealthChecksScopedList, + ) + + assert pager.get("a") is None + assert isinstance(pager.get("h"), compute.HealthChecksScopedList) + + pages = list(client.aggregated_list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_delete_rest( transport: str = "rest", request_type=compute.DeleteHealthCheckRequest ): @@ -519,6 +571,7 @@ def test_delete_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -569,6 +622,7 @@ def test_delete_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -582,7 +636,7 @@ def test_delete_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -640,6 +694,7 @@ def test_get_rest(transport: str = "rest", request_type=compute.GetHealthCheckRe # Wrap the value into a proper Response obj json_return_value = compute.HealthCheck.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -686,6 +741,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.HealthCheck.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -699,7 +755,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -761,6 +817,7 @@ def test_insert_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -811,6 +868,7 @@ def test_insert_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -826,12 +884,14 @@ def test_insert_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert compute.HealthCheck.to_json( - health_check_resource, including_default_value_fields=False + health_check_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -873,16 +933,15 @@ def test_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.HealthCheckList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.HealthCheckList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [compute.HealthCheck(check_interval_sec=1884)] assert response.kind == "kind_value" @@ -906,6 +965,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.HealthCheckList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -917,7 +977,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -933,6 +993,57 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = HealthChecksClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.HealthCheckList( + items=[ + compute.HealthCheck(), + compute.HealthCheck(), + compute.HealthCheck(), + ], + next_page_token="abc", + ), + compute.HealthCheckList(items=[], next_page_token="def",), + compute.HealthCheckList( + items=[compute.HealthCheck(),], next_page_token="ghi", + ), + compute.HealthCheckList( + items=[compute.HealthCheck(), compute.HealthCheck(),], + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.HealthCheckList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.HealthCheck) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_patch_rest( transport: str = "rest", request_type=compute.PatchHealthCheckRequest ): @@ -975,6 +1086,7 @@ def test_patch_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1025,6 +1137,7 @@ def test_patch_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1042,14 +1155,16 @@ def test_patch_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "health_check_value" in http_call[1] + str(body) assert compute.HealthCheck.to_json( - health_check_resource, including_default_value_fields=False + health_check_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1109,6 +1224,7 @@ def test_update_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1159,6 +1275,7 @@ def test_update_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1176,14 +1293,16 @@ def test_update_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "health_check_value" in http_call[1] + str(body) assert compute.HealthCheck.to_json( - health_check_resource, including_default_value_fields=False + health_check_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1331,6 +1450,17 @@ def test_health_checks_auth_adc(): ) +def test_health_checks_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.HealthChecksRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_health_checks_host_no_port(): client = HealthChecksClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_images.py b/tests/unit/gapic/compute_v1/test_images.py index 1b75a086a..86b5bd711 100644 --- a/tests/unit/gapic/compute_v1/test_images.py +++ b/tests/unit/gapic/compute_v1/test_images.py @@ -35,6 +35,7 @@ from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.compute_v1.services.images import ImagesClient +from google.cloud.compute_v1.services.images import pagers from google.cloud.compute_v1.services.images import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -147,7 +148,7 @@ def test_images_client_client_options(client_class, transport_class, transport_n credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -163,7 +164,7 @@ def test_images_client_client_options(client_class, transport_class, transport_n credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -179,7 +180,7 @@ def test_images_client_client_options(client_class, transport_class, transport_n credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -207,7 +208,7 @@ def test_images_client_client_options(client_class, transport_class, transport_n credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -239,29 +240,25 @@ def test_images_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -270,66 +267,53 @@ def test_images_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -348,7 +332,7 @@ def test_images_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -371,7 +355,7 @@ def test_images_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -417,6 +401,7 @@ def test_delete_rest(transport: str = "rest", request_type=compute.DeleteImageRe # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -467,6 +452,7 @@ def test_delete_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -480,7 +466,7 @@ def test_delete_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -540,6 +526,7 @@ def test_deprecate_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -590,6 +577,7 @@ def test_deprecate_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -607,14 +595,16 @@ def test_deprecate_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "image_value" in http_call[1] + str(body) assert compute.DeprecationStatus.to_json( - deprecation_status_resource, including_default_value_fields=False + deprecation_status_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -695,6 +685,7 @@ def test_get_rest(transport: str = "rest", request_type=compute.GetImageRequest) # Wrap the value into a proper Response obj json_return_value = compute.Image.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -766,6 +757,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Image.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -779,7 +771,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -860,6 +852,7 @@ def test_get_from_family_rest( # Wrap the value into a proper Response obj json_return_value = compute.Image.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -931,6 +924,7 @@ def test_get_from_family_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Image.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -944,7 +938,7 @@ def test_get_from_family_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -997,6 +991,7 @@ def test_get_iam_policy_rest( # Wrap the value into a proper Response obj json_return_value = compute.Policy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1035,6 +1030,7 @@ def test_get_iam_policy_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Policy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1048,7 +1044,7 @@ def test_get_iam_policy_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1108,6 +1104,7 @@ def test_insert_rest(transport: str = "rest", request_type=compute.InsertImageRe # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1158,6 +1155,7 @@ def test_insert_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1173,12 +1171,14 @@ def test_insert_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert compute.Image.to_json( - image_resource, including_default_value_fields=False + image_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1218,16 +1218,15 @@ def test_list_rest(transport: str = "rest", request_type=compute.ListImagesReque # Wrap the value into a proper Response obj json_return_value = compute.ImageList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.ImageList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [ compute.Image(archive_size_bytes="archive_size_bytes_value") @@ -1253,6 +1252,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.ImageList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1264,7 +1264,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1280,6 +1280,49 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = ImagesClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.ImageList( + items=[compute.Image(), compute.Image(), compute.Image(),], + next_page_token="abc", + ), + compute.ImageList(items=[], next_page_token="def",), + compute.ImageList(items=[compute.Image(),], next_page_token="ghi",), + compute.ImageList(items=[compute.Image(), compute.Image(),],), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.ImageList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.Image) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_patch_rest(transport: str = "rest", request_type=compute.PatchImageRequest): client = ImagesClient( credentials=credentials.AnonymousCredentials(), transport=transport, @@ -1320,6 +1363,7 @@ def test_patch_rest(transport: str = "rest", request_type=compute.PatchImageRequ # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1370,6 +1414,7 @@ def test_patch_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1385,14 +1430,16 @@ def test_patch_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "image_value" in http_call[1] + str(body) assert compute.Image.to_json( - image_resource, including_default_value_fields=False + image_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1443,6 +1490,7 @@ def test_set_iam_policy_rest( # Wrap the value into a proper Response obj json_return_value = compute.Policy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1481,6 +1529,7 @@ def test_set_iam_policy_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Policy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1500,14 +1549,16 @@ def test_set_iam_policy_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "resource_value" in http_call[1] + str(body) assert compute.GlobalSetPolicyRequest.to_json( - global_set_policy_request_resource, including_default_value_fields=False + global_set_policy_request_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1569,6 +1620,7 @@ def test_set_labels_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1619,6 +1671,7 @@ def test_set_labels_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1638,14 +1691,16 @@ def test_set_labels_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "resource_value" in http_call[1] + str(body) assert compute.GlobalSetLabelsRequest.to_json( - global_set_labels_request_resource, including_default_value_fields=False + global_set_labels_request_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1685,6 +1740,7 @@ def test_test_iam_permissions_rest( # Wrap the value into a proper Response obj json_return_value = compute.TestPermissionsResponse.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1711,6 +1767,7 @@ def test_test_iam_permissions_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.TestPermissionsResponse.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1730,14 +1787,16 @@ def test_test_iam_permissions_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "resource_value" in http_call[1] + str(body) assert compute.TestPermissionsRequest.to_json( - test_permissions_request_resource, including_default_value_fields=False + test_permissions_request_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1891,6 +1950,17 @@ def test_images_auth_adc(): ) +def test_images_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.ImagesRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_images_host_no_port(): client = ImagesClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_instance_group_managers.py b/tests/unit/gapic/compute_v1/test_instance_group_managers.py index ad34d7aec..49c9ae908 100644 --- a/tests/unit/gapic/compute_v1/test_instance_group_managers.py +++ b/tests/unit/gapic/compute_v1/test_instance_group_managers.py @@ -37,6 +37,7 @@ from google.cloud.compute_v1.services.instance_group_managers import ( InstanceGroupManagersClient, ) +from google.cloud.compute_v1.services.instance_group_managers import pagers from google.cloud.compute_v1.services.instance_group_managers import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -166,7 +167,7 @@ def test_instance_group_managers_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -182,7 +183,7 @@ def test_instance_group_managers_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -198,7 +199,7 @@ def test_instance_group_managers_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -226,7 +227,7 @@ def test_instance_group_managers_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -270,29 +271,25 @@ def test_instance_group_managers_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -301,66 +298,53 @@ def test_instance_group_managers_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -385,7 +369,7 @@ def test_instance_group_managers_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -414,7 +398,7 @@ def test_instance_group_managers_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -463,6 +447,7 @@ def test_abandon_instances_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -515,6 +500,7 @@ def test_abandon_instances_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -535,7 +521,7 @@ def test_abandon_instances_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -546,6 +532,7 @@ def test_abandon_instances_rest_flattened(): assert compute.InstanceGroupManagersAbandonInstancesRequest.to_json( instance_group_managers_abandon_instances_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -609,16 +596,15 @@ def test_aggregated_list_rest( return_value ) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.aggregated_list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.InstanceGroupManagerAggregatedList) + assert isinstance(response, pagers.AggregatedListPager) assert response.id == "id_value" assert response.items == { "key_value": compute.InstanceGroupManagersScopedList( @@ -659,6 +645,7 @@ def test_aggregated_list_rest_flattened(): return_value ) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -670,7 +657,7 @@ def test_aggregated_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -689,6 +676,79 @@ def test_aggregated_list_rest_flattened_error(): ) +def test_aggregated_list_pager(): + client = InstanceGroupManagersClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.InstanceGroupManagerAggregatedList( + items={ + "a": compute.InstanceGroupManagersScopedList(), + "b": compute.InstanceGroupManagersScopedList(), + "c": compute.InstanceGroupManagersScopedList(), + }, + next_page_token="abc", + ), + compute.InstanceGroupManagerAggregatedList( + items={}, next_page_token="def", + ), + compute.InstanceGroupManagerAggregatedList( + items={"g": compute.InstanceGroupManagersScopedList(),}, + next_page_token="ghi", + ), + compute.InstanceGroupManagerAggregatedList( + items={ + "h": compute.InstanceGroupManagersScopedList(), + "i": compute.InstanceGroupManagersScopedList(), + }, + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple( + compute.InstanceGroupManagerAggregatedList.to_json(x) for x in response + ) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.aggregated_list(request={}) + + assert pager._metadata == metadata + + assert isinstance(pager.get("a"), compute.InstanceGroupManagersScopedList) + assert pager.get("h") is None + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, tuple) for i in results) + for result in results: + assert isinstance(result, tuple) + assert tuple(type(t) for t in result) == ( + str, + compute.InstanceGroupManagersScopedList, + ) + + assert pager.get("a") is None + assert isinstance(pager.get("h"), compute.InstanceGroupManagersScopedList) + + pages = list(client.aggregated_list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_apply_updates_to_instances_rest( transport: str = "rest", request_type=compute.ApplyUpdatesToInstancesInstanceGroupManagerRequest, @@ -732,6 +792,7 @@ def test_apply_updates_to_instances_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -784,6 +845,7 @@ def test_apply_updates_to_instances_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -804,7 +866,7 @@ def test_apply_updates_to_instances_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -815,6 +877,7 @@ def test_apply_updates_to_instances_rest_flattened(): assert compute.InstanceGroupManagersApplyUpdatesRequest.to_json( instance_group_managers_apply_updates_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -880,6 +943,7 @@ def test_create_instances_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -932,6 +996,7 @@ def test_create_instances_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -952,7 +1017,7 @@ def test_create_instances_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -963,6 +1028,7 @@ def test_create_instances_rest_flattened(): assert compute.InstanceGroupManagersCreateInstancesRequest.to_json( instance_group_managers_create_instances_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1027,6 +1093,7 @@ def test_delete_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1079,6 +1146,7 @@ def test_delete_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1094,7 +1162,7 @@ def test_delete_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1162,6 +1230,7 @@ def test_delete_instances_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1214,6 +1283,7 @@ def test_delete_instances_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1234,7 +1304,7 @@ def test_delete_instances_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1245,6 +1315,7 @@ def test_delete_instances_rest_flattened(): assert compute.InstanceGroupManagersDeleteInstancesRequest.to_json( instance_group_managers_delete_instances_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1310,6 +1381,7 @@ def test_delete_per_instance_configs_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1362,6 +1434,7 @@ def test_delete_per_instance_configs_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1382,7 +1455,7 @@ def test_delete_per_instance_configs_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1393,6 +1466,7 @@ def test_delete_per_instance_configs_rest_flattened(): assert compute.InstanceGroupManagersDeletePerInstanceConfigsReq.to_json( instance_group_managers_delete_per_instance_configs_req_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1476,6 +1550,7 @@ def test_get_rest( # Wrap the value into a proper Response obj json_return_value = compute.InstanceGroupManager.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1545,6 +1620,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.InstanceGroupManager.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1560,7 +1636,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1627,6 +1703,7 @@ def test_insert_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1679,6 +1756,7 @@ def test_insert_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1702,14 +1780,16 @@ def test_insert_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "zone_value" in http_call[1] + str(body) assert compute.InstanceGroupManager.to_json( - instance_group_manager_resource, including_default_value_fields=False + instance_group_manager_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1768,16 +1848,15 @@ def test_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.InstanceGroupManagerList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.InstanceGroupManagerList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [ compute.InstanceGroupManager( @@ -1811,6 +1890,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.InstanceGroupManagerList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1824,7 +1904,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1846,6 +1926,59 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = InstanceGroupManagersClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.InstanceGroupManagerList( + items=[ + compute.InstanceGroupManager(), + compute.InstanceGroupManager(), + compute.InstanceGroupManager(), + ], + next_page_token="abc", + ), + compute.InstanceGroupManagerList(items=[], next_page_token="def",), + compute.InstanceGroupManagerList( + items=[compute.InstanceGroupManager(),], next_page_token="ghi", + ), + compute.InstanceGroupManagerList( + items=[compute.InstanceGroupManager(), compute.InstanceGroupManager(),], + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.InstanceGroupManagerList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.InstanceGroupManager) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_list_errors_rest( transport: str = "rest", request_type=compute.ListErrorsInstanceGroupManagersRequest ): @@ -1875,16 +2008,15 @@ def test_list_errors_rest( return_value ) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list_errors(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.InstanceGroupManagersListErrorsResponse) + assert isinstance(response, pagers.ListErrorsPager) assert response.items == [ compute.InstanceManagedByIgmError( error=compute.InstanceManagedByIgmErrorManagedInstanceError( @@ -1914,6 +2046,7 @@ def test_list_errors_rest_flattened(): return_value ) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1929,7 +2062,7 @@ def test_list_errors_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1954,6 +2087,66 @@ def test_list_errors_rest_flattened_error(): ) +def test_list_errors_pager(): + client = InstanceGroupManagersClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.InstanceGroupManagersListErrorsResponse( + items=[ + compute.InstanceManagedByIgmError(), + compute.InstanceManagedByIgmError(), + compute.InstanceManagedByIgmError(), + ], + next_page_token="abc", + ), + compute.InstanceGroupManagersListErrorsResponse( + items=[], next_page_token="def", + ), + compute.InstanceGroupManagersListErrorsResponse( + items=[compute.InstanceManagedByIgmError(),], next_page_token="ghi", + ), + compute.InstanceGroupManagersListErrorsResponse( + items=[ + compute.InstanceManagedByIgmError(), + compute.InstanceManagedByIgmError(), + ], + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple( + compute.InstanceGroupManagersListErrorsResponse.to_json(x) for x in response + ) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list_errors(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.InstanceManagedByIgmError) for i in results) + + pages = list(client.list_errors(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_list_managed_instances_rest( transport: str = "rest", request_type=compute.ListManagedInstancesInstanceGroupManagersRequest, @@ -1982,18 +2175,15 @@ def test_list_managed_instances_rest( return_value ) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list_managed_instances(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance( - response, compute.InstanceGroupManagersListManagedInstancesResponse - ) + assert isinstance(response, pagers.ListManagedInstancesPager) assert response.managed_instances == [ compute.ManagedInstance( current_action=compute.ManagedInstance.CurrentAction.ABANDONING @@ -2021,6 +2211,7 @@ def test_list_managed_instances_rest_flattened(): return_value ) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -2036,7 +2227,7 @@ def test_list_managed_instances_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -2061,6 +2252,67 @@ def test_list_managed_instances_rest_flattened_error(): ) +def test_list_managed_instances_pager(): + client = InstanceGroupManagersClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.InstanceGroupManagersListManagedInstancesResponse( + managed_instances=[ + compute.ManagedInstance(), + compute.ManagedInstance(), + compute.ManagedInstance(), + ], + next_page_token="abc", + ), + compute.InstanceGroupManagersListManagedInstancesResponse( + managed_instances=[], next_page_token="def", + ), + compute.InstanceGroupManagersListManagedInstancesResponse( + managed_instances=[compute.ManagedInstance(),], next_page_token="ghi", + ), + compute.InstanceGroupManagersListManagedInstancesResponse( + managed_instances=[ + compute.ManagedInstance(), + compute.ManagedInstance(), + ], + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple( + compute.InstanceGroupManagersListManagedInstancesResponse.to_json(x) + for x in response + ) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list_managed_instances(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.ManagedInstance) for i in results) + + pages = list(client.list_managed_instances(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_list_per_instance_configs_rest( transport: str = "rest", request_type=compute.ListPerInstanceConfigsInstanceGroupManagersRequest, @@ -2086,16 +2338,15 @@ def test_list_per_instance_configs_rest( return_value ) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list_per_instance_configs(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.InstanceGroupManagersListPerInstanceConfigsResp) + assert isinstance(response, pagers.ListPerInstanceConfigsPager) assert response.items == [ compute.PerInstanceConfig(fingerprint="fingerprint_value") ] @@ -2122,6 +2373,7 @@ def test_list_per_instance_configs_rest_flattened(): return_value ) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -2137,7 +2389,7 @@ def test_list_per_instance_configs_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -2162,6 +2414,64 @@ def test_list_per_instance_configs_rest_flattened_error(): ) +def test_list_per_instance_configs_pager(): + client = InstanceGroupManagersClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.InstanceGroupManagersListPerInstanceConfigsResp( + items=[ + compute.PerInstanceConfig(), + compute.PerInstanceConfig(), + compute.PerInstanceConfig(), + ], + next_page_token="abc", + ), + compute.InstanceGroupManagersListPerInstanceConfigsResp( + items=[], next_page_token="def", + ), + compute.InstanceGroupManagersListPerInstanceConfigsResp( + items=[compute.PerInstanceConfig(),], next_page_token="ghi", + ), + compute.InstanceGroupManagersListPerInstanceConfigsResp( + items=[compute.PerInstanceConfig(), compute.PerInstanceConfig(),], + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple( + compute.InstanceGroupManagersListPerInstanceConfigsResp.to_json(x) + for x in response + ) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list_per_instance_configs(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.PerInstanceConfig) for i in results) + + pages = list(client.list_per_instance_configs(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_patch_rest( transport: str = "rest", request_type=compute.PatchInstanceGroupManagerRequest ): @@ -2204,6 +2514,7 @@ def test_patch_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -2256,6 +2567,7 @@ def test_patch_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -2280,7 +2592,7 @@ def test_patch_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -2289,7 +2601,9 @@ def test_patch_rest_flattened(): assert "instance_group_manager_value" in http_call[1] + str(body) assert compute.InstanceGroupManager.to_json( - instance_group_manager_resource, including_default_value_fields=False + instance_group_manager_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -2359,6 +2673,7 @@ def test_patch_per_instance_configs_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -2411,6 +2726,7 @@ def test_patch_per_instance_configs_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -2433,7 +2749,7 @@ def test_patch_per_instance_configs_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -2444,6 +2760,7 @@ def test_patch_per_instance_configs_rest_flattened(): assert compute.InstanceGroupManagersPatchPerInstanceConfigsReq.to_json( instance_group_managers_patch_per_instance_configs_req_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -2511,6 +2828,7 @@ def test_recreate_instances_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -2563,6 +2881,7 @@ def test_recreate_instances_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -2583,7 +2902,7 @@ def test_recreate_instances_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -2594,6 +2913,7 @@ def test_recreate_instances_rest_flattened(): assert compute.InstanceGroupManagersRecreateInstancesRequest.to_json( instance_group_managers_recreate_instances_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -2658,6 +2978,7 @@ def test_resize_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -2710,6 +3031,7 @@ def test_resize_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -2726,7 +3048,7 @@ def test_resize_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -2797,6 +3119,7 @@ def test_set_instance_template_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -2849,6 +3172,7 @@ def test_set_instance_template_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -2869,7 +3193,7 @@ def test_set_instance_template_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -2880,6 +3204,7 @@ def test_set_instance_template_rest_flattened(): assert compute.InstanceGroupManagersSetInstanceTemplateRequest.to_json( instance_group_managers_set_instance_template_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -2945,6 +3270,7 @@ def test_set_target_pools_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -2997,6 +3323,7 @@ def test_set_target_pools_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -3017,7 +3344,7 @@ def test_set_target_pools_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -3028,6 +3355,7 @@ def test_set_target_pools_rest_flattened(): assert compute.InstanceGroupManagersSetTargetPoolsRequest.to_json( instance_group_managers_set_target_pools_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -3093,6 +3421,7 @@ def test_update_per_instance_configs_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -3145,6 +3474,7 @@ def test_update_per_instance_configs_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -3167,7 +3497,7 @@ def test_update_per_instance_configs_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -3178,6 +3508,7 @@ def test_update_per_instance_configs_rest_flattened(): assert compute.InstanceGroupManagersUpdatePerInstanceConfigsReq.to_json( instance_group_managers_update_per_instance_configs_req_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -3347,6 +3678,17 @@ def test_instance_group_managers_auth_adc(): ) +def test_instance_group_managers_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.InstanceGroupManagersRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_instance_group_managers_host_no_port(): client = InstanceGroupManagersClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_instance_groups.py b/tests/unit/gapic/compute_v1/test_instance_groups.py index 8b0dad861..e0d328caf 100644 --- a/tests/unit/gapic/compute_v1/test_instance_groups.py +++ b/tests/unit/gapic/compute_v1/test_instance_groups.py @@ -35,6 +35,7 @@ from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.compute_v1.services.instance_groups import InstanceGroupsClient +from google.cloud.compute_v1.services.instance_groups import pagers from google.cloud.compute_v1.services.instance_groups import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -157,7 +158,7 @@ def test_instance_groups_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -173,7 +174,7 @@ def test_instance_groups_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -189,7 +190,7 @@ def test_instance_groups_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -217,7 +218,7 @@ def test_instance_groups_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -251,29 +252,25 @@ def test_instance_groups_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -282,66 +279,53 @@ def test_instance_groups_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -360,7 +344,7 @@ def test_instance_groups_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -383,7 +367,7 @@ def test_instance_groups_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -431,6 +415,7 @@ def test_add_instances_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -481,6 +466,7 @@ def test_add_instances_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -501,7 +487,7 @@ def test_add_instances_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -512,6 +498,7 @@ def test_add_instances_rest_flattened(): assert compute.InstanceGroupsAddInstancesRequest.to_json( instance_groups_add_instances_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -566,16 +553,15 @@ def test_aggregated_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.InstanceGroupAggregatedList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.aggregated_list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.InstanceGroupAggregatedList) + assert isinstance(response, pagers.AggregatedListPager) assert response.id == "id_value" assert response.items == { "key_value": compute.InstanceGroupsScopedList( @@ -606,6 +592,7 @@ def test_aggregated_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.InstanceGroupAggregatedList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -617,7 +604,7 @@ def test_aggregated_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -633,6 +620,74 @@ def test_aggregated_list_rest_flattened_error(): ) +def test_aggregated_list_pager(): + client = InstanceGroupsClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.InstanceGroupAggregatedList( + items={ + "a": compute.InstanceGroupsScopedList(), + "b": compute.InstanceGroupsScopedList(), + "c": compute.InstanceGroupsScopedList(), + }, + next_page_token="abc", + ), + compute.InstanceGroupAggregatedList(items={}, next_page_token="def",), + compute.InstanceGroupAggregatedList( + items={"g": compute.InstanceGroupsScopedList(),}, next_page_token="ghi", + ), + compute.InstanceGroupAggregatedList( + items={ + "h": compute.InstanceGroupsScopedList(), + "i": compute.InstanceGroupsScopedList(), + }, + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple( + compute.InstanceGroupAggregatedList.to_json(x) for x in response + ) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.aggregated_list(request={}) + + assert pager._metadata == metadata + + assert isinstance(pager.get("a"), compute.InstanceGroupsScopedList) + assert pager.get("h") is None + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, tuple) for i in results) + for result in results: + assert isinstance(result, tuple) + assert tuple(type(t) for t in result) == ( + str, + compute.InstanceGroupsScopedList, + ) + + assert pager.get("a") is None + assert isinstance(pager.get("h"), compute.InstanceGroupsScopedList) + + pages = list(client.aggregated_list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_delete_rest( transport: str = "rest", request_type=compute.DeleteInstanceGroupRequest ): @@ -675,6 +730,7 @@ def test_delete_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -725,6 +781,7 @@ def test_delete_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -740,7 +797,7 @@ def test_delete_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -795,6 +852,7 @@ def test_get_rest( # Wrap the value into a proper Response obj json_return_value = compute.InstanceGroup.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -833,6 +891,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.InstanceGroup.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -848,7 +907,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -913,6 +972,7 @@ def test_insert_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -963,6 +1023,7 @@ def test_insert_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -982,14 +1043,16 @@ def test_insert_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "zone_value" in http_call[1] + str(body) assert compute.InstanceGroup.to_json( - instance_group_resource, including_default_value_fields=False + instance_group_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1036,16 +1099,15 @@ def test_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.InstanceGroupList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.InstanceGroupList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [ compute.InstanceGroup(creation_timestamp="creation_timestamp_value") @@ -1071,6 +1133,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.InstanceGroupList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1084,7 +1147,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1104,6 +1167,57 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = InstanceGroupsClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.InstanceGroupList( + items=[ + compute.InstanceGroup(), + compute.InstanceGroup(), + compute.InstanceGroup(), + ], + next_page_token="abc", + ), + compute.InstanceGroupList(items=[], next_page_token="def",), + compute.InstanceGroupList( + items=[compute.InstanceGroup(),], next_page_token="ghi", + ), + compute.InstanceGroupList( + items=[compute.InstanceGroup(), compute.InstanceGroup(),], + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.InstanceGroupList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.InstanceGroup) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_list_instances_rest( transport: str = "rest", request_type=compute.ListInstancesInstanceGroupsRequest ): @@ -1129,16 +1243,15 @@ def test_list_instances_rest( # Wrap the value into a proper Response obj json_return_value = compute.InstanceGroupsListInstances.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list_instances(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.InstanceGroupsListInstances) + assert isinstance(response, pagers.ListInstancesPager) assert response.id == "id_value" assert response.items == [compute.InstanceWithNamedPorts(instance="instance_value")] assert response.kind == "kind_value" @@ -1162,6 +1275,7 @@ def test_list_instances_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.InstanceGroupsListInstances.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1182,7 +1296,7 @@ def test_list_instances_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1193,6 +1307,7 @@ def test_list_instances_rest_flattened(): assert compute.InstanceGroupsListInstancesRequest.to_json( instance_groups_list_instances_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1213,6 +1328,62 @@ def test_list_instances_rest_flattened_error(): ) +def test_list_instances_pager(): + client = InstanceGroupsClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.InstanceGroupsListInstances( + items=[ + compute.InstanceWithNamedPorts(), + compute.InstanceWithNamedPorts(), + compute.InstanceWithNamedPorts(), + ], + next_page_token="abc", + ), + compute.InstanceGroupsListInstances(items=[], next_page_token="def",), + compute.InstanceGroupsListInstances( + items=[compute.InstanceWithNamedPorts(),], next_page_token="ghi", + ), + compute.InstanceGroupsListInstances( + items=[ + compute.InstanceWithNamedPorts(), + compute.InstanceWithNamedPorts(), + ], + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple( + compute.InstanceGroupsListInstances.to_json(x) for x in response + ) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list_instances(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.InstanceWithNamedPorts) for i in results) + + pages = list(client.list_instances(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_remove_instances_rest( transport: str = "rest", request_type=compute.RemoveInstancesInstanceGroupRequest ): @@ -1255,6 +1426,7 @@ def test_remove_instances_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1305,6 +1477,7 @@ def test_remove_instances_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1325,7 +1498,7 @@ def test_remove_instances_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1336,6 +1509,7 @@ def test_remove_instances_rest_flattened(): assert compute.InstanceGroupsRemoveInstancesRequest.to_json( instance_groups_remove_instances_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1398,6 +1572,7 @@ def test_set_named_ports_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1448,6 +1623,7 @@ def test_set_named_ports_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1468,7 +1644,7 @@ def test_set_named_ports_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1479,6 +1655,7 @@ def test_set_named_ports_rest_flattened(): assert compute.InstanceGroupsSetNamedPortsRequest.to_json( instance_groups_set_named_ports_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1631,6 +1808,17 @@ def test_instance_groups_auth_adc(): ) +def test_instance_groups_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.InstanceGroupsRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_instance_groups_host_no_port(): client = InstanceGroupsClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_instance_templates.py b/tests/unit/gapic/compute_v1/test_instance_templates.py index 2ac641b96..35fde8f9f 100644 --- a/tests/unit/gapic/compute_v1/test_instance_templates.py +++ b/tests/unit/gapic/compute_v1/test_instance_templates.py @@ -35,6 +35,7 @@ from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.compute_v1.services.instance_templates import InstanceTemplatesClient +from google.cloud.compute_v1.services.instance_templates import pagers from google.cloud.compute_v1.services.instance_templates import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -158,7 +159,7 @@ def test_instance_templates_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -174,7 +175,7 @@ def test_instance_templates_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -190,7 +191,7 @@ def test_instance_templates_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -218,7 +219,7 @@ def test_instance_templates_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -262,29 +263,25 @@ def test_instance_templates_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -293,66 +290,53 @@ def test_instance_templates_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -371,7 +355,7 @@ def test_instance_templates_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -394,7 +378,7 @@ def test_instance_templates_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -442,6 +426,7 @@ def test_delete_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -492,6 +477,7 @@ def test_delete_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -505,7 +491,7 @@ def test_delete_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -555,6 +541,7 @@ def test_get_rest( # Wrap the value into a proper Response obj json_return_value = compute.InstanceTemplate.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -591,6 +578,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.InstanceTemplate.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -604,7 +592,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -657,6 +645,7 @@ def test_get_iam_policy_rest( # Wrap the value into a proper Response obj json_return_value = compute.Policy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -695,6 +684,7 @@ def test_get_iam_policy_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Policy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -708,7 +698,7 @@ def test_get_iam_policy_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -770,6 +760,7 @@ def test_insert_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -820,6 +811,7 @@ def test_insert_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -838,12 +830,14 @@ def test_insert_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert compute.InstanceTemplate.to_json( - instance_template_resource, including_default_value_fields=False + instance_template_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -889,16 +883,15 @@ def test_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.InstanceTemplateList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.InstanceTemplateList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [ compute.InstanceTemplate(creation_timestamp="creation_timestamp_value") @@ -924,6 +917,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.InstanceTemplateList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -935,7 +929,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -951,6 +945,57 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = InstanceTemplatesClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.InstanceTemplateList( + items=[ + compute.InstanceTemplate(), + compute.InstanceTemplate(), + compute.InstanceTemplate(), + ], + next_page_token="abc", + ), + compute.InstanceTemplateList(items=[], next_page_token="def",), + compute.InstanceTemplateList( + items=[compute.InstanceTemplate(),], next_page_token="ghi", + ), + compute.InstanceTemplateList( + items=[compute.InstanceTemplate(), compute.InstanceTemplate(),], + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.InstanceTemplateList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.InstanceTemplate) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_set_iam_policy_rest( transport: str = "rest", request_type=compute.SetIamPolicyInstanceTemplateRequest ): @@ -984,6 +1029,7 @@ def test_set_iam_policy_rest( # Wrap the value into a proper Response obj json_return_value = compute.Policy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1022,6 +1068,7 @@ def test_set_iam_policy_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Policy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1041,14 +1088,16 @@ def test_set_iam_policy_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "resource_value" in http_call[1] + str(body) assert compute.GlobalSetPolicyRequest.to_json( - global_set_policy_request_resource, including_default_value_fields=False + global_set_policy_request_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1089,6 +1138,7 @@ def test_test_iam_permissions_rest( # Wrap the value into a proper Response obj json_return_value = compute.TestPermissionsResponse.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1115,6 +1165,7 @@ def test_test_iam_permissions_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.TestPermissionsResponse.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1134,14 +1185,16 @@ def test_test_iam_permissions_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "resource_value" in http_call[1] + str(body) assert compute.TestPermissionsRequest.to_json( - test_permissions_request_resource, including_default_value_fields=False + test_permissions_request_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1293,6 +1346,17 @@ def test_instance_templates_auth_adc(): ) +def test_instance_templates_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.InstanceTemplatesRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_instance_templates_host_no_port(): client = InstanceTemplatesClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_instances.py b/tests/unit/gapic/compute_v1/test_instances.py index 016398b50..7e2d9d0be 100644 --- a/tests/unit/gapic/compute_v1/test_instances.py +++ b/tests/unit/gapic/compute_v1/test_instances.py @@ -35,6 +35,7 @@ from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.compute_v1.services.instances import InstancesClient +from google.cloud.compute_v1.services.instances import pagers from google.cloud.compute_v1.services.instances import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -148,7 +149,7 @@ def test_instances_client_client_options(client_class, transport_class, transpor credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -164,7 +165,7 @@ def test_instances_client_client_options(client_class, transport_class, transpor credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -180,7 +181,7 @@ def test_instances_client_client_options(client_class, transport_class, transpor credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -208,7 +209,7 @@ def test_instances_client_client_options(client_class, transport_class, transpor credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -240,29 +241,25 @@ def test_instances_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -271,66 +268,53 @@ def test_instances_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -349,7 +333,7 @@ def test_instances_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -372,7 +356,7 @@ def test_instances_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -420,6 +404,7 @@ def test_add_access_config_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -470,6 +455,7 @@ def test_add_access_config_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -489,7 +475,7 @@ def test_add_access_config_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -500,7 +486,9 @@ def test_add_access_config_rest_flattened(): assert "network_interface_value" in http_call[1] + str(body) assert compute.AccessConfig.to_json( - access_config_resource, including_default_value_fields=False + access_config_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -562,6 +550,7 @@ def test_add_resource_policies_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -612,6 +601,7 @@ def test_add_resource_policies_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -632,7 +622,7 @@ def test_add_resource_policies_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -643,6 +633,7 @@ def test_add_resource_policies_rest_flattened(): assert compute.InstancesAddResourcePoliciesRequest.to_json( instances_add_resource_policies_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -693,16 +684,15 @@ def test_aggregated_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.InstanceAggregatedList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.aggregated_list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.InstanceAggregatedList) + assert isinstance(response, pagers.AggregatedListPager) assert response.id == "id_value" assert response.items == { "key_value": compute.InstancesScopedList( @@ -731,6 +721,7 @@ def test_aggregated_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.InstanceAggregatedList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -742,7 +733,7 @@ def test_aggregated_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -758,6 +749,69 @@ def test_aggregated_list_rest_flattened_error(): ) +def test_aggregated_list_pager(): + client = InstancesClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.InstanceAggregatedList( + items={ + "a": compute.InstancesScopedList(), + "b": compute.InstancesScopedList(), + "c": compute.InstancesScopedList(), + }, + next_page_token="abc", + ), + compute.InstanceAggregatedList(items={}, next_page_token="def",), + compute.InstanceAggregatedList( + items={"g": compute.InstancesScopedList(),}, next_page_token="ghi", + ), + compute.InstanceAggregatedList( + items={ + "h": compute.InstancesScopedList(), + "i": compute.InstancesScopedList(), + }, + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.InstanceAggregatedList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.aggregated_list(request={}) + + assert pager._metadata == metadata + + assert isinstance(pager.get("a"), compute.InstancesScopedList) + assert pager.get("h") is None + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, tuple) for i in results) + for result in results: + assert isinstance(result, tuple) + assert tuple(type(t) for t in result) == (str, compute.InstancesScopedList) + + assert pager.get("a") is None + assert isinstance(pager.get("h"), compute.InstancesScopedList) + + pages = list(client.aggregated_list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_attach_disk_rest( transport: str = "rest", request_type=compute.AttachDiskInstanceRequest ): @@ -800,6 +854,7 @@ def test_attach_disk_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -850,6 +905,7 @@ def test_attach_disk_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -868,7 +924,7 @@ def test_attach_disk_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -877,7 +933,9 @@ def test_attach_disk_rest_flattened(): assert "instance_value" in http_call[1] + str(body) assert compute.AttachedDisk.to_json( - attached_disk_resource, including_default_value_fields=False + attached_disk_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -938,6 +996,7 @@ def test_delete_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -988,6 +1047,7 @@ def test_delete_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1001,7 +1061,7 @@ def test_delete_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1066,6 +1126,7 @@ def test_delete_access_config_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1116,6 +1177,7 @@ def test_delete_access_config_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1133,7 +1195,7 @@ def test_delete_access_config_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1204,6 +1266,7 @@ def test_detach_disk_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1254,6 +1317,7 @@ def test_detach_disk_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1270,7 +1334,7 @@ def test_detach_disk_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1361,6 +1425,7 @@ def test_get_rest(transport: str = "rest", request_type=compute.GetInstanceReque # Wrap the value into a proper Response obj json_return_value = compute.Instance.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1443,6 +1508,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Instance.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1456,7 +1522,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1506,6 +1572,7 @@ def test_get_guest_attributes_rest( # Wrap the value into a proper Response obj json_return_value = compute.GuestAttributes.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1539,6 +1606,7 @@ def test_get_guest_attributes_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.GuestAttributes.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1552,7 +1620,7 @@ def test_get_guest_attributes_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1608,6 +1676,7 @@ def test_get_iam_policy_rest( # Wrap the value into a proper Response obj json_return_value = compute.Policy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1646,6 +1715,7 @@ def test_get_iam_policy_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Policy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1659,7 +1729,7 @@ def test_get_iam_policy_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1700,6 +1770,7 @@ def test_get_screenshot_rest( # Wrap the value into a proper Response obj json_return_value = compute.Screenshot.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1727,6 +1798,7 @@ def test_get_screenshot_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Screenshot.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1740,7 +1812,7 @@ def test_get_screenshot_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1787,6 +1859,7 @@ def test_get_serial_port_output_rest( # Wrap the value into a proper Response obj json_return_value = compute.SerialPortOutput.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1817,6 +1890,7 @@ def test_get_serial_port_output_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.SerialPortOutput.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1830,7 +1904,7 @@ def test_get_serial_port_output_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1878,6 +1952,7 @@ def test_get_shielded_instance_identity_rest( # Wrap the value into a proper Response obj json_return_value = compute.ShieldedInstanceIdentity.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1910,6 +1985,7 @@ def test_get_shielded_instance_identity_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.ShieldedInstanceIdentity.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1923,7 +1999,7 @@ def test_get_shielded_instance_identity_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1988,6 +2064,7 @@ def test_insert_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -2038,6 +2115,7 @@ def test_insert_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -2055,14 +2133,16 @@ def test_insert_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "zone_value" in http_call[1] + str(body) assert compute.Instance.to_json( - instance_resource, including_default_value_fields=False + instance_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -2103,16 +2183,15 @@ def test_list_rest(transport: str = "rest", request_type=compute.ListInstancesRe # Wrap the value into a proper Response obj json_return_value = compute.InstanceList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.InstanceList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [compute.Instance(can_ip_forward=True)] assert response.kind == "kind_value" @@ -2136,6 +2215,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.InstanceList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -2149,7 +2229,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -2167,6 +2247,49 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = InstancesClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.InstanceList( + items=[compute.Instance(), compute.Instance(), compute.Instance(),], + next_page_token="abc", + ), + compute.InstanceList(items=[], next_page_token="def",), + compute.InstanceList(items=[compute.Instance(),], next_page_token="ghi",), + compute.InstanceList(items=[compute.Instance(), compute.Instance(),],), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.InstanceList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.Instance) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_list_referrers_rest( transport: str = "rest", request_type=compute.ListReferrersInstancesRequest ): @@ -2192,16 +2315,15 @@ def test_list_referrers_rest( # Wrap the value into a proper Response obj json_return_value = compute.InstanceListReferrers.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list_referrers(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.InstanceListReferrers) + assert isinstance(response, pagers.ListReferrersPager) assert response.id == "id_value" assert response.items == [compute.Reference(kind="kind_value")] assert response.kind == "kind_value" @@ -2225,6 +2347,7 @@ def test_list_referrers_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.InstanceListReferrers.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -2238,7 +2361,7 @@ def test_list_referrers_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -2261,6 +2384,53 @@ def test_list_referrers_rest_flattened_error(): ) +def test_list_referrers_pager(): + client = InstancesClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.InstanceListReferrers( + items=[compute.Reference(), compute.Reference(), compute.Reference(),], + next_page_token="abc", + ), + compute.InstanceListReferrers(items=[], next_page_token="def",), + compute.InstanceListReferrers( + items=[compute.Reference(),], next_page_token="ghi", + ), + compute.InstanceListReferrers( + items=[compute.Reference(), compute.Reference(),], + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.InstanceListReferrers.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list_referrers(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.Reference) for i in results) + + pages = list(client.list_referrers(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_remove_resource_policies_rest( transport: str = "rest", request_type=compute.RemoveResourcePoliciesInstanceRequest ): @@ -2303,6 +2473,7 @@ def test_remove_resource_policies_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -2353,6 +2524,7 @@ def test_remove_resource_policies_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -2373,7 +2545,7 @@ def test_remove_resource_policies_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -2384,6 +2556,7 @@ def test_remove_resource_policies_rest_flattened(): assert compute.InstancesRemoveResourcePoliciesRequest.to_json( instances_remove_resource_policies_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -2444,6 +2617,7 @@ def test_reset_rest(transport: str = "rest", request_type=compute.ResetInstanceR # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -2494,6 +2668,7 @@ def test_reset_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -2507,7 +2682,7 @@ def test_reset_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -2572,6 +2747,7 @@ def test_set_deletion_protection_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -2622,6 +2798,7 @@ def test_set_deletion_protection_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -2635,7 +2812,7 @@ def test_set_deletion_protection_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -2700,6 +2877,7 @@ def test_set_disk_auto_delete_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -2750,6 +2928,7 @@ def test_set_disk_auto_delete_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -2767,7 +2946,7 @@ def test_set_disk_auto_delete_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -2829,6 +3008,7 @@ def test_set_iam_policy_rest( # Wrap the value into a proper Response obj json_return_value = compute.Policy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -2867,6 +3047,7 @@ def test_set_iam_policy_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Policy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -2887,7 +3068,7 @@ def test_set_iam_policy_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -2896,7 +3077,9 @@ def test_set_iam_policy_rest_flattened(): assert "resource_value" in http_call[1] + str(body) assert compute.ZoneSetPolicyRequest.to_json( - zone_set_policy_request_resource, including_default_value_fields=False + zone_set_policy_request_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -2959,6 +3142,7 @@ def test_set_labels_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -3009,6 +3193,7 @@ def test_set_labels_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -3029,7 +3214,7 @@ def test_set_labels_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -3038,7 +3223,9 @@ def test_set_labels_rest_flattened(): assert "instance_value" in http_call[1] + str(body) assert compute.InstancesSetLabelsRequest.to_json( - instances_set_labels_request_resource, including_default_value_fields=False + instances_set_labels_request_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -3101,6 +3288,7 @@ def test_set_machine_resources_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -3151,6 +3339,7 @@ def test_set_machine_resources_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -3171,7 +3360,7 @@ def test_set_machine_resources_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -3182,6 +3371,7 @@ def test_set_machine_resources_rest_flattened(): assert compute.InstancesSetMachineResourcesRequest.to_json( instances_set_machine_resources_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -3244,6 +3434,7 @@ def test_set_machine_type_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -3294,6 +3485,7 @@ def test_set_machine_type_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -3314,7 +3506,7 @@ def test_set_machine_type_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -3325,6 +3517,7 @@ def test_set_machine_type_rest_flattened(): assert compute.InstancesSetMachineTypeRequest.to_json( instances_set_machine_type_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -3387,6 +3580,7 @@ def test_set_metadata_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -3437,6 +3631,7 @@ def test_set_metadata_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -3455,7 +3650,7 @@ def test_set_metadata_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -3464,7 +3659,9 @@ def test_set_metadata_rest_flattened(): assert "instance_value" in http_call[1] + str(body) assert compute.Metadata.to_json( - metadata_resource, including_default_value_fields=False + metadata_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -3525,6 +3722,7 @@ def test_set_min_cpu_platform_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -3575,6 +3773,7 @@ def test_set_min_cpu_platform_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -3595,7 +3794,7 @@ def test_set_min_cpu_platform_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -3606,6 +3805,7 @@ def test_set_min_cpu_platform_rest_flattened(): assert compute.InstancesSetMinCpuPlatformRequest.to_json( instances_set_min_cpu_platform_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -3668,6 +3868,7 @@ def test_set_scheduling_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -3718,6 +3919,7 @@ def test_set_scheduling_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -3736,7 +3938,7 @@ def test_set_scheduling_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -3745,7 +3947,9 @@ def test_set_scheduling_rest_flattened(): assert "instance_value" in http_call[1] + str(body) assert compute.Scheduling.to_json( - scheduling_resource, including_default_value_fields=False + scheduling_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -3806,6 +4010,7 @@ def test_set_service_account_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -3856,6 +4061,7 @@ def test_set_service_account_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -3876,7 +4082,7 @@ def test_set_service_account_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -3887,6 +4093,7 @@ def test_set_service_account_rest_flattened(): assert compute.InstancesSetServiceAccountRequest.to_json( instances_set_service_account_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -3950,6 +4157,7 @@ def test_set_shielded_instance_integrity_policy_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -4000,6 +4208,7 @@ def test_set_shielded_instance_integrity_policy_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -4020,7 +4229,7 @@ def test_set_shielded_instance_integrity_policy_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -4031,6 +4240,7 @@ def test_set_shielded_instance_integrity_policy_rest_flattened(): assert compute.ShieldedInstanceIntegrityPolicy.to_json( shielded_instance_integrity_policy_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -4093,6 +4303,7 @@ def test_set_tags_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -4143,6 +4354,7 @@ def test_set_tags_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -4161,7 +4373,7 @@ def test_set_tags_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -4170,7 +4382,9 @@ def test_set_tags_rest_flattened(): assert "instance_value" in http_call[1] + str(body) assert compute.Tags.to_json( - tags_resource, including_default_value_fields=False + tags_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -4232,6 +4446,7 @@ def test_simulate_maintenance_event_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -4282,6 +4497,7 @@ def test_simulate_maintenance_event_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -4295,7 +4511,7 @@ def test_simulate_maintenance_event_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -4358,6 +4574,7 @@ def test_start_rest(transport: str = "rest", request_type=compute.StartInstanceR # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -4408,6 +4625,7 @@ def test_start_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -4421,7 +4639,7 @@ def test_start_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -4486,6 +4704,7 @@ def test_start_with_encryption_key_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -4536,6 +4755,7 @@ def test_start_with_encryption_key_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -4562,7 +4782,7 @@ def test_start_with_encryption_key_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -4573,6 +4793,7 @@ def test_start_with_encryption_key_rest_flattened(): assert compute.InstancesStartWithEncryptionKeyRequest.to_json( instances_start_with_encryption_key_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -4639,6 +4860,7 @@ def test_stop_rest(transport: str = "rest", request_type=compute.StopInstanceReq # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -4689,6 +4911,7 @@ def test_stop_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -4702,7 +4925,7 @@ def test_stop_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -4745,6 +4968,7 @@ def test_test_iam_permissions_rest( # Wrap the value into a proper Response obj json_return_value = compute.TestPermissionsResponse.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -4771,6 +4995,7 @@ def test_test_iam_permissions_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.TestPermissionsResponse.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -4791,7 +5016,7 @@ def test_test_iam_permissions_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -4800,7 +5025,9 @@ def test_test_iam_permissions_rest_flattened(): assert "resource_value" in http_call[1] + str(body) assert compute.TestPermissionsRequest.to_json( - test_permissions_request_resource, including_default_value_fields=False + test_permissions_request_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -4863,6 +5090,7 @@ def test_update_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -4913,6 +5141,7 @@ def test_update_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -4931,7 +5160,7 @@ def test_update_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -4940,7 +5169,9 @@ def test_update_rest_flattened(): assert "instance_value" in http_call[1] + str(body) assert compute.Instance.to_json( - instance_resource, including_default_value_fields=False + instance_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -5001,6 +5232,7 @@ def test_update_access_config_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -5051,6 +5283,7 @@ def test_update_access_config_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -5070,7 +5303,7 @@ def test_update_access_config_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -5081,7 +5314,9 @@ def test_update_access_config_rest_flattened(): assert "network_interface_value" in http_call[1] + str(body) assert compute.AccessConfig.to_json( - access_config_resource, including_default_value_fields=False + access_config_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -5143,6 +5378,7 @@ def test_update_display_device_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -5193,6 +5429,7 @@ def test_update_display_device_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -5211,7 +5448,7 @@ def test_update_display_device_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -5220,7 +5457,9 @@ def test_update_display_device_rest_flattened(): assert "instance_value" in http_call[1] + str(body) assert compute.DisplayDevice.to_json( - display_device_resource, including_default_value_fields=False + display_device_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -5281,6 +5520,7 @@ def test_update_network_interface_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -5331,6 +5571,7 @@ def test_update_network_interface_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -5352,7 +5593,7 @@ def test_update_network_interface_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -5363,7 +5604,9 @@ def test_update_network_interface_rest_flattened(): assert "network_interface_value" in http_call[1] + str(body) assert compute.NetworkInterface.to_json( - network_interface_resource, including_default_value_fields=False + network_interface_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -5428,6 +5671,7 @@ def test_update_shielded_instance_config_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -5478,6 +5722,7 @@ def test_update_shielded_instance_config_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -5498,7 +5743,7 @@ def test_update_shielded_instance_config_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -5507,7 +5752,9 @@ def test_update_shielded_instance_config_rest_flattened(): assert "instance_value" in http_call[1] + str(body) assert compute.ShieldedInstanceConfig.to_json( - shielded_instance_config_resource, including_default_value_fields=False + shielded_instance_config_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -5691,6 +5938,17 @@ def test_instances_auth_adc(): ) +def test_instances_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.InstancesRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_instances_host_no_port(): client = InstancesClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_interconnect_attachments.py b/tests/unit/gapic/compute_v1/test_interconnect_attachments.py index 0d4c0a2c1..14ee4caaa 100644 --- a/tests/unit/gapic/compute_v1/test_interconnect_attachments.py +++ b/tests/unit/gapic/compute_v1/test_interconnect_attachments.py @@ -37,6 +37,7 @@ from google.cloud.compute_v1.services.interconnect_attachments import ( InterconnectAttachmentsClient, ) +from google.cloud.compute_v1.services.interconnect_attachments import pagers from google.cloud.compute_v1.services.interconnect_attachments import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -166,7 +167,7 @@ def test_interconnect_attachments_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -182,7 +183,7 @@ def test_interconnect_attachments_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -198,7 +199,7 @@ def test_interconnect_attachments_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -226,7 +227,7 @@ def test_interconnect_attachments_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -270,29 +271,25 @@ def test_interconnect_attachments_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -301,66 +298,53 @@ def test_interconnect_attachments_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -385,7 +369,7 @@ def test_interconnect_attachments_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -414,7 +398,7 @@ def test_interconnect_attachments_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -455,16 +439,15 @@ def test_aggregated_list_rest( return_value ) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.aggregated_list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.InterconnectAttachmentAggregatedList) + assert isinstance(response, pagers.AggregatedListPager) assert response.id == "id_value" assert response.items == { "key_value": compute.InterconnectAttachmentsScopedList( @@ -499,6 +482,7 @@ def test_aggregated_list_rest_flattened(): return_value ) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -510,7 +494,7 @@ def test_aggregated_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -529,6 +513,79 @@ def test_aggregated_list_rest_flattened_error(): ) +def test_aggregated_list_pager(): + client = InterconnectAttachmentsClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.InterconnectAttachmentAggregatedList( + items={ + "a": compute.InterconnectAttachmentsScopedList(), + "b": compute.InterconnectAttachmentsScopedList(), + "c": compute.InterconnectAttachmentsScopedList(), + }, + next_page_token="abc", + ), + compute.InterconnectAttachmentAggregatedList( + items={}, next_page_token="def", + ), + compute.InterconnectAttachmentAggregatedList( + items={"g": compute.InterconnectAttachmentsScopedList(),}, + next_page_token="ghi", + ), + compute.InterconnectAttachmentAggregatedList( + items={ + "h": compute.InterconnectAttachmentsScopedList(), + "i": compute.InterconnectAttachmentsScopedList(), + }, + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple( + compute.InterconnectAttachmentAggregatedList.to_json(x) for x in response + ) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.aggregated_list(request={}) + + assert pager._metadata == metadata + + assert isinstance(pager.get("a"), compute.InterconnectAttachmentsScopedList) + assert pager.get("h") is None + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, tuple) for i in results) + for result in results: + assert isinstance(result, tuple) + assert tuple(type(t) for t in result) == ( + str, + compute.InterconnectAttachmentsScopedList, + ) + + assert pager.get("a") is None + assert isinstance(pager.get("h"), compute.InterconnectAttachmentsScopedList) + + pages = list(client.aggregated_list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_delete_rest( transport: str = "rest", request_type=compute.DeleteInterconnectAttachmentRequest ): @@ -571,6 +628,7 @@ def test_delete_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -623,6 +681,7 @@ def test_delete_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -638,7 +697,7 @@ def test_delete_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -712,6 +771,7 @@ def test_get_rest( # Wrap the value into a proper Response obj json_return_value = compute.InterconnectAttachment.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -777,6 +837,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.InterconnectAttachment.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -792,7 +853,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -859,6 +920,7 @@ def test_insert_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -911,6 +973,7 @@ def test_insert_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -930,14 +993,16 @@ def test_insert_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "region_value" in http_call[1] + str(body) assert compute.InterconnectAttachment.to_json( - interconnect_attachment_resource, including_default_value_fields=False + interconnect_attachment_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -984,16 +1049,15 @@ def test_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.InterconnectAttachmentList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.InterconnectAttachmentList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [compute.InterconnectAttachment(admin_enabled=True)] assert response.kind == "kind_value" @@ -1019,6 +1083,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.InterconnectAttachmentList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1032,7 +1097,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1054,6 +1119,64 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = InterconnectAttachmentsClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.InterconnectAttachmentList( + items=[ + compute.InterconnectAttachment(), + compute.InterconnectAttachment(), + compute.InterconnectAttachment(), + ], + next_page_token="abc", + ), + compute.InterconnectAttachmentList(items=[], next_page_token="def",), + compute.InterconnectAttachmentList( + items=[compute.InterconnectAttachment(),], next_page_token="ghi", + ), + compute.InterconnectAttachmentList( + items=[ + compute.InterconnectAttachment(), + compute.InterconnectAttachment(), + ], + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple( + compute.InterconnectAttachmentList.to_json(x) for x in response + ) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.InterconnectAttachment) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_patch_rest( transport: str = "rest", request_type=compute.PatchInterconnectAttachmentRequest ): @@ -1096,6 +1219,7 @@ def test_patch_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1148,6 +1272,7 @@ def test_patch_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1168,7 +1293,7 @@ def test_patch_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1177,7 +1302,9 @@ def test_patch_rest_flattened(): assert "interconnect_attachment_value" in http_call[1] + str(body) assert compute.InterconnectAttachment.to_json( - interconnect_attachment_resource, including_default_value_fields=False + interconnect_attachment_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1331,6 +1458,17 @@ def test_interconnect_attachments_auth_adc(): ) +def test_interconnect_attachments_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.InterconnectAttachmentsRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_interconnect_attachments_host_no_port(): client = InterconnectAttachmentsClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_interconnect_locations.py b/tests/unit/gapic/compute_v1/test_interconnect_locations.py index 3f01216d0..a8129229d 100644 --- a/tests/unit/gapic/compute_v1/test_interconnect_locations.py +++ b/tests/unit/gapic/compute_v1/test_interconnect_locations.py @@ -37,6 +37,7 @@ from google.cloud.compute_v1.services.interconnect_locations import ( InterconnectLocationsClient, ) +from google.cloud.compute_v1.services.interconnect_locations import pagers from google.cloud.compute_v1.services.interconnect_locations import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -166,7 +167,7 @@ def test_interconnect_locations_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -182,7 +183,7 @@ def test_interconnect_locations_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -198,7 +199,7 @@ def test_interconnect_locations_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -226,7 +227,7 @@ def test_interconnect_locations_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -270,29 +271,25 @@ def test_interconnect_locations_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -301,66 +298,53 @@ def test_interconnect_locations_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -385,7 +369,7 @@ def test_interconnect_locations_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -414,7 +398,7 @@ def test_interconnect_locations_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -458,6 +442,7 @@ def test_get_rest( # Wrap the value into a proper Response obj json_return_value = compute.InterconnectLocation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -504,6 +489,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.InterconnectLocation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -518,7 +504,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -565,16 +551,15 @@ def test_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.InterconnectLocationList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.InterconnectLocationList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [compute.InterconnectLocation(address="address_value")] assert response.kind == "kind_value" @@ -600,6 +585,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.InterconnectLocationList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -611,7 +597,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -629,6 +615,59 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = InterconnectLocationsClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.InterconnectLocationList( + items=[ + compute.InterconnectLocation(), + compute.InterconnectLocation(), + compute.InterconnectLocation(), + ], + next_page_token="abc", + ), + compute.InterconnectLocationList(items=[], next_page_token="def",), + compute.InterconnectLocationList( + items=[compute.InterconnectLocation(),], next_page_token="ghi", + ), + compute.InterconnectLocationList( + items=[compute.InterconnectLocation(), compute.InterconnectLocation(),], + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.InterconnectLocationList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.InterconnectLocation) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_credentials_transport_error(): # It is an error to provide credentials and a transport instance. transport = transports.InterconnectLocationsRestTransport( @@ -758,6 +797,17 @@ def test_interconnect_locations_auth_adc(): ) +def test_interconnect_locations_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.InterconnectLocationsRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_interconnect_locations_host_no_port(): client = InterconnectLocationsClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_interconnects.py b/tests/unit/gapic/compute_v1/test_interconnects.py index 3a0cb592a..f0859ddb8 100644 --- a/tests/unit/gapic/compute_v1/test_interconnects.py +++ b/tests/unit/gapic/compute_v1/test_interconnects.py @@ -35,6 +35,7 @@ from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.compute_v1.services.interconnects import InterconnectsClient +from google.cloud.compute_v1.services.interconnects import pagers from google.cloud.compute_v1.services.interconnects import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -157,7 +158,7 @@ def test_interconnects_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -173,7 +174,7 @@ def test_interconnects_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -189,7 +190,7 @@ def test_interconnects_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -217,7 +218,7 @@ def test_interconnects_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -251,29 +252,25 @@ def test_interconnects_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -282,66 +279,53 @@ def test_interconnects_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -360,7 +344,7 @@ def test_interconnects_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -383,7 +367,7 @@ def test_interconnects_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -431,6 +415,7 @@ def test_delete_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -481,6 +466,7 @@ def test_delete_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -494,7 +480,7 @@ def test_delete_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -561,6 +547,7 @@ def test_get_rest(transport: str = "rest", request_type=compute.GetInterconnectR # Wrap the value into a proper Response obj json_return_value = compute.Interconnect.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -619,6 +606,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Interconnect.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -632,7 +620,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -680,6 +668,7 @@ def test_get_diagnostics_rest( return_value ) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -712,6 +701,7 @@ def test_get_diagnostics_rest_flattened(): return_value ) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -725,7 +715,7 @@ def test_get_diagnostics_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -787,6 +777,7 @@ def test_insert_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -837,6 +828,7 @@ def test_insert_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -852,12 +844,14 @@ def test_insert_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert compute.Interconnect.to_json( - interconnect_resource, including_default_value_fields=False + interconnect_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -899,16 +893,15 @@ def test_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.InterconnectList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.InterconnectList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [compute.Interconnect(admin_enabled=True)] assert response.kind == "kind_value" @@ -932,6 +925,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.InterconnectList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -943,7 +937,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -959,6 +953,57 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = InterconnectsClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.InterconnectList( + items=[ + compute.Interconnect(), + compute.Interconnect(), + compute.Interconnect(), + ], + next_page_token="abc", + ), + compute.InterconnectList(items=[], next_page_token="def",), + compute.InterconnectList( + items=[compute.Interconnect(),], next_page_token="ghi", + ), + compute.InterconnectList( + items=[compute.Interconnect(), compute.Interconnect(),], + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.InterconnectList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.Interconnect) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_patch_rest( transport: str = "rest", request_type=compute.PatchInterconnectRequest ): @@ -1001,6 +1046,7 @@ def test_patch_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1051,6 +1097,7 @@ def test_patch_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1068,14 +1115,16 @@ def test_patch_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "interconnect_value" in http_call[1] + str(body) assert compute.Interconnect.to_json( - interconnect_resource, including_default_value_fields=False + interconnect_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1222,6 +1271,17 @@ def test_interconnects_auth_adc(): ) +def test_interconnects_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.InterconnectsRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_interconnects_host_no_port(): client = InterconnectsClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_license_codes.py b/tests/unit/gapic/compute_v1/test_license_codes.py index a489d1c44..f1c8749c8 100644 --- a/tests/unit/gapic/compute_v1/test_license_codes.py +++ b/tests/unit/gapic/compute_v1/test_license_codes.py @@ -152,7 +152,7 @@ def test_license_codes_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -168,7 +168,7 @@ def test_license_codes_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -184,7 +184,7 @@ def test_license_codes_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -212,7 +212,7 @@ def test_license_codes_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -244,29 +244,25 @@ def test_license_codes_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -275,66 +271,53 @@ def test_license_codes_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -353,7 +336,7 @@ def test_license_codes_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -376,7 +359,7 @@ def test_license_codes_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -410,6 +393,7 @@ def test_get_rest(transport: str = "rest", request_type=compute.GetLicenseCodeRe # Wrap the value into a proper Response obj json_return_value = compute.LicenseCode.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -447,6 +431,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.LicenseCode.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -460,7 +445,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -500,6 +485,7 @@ def test_test_iam_permissions_rest( # Wrap the value into a proper Response obj json_return_value = compute.TestPermissionsResponse.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -526,6 +512,7 @@ def test_test_iam_permissions_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.TestPermissionsResponse.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -545,14 +532,16 @@ def test_test_iam_permissions_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "resource_value" in http_call[1] + str(body) assert compute.TestPermissionsRequest.to_json( - test_permissions_request_resource, including_default_value_fields=False + test_permissions_request_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -699,6 +688,17 @@ def test_license_codes_auth_adc(): ) +def test_license_codes_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.LicenseCodesRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_license_codes_host_no_port(): client = LicenseCodesClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_licenses.py b/tests/unit/gapic/compute_v1/test_licenses.py index f49f070d7..794e8d7dd 100644 --- a/tests/unit/gapic/compute_v1/test_licenses.py +++ b/tests/unit/gapic/compute_v1/test_licenses.py @@ -35,6 +35,7 @@ from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.compute_v1.services.licenses import LicensesClient +from google.cloud.compute_v1.services.licenses import pagers from google.cloud.compute_v1.services.licenses import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -148,7 +149,7 @@ def test_licenses_client_client_options(client_class, transport_class, transport credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -164,7 +165,7 @@ def test_licenses_client_client_options(client_class, transport_class, transport credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -180,7 +181,7 @@ def test_licenses_client_client_options(client_class, transport_class, transport credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -208,7 +209,7 @@ def test_licenses_client_client_options(client_class, transport_class, transport credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -240,29 +241,25 @@ def test_licenses_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -271,66 +268,53 @@ def test_licenses_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -349,7 +333,7 @@ def test_licenses_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -372,7 +356,7 @@ def test_licenses_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -420,6 +404,7 @@ def test_delete_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -470,6 +455,7 @@ def test_delete_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -483,7 +469,7 @@ def test_delete_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -532,6 +518,7 @@ def test_get_rest(transport: str = "rest", request_type=compute.GetLicenseReques # Wrap the value into a proper Response obj json_return_value = compute.License.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -571,6 +558,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.License.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -584,7 +572,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -637,6 +625,7 @@ def test_get_iam_policy_rest( # Wrap the value into a proper Response obj json_return_value = compute.Policy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -675,6 +664,7 @@ def test_get_iam_policy_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Policy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -688,7 +678,7 @@ def test_get_iam_policy_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -750,6 +740,7 @@ def test_insert_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -800,6 +791,7 @@ def test_insert_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -815,12 +807,14 @@ def test_insert_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert compute.License.to_json( - license_resource, including_default_value_fields=False + license_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -859,16 +853,15 @@ def test_list_rest(transport: str = "rest", request_type=compute.ListLicensesReq # Wrap the value into a proper Response obj json_return_value = compute.LicensesListResponse.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.LicensesListResponse) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [compute.License(charges_use_fee=True)] assert response.next_page_token == "next_page_token_value" @@ -891,6 +884,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.LicensesListResponse.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -902,7 +896,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -918,6 +912,53 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = LicensesClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.LicensesListResponse( + items=[compute.License(), compute.License(), compute.License(),], + next_page_token="abc", + ), + compute.LicensesListResponse(items=[], next_page_token="def",), + compute.LicensesListResponse( + items=[compute.License(),], next_page_token="ghi", + ), + compute.LicensesListResponse( + items=[compute.License(), compute.License(),], + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.LicensesListResponse.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.License) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_set_iam_policy_rest( transport: str = "rest", request_type=compute.SetIamPolicyLicenseRequest ): @@ -951,6 +992,7 @@ def test_set_iam_policy_rest( # Wrap the value into a proper Response obj json_return_value = compute.Policy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -989,6 +1031,7 @@ def test_set_iam_policy_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Policy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1008,14 +1051,16 @@ def test_set_iam_policy_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "resource_value" in http_call[1] + str(body) assert compute.GlobalSetPolicyRequest.to_json( - global_set_policy_request_resource, including_default_value_fields=False + global_set_policy_request_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1055,6 +1100,7 @@ def test_test_iam_permissions_rest( # Wrap the value into a proper Response obj json_return_value = compute.TestPermissionsResponse.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1081,6 +1127,7 @@ def test_test_iam_permissions_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.TestPermissionsResponse.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1100,14 +1147,16 @@ def test_test_iam_permissions_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "resource_value" in http_call[1] + str(body) assert compute.TestPermissionsRequest.to_json( - test_permissions_request_resource, including_default_value_fields=False + test_permissions_request_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1257,6 +1306,17 @@ def test_licenses_auth_adc(): ) +def test_licenses_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.LicensesRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_licenses_host_no_port(): client = LicensesClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_machine_types.py b/tests/unit/gapic/compute_v1/test_machine_types.py index a7ecfa5ce..b89798a9a 100644 --- a/tests/unit/gapic/compute_v1/test_machine_types.py +++ b/tests/unit/gapic/compute_v1/test_machine_types.py @@ -35,6 +35,7 @@ from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.compute_v1.services.machine_types import MachineTypesClient +from google.cloud.compute_v1.services.machine_types import pagers from google.cloud.compute_v1.services.machine_types import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -152,7 +153,7 @@ def test_machine_types_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -168,7 +169,7 @@ def test_machine_types_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -184,7 +185,7 @@ def test_machine_types_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -212,7 +213,7 @@ def test_machine_types_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -244,29 +245,25 @@ def test_machine_types_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -275,66 +272,53 @@ def test_machine_types_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -353,7 +337,7 @@ def test_machine_types_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -376,7 +360,7 @@ def test_machine_types_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -418,16 +402,15 @@ def test_aggregated_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.MachineTypeAggregatedList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.aggregated_list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.MachineTypeAggregatedList) + assert isinstance(response, pagers.AggregatedListPager) assert response.id == "id_value" assert response.items == { "key_value": compute.MachineTypesScopedList( @@ -460,6 +443,7 @@ def test_aggregated_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.MachineTypeAggregatedList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -471,7 +455,7 @@ def test_aggregated_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -487,6 +471,72 @@ def test_aggregated_list_rest_flattened_error(): ) +def test_aggregated_list_pager(): + client = MachineTypesClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.MachineTypeAggregatedList( + items={ + "a": compute.MachineTypesScopedList(), + "b": compute.MachineTypesScopedList(), + "c": compute.MachineTypesScopedList(), + }, + next_page_token="abc", + ), + compute.MachineTypeAggregatedList(items={}, next_page_token="def",), + compute.MachineTypeAggregatedList( + items={"g": compute.MachineTypesScopedList(),}, next_page_token="ghi", + ), + compute.MachineTypeAggregatedList( + items={ + "h": compute.MachineTypesScopedList(), + "i": compute.MachineTypesScopedList(), + }, + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.MachineTypeAggregatedList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.aggregated_list(request={}) + + assert pager._metadata == metadata + + assert isinstance(pager.get("a"), compute.MachineTypesScopedList) + assert pager.get("h") is None + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, tuple) for i in results) + for result in results: + assert isinstance(result, tuple) + assert tuple(type(t) for t in result) == ( + str, + compute.MachineTypesScopedList, + ) + + assert pager.get("a") is None + assert isinstance(pager.get("h"), compute.MachineTypesScopedList) + + pages = list(client.aggregated_list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_get_rest(transport: str = "rest", request_type=compute.GetMachineTypeRequest): client = MachineTypesClient( credentials=credentials.AnonymousCredentials(), transport=transport, @@ -520,6 +570,7 @@ def test_get_rest(transport: str = "rest", request_type=compute.GetMachineTypeRe # Wrap the value into a proper Response obj json_return_value = compute.MachineType.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -565,6 +616,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.MachineType.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -580,7 +632,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -632,16 +684,15 @@ def test_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.MachineTypeList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.MachineTypeList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [ compute.MachineType( @@ -669,6 +720,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.MachineTypeList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -682,7 +734,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -702,6 +754,57 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = MachineTypesClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.MachineTypeList( + items=[ + compute.MachineType(), + compute.MachineType(), + compute.MachineType(), + ], + next_page_token="abc", + ), + compute.MachineTypeList(items=[], next_page_token="def",), + compute.MachineTypeList( + items=[compute.MachineType(),], next_page_token="ghi", + ), + compute.MachineTypeList( + items=[compute.MachineType(), compute.MachineType(),], + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.MachineTypeList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.MachineType) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_credentials_transport_error(): # It is an error to provide credentials and a transport instance. transport = transports.MachineTypesRestTransport( @@ -830,6 +933,17 @@ def test_machine_types_auth_adc(): ) +def test_machine_types_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.MachineTypesRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_machine_types_host_no_port(): client = MachineTypesClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_network_endpoint_groups.py b/tests/unit/gapic/compute_v1/test_network_endpoint_groups.py index d0c522f12..e9016843e 100644 --- a/tests/unit/gapic/compute_v1/test_network_endpoint_groups.py +++ b/tests/unit/gapic/compute_v1/test_network_endpoint_groups.py @@ -37,6 +37,7 @@ from google.cloud.compute_v1.services.network_endpoint_groups import ( NetworkEndpointGroupsClient, ) +from google.cloud.compute_v1.services.network_endpoint_groups import pagers from google.cloud.compute_v1.services.network_endpoint_groups import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -166,7 +167,7 @@ def test_network_endpoint_groups_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -182,7 +183,7 @@ def test_network_endpoint_groups_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -198,7 +199,7 @@ def test_network_endpoint_groups_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -226,7 +227,7 @@ def test_network_endpoint_groups_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -270,29 +271,25 @@ def test_network_endpoint_groups_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -301,66 +298,53 @@ def test_network_endpoint_groups_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -385,7 +369,7 @@ def test_network_endpoint_groups_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -414,7 +398,7 @@ def test_network_endpoint_groups_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -457,16 +441,15 @@ def test_aggregated_list_rest( return_value ) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.aggregated_list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.NetworkEndpointGroupAggregatedList) + assert isinstance(response, pagers.AggregatedListPager) assert response.id == "id_value" assert response.items == { "key_value": compute.NetworkEndpointGroupsScopedList( @@ -501,6 +484,7 @@ def test_aggregated_list_rest_flattened(): return_value ) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -512,7 +496,7 @@ def test_aggregated_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -531,6 +515,79 @@ def test_aggregated_list_rest_flattened_error(): ) +def test_aggregated_list_pager(): + client = NetworkEndpointGroupsClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.NetworkEndpointGroupAggregatedList( + items={ + "a": compute.NetworkEndpointGroupsScopedList(), + "b": compute.NetworkEndpointGroupsScopedList(), + "c": compute.NetworkEndpointGroupsScopedList(), + }, + next_page_token="abc", + ), + compute.NetworkEndpointGroupAggregatedList( + items={}, next_page_token="def", + ), + compute.NetworkEndpointGroupAggregatedList( + items={"g": compute.NetworkEndpointGroupsScopedList(),}, + next_page_token="ghi", + ), + compute.NetworkEndpointGroupAggregatedList( + items={ + "h": compute.NetworkEndpointGroupsScopedList(), + "i": compute.NetworkEndpointGroupsScopedList(), + }, + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple( + compute.NetworkEndpointGroupAggregatedList.to_json(x) for x in response + ) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.aggregated_list(request={}) + + assert pager._metadata == metadata + + assert isinstance(pager.get("a"), compute.NetworkEndpointGroupsScopedList) + assert pager.get("h") is None + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, tuple) for i in results) + for result in results: + assert isinstance(result, tuple) + assert tuple(type(t) for t in result) == ( + str, + compute.NetworkEndpointGroupsScopedList, + ) + + assert pager.get("a") is None + assert isinstance(pager.get("h"), compute.NetworkEndpointGroupsScopedList) + + pages = list(client.aggregated_list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_attach_network_endpoints_rest( transport: str = "rest", request_type=compute.AttachNetworkEndpointsNetworkEndpointGroupRequest, @@ -574,6 +631,7 @@ def test_attach_network_endpoints_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -626,6 +684,7 @@ def test_attach_network_endpoints_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -648,7 +707,7 @@ def test_attach_network_endpoints_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -659,6 +718,7 @@ def test_attach_network_endpoints_rest_flattened(): assert compute.NetworkEndpointGroupsAttachEndpointsRequest.to_json( network_endpoint_groups_attach_endpoints_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -725,6 +785,7 @@ def test_delete_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -777,6 +838,7 @@ def test_delete_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -792,7 +854,7 @@ def test_delete_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -860,6 +922,7 @@ def test_detach_network_endpoints_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -912,6 +975,7 @@ def test_detach_network_endpoints_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -934,7 +998,7 @@ def test_detach_network_endpoints_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -945,6 +1009,7 @@ def test_detach_network_endpoints_rest_flattened(): assert compute.NetworkEndpointGroupsDetachEndpointsRequest.to_json( network_endpoint_groups_detach_endpoints_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1007,6 +1072,7 @@ def test_get_rest( # Wrap the value into a proper Response obj json_return_value = compute.NetworkEndpointGroup.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1060,6 +1126,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.NetworkEndpointGroup.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1075,7 +1142,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1142,6 +1209,7 @@ def test_insert_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1194,6 +1262,7 @@ def test_insert_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1213,14 +1282,16 @@ def test_insert_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "zone_value" in http_call[1] + str(body) assert compute.NetworkEndpointGroup.to_json( - network_endpoint_group_resource, including_default_value_fields=False + network_endpoint_group_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1269,16 +1340,15 @@ def test_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.NetworkEndpointGroupList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.NetworkEndpointGroupList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [ compute.NetworkEndpointGroup(annotations={"key_value": "value_value"}) @@ -1306,6 +1376,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.NetworkEndpointGroupList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1319,7 +1390,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1341,6 +1412,59 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = NetworkEndpointGroupsClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.NetworkEndpointGroupList( + items=[ + compute.NetworkEndpointGroup(), + compute.NetworkEndpointGroup(), + compute.NetworkEndpointGroup(), + ], + next_page_token="abc", + ), + compute.NetworkEndpointGroupList(items=[], next_page_token="def",), + compute.NetworkEndpointGroupList( + items=[compute.NetworkEndpointGroup(),], next_page_token="ghi", + ), + compute.NetworkEndpointGroupList( + items=[compute.NetworkEndpointGroup(), compute.NetworkEndpointGroup(),], + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.NetworkEndpointGroupList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.NetworkEndpointGroup) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_list_network_endpoints_rest( transport: str = "rest", request_type=compute.ListNetworkEndpointsNetworkEndpointGroupsRequest, @@ -1378,16 +1502,15 @@ def test_list_network_endpoints_rest( return_value ) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list_network_endpoints(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.NetworkEndpointGroupsListNetworkEndpoints) + assert isinstance(response, pagers.ListNetworkEndpointsPager) assert response.id == "id_value" assert response.items == [ compute.NetworkEndpointWithHealthStatus( @@ -1424,6 +1547,7 @@ def test_list_network_endpoints_rest_flattened(): return_value ) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1444,7 +1568,7 @@ def test_list_network_endpoints_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1455,6 +1579,7 @@ def test_list_network_endpoints_rest_flattened(): assert compute.NetworkEndpointGroupsListEndpointsRequest.to_json( network_endpoint_groups_list_endpoints_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1477,6 +1602,70 @@ def test_list_network_endpoints_rest_flattened_error(): ) +def test_list_network_endpoints_pager(): + client = NetworkEndpointGroupsClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.NetworkEndpointGroupsListNetworkEndpoints( + items=[ + compute.NetworkEndpointWithHealthStatus(), + compute.NetworkEndpointWithHealthStatus(), + compute.NetworkEndpointWithHealthStatus(), + ], + next_page_token="abc", + ), + compute.NetworkEndpointGroupsListNetworkEndpoints( + items=[], next_page_token="def", + ), + compute.NetworkEndpointGroupsListNetworkEndpoints( + items=[compute.NetworkEndpointWithHealthStatus(),], + next_page_token="ghi", + ), + compute.NetworkEndpointGroupsListNetworkEndpoints( + items=[ + compute.NetworkEndpointWithHealthStatus(), + compute.NetworkEndpointWithHealthStatus(), + ], + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple( + compute.NetworkEndpointGroupsListNetworkEndpoints.to_json(x) + for x in response + ) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list_network_endpoints(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all( + isinstance(i, compute.NetworkEndpointWithHealthStatus) for i in results + ) + + pages = list(client.list_network_endpoints(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_test_iam_permissions_rest( transport: str = "rest", request_type=compute.TestIamPermissionsNetworkEndpointGroupRequest, @@ -1498,6 +1687,7 @@ def test_test_iam_permissions_rest( # Wrap the value into a proper Response obj json_return_value = compute.TestPermissionsResponse.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1526,6 +1716,7 @@ def test_test_iam_permissions_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.TestPermissionsResponse.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1546,7 +1737,7 @@ def test_test_iam_permissions_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1555,7 +1746,9 @@ def test_test_iam_permissions_rest_flattened(): assert "resource_value" in http_call[1] + str(body) assert compute.TestPermissionsRequest.to_json( - test_permissions_request_resource, including_default_value_fields=False + test_permissions_request_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1712,6 +1905,17 @@ def test_network_endpoint_groups_auth_adc(): ) +def test_network_endpoint_groups_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.NetworkEndpointGroupsRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_network_endpoint_groups_host_no_port(): client = NetworkEndpointGroupsClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_networks.py b/tests/unit/gapic/compute_v1/test_networks.py index 1e7059ced..a2b70a6e2 100644 --- a/tests/unit/gapic/compute_v1/test_networks.py +++ b/tests/unit/gapic/compute_v1/test_networks.py @@ -35,6 +35,7 @@ from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.compute_v1.services.networks import NetworksClient +from google.cloud.compute_v1.services.networks import pagers from google.cloud.compute_v1.services.networks import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -148,7 +149,7 @@ def test_networks_client_client_options(client_class, transport_class, transport credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -164,7 +165,7 @@ def test_networks_client_client_options(client_class, transport_class, transport credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -180,7 +181,7 @@ def test_networks_client_client_options(client_class, transport_class, transport credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -208,7 +209,7 @@ def test_networks_client_client_options(client_class, transport_class, transport credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -240,29 +241,25 @@ def test_networks_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -271,66 +268,53 @@ def test_networks_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -349,7 +333,7 @@ def test_networks_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -372,7 +356,7 @@ def test_networks_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -420,6 +404,7 @@ def test_add_peering_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -470,6 +455,7 @@ def test_add_peering_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -489,14 +475,16 @@ def test_add_peering_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "network_value" in http_call[1] + str(body) assert compute.NetworksAddPeeringRequest.to_json( - networks_add_peering_request_resource, including_default_value_fields=False + networks_add_peering_request_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -558,6 +546,7 @@ def test_delete_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -608,6 +597,7 @@ def test_delete_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -621,7 +611,7 @@ def test_delete_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -673,6 +663,7 @@ def test_get_rest(transport: str = "rest", request_type=compute.GetNetworkReques # Wrap the value into a proper Response obj json_return_value = compute.Network.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -714,6 +705,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Network.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -727,7 +719,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -789,6 +781,7 @@ def test_insert_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -839,6 +832,7 @@ def test_insert_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -854,12 +848,14 @@ def test_insert_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert compute.Network.to_json( - network_resource, including_default_value_fields=False + network_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -899,16 +895,15 @@ def test_list_rest(transport: str = "rest", request_type=compute.ListNetworksReq # Wrap the value into a proper Response obj json_return_value = compute.NetworkList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.NetworkList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [compute.Network(auto_create_subnetworks=True)] assert response.kind == "kind_value" @@ -932,6 +927,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.NetworkList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -943,7 +939,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -959,6 +955,49 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = NetworksClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.NetworkList( + items=[compute.Network(), compute.Network(), compute.Network(),], + next_page_token="abc", + ), + compute.NetworkList(items=[], next_page_token="def",), + compute.NetworkList(items=[compute.Network(),], next_page_token="ghi",), + compute.NetworkList(items=[compute.Network(), compute.Network(),],), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.NetworkList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.Network) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_list_peering_routes_rest( transport: str = "rest", request_type=compute.ListPeeringRoutesNetworksRequest ): @@ -984,16 +1023,15 @@ def test_list_peering_routes_rest( # Wrap the value into a proper Response obj json_return_value = compute.ExchangedPeeringRoutesList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list_peering_routes(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.ExchangedPeeringRoutesList) + assert isinstance(response, pagers.ListPeeringRoutesPager) assert response.id == "id_value" assert response.items == [ compute.ExchangedPeeringRoute(dest_range="dest_range_value") @@ -1019,6 +1057,7 @@ def test_list_peering_routes_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.ExchangedPeeringRoutesList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1032,7 +1071,7 @@ def test_list_peering_routes_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1052,6 +1091,62 @@ def test_list_peering_routes_rest_flattened_error(): ) +def test_list_peering_routes_pager(): + client = NetworksClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.ExchangedPeeringRoutesList( + items=[ + compute.ExchangedPeeringRoute(), + compute.ExchangedPeeringRoute(), + compute.ExchangedPeeringRoute(), + ], + next_page_token="abc", + ), + compute.ExchangedPeeringRoutesList(items=[], next_page_token="def",), + compute.ExchangedPeeringRoutesList( + items=[compute.ExchangedPeeringRoute(),], next_page_token="ghi", + ), + compute.ExchangedPeeringRoutesList( + items=[ + compute.ExchangedPeeringRoute(), + compute.ExchangedPeeringRoute(), + ], + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple( + compute.ExchangedPeeringRoutesList.to_json(x) for x in response + ) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list_peering_routes(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.ExchangedPeeringRoute) for i in results) + + pages = list(client.list_peering_routes(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_patch_rest(transport: str = "rest", request_type=compute.PatchNetworkRequest): client = NetworksClient( credentials=credentials.AnonymousCredentials(), transport=transport, @@ -1092,6 +1187,7 @@ def test_patch_rest(transport: str = "rest", request_type=compute.PatchNetworkRe # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1142,6 +1238,7 @@ def test_patch_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1159,14 +1256,16 @@ def test_patch_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "network_value" in http_call[1] + str(body) assert compute.Network.to_json( - network_resource, including_default_value_fields=False + network_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1226,6 +1325,7 @@ def test_remove_peering_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1276,6 +1376,7 @@ def test_remove_peering_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1295,7 +1396,7 @@ def test_remove_peering_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1304,6 +1405,7 @@ def test_remove_peering_rest_flattened(): assert compute.NetworksRemovePeeringRequest.to_json( networks_remove_peering_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1365,6 +1467,7 @@ def test_switch_to_custom_mode_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1415,6 +1518,7 @@ def test_switch_to_custom_mode_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1428,7 +1532,7 @@ def test_switch_to_custom_mode_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1490,6 +1594,7 @@ def test_update_peering_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1540,6 +1645,7 @@ def test_update_peering_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1559,7 +1665,7 @@ def test_update_peering_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1568,6 +1674,7 @@ def test_update_peering_rest_flattened(): assert compute.NetworksUpdatePeeringRequest.to_json( networks_update_peering_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1720,6 +1827,17 @@ def test_networks_auth_adc(): ) +def test_networks_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.NetworksRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_networks_host_no_port(): client = NetworksClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_node_groups.py b/tests/unit/gapic/compute_v1/test_node_groups.py index 65391ab71..0b02a199b 100644 --- a/tests/unit/gapic/compute_v1/test_node_groups.py +++ b/tests/unit/gapic/compute_v1/test_node_groups.py @@ -35,6 +35,7 @@ from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.compute_v1.services.node_groups import NodeGroupsClient +from google.cloud.compute_v1.services.node_groups import pagers from google.cloud.compute_v1.services.node_groups import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -152,7 +153,7 @@ def test_node_groups_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -168,7 +169,7 @@ def test_node_groups_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -184,7 +185,7 @@ def test_node_groups_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -212,7 +213,7 @@ def test_node_groups_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -244,29 +245,25 @@ def test_node_groups_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -275,66 +272,53 @@ def test_node_groups_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -353,7 +337,7 @@ def test_node_groups_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -376,7 +360,7 @@ def test_node_groups_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -424,6 +408,7 @@ def test_add_nodes_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -474,6 +459,7 @@ def test_add_nodes_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -494,7 +480,7 @@ def test_add_nodes_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -503,7 +489,9 @@ def test_add_nodes_rest_flattened(): assert "node_group_value" in http_call[1] + str(body) assert compute.NodeGroupsAddNodesRequest.to_json( - node_groups_add_nodes_request_resource, including_default_value_fields=False + node_groups_add_nodes_request_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -560,16 +548,15 @@ def test_aggregated_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.NodeGroupAggregatedList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.aggregated_list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.NodeGroupAggregatedList) + assert isinstance(response, pagers.AggregatedListPager) assert response.id == "id_value" assert response.items == { "key_value": compute.NodeGroupsScopedList( @@ -602,6 +589,7 @@ def test_aggregated_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.NodeGroupAggregatedList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -613,7 +601,7 @@ def test_aggregated_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -629,6 +617,69 @@ def test_aggregated_list_rest_flattened_error(): ) +def test_aggregated_list_pager(): + client = NodeGroupsClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.NodeGroupAggregatedList( + items={ + "a": compute.NodeGroupsScopedList(), + "b": compute.NodeGroupsScopedList(), + "c": compute.NodeGroupsScopedList(), + }, + next_page_token="abc", + ), + compute.NodeGroupAggregatedList(items={}, next_page_token="def",), + compute.NodeGroupAggregatedList( + items={"g": compute.NodeGroupsScopedList(),}, next_page_token="ghi", + ), + compute.NodeGroupAggregatedList( + items={ + "h": compute.NodeGroupsScopedList(), + "i": compute.NodeGroupsScopedList(), + }, + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.NodeGroupAggregatedList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.aggregated_list(request={}) + + assert pager._metadata == metadata + + assert isinstance(pager.get("a"), compute.NodeGroupsScopedList) + assert pager.get("h") is None + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, tuple) for i in results) + for result in results: + assert isinstance(result, tuple) + assert tuple(type(t) for t in result) == (str, compute.NodeGroupsScopedList) + + assert pager.get("a") is None + assert isinstance(pager.get("h"), compute.NodeGroupsScopedList) + + pages = list(client.aggregated_list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_delete_rest( transport: str = "rest", request_type=compute.DeleteNodeGroupRequest ): @@ -671,6 +722,7 @@ def test_delete_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -721,6 +773,7 @@ def test_delete_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -734,7 +787,7 @@ def test_delete_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -799,6 +852,7 @@ def test_delete_nodes_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -849,6 +903,7 @@ def test_delete_nodes_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -869,7 +924,7 @@ def test_delete_nodes_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -880,6 +935,7 @@ def test_delete_nodes_rest_flattened(): assert compute.NodeGroupsDeleteNodesRequest.to_json( node_groups_delete_nodes_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -933,6 +989,7 @@ def test_get_rest(transport: str = "rest", request_type=compute.GetNodeGroupRequ # Wrap the value into a proper Response obj json_return_value = compute.NodeGroup.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -976,6 +1033,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.NodeGroup.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -989,7 +1047,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1045,6 +1103,7 @@ def test_get_iam_policy_rest( # Wrap the value into a proper Response obj json_return_value = compute.Policy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1083,6 +1142,7 @@ def test_get_iam_policy_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Policy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1096,7 +1156,7 @@ def test_get_iam_policy_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1161,6 +1221,7 @@ def test_insert_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1211,6 +1272,7 @@ def test_insert_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1231,7 +1293,7 @@ def test_insert_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1240,7 +1302,9 @@ def test_insert_rest_flattened(): assert str(1911) in http_call[1] + str(body) assert compute.NodeGroup.to_json( - node_group_resource, including_default_value_fields=False + node_group_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1288,16 +1352,15 @@ def test_list_rest(transport: str = "rest", request_type=compute.ListNodeGroupsR # Wrap the value into a proper Response obj json_return_value = compute.NodeGroupList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.NodeGroupList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [ compute.NodeGroup( @@ -1325,6 +1388,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.NodeGroupList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1338,7 +1402,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1356,6 +1420,49 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = NodeGroupsClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.NodeGroupList( + items=[compute.NodeGroup(), compute.NodeGroup(), compute.NodeGroup(),], + next_page_token="abc", + ), + compute.NodeGroupList(items=[], next_page_token="def",), + compute.NodeGroupList(items=[compute.NodeGroup(),], next_page_token="ghi",), + compute.NodeGroupList(items=[compute.NodeGroup(), compute.NodeGroup(),],), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.NodeGroupList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.NodeGroup) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_list_nodes_rest( transport: str = "rest", request_type=compute.ListNodesNodeGroupsRequest ): @@ -1385,16 +1492,15 @@ def test_list_nodes_rest( # Wrap the value into a proper Response obj json_return_value = compute.NodeGroupsListNodes.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list_nodes(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.NodeGroupsListNodes) + assert isinstance(response, pagers.ListNodesPager) assert response.id == "id_value" assert response.items == [ compute.NodeGroupNode( @@ -1422,6 +1528,7 @@ def test_list_nodes_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.NodeGroupsListNodes.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1435,7 +1542,7 @@ def test_list_nodes_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1458,6 +1565,57 @@ def test_list_nodes_rest_flattened_error(): ) +def test_list_nodes_pager(): + client = NodeGroupsClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.NodeGroupsListNodes( + items=[ + compute.NodeGroupNode(), + compute.NodeGroupNode(), + compute.NodeGroupNode(), + ], + next_page_token="abc", + ), + compute.NodeGroupsListNodes(items=[], next_page_token="def",), + compute.NodeGroupsListNodes( + items=[compute.NodeGroupNode(),], next_page_token="ghi", + ), + compute.NodeGroupsListNodes( + items=[compute.NodeGroupNode(), compute.NodeGroupNode(),], + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.NodeGroupsListNodes.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list_nodes(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.NodeGroupNode) for i in results) + + pages = list(client.list_nodes(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_patch_rest( transport: str = "rest", request_type=compute.PatchNodeGroupRequest ): @@ -1500,6 +1658,7 @@ def test_patch_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1550,6 +1709,7 @@ def test_patch_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1570,7 +1730,7 @@ def test_patch_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1579,7 +1739,9 @@ def test_patch_rest_flattened(): assert "node_group_value" in http_call[1] + str(body) assert compute.NodeGroup.to_json( - node_group_resource, including_default_value_fields=False + node_group_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1633,6 +1795,7 @@ def test_set_iam_policy_rest( # Wrap the value into a proper Response obj json_return_value = compute.Policy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1671,6 +1834,7 @@ def test_set_iam_policy_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Policy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1691,7 +1855,7 @@ def test_set_iam_policy_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1700,7 +1864,9 @@ def test_set_iam_policy_rest_flattened(): assert "resource_value" in http_call[1] + str(body) assert compute.ZoneSetPolicyRequest.to_json( - zone_set_policy_request_resource, including_default_value_fields=False + zone_set_policy_request_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1763,6 +1929,7 @@ def test_set_node_template_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1813,6 +1980,7 @@ def test_set_node_template_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1833,7 +2001,7 @@ def test_set_node_template_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1844,6 +2012,7 @@ def test_set_node_template_rest_flattened(): assert compute.NodeGroupsSetNodeTemplateRequest.to_json( node_groups_set_node_template_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1884,6 +2053,7 @@ def test_test_iam_permissions_rest( # Wrap the value into a proper Response obj json_return_value = compute.TestPermissionsResponse.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1910,6 +2080,7 @@ def test_test_iam_permissions_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.TestPermissionsResponse.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1930,7 +2101,7 @@ def test_test_iam_permissions_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1939,7 +2110,9 @@ def test_test_iam_permissions_rest_flattened(): assert "resource_value" in http_call[1] + str(body) assert compute.TestPermissionsRequest.to_json( - test_permissions_request_resource, including_default_value_fields=False + test_permissions_request_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -2096,6 +2269,17 @@ def test_node_groups_auth_adc(): ) +def test_node_groups_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.NodeGroupsRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_node_groups_host_no_port(): client = NodeGroupsClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_node_templates.py b/tests/unit/gapic/compute_v1/test_node_templates.py index 7af2e9159..fc07d6298 100644 --- a/tests/unit/gapic/compute_v1/test_node_templates.py +++ b/tests/unit/gapic/compute_v1/test_node_templates.py @@ -35,6 +35,7 @@ from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.compute_v1.services.node_templates import NodeTemplatesClient +from google.cloud.compute_v1.services.node_templates import pagers from google.cloud.compute_v1.services.node_templates import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -157,7 +158,7 @@ def test_node_templates_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -173,7 +174,7 @@ def test_node_templates_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -189,7 +190,7 @@ def test_node_templates_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -217,7 +218,7 @@ def test_node_templates_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -251,29 +252,25 @@ def test_node_templates_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -282,66 +279,53 @@ def test_node_templates_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -360,7 +344,7 @@ def test_node_templates_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -383,7 +367,7 @@ def test_node_templates_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -425,16 +409,15 @@ def test_aggregated_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.NodeTemplateAggregatedList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.aggregated_list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.NodeTemplateAggregatedList) + assert isinstance(response, pagers.AggregatedListPager) assert response.id == "id_value" assert response.items == { "key_value": compute.NodeTemplatesScopedList( @@ -467,6 +450,7 @@ def test_aggregated_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.NodeTemplateAggregatedList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -478,7 +462,7 @@ def test_aggregated_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -494,6 +478,74 @@ def test_aggregated_list_rest_flattened_error(): ) +def test_aggregated_list_pager(): + client = NodeTemplatesClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.NodeTemplateAggregatedList( + items={ + "a": compute.NodeTemplatesScopedList(), + "b": compute.NodeTemplatesScopedList(), + "c": compute.NodeTemplatesScopedList(), + }, + next_page_token="abc", + ), + compute.NodeTemplateAggregatedList(items={}, next_page_token="def",), + compute.NodeTemplateAggregatedList( + items={"g": compute.NodeTemplatesScopedList(),}, next_page_token="ghi", + ), + compute.NodeTemplateAggregatedList( + items={ + "h": compute.NodeTemplatesScopedList(), + "i": compute.NodeTemplatesScopedList(), + }, + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple( + compute.NodeTemplateAggregatedList.to_json(x) for x in response + ) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.aggregated_list(request={}) + + assert pager._metadata == metadata + + assert isinstance(pager.get("a"), compute.NodeTemplatesScopedList) + assert pager.get("h") is None + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, tuple) for i in results) + for result in results: + assert isinstance(result, tuple) + assert tuple(type(t) for t in result) == ( + str, + compute.NodeTemplatesScopedList, + ) + + assert pager.get("a") is None + assert isinstance(pager.get("h"), compute.NodeTemplatesScopedList) + + pages = list(client.aggregated_list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_delete_rest( transport: str = "rest", request_type=compute.DeleteNodeTemplateRequest ): @@ -536,6 +588,7 @@ def test_delete_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -586,6 +639,7 @@ def test_delete_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -601,7 +655,7 @@ def test_delete_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -661,6 +715,7 @@ def test_get_rest(transport: str = "rest", request_type=compute.GetNodeTemplateR # Wrap the value into a proper Response obj json_return_value = compute.NodeTemplate.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -709,6 +764,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.NodeTemplate.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -724,7 +780,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -780,6 +836,7 @@ def test_get_iam_policy_rest( # Wrap the value into a proper Response obj json_return_value = compute.Policy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -818,6 +875,7 @@ def test_get_iam_policy_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Policy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -831,7 +889,7 @@ def test_get_iam_policy_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -896,6 +954,7 @@ def test_insert_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -946,6 +1005,7 @@ def test_insert_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -965,14 +1025,16 @@ def test_insert_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "region_value" in http_call[1] + str(body) assert compute.NodeTemplate.to_json( - node_template_resource, including_default_value_fields=False + node_template_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1021,16 +1083,15 @@ def test_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.NodeTemplateList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.NodeTemplateList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [ compute.NodeTemplate( @@ -1058,6 +1119,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.NodeTemplateList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1071,7 +1133,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1091,6 +1153,57 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = NodeTemplatesClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.NodeTemplateList( + items=[ + compute.NodeTemplate(), + compute.NodeTemplate(), + compute.NodeTemplate(), + ], + next_page_token="abc", + ), + compute.NodeTemplateList(items=[], next_page_token="def",), + compute.NodeTemplateList( + items=[compute.NodeTemplate(),], next_page_token="ghi", + ), + compute.NodeTemplateList( + items=[compute.NodeTemplate(), compute.NodeTemplate(),], + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.NodeTemplateList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.NodeTemplate) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_set_iam_policy_rest( transport: str = "rest", request_type=compute.SetIamPolicyNodeTemplateRequest ): @@ -1124,6 +1237,7 @@ def test_set_iam_policy_rest( # Wrap the value into a proper Response obj json_return_value = compute.Policy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1162,6 +1276,7 @@ def test_set_iam_policy_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Policy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1182,7 +1297,7 @@ def test_set_iam_policy_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1191,7 +1306,9 @@ def test_set_iam_policy_rest_flattened(): assert "resource_value" in http_call[1] + str(body) assert compute.RegionSetPolicyRequest.to_json( - region_set_policy_request_resource, including_default_value_fields=False + region_set_policy_request_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1232,6 +1349,7 @@ def test_test_iam_permissions_rest( # Wrap the value into a proper Response obj json_return_value = compute.TestPermissionsResponse.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1258,6 +1376,7 @@ def test_test_iam_permissions_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.TestPermissionsResponse.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1278,7 +1397,7 @@ def test_test_iam_permissions_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1287,7 +1406,9 @@ def test_test_iam_permissions_rest_flattened(): assert "resource_value" in http_call[1] + str(body) assert compute.TestPermissionsRequest.to_json( - test_permissions_request_resource, including_default_value_fields=False + test_permissions_request_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1439,6 +1560,17 @@ def test_node_templates_auth_adc(): ) +def test_node_templates_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.NodeTemplatesRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_node_templates_host_no_port(): client = NodeTemplatesClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_node_types.py b/tests/unit/gapic/compute_v1/test_node_types.py index 93e78b77c..63eb4e27c 100644 --- a/tests/unit/gapic/compute_v1/test_node_types.py +++ b/tests/unit/gapic/compute_v1/test_node_types.py @@ -35,6 +35,7 @@ from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.compute_v1.services.node_types import NodeTypesClient +from google.cloud.compute_v1.services.node_types import pagers from google.cloud.compute_v1.services.node_types import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -150,7 +151,7 @@ def test_node_types_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -166,7 +167,7 @@ def test_node_types_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -182,7 +183,7 @@ def test_node_types_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -210,7 +211,7 @@ def test_node_types_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -242,29 +243,25 @@ def test_node_types_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -273,66 +270,53 @@ def test_node_types_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -351,7 +335,7 @@ def test_node_types_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -374,7 +358,7 @@ def test_node_types_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -410,16 +394,15 @@ def test_aggregated_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.NodeTypeAggregatedList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.aggregated_list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.NodeTypeAggregatedList) + assert isinstance(response, pagers.AggregatedListPager) assert response.id == "id_value" assert response.items == { "key_value": compute.NodeTypesScopedList( @@ -448,6 +431,7 @@ def test_aggregated_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.NodeTypeAggregatedList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -459,7 +443,7 @@ def test_aggregated_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -475,6 +459,69 @@ def test_aggregated_list_rest_flattened_error(): ) +def test_aggregated_list_pager(): + client = NodeTypesClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.NodeTypeAggregatedList( + items={ + "a": compute.NodeTypesScopedList(), + "b": compute.NodeTypesScopedList(), + "c": compute.NodeTypesScopedList(), + }, + next_page_token="abc", + ), + compute.NodeTypeAggregatedList(items={}, next_page_token="def",), + compute.NodeTypeAggregatedList( + items={"g": compute.NodeTypesScopedList(),}, next_page_token="ghi", + ), + compute.NodeTypeAggregatedList( + items={ + "h": compute.NodeTypesScopedList(), + "i": compute.NodeTypesScopedList(), + }, + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.NodeTypeAggregatedList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.aggregated_list(request={}) + + assert pager._metadata == metadata + + assert isinstance(pager.get("a"), compute.NodeTypesScopedList) + assert pager.get("h") is None + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, tuple) for i in results) + for result in results: + assert isinstance(result, tuple) + assert tuple(type(t) for t in result) == (str, compute.NodeTypesScopedList) + + assert pager.get("a") is None + assert isinstance(pager.get("h"), compute.NodeTypesScopedList) + + pages = list(client.aggregated_list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_get_rest(transport: str = "rest", request_type=compute.GetNodeTypeRequest): client = NodeTypesClient( credentials=credentials.AnonymousCredentials(), transport=transport, @@ -504,6 +551,7 @@ def test_get_rest(transport: str = "rest", request_type=compute.GetNodeTypeReque # Wrap the value into a proper Response obj json_return_value = compute.NodeType.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -541,6 +589,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.NodeType.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -554,7 +603,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -600,16 +649,15 @@ def test_list_rest(transport: str = "rest", request_type=compute.ListNodeTypesRe # Wrap the value into a proper Response obj json_return_value = compute.NodeTypeList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.NodeTypeList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [compute.NodeType(cpu_platform="cpu_platform_value")] assert response.kind == "kind_value" @@ -633,6 +681,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.NodeTypeList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -646,7 +695,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -664,6 +713,49 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = NodeTypesClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.NodeTypeList( + items=[compute.NodeType(), compute.NodeType(), compute.NodeType(),], + next_page_token="abc", + ), + compute.NodeTypeList(items=[], next_page_token="def",), + compute.NodeTypeList(items=[compute.NodeType(),], next_page_token="ghi",), + compute.NodeTypeList(items=[compute.NodeType(), compute.NodeType(),],), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.NodeTypeList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.NodeType) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_credentials_transport_error(): # It is an error to provide credentials and a transport instance. transport = transports.NodeTypesRestTransport( @@ -792,6 +884,17 @@ def test_node_types_auth_adc(): ) +def test_node_types_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.NodeTypesRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_node_types_host_no_port(): client = NodeTypesClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_packet_mirrorings.py b/tests/unit/gapic/compute_v1/test_packet_mirrorings.py index 9a022a659..2c0379dac 100644 --- a/tests/unit/gapic/compute_v1/test_packet_mirrorings.py +++ b/tests/unit/gapic/compute_v1/test_packet_mirrorings.py @@ -35,6 +35,7 @@ from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.compute_v1.services.packet_mirrorings import PacketMirroringsClient +from google.cloud.compute_v1.services.packet_mirrorings import pagers from google.cloud.compute_v1.services.packet_mirrorings import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -158,7 +159,7 @@ def test_packet_mirrorings_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -174,7 +175,7 @@ def test_packet_mirrorings_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -190,7 +191,7 @@ def test_packet_mirrorings_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -218,7 +219,7 @@ def test_packet_mirrorings_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -262,29 +263,25 @@ def test_packet_mirrorings_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -293,66 +290,53 @@ def test_packet_mirrorings_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -371,7 +355,7 @@ def test_packet_mirrorings_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -394,7 +378,7 @@ def test_packet_mirrorings_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -436,16 +420,15 @@ def test_aggregated_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.PacketMirroringAggregatedList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.aggregated_list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.PacketMirroringAggregatedList) + assert isinstance(response, pagers.AggregatedListPager) assert response.id == "id_value" assert response.items == { "key_value": compute.PacketMirroringsScopedList( @@ -480,6 +463,7 @@ def test_aggregated_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.PacketMirroringAggregatedList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -491,7 +475,7 @@ def test_aggregated_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -507,6 +491,75 @@ def test_aggregated_list_rest_flattened_error(): ) +def test_aggregated_list_pager(): + client = PacketMirroringsClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.PacketMirroringAggregatedList( + items={ + "a": compute.PacketMirroringsScopedList(), + "b": compute.PacketMirroringsScopedList(), + "c": compute.PacketMirroringsScopedList(), + }, + next_page_token="abc", + ), + compute.PacketMirroringAggregatedList(items={}, next_page_token="def",), + compute.PacketMirroringAggregatedList( + items={"g": compute.PacketMirroringsScopedList(),}, + next_page_token="ghi", + ), + compute.PacketMirroringAggregatedList( + items={ + "h": compute.PacketMirroringsScopedList(), + "i": compute.PacketMirroringsScopedList(), + }, + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple( + compute.PacketMirroringAggregatedList.to_json(x) for x in response + ) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.aggregated_list(request={}) + + assert pager._metadata == metadata + + assert isinstance(pager.get("a"), compute.PacketMirroringsScopedList) + assert pager.get("h") is None + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, tuple) for i in results) + for result in results: + assert isinstance(result, tuple) + assert tuple(type(t) for t in result) == ( + str, + compute.PacketMirroringsScopedList, + ) + + assert pager.get("a") is None + assert isinstance(pager.get("h"), compute.PacketMirroringsScopedList) + + pages = list(client.aggregated_list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_delete_rest( transport: str = "rest", request_type=compute.DeletePacketMirroringRequest ): @@ -549,6 +602,7 @@ def test_delete_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -599,6 +653,7 @@ def test_delete_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -614,7 +669,7 @@ def test_delete_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -679,6 +734,7 @@ def test_get_rest( # Wrap the value into a proper Response obj json_return_value = compute.PacketMirroring.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -729,6 +785,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.PacketMirroring.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -744,7 +801,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -809,6 +866,7 @@ def test_insert_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -859,6 +917,7 @@ def test_insert_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -880,14 +939,16 @@ def test_insert_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "region_value" in http_call[1] + str(body) assert compute.PacketMirroring.to_json( - packet_mirroring_resource, including_default_value_fields=False + packet_mirroring_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -940,16 +1001,15 @@ def test_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.PacketMirroringList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.PacketMirroringList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [ compute.PacketMirroring( @@ -979,6 +1039,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.PacketMirroringList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -992,7 +1053,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1012,6 +1073,57 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = PacketMirroringsClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.PacketMirroringList( + items=[ + compute.PacketMirroring(), + compute.PacketMirroring(), + compute.PacketMirroring(), + ], + next_page_token="abc", + ), + compute.PacketMirroringList(items=[], next_page_token="def",), + compute.PacketMirroringList( + items=[compute.PacketMirroring(),], next_page_token="ghi", + ), + compute.PacketMirroringList( + items=[compute.PacketMirroring(), compute.PacketMirroring(),], + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.PacketMirroringList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.PacketMirroring) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_patch_rest( transport: str = "rest", request_type=compute.PatchPacketMirroringRequest ): @@ -1054,6 +1166,7 @@ def test_patch_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1104,6 +1217,7 @@ def test_patch_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1126,7 +1240,7 @@ def test_patch_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1135,7 +1249,9 @@ def test_patch_rest_flattened(): assert "packet_mirroring_value" in http_call[1] + str(body) assert compute.PacketMirroring.to_json( - packet_mirroring_resource, including_default_value_fields=False + packet_mirroring_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1179,6 +1295,7 @@ def test_test_iam_permissions_rest( # Wrap the value into a proper Response obj json_return_value = compute.TestPermissionsResponse.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1205,6 +1322,7 @@ def test_test_iam_permissions_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.TestPermissionsResponse.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1225,7 +1343,7 @@ def test_test_iam_permissions_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1234,7 +1352,9 @@ def test_test_iam_permissions_rest_flattened(): assert "resource_value" in http_call[1] + str(body) assert compute.TestPermissionsRequest.to_json( - test_permissions_request_resource, including_default_value_fields=False + test_permissions_request_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1385,6 +1505,17 @@ def test_packet_mirrorings_auth_adc(): ) +def test_packet_mirrorings_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.PacketMirroringsRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_packet_mirrorings_host_no_port(): client = PacketMirroringsClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_projects.py b/tests/unit/gapic/compute_v1/test_projects.py index 58faf1843..53972df75 100644 --- a/tests/unit/gapic/compute_v1/test_projects.py +++ b/tests/unit/gapic/compute_v1/test_projects.py @@ -35,6 +35,7 @@ from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.compute_v1.services.projects import ProjectsClient +from google.cloud.compute_v1.services.projects import pagers from google.cloud.compute_v1.services.projects import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -148,7 +149,7 @@ def test_projects_client_client_options(client_class, transport_class, transport credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -164,7 +165,7 @@ def test_projects_client_client_options(client_class, transport_class, transport credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -180,7 +181,7 @@ def test_projects_client_client_options(client_class, transport_class, transport credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -208,7 +209,7 @@ def test_projects_client_client_options(client_class, transport_class, transport credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -240,29 +241,25 @@ def test_projects_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -271,66 +268,53 @@ def test_projects_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -349,7 +333,7 @@ def test_projects_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -372,7 +356,7 @@ def test_projects_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -420,6 +404,7 @@ def test_disable_xpn_host_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -470,6 +455,7 @@ def test_disable_xpn_host_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -481,7 +467,7 @@ def test_disable_xpn_host_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -539,6 +525,7 @@ def test_disable_xpn_resource_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -589,6 +576,7 @@ def test_disable_xpn_resource_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -607,13 +595,14 @@ def test_disable_xpn_resource_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert compute.ProjectsDisableXpnResourceRequest.to_json( projects_disable_xpn_resource_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -674,6 +663,7 @@ def test_enable_xpn_host_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -724,6 +714,7 @@ def test_enable_xpn_host_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -735,7 +726,7 @@ def test_enable_xpn_host_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -793,6 +784,7 @@ def test_enable_xpn_resource_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -843,6 +835,7 @@ def test_enable_xpn_resource_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -861,13 +854,14 @@ def test_enable_xpn_resource_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert compute.ProjectsEnableXpnResourceRequest.to_json( projects_enable_xpn_resource_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -918,6 +912,7 @@ def test_get_rest(transport: str = "rest", request_type=compute.GetProjectReques # Wrap the value into a proper Response obj json_return_value = compute.Project.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -960,6 +955,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Project.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -971,7 +967,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1021,6 +1017,7 @@ def test_get_xpn_host_rest( # Wrap the value into a proper Response obj json_return_value = compute.Project.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1063,6 +1060,7 @@ def test_get_xpn_host_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Project.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1074,7 +1072,7 @@ def test_get_xpn_host_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1112,16 +1110,15 @@ def test_get_xpn_resources_rest( # Wrap the value into a proper Response obj json_return_value = compute.ProjectsGetXpnResources.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.get_xpn_resources(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.ProjectsGetXpnResources) + assert isinstance(response, pagers.GetXpnResourcesPager) assert response.kind == "kind_value" assert response.next_page_token == "next_page_token_value" assert response.resources == [compute.XpnResourceId(id="id_value")] @@ -1142,6 +1139,7 @@ def test_get_xpn_resources_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.ProjectsGetXpnResources.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1153,7 +1151,7 @@ def test_get_xpn_resources_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1169,6 +1167,57 @@ def test_get_xpn_resources_rest_flattened_error(): ) +def test_get_xpn_resources_pager(): + client = ProjectsClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.ProjectsGetXpnResources( + resources=[ + compute.XpnResourceId(), + compute.XpnResourceId(), + compute.XpnResourceId(), + ], + next_page_token="abc", + ), + compute.ProjectsGetXpnResources(resources=[], next_page_token="def",), + compute.ProjectsGetXpnResources( + resources=[compute.XpnResourceId(),], next_page_token="ghi", + ), + compute.ProjectsGetXpnResources( + resources=[compute.XpnResourceId(), compute.XpnResourceId(),], + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.ProjectsGetXpnResources.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.get_xpn_resources(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.XpnResourceId) for i in results) + + pages = list(client.get_xpn_resources(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_list_xpn_hosts_rest( transport: str = "rest", request_type=compute.ListXpnHostsProjectsRequest ): @@ -1200,16 +1249,15 @@ def test_list_xpn_hosts_rest( # Wrap the value into a proper Response obj json_return_value = compute.XpnHostList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list_xpn_hosts(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.XpnHostList) + assert isinstance(response, pagers.ListXpnHostsPager) assert response.id == "id_value" assert response.items == [ compute.Project( @@ -1237,6 +1285,7 @@ def test_list_xpn_hosts_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.XpnHostList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1255,13 +1304,14 @@ def test_list_xpn_hosts_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert compute.ProjectsListXpnHostsRequest.to_json( projects_list_xpn_hosts_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1280,6 +1330,49 @@ def test_list_xpn_hosts_rest_flattened_error(): ) +def test_list_xpn_hosts_pager(): + client = ProjectsClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.XpnHostList( + items=[compute.Project(), compute.Project(), compute.Project(),], + next_page_token="abc", + ), + compute.XpnHostList(items=[], next_page_token="def",), + compute.XpnHostList(items=[compute.Project(),], next_page_token="ghi",), + compute.XpnHostList(items=[compute.Project(), compute.Project(),],), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.XpnHostList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list_xpn_hosts(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.Project) for i in results) + + pages = list(client.list_xpn_hosts(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_move_disk_rest( transport: str = "rest", request_type=compute.MoveDiskProjectRequest ): @@ -1322,6 +1415,7 @@ def test_move_disk_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1372,6 +1466,7 @@ def test_move_disk_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1390,12 +1485,14 @@ def test_move_disk_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert compute.DiskMoveRequest.to_json( - disk_move_request_resource, including_default_value_fields=False + disk_move_request_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1456,6 +1553,7 @@ def test_move_instance_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1506,6 +1604,7 @@ def test_move_instance_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1524,12 +1623,14 @@ def test_move_instance_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert compute.InstanceMoveRequest.to_json( - instance_move_request_resource, including_default_value_fields=False + instance_move_request_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1591,6 +1692,7 @@ def test_set_common_instance_metadata_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1641,6 +1743,7 @@ def test_set_common_instance_metadata_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1656,12 +1759,14 @@ def test_set_common_instance_metadata_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert compute.Metadata.to_json( - metadata_resource, including_default_value_fields=False + metadata_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1720,6 +1825,7 @@ def test_set_default_network_tier_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1770,6 +1876,7 @@ def test_set_default_network_tier_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1788,13 +1895,14 @@ def test_set_default_network_tier_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert compute.ProjectsSetDefaultNetworkTierRequest.to_json( projects_set_default_network_tier_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1855,6 +1963,7 @@ def test_set_usage_export_bucket_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1905,6 +2014,7 @@ def test_set_usage_export_bucket_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1923,12 +2033,14 @@ def test_set_usage_export_bucket_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert compute.UsageExportLocation.to_json( - usage_export_location_resource, including_default_value_fields=False + usage_export_location_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -2083,6 +2195,17 @@ def test_projects_auth_adc(): ) +def test_projects_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.ProjectsRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_projects_host_no_port(): client = ProjectsClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_region_autoscalers.py b/tests/unit/gapic/compute_v1/test_region_autoscalers.py index 908a33b73..a13fddcd1 100644 --- a/tests/unit/gapic/compute_v1/test_region_autoscalers.py +++ b/tests/unit/gapic/compute_v1/test_region_autoscalers.py @@ -35,6 +35,7 @@ from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.compute_v1.services.region_autoscalers import RegionAutoscalersClient +from google.cloud.compute_v1.services.region_autoscalers import pagers from google.cloud.compute_v1.services.region_autoscalers import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -158,7 +159,7 @@ def test_region_autoscalers_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -174,7 +175,7 @@ def test_region_autoscalers_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -190,7 +191,7 @@ def test_region_autoscalers_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -218,7 +219,7 @@ def test_region_autoscalers_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -262,29 +263,25 @@ def test_region_autoscalers_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -293,66 +290,53 @@ def test_region_autoscalers_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -371,7 +355,7 @@ def test_region_autoscalers_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -394,7 +378,7 @@ def test_region_autoscalers_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -442,6 +426,7 @@ def test_delete_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -492,6 +477,7 @@ def test_delete_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -507,7 +493,7 @@ def test_delete_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -562,6 +548,7 @@ def test_get_rest( # Wrap the value into a proper Response obj json_return_value = compute.Autoscaler.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -604,6 +591,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Autoscaler.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -619,7 +607,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -684,6 +672,7 @@ def test_insert_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -734,6 +723,7 @@ def test_insert_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -753,14 +743,16 @@ def test_insert_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "region_value" in http_call[1] + str(body) assert compute.Autoscaler.to_json( - autoscaler_resource, including_default_value_fields=False + autoscaler_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -811,16 +803,15 @@ def test_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.RegionAutoscalerList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.RegionAutoscalerList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [ compute.Autoscaler( @@ -848,6 +839,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.RegionAutoscalerList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -861,7 +853,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -881,6 +873,57 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = RegionAutoscalersClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.RegionAutoscalerList( + items=[ + compute.Autoscaler(), + compute.Autoscaler(), + compute.Autoscaler(), + ], + next_page_token="abc", + ), + compute.RegionAutoscalerList(items=[], next_page_token="def",), + compute.RegionAutoscalerList( + items=[compute.Autoscaler(),], next_page_token="ghi", + ), + compute.RegionAutoscalerList( + items=[compute.Autoscaler(), compute.Autoscaler(),], + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.RegionAutoscalerList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.Autoscaler) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_patch_rest( transport: str = "rest", request_type=compute.PatchRegionAutoscalerRequest ): @@ -923,6 +966,7 @@ def test_patch_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -973,6 +1017,7 @@ def test_patch_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -992,14 +1037,16 @@ def test_patch_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "region_value" in http_call[1] + str(body) assert compute.Autoscaler.to_json( - autoscaler_resource, including_default_value_fields=False + autoscaler_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1061,6 +1108,7 @@ def test_update_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1111,6 +1159,7 @@ def test_update_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1130,14 +1179,16 @@ def test_update_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "region_value" in http_call[1] + str(body) assert compute.Autoscaler.to_json( - autoscaler_resource, including_default_value_fields=False + autoscaler_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1288,6 +1339,17 @@ def test_region_autoscalers_auth_adc(): ) +def test_region_autoscalers_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.RegionAutoscalersRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_region_autoscalers_host_no_port(): client = RegionAutoscalersClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_region_backend_services.py b/tests/unit/gapic/compute_v1/test_region_backend_services.py index dd81b8162..1fb40fe1a 100644 --- a/tests/unit/gapic/compute_v1/test_region_backend_services.py +++ b/tests/unit/gapic/compute_v1/test_region_backend_services.py @@ -37,6 +37,7 @@ from google.cloud.compute_v1.services.region_backend_services import ( RegionBackendServicesClient, ) +from google.cloud.compute_v1.services.region_backend_services import pagers from google.cloud.compute_v1.services.region_backend_services import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -166,7 +167,7 @@ def test_region_backend_services_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -182,7 +183,7 @@ def test_region_backend_services_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -198,7 +199,7 @@ def test_region_backend_services_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -226,7 +227,7 @@ def test_region_backend_services_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -270,29 +271,25 @@ def test_region_backend_services_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -301,66 +298,53 @@ def test_region_backend_services_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -385,7 +369,7 @@ def test_region_backend_services_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -414,7 +398,7 @@ def test_region_backend_services_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -462,6 +446,7 @@ def test_delete_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -514,6 +499,7 @@ def test_delete_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -529,7 +515,7 @@ def test_delete_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -619,6 +605,7 @@ def test_get_rest( # Wrap the value into a proper Response obj json_return_value = compute.BackendService.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -701,6 +688,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.BackendService.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -716,7 +704,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -765,6 +753,7 @@ def test_get_health_rest( # Wrap the value into a proper Response obj json_return_value = compute.BackendServiceGroupHealth.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -797,6 +786,7 @@ def test_get_health_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.BackendServiceGroupHealth.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -817,7 +807,7 @@ def test_get_health_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -826,7 +816,9 @@ def test_get_health_rest_flattened(): assert "backend_service_value" in http_call[1] + str(body) assert compute.ResourceGroupReference.to_json( - resource_group_reference_resource, including_default_value_fields=False + resource_group_reference_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -891,6 +883,7 @@ def test_insert_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -943,6 +936,7 @@ def test_insert_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -960,14 +954,16 @@ def test_insert_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "region_value" in http_call[1] + str(body) assert compute.BackendService.to_json( - backend_service_resource, including_default_value_fields=False + backend_service_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1014,16 +1010,15 @@ def test_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.BackendServiceList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.BackendServiceList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [compute.BackendService(affinity_cookie_ttl_sec=2432)] assert response.kind == "kind_value" @@ -1049,6 +1044,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.BackendServiceList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1062,7 +1058,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1084,6 +1080,59 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = RegionBackendServicesClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.BackendServiceList( + items=[ + compute.BackendService(), + compute.BackendService(), + compute.BackendService(), + ], + next_page_token="abc", + ), + compute.BackendServiceList(items=[], next_page_token="def",), + compute.BackendServiceList( + items=[compute.BackendService(),], next_page_token="ghi", + ), + compute.BackendServiceList( + items=[compute.BackendService(), compute.BackendService(),], + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.BackendServiceList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.BackendService) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_patch_rest( transport: str = "rest", request_type=compute.PatchRegionBackendServiceRequest ): @@ -1126,6 +1175,7 @@ def test_patch_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1178,6 +1228,7 @@ def test_patch_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1196,7 +1247,7 @@ def test_patch_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1205,7 +1256,9 @@ def test_patch_rest_flattened(): assert "backend_service_value" in http_call[1] + str(body) assert compute.BackendService.to_json( - backend_service_resource, including_default_value_fields=False + backend_service_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1270,6 +1323,7 @@ def test_update_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1322,6 +1376,7 @@ def test_update_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1340,7 +1395,7 @@ def test_update_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1349,7 +1404,9 @@ def test_update_rest_flattened(): assert "backend_service_value" in http_call[1] + str(body) assert compute.BackendService.to_json( - backend_service_resource, including_default_value_fields=False + backend_service_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1504,6 +1561,17 @@ def test_region_backend_services_auth_adc(): ) +def test_region_backend_services_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.RegionBackendServicesRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_region_backend_services_host_no_port(): client = RegionBackendServicesClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_region_commitments.py b/tests/unit/gapic/compute_v1/test_region_commitments.py index aa4c3c0cf..7bb9d516f 100644 --- a/tests/unit/gapic/compute_v1/test_region_commitments.py +++ b/tests/unit/gapic/compute_v1/test_region_commitments.py @@ -35,6 +35,7 @@ from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.compute_v1.services.region_commitments import RegionCommitmentsClient +from google.cloud.compute_v1.services.region_commitments import pagers from google.cloud.compute_v1.services.region_commitments import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -158,7 +159,7 @@ def test_region_commitments_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -174,7 +175,7 @@ def test_region_commitments_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -190,7 +191,7 @@ def test_region_commitments_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -218,7 +219,7 @@ def test_region_commitments_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -262,29 +263,25 @@ def test_region_commitments_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -293,66 +290,53 @@ def test_region_commitments_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -371,7 +355,7 @@ def test_region_commitments_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -394,7 +378,7 @@ def test_region_commitments_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -434,16 +418,15 @@ def test_aggregated_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.CommitmentAggregatedList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.aggregated_list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.CommitmentAggregatedList) + assert isinstance(response, pagers.AggregatedListPager) assert response.id == "id_value" assert response.items == { "key_value": compute.CommitmentsScopedList( @@ -476,6 +459,7 @@ def test_aggregated_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.CommitmentAggregatedList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -487,7 +471,7 @@ def test_aggregated_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -503,6 +487,72 @@ def test_aggregated_list_rest_flattened_error(): ) +def test_aggregated_list_pager(): + client = RegionCommitmentsClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.CommitmentAggregatedList( + items={ + "a": compute.CommitmentsScopedList(), + "b": compute.CommitmentsScopedList(), + "c": compute.CommitmentsScopedList(), + }, + next_page_token="abc", + ), + compute.CommitmentAggregatedList(items={}, next_page_token="def",), + compute.CommitmentAggregatedList( + items={"g": compute.CommitmentsScopedList(),}, next_page_token="ghi", + ), + compute.CommitmentAggregatedList( + items={ + "h": compute.CommitmentsScopedList(), + "i": compute.CommitmentsScopedList(), + }, + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.CommitmentAggregatedList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.aggregated_list(request={}) + + assert pager._metadata == metadata + + assert isinstance(pager.get("a"), compute.CommitmentsScopedList) + assert pager.get("h") is None + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, tuple) for i in results) + for result in results: + assert isinstance(result, tuple) + assert tuple(type(t) for t in result) == ( + str, + compute.CommitmentsScopedList, + ) + + assert pager.get("a") is None + assert isinstance(pager.get("h"), compute.CommitmentsScopedList) + + pages = list(client.aggregated_list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_get_rest( transport: str = "rest", request_type=compute.GetRegionCommitmentRequest ): @@ -540,6 +590,7 @@ def test_get_rest( # Wrap the value into a proper Response obj json_return_value = compute.Commitment.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -585,6 +636,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Commitment.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -600,7 +652,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -665,6 +717,7 @@ def test_insert_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -715,6 +768,7 @@ def test_insert_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -734,14 +788,16 @@ def test_insert_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "region_value" in http_call[1] + str(body) assert compute.Commitment.to_json( - commitment_resource, including_default_value_fields=False + commitment_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -790,16 +846,15 @@ def test_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.CommitmentList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.CommitmentList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [ compute.Commitment(category=compute.Commitment.Category.CATEGORY_UNSPECIFIED) @@ -825,6 +880,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.CommitmentList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -838,7 +894,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -858,6 +914,57 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = RegionCommitmentsClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.CommitmentList( + items=[ + compute.Commitment(), + compute.Commitment(), + compute.Commitment(), + ], + next_page_token="abc", + ), + compute.CommitmentList(items=[], next_page_token="def",), + compute.CommitmentList( + items=[compute.Commitment(),], next_page_token="ghi", + ), + compute.CommitmentList( + items=[compute.Commitment(), compute.Commitment(),], + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.CommitmentList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.Commitment) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_credentials_transport_error(): # It is an error to provide credentials and a transport instance. transport = transports.RegionCommitmentsRestTransport( @@ -987,6 +1094,17 @@ def test_region_commitments_auth_adc(): ) +def test_region_commitments_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.RegionCommitmentsRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_region_commitments_host_no_port(): client = RegionCommitmentsClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_region_disk_types.py b/tests/unit/gapic/compute_v1/test_region_disk_types.py index 873741822..b43e39df9 100644 --- a/tests/unit/gapic/compute_v1/test_region_disk_types.py +++ b/tests/unit/gapic/compute_v1/test_region_disk_types.py @@ -35,6 +35,7 @@ from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.compute_v1.services.region_disk_types import RegionDiskTypesClient +from google.cloud.compute_v1.services.region_disk_types import pagers from google.cloud.compute_v1.services.region_disk_types import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -157,7 +158,7 @@ def test_region_disk_types_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -173,7 +174,7 @@ def test_region_disk_types_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -189,7 +190,7 @@ def test_region_disk_types_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -217,7 +218,7 @@ def test_region_disk_types_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -261,29 +262,25 @@ def test_region_disk_types_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -292,66 +289,53 @@ def test_region_disk_types_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -370,7 +354,7 @@ def test_region_disk_types_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -393,7 +377,7 @@ def test_region_disk_types_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -429,6 +413,7 @@ def test_get_rest( # Wrap the value into a proper Response obj json_return_value = compute.DiskType.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -465,6 +450,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.DiskType.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -478,7 +464,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -526,16 +512,15 @@ def test_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.RegionDiskTypeList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.RegionDiskTypeList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [ compute.DiskType(creation_timestamp="creation_timestamp_value") @@ -561,6 +546,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.RegionDiskTypeList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -574,7 +560,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -594,6 +580,53 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = RegionDiskTypesClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.RegionDiskTypeList( + items=[compute.DiskType(), compute.DiskType(), compute.DiskType(),], + next_page_token="abc", + ), + compute.RegionDiskTypeList(items=[], next_page_token="def",), + compute.RegionDiskTypeList( + items=[compute.DiskType(),], next_page_token="ghi", + ), + compute.RegionDiskTypeList( + items=[compute.DiskType(), compute.DiskType(),], + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.RegionDiskTypeList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.DiskType) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_credentials_transport_error(): # It is an error to provide credentials and a transport instance. transport = transports.RegionDiskTypesRestTransport( @@ -721,6 +754,17 @@ def test_region_disk_types_auth_adc(): ) +def test_region_disk_types_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.RegionDiskTypesRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_region_disk_types_host_no_port(): client = RegionDiskTypesClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_region_disks.py b/tests/unit/gapic/compute_v1/test_region_disks.py index b64723260..3dac15c24 100644 --- a/tests/unit/gapic/compute_v1/test_region_disks.py +++ b/tests/unit/gapic/compute_v1/test_region_disks.py @@ -35,6 +35,7 @@ from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.compute_v1.services.region_disks import RegionDisksClient +from google.cloud.compute_v1.services.region_disks import pagers from google.cloud.compute_v1.services.region_disks import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -152,7 +153,7 @@ def test_region_disks_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -168,7 +169,7 @@ def test_region_disks_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -184,7 +185,7 @@ def test_region_disks_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -212,7 +213,7 @@ def test_region_disks_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -244,29 +245,25 @@ def test_region_disks_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -275,66 +272,53 @@ def test_region_disks_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -353,7 +337,7 @@ def test_region_disks_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -376,7 +360,7 @@ def test_region_disks_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -424,6 +408,7 @@ def test_add_resource_policies_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -474,6 +459,7 @@ def test_add_resource_policies_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -494,7 +480,7 @@ def test_add_resource_policies_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -505,6 +491,7 @@ def test_add_resource_policies_rest_flattened(): assert compute.RegionDisksAddResourcePoliciesRequest.to_json( region_disks_add_resource_policies_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -567,6 +554,7 @@ def test_create_snapshot_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -617,6 +605,7 @@ def test_create_snapshot_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -635,7 +624,7 @@ def test_create_snapshot_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -644,7 +633,9 @@ def test_create_snapshot_rest_flattened(): assert "disk_value" in http_call[1] + str(body) assert compute.Snapshot.to_json( - snapshot_resource, including_default_value_fields=False + snapshot_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -705,6 +696,7 @@ def test_delete_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -755,6 +747,7 @@ def test_delete_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -768,7 +761,7 @@ def test_delete_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -850,6 +843,7 @@ def test_get_rest(transport: str = "rest", request_type=compute.GetRegionDiskReq # Wrap the value into a proper Response obj json_return_value = compute.Disk.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -917,6 +911,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Disk.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -930,7 +925,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -986,6 +981,7 @@ def test_get_iam_policy_rest( # Wrap the value into a proper Response obj json_return_value = compute.Policy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1024,6 +1020,7 @@ def test_get_iam_policy_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Policy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1037,7 +1034,7 @@ def test_get_iam_policy_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1102,6 +1099,7 @@ def test_insert_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1152,6 +1150,7 @@ def test_insert_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1167,14 +1166,16 @@ def test_insert_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "region_value" in http_call[1] + str(body) assert compute.Disk.to_json( - disk_resource, including_default_value_fields=False + disk_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1217,16 +1218,15 @@ def test_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.DiskList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.DiskList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [ compute.Disk(creation_timestamp="creation_timestamp_value") @@ -1252,6 +1252,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.DiskList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1265,7 +1266,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1285,6 +1286,49 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = RegionDisksClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.DiskList( + items=[compute.Disk(), compute.Disk(), compute.Disk(),], + next_page_token="abc", + ), + compute.DiskList(items=[], next_page_token="def",), + compute.DiskList(items=[compute.Disk(),], next_page_token="ghi",), + compute.DiskList(items=[compute.Disk(), compute.Disk(),],), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.DiskList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.Disk) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_remove_resource_policies_rest( transport: str = "rest", request_type=compute.RemoveResourcePoliciesRegionDiskRequest, @@ -1328,6 +1372,7 @@ def test_remove_resource_policies_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1378,6 +1423,7 @@ def test_remove_resource_policies_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1398,7 +1444,7 @@ def test_remove_resource_policies_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1409,6 +1455,7 @@ def test_remove_resource_policies_rest_flattened(): assert compute.RegionDisksRemoveResourcePoliciesRequest.to_json( region_disks_remove_resource_policies_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1471,6 +1518,7 @@ def test_resize_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1521,6 +1569,7 @@ def test_resize_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1541,7 +1590,7 @@ def test_resize_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1550,7 +1599,9 @@ def test_resize_rest_flattened(): assert "disk_value" in http_call[1] + str(body) assert compute.RegionDisksResizeRequest.to_json( - region_disks_resize_request_resource, including_default_value_fields=False + region_disks_resize_request_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1604,6 +1655,7 @@ def test_set_iam_policy_rest( # Wrap the value into a proper Response obj json_return_value = compute.Policy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1642,6 +1694,7 @@ def test_set_iam_policy_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Policy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1662,7 +1715,7 @@ def test_set_iam_policy_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1671,7 +1724,9 @@ def test_set_iam_policy_rest_flattened(): assert "resource_value" in http_call[1] + str(body) assert compute.RegionSetPolicyRequest.to_json( - region_set_policy_request_resource, including_default_value_fields=False + region_set_policy_request_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1734,6 +1789,7 @@ def test_set_labels_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1784,6 +1840,7 @@ def test_set_labels_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1804,7 +1861,7 @@ def test_set_labels_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1813,7 +1870,9 @@ def test_set_labels_rest_flattened(): assert "resource_value" in http_call[1] + str(body) assert compute.RegionSetLabelsRequest.to_json( - region_set_labels_request_resource, including_default_value_fields=False + region_set_labels_request_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1854,6 +1913,7 @@ def test_test_iam_permissions_rest( # Wrap the value into a proper Response obj json_return_value = compute.TestPermissionsResponse.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1880,6 +1940,7 @@ def test_test_iam_permissions_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.TestPermissionsResponse.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1900,7 +1961,7 @@ def test_test_iam_permissions_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1909,7 +1970,9 @@ def test_test_iam_permissions_rest_flattened(): assert "resource_value" in http_call[1] + str(body) assert compute.TestPermissionsRequest.to_json( - test_permissions_request_resource, including_default_value_fields=False + test_permissions_request_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -2065,6 +2128,17 @@ def test_region_disks_auth_adc(): ) +def test_region_disks_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.RegionDisksRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_region_disks_host_no_port(): client = RegionDisksClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_region_health_check_services.py b/tests/unit/gapic/compute_v1/test_region_health_check_services.py index f074bceda..320685c27 100644 --- a/tests/unit/gapic/compute_v1/test_region_health_check_services.py +++ b/tests/unit/gapic/compute_v1/test_region_health_check_services.py @@ -37,6 +37,7 @@ from google.cloud.compute_v1.services.region_health_check_services import ( RegionHealthCheckServicesClient, ) +from google.cloud.compute_v1.services.region_health_check_services import pagers from google.cloud.compute_v1.services.region_health_check_services import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -172,7 +173,7 @@ def test_region_health_check_services_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -188,7 +189,7 @@ def test_region_health_check_services_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -204,7 +205,7 @@ def test_region_health_check_services_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -232,7 +233,7 @@ def test_region_health_check_services_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -276,29 +277,25 @@ def test_region_health_check_services_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -307,66 +304,53 @@ def test_region_health_check_services_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -391,7 +375,7 @@ def test_region_health_check_services_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -420,7 +404,7 @@ def test_region_health_check_services_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -468,6 +452,7 @@ def test_delete_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -520,6 +505,7 @@ def test_delete_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -535,7 +521,7 @@ def test_delete_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -591,6 +577,7 @@ def test_get_rest( # Wrap the value into a proper Response obj json_return_value = compute.HealthCheckService.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -633,6 +620,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.HealthCheckService.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -648,7 +636,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -715,6 +703,7 @@ def test_insert_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -767,6 +756,7 @@ def test_insert_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -786,14 +776,16 @@ def test_insert_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "region_value" in http_call[1] + str(body) assert compute.HealthCheckService.to_json( - health_check_service_resource, including_default_value_fields=False + health_check_service_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -844,16 +836,15 @@ def test_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.HealthCheckServicesList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.HealthCheckServicesList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [ compute.HealthCheckService(creation_timestamp="creation_timestamp_value") @@ -881,6 +872,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.HealthCheckServicesList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -894,7 +886,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -916,6 +908,59 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = RegionHealthCheckServicesClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.HealthCheckServicesList( + items=[ + compute.HealthCheckService(), + compute.HealthCheckService(), + compute.HealthCheckService(), + ], + next_page_token="abc", + ), + compute.HealthCheckServicesList(items=[], next_page_token="def",), + compute.HealthCheckServicesList( + items=[compute.HealthCheckService(),], next_page_token="ghi", + ), + compute.HealthCheckServicesList( + items=[compute.HealthCheckService(), compute.HealthCheckService(),], + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.HealthCheckServicesList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.HealthCheckService) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_patch_rest( transport: str = "rest", request_type=compute.PatchRegionHealthCheckServiceRequest ): @@ -958,6 +1003,7 @@ def test_patch_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1010,6 +1056,7 @@ def test_patch_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1030,7 +1077,7 @@ def test_patch_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1039,7 +1086,9 @@ def test_patch_rest_flattened(): assert "health_check_service_value" in http_call[1] + str(body) assert compute.HealthCheckService.to_json( - health_check_service_resource, including_default_value_fields=False + health_check_service_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1192,6 +1241,17 @@ def test_region_health_check_services_auth_adc(): ) +def test_region_health_check_services_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.RegionHealthCheckServicesRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_region_health_check_services_host_no_port(): client = RegionHealthCheckServicesClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_region_health_checks.py b/tests/unit/gapic/compute_v1/test_region_health_checks.py index d2a6ddc8d..dbe1162c1 100644 --- a/tests/unit/gapic/compute_v1/test_region_health_checks.py +++ b/tests/unit/gapic/compute_v1/test_region_health_checks.py @@ -37,6 +37,7 @@ from google.cloud.compute_v1.services.region_health_checks import ( RegionHealthChecksClient, ) +from google.cloud.compute_v1.services.region_health_checks import pagers from google.cloud.compute_v1.services.region_health_checks import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -160,7 +161,7 @@ def test_region_health_checks_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -176,7 +177,7 @@ def test_region_health_checks_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -192,7 +193,7 @@ def test_region_health_checks_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -220,7 +221,7 @@ def test_region_health_checks_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -264,29 +265,25 @@ def test_region_health_checks_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -295,66 +292,53 @@ def test_region_health_checks_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -373,7 +357,7 @@ def test_region_health_checks_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -396,7 +380,7 @@ def test_region_health_checks_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -444,6 +428,7 @@ def test_delete_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -494,6 +479,7 @@ def test_delete_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -509,7 +495,7 @@ def test_delete_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -572,6 +558,7 @@ def test_get_rest( # Wrap the value into a proper Response obj json_return_value = compute.HealthCheck.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -618,6 +605,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.HealthCheck.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -633,7 +621,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -698,6 +686,7 @@ def test_insert_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -748,6 +737,7 @@ def test_insert_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -765,14 +755,16 @@ def test_insert_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "region_value" in http_call[1] + str(body) assert compute.HealthCheck.to_json( - health_check_resource, including_default_value_fields=False + health_check_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -815,16 +807,15 @@ def test_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.HealthCheckList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.HealthCheckList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [compute.HealthCheck(check_interval_sec=1884)] assert response.kind == "kind_value" @@ -848,6 +839,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.HealthCheckList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -861,7 +853,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -881,6 +873,57 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = RegionHealthChecksClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.HealthCheckList( + items=[ + compute.HealthCheck(), + compute.HealthCheck(), + compute.HealthCheck(), + ], + next_page_token="abc", + ), + compute.HealthCheckList(items=[], next_page_token="def",), + compute.HealthCheckList( + items=[compute.HealthCheck(),], next_page_token="ghi", + ), + compute.HealthCheckList( + items=[compute.HealthCheck(), compute.HealthCheck(),], + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.HealthCheckList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.HealthCheck) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_patch_rest( transport: str = "rest", request_type=compute.PatchRegionHealthCheckRequest ): @@ -923,6 +966,7 @@ def test_patch_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -973,6 +1017,7 @@ def test_patch_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -991,7 +1036,7 @@ def test_patch_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1000,7 +1045,9 @@ def test_patch_rest_flattened(): assert "health_check_value" in http_call[1] + str(body) assert compute.HealthCheck.to_json( - health_check_resource, including_default_value_fields=False + health_check_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1061,6 +1108,7 @@ def test_update_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1111,6 +1159,7 @@ def test_update_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1129,7 +1178,7 @@ def test_update_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1138,7 +1187,9 @@ def test_update_rest_flattened(): assert "health_check_value" in http_call[1] + str(body) assert compute.HealthCheck.to_json( - health_check_resource, including_default_value_fields=False + health_check_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1288,6 +1339,17 @@ def test_region_health_checks_auth_adc(): ) +def test_region_health_checks_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.RegionHealthChecksRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_region_health_checks_host_no_port(): client = RegionHealthChecksClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_region_instance_group_managers.py b/tests/unit/gapic/compute_v1/test_region_instance_group_managers.py index 713e9ef06..abf5c43bd 100644 --- a/tests/unit/gapic/compute_v1/test_region_instance_group_managers.py +++ b/tests/unit/gapic/compute_v1/test_region_instance_group_managers.py @@ -37,6 +37,7 @@ from google.cloud.compute_v1.services.region_instance_group_managers import ( RegionInstanceGroupManagersClient, ) +from google.cloud.compute_v1.services.region_instance_group_managers import pagers from google.cloud.compute_v1.services.region_instance_group_managers import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -172,7 +173,7 @@ def test_region_instance_group_managers_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -188,7 +189,7 @@ def test_region_instance_group_managers_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -204,7 +205,7 @@ def test_region_instance_group_managers_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -232,7 +233,7 @@ def test_region_instance_group_managers_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -276,29 +277,25 @@ def test_region_instance_group_managers_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -307,66 +304,53 @@ def test_region_instance_group_managers_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -391,7 +375,7 @@ def test_region_instance_group_managers_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -420,7 +404,7 @@ def test_region_instance_group_managers_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -469,6 +453,7 @@ def test_abandon_instances_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -521,6 +506,7 @@ def test_abandon_instances_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -541,7 +527,7 @@ def test_abandon_instances_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -552,6 +538,7 @@ def test_abandon_instances_rest_flattened(): assert compute.RegionInstanceGroupManagersAbandonInstancesRequest.to_json( region_instance_group_managers_abandon_instances_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -617,6 +604,7 @@ def test_apply_updates_to_instances_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -669,6 +657,7 @@ def test_apply_updates_to_instances_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -689,7 +678,7 @@ def test_apply_updates_to_instances_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -700,6 +689,7 @@ def test_apply_updates_to_instances_rest_flattened(): assert compute.RegionInstanceGroupManagersApplyUpdatesRequest.to_json( region_instance_group_managers_apply_updates_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -765,6 +755,7 @@ def test_create_instances_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -817,6 +808,7 @@ def test_create_instances_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -837,7 +829,7 @@ def test_create_instances_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -848,6 +840,7 @@ def test_create_instances_rest_flattened(): assert compute.RegionInstanceGroupManagersCreateInstancesRequest.to_json( region_instance_group_managers_create_instances_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -913,6 +906,7 @@ def test_delete_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -965,6 +959,7 @@ def test_delete_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -980,7 +975,7 @@ def test_delete_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1048,6 +1043,7 @@ def test_delete_instances_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1100,6 +1096,7 @@ def test_delete_instances_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1120,7 +1117,7 @@ def test_delete_instances_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1131,6 +1128,7 @@ def test_delete_instances_rest_flattened(): assert compute.RegionInstanceGroupManagersDeleteInstancesRequest.to_json( region_instance_group_managers_delete_instances_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1196,6 +1194,7 @@ def test_delete_per_instance_configs_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1248,6 +1247,7 @@ def test_delete_per_instance_configs_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1268,7 +1268,7 @@ def test_delete_per_instance_configs_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1279,6 +1279,7 @@ def test_delete_per_instance_configs_rest_flattened(): assert compute.RegionInstanceGroupManagerDeleteInstanceConfigReq.to_json( region_instance_group_manager_delete_instance_config_req_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1362,6 +1363,7 @@ def test_get_rest( # Wrap the value into a proper Response obj json_return_value = compute.InstanceGroupManager.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1431,6 +1433,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.InstanceGroupManager.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1446,7 +1449,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1514,6 +1517,7 @@ def test_insert_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1566,6 +1570,7 @@ def test_insert_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1589,14 +1594,16 @@ def test_insert_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "region_value" in http_call[1] + str(body) assert compute.InstanceGroupManager.to_json( - instance_group_manager_resource, including_default_value_fields=False + instance_group_manager_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1655,16 +1662,15 @@ def test_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.RegionInstanceGroupManagerList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.RegionInstanceGroupManagerList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [ compute.InstanceGroupManager( @@ -1698,6 +1704,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.RegionInstanceGroupManagerList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1711,7 +1718,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1733,6 +1740,61 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = RegionInstanceGroupManagersClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.RegionInstanceGroupManagerList( + items=[ + compute.InstanceGroupManager(), + compute.InstanceGroupManager(), + compute.InstanceGroupManager(), + ], + next_page_token="abc", + ), + compute.RegionInstanceGroupManagerList(items=[], next_page_token="def",), + compute.RegionInstanceGroupManagerList( + items=[compute.InstanceGroupManager(),], next_page_token="ghi", + ), + compute.RegionInstanceGroupManagerList( + items=[compute.InstanceGroupManager(), compute.InstanceGroupManager(),], + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple( + compute.RegionInstanceGroupManagerList.to_json(x) for x in response + ) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.InstanceGroupManager) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_list_errors_rest( transport: str = "rest", request_type=compute.ListErrorsRegionInstanceGroupManagersRequest, @@ -1763,16 +1825,15 @@ def test_list_errors_rest( return_value ) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list_errors(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.RegionInstanceGroupManagersListErrorsResponse) + assert isinstance(response, pagers.ListErrorsPager) assert response.items == [ compute.InstanceManagedByIgmError( error=compute.InstanceManagedByIgmErrorManagedInstanceError( @@ -1802,6 +1863,7 @@ def test_list_errors_rest_flattened(): return_value ) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1817,7 +1879,7 @@ def test_list_errors_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1842,6 +1904,67 @@ def test_list_errors_rest_flattened_error(): ) +def test_list_errors_pager(): + client = RegionInstanceGroupManagersClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.RegionInstanceGroupManagersListErrorsResponse( + items=[ + compute.InstanceManagedByIgmError(), + compute.InstanceManagedByIgmError(), + compute.InstanceManagedByIgmError(), + ], + next_page_token="abc", + ), + compute.RegionInstanceGroupManagersListErrorsResponse( + items=[], next_page_token="def", + ), + compute.RegionInstanceGroupManagersListErrorsResponse( + items=[compute.InstanceManagedByIgmError(),], next_page_token="ghi", + ), + compute.RegionInstanceGroupManagersListErrorsResponse( + items=[ + compute.InstanceManagedByIgmError(), + compute.InstanceManagedByIgmError(), + ], + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple( + compute.RegionInstanceGroupManagersListErrorsResponse.to_json(x) + for x in response + ) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list_errors(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.InstanceManagedByIgmError) for i in results) + + pages = list(client.list_errors(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_list_managed_instances_rest( transport: str = "rest", request_type=compute.ListManagedInstancesRegionInstanceGroupManagersRequest, @@ -1870,18 +1993,15 @@ def test_list_managed_instances_rest( return_value ) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list_managed_instances(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance( - response, compute.RegionInstanceGroupManagersListInstancesResponse - ) + assert isinstance(response, pagers.ListManagedInstancesPager) assert response.managed_instances == [ compute.ManagedInstance( current_action=compute.ManagedInstance.CurrentAction.ABANDONING @@ -1909,6 +2029,7 @@ def test_list_managed_instances_rest_flattened(): return_value ) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1924,7 +2045,7 @@ def test_list_managed_instances_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1949,6 +2070,67 @@ def test_list_managed_instances_rest_flattened_error(): ) +def test_list_managed_instances_pager(): + client = RegionInstanceGroupManagersClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.RegionInstanceGroupManagersListInstancesResponse( + managed_instances=[ + compute.ManagedInstance(), + compute.ManagedInstance(), + compute.ManagedInstance(), + ], + next_page_token="abc", + ), + compute.RegionInstanceGroupManagersListInstancesResponse( + managed_instances=[], next_page_token="def", + ), + compute.RegionInstanceGroupManagersListInstancesResponse( + managed_instances=[compute.ManagedInstance(),], next_page_token="ghi", + ), + compute.RegionInstanceGroupManagersListInstancesResponse( + managed_instances=[ + compute.ManagedInstance(), + compute.ManagedInstance(), + ], + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple( + compute.RegionInstanceGroupManagersListInstancesResponse.to_json(x) + for x in response + ) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list_managed_instances(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.ManagedInstance) for i in results) + + pages = list(client.list_managed_instances(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_list_per_instance_configs_rest( transport: str = "rest", request_type=compute.ListPerInstanceConfigsRegionInstanceGroupManagersRequest, @@ -1974,18 +2156,15 @@ def test_list_per_instance_configs_rest( return_value ) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list_per_instance_configs(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance( - response, compute.RegionInstanceGroupManagersListInstanceConfigsResp - ) + assert isinstance(response, pagers.ListPerInstanceConfigsPager) assert response.items == [ compute.PerInstanceConfig(fingerprint="fingerprint_value") ] @@ -2012,6 +2191,7 @@ def test_list_per_instance_configs_rest_flattened(): return_value ) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -2027,7 +2207,7 @@ def test_list_per_instance_configs_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -2052,6 +2232,64 @@ def test_list_per_instance_configs_rest_flattened_error(): ) +def test_list_per_instance_configs_pager(): + client = RegionInstanceGroupManagersClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.RegionInstanceGroupManagersListInstanceConfigsResp( + items=[ + compute.PerInstanceConfig(), + compute.PerInstanceConfig(), + compute.PerInstanceConfig(), + ], + next_page_token="abc", + ), + compute.RegionInstanceGroupManagersListInstanceConfigsResp( + items=[], next_page_token="def", + ), + compute.RegionInstanceGroupManagersListInstanceConfigsResp( + items=[compute.PerInstanceConfig(),], next_page_token="ghi", + ), + compute.RegionInstanceGroupManagersListInstanceConfigsResp( + items=[compute.PerInstanceConfig(), compute.PerInstanceConfig(),], + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple( + compute.RegionInstanceGroupManagersListInstanceConfigsResp.to_json(x) + for x in response + ) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list_per_instance_configs(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.PerInstanceConfig) for i in results) + + pages = list(client.list_per_instance_configs(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_patch_rest( transport: str = "rest", request_type=compute.PatchRegionInstanceGroupManagerRequest ): @@ -2094,6 +2332,7 @@ def test_patch_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -2146,6 +2385,7 @@ def test_patch_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -2170,7 +2410,7 @@ def test_patch_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -2179,7 +2419,9 @@ def test_patch_rest_flattened(): assert "instance_group_manager_value" in http_call[1] + str(body) assert compute.InstanceGroupManager.to_json( - instance_group_manager_resource, including_default_value_fields=False + instance_group_manager_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -2249,6 +2491,7 @@ def test_patch_per_instance_configs_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -2301,6 +2544,7 @@ def test_patch_per_instance_configs_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -2323,7 +2567,7 @@ def test_patch_per_instance_configs_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -2334,6 +2578,7 @@ def test_patch_per_instance_configs_rest_flattened(): assert compute.RegionInstanceGroupManagerPatchInstanceConfigReq.to_json( region_instance_group_manager_patch_instance_config_req_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -2401,6 +2646,7 @@ def test_recreate_instances_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -2453,6 +2699,7 @@ def test_recreate_instances_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -2473,7 +2720,7 @@ def test_recreate_instances_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -2484,6 +2731,7 @@ def test_recreate_instances_rest_flattened(): assert compute.RegionInstanceGroupManagersRecreateRequest.to_json( region_instance_group_managers_recreate_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -2549,6 +2797,7 @@ def test_resize_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -2601,6 +2850,7 @@ def test_resize_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -2617,7 +2867,7 @@ def test_resize_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -2688,6 +2938,7 @@ def test_set_instance_template_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -2740,6 +2991,7 @@ def test_set_instance_template_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -2760,7 +3012,7 @@ def test_set_instance_template_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -2771,6 +3023,7 @@ def test_set_instance_template_rest_flattened(): assert compute.RegionInstanceGroupManagersSetTemplateRequest.to_json( region_instance_group_managers_set_template_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -2836,6 +3089,7 @@ def test_set_target_pools_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -2888,6 +3142,7 @@ def test_set_target_pools_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -2908,7 +3163,7 @@ def test_set_target_pools_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -2919,6 +3174,7 @@ def test_set_target_pools_rest_flattened(): assert compute.RegionInstanceGroupManagersSetTargetPoolsRequest.to_json( region_instance_group_managers_set_target_pools_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -2984,6 +3240,7 @@ def test_update_per_instance_configs_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -3036,6 +3293,7 @@ def test_update_per_instance_configs_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -3058,7 +3316,7 @@ def test_update_per_instance_configs_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -3069,6 +3327,7 @@ def test_update_per_instance_configs_rest_flattened(): assert compute.RegionInstanceGroupManagerUpdateInstanceConfigReq.to_json( region_instance_group_manager_update_instance_config_req_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -3237,6 +3496,17 @@ def test_region_instance_group_managers_auth_adc(): ) +def test_region_instance_group_managers_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.RegionInstanceGroupManagersRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_region_instance_group_managers_host_no_port(): client = RegionInstanceGroupManagersClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_region_instance_groups.py b/tests/unit/gapic/compute_v1/test_region_instance_groups.py index 0a85dcfe2..099eeeced 100644 --- a/tests/unit/gapic/compute_v1/test_region_instance_groups.py +++ b/tests/unit/gapic/compute_v1/test_region_instance_groups.py @@ -37,6 +37,7 @@ from google.cloud.compute_v1.services.region_instance_groups import ( RegionInstanceGroupsClient, ) +from google.cloud.compute_v1.services.region_instance_groups import pagers from google.cloud.compute_v1.services.region_instance_groups import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -166,7 +167,7 @@ def test_region_instance_groups_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -182,7 +183,7 @@ def test_region_instance_groups_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -198,7 +199,7 @@ def test_region_instance_groups_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -226,7 +227,7 @@ def test_region_instance_groups_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -270,29 +271,25 @@ def test_region_instance_groups_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -301,66 +298,53 @@ def test_region_instance_groups_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -385,7 +369,7 @@ def test_region_instance_groups_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -414,7 +398,7 @@ def test_region_instance_groups_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -452,6 +436,7 @@ def test_get_rest( # Wrap the value into a proper Response obj json_return_value = compute.InstanceGroup.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -490,6 +475,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.InstanceGroup.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -505,7 +491,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -555,16 +541,15 @@ def test_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.RegionInstanceGroupList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.RegionInstanceGroupList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [ compute.InstanceGroup(creation_timestamp="creation_timestamp_value") @@ -590,6 +575,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.RegionInstanceGroupList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -603,7 +589,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -623,6 +609,57 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = RegionInstanceGroupsClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.RegionInstanceGroupList( + items=[ + compute.InstanceGroup(), + compute.InstanceGroup(), + compute.InstanceGroup(), + ], + next_page_token="abc", + ), + compute.RegionInstanceGroupList(items=[], next_page_token="def",), + compute.RegionInstanceGroupList( + items=[compute.InstanceGroup(),], next_page_token="ghi", + ), + compute.RegionInstanceGroupList( + items=[compute.InstanceGroup(), compute.InstanceGroup(),], + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.RegionInstanceGroupList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.InstanceGroup) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_list_instances_rest( transport: str = "rest", request_type=compute.ListInstancesRegionInstanceGroupsRequest, @@ -651,16 +688,15 @@ def test_list_instances_rest( return_value ) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list_instances(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.RegionInstanceGroupsListInstances) + assert isinstance(response, pagers.ListInstancesPager) assert response.id == "id_value" assert response.items == [compute.InstanceWithNamedPorts(instance="instance_value")] assert response.kind == "kind_value" @@ -686,6 +722,7 @@ def test_list_instances_rest_flattened(): return_value ) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -706,7 +743,7 @@ def test_list_instances_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -717,6 +754,7 @@ def test_list_instances_rest_flattened(): assert compute.RegionInstanceGroupsListInstancesRequest.to_json( region_instance_groups_list_instances_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -737,6 +775,62 @@ def test_list_instances_rest_flattened_error(): ) +def test_list_instances_pager(): + client = RegionInstanceGroupsClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.RegionInstanceGroupsListInstances( + items=[ + compute.InstanceWithNamedPorts(), + compute.InstanceWithNamedPorts(), + compute.InstanceWithNamedPorts(), + ], + next_page_token="abc", + ), + compute.RegionInstanceGroupsListInstances(items=[], next_page_token="def",), + compute.RegionInstanceGroupsListInstances( + items=[compute.InstanceWithNamedPorts(),], next_page_token="ghi", + ), + compute.RegionInstanceGroupsListInstances( + items=[ + compute.InstanceWithNamedPorts(), + compute.InstanceWithNamedPorts(), + ], + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple( + compute.RegionInstanceGroupsListInstances.to_json(x) for x in response + ) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list_instances(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.InstanceWithNamedPorts) for i in results) + + pages = list(client.list_instances(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_set_named_ports_rest( transport: str = "rest", request_type=compute.SetNamedPortsRegionInstanceGroupRequest, @@ -780,6 +874,7 @@ def test_set_named_ports_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -830,6 +925,7 @@ def test_set_named_ports_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -850,7 +946,7 @@ def test_set_named_ports_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -861,6 +957,7 @@ def test_set_named_ports_rest_flattened(): assert compute.RegionInstanceGroupsSetNamedPortsRequest.to_json( region_instance_groups_set_named_ports_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1010,6 +1107,17 @@ def test_region_instance_groups_auth_adc(): ) +def test_region_instance_groups_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.RegionInstanceGroupsRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_region_instance_groups_host_no_port(): client = RegionInstanceGroupsClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_region_network_endpoint_groups.py b/tests/unit/gapic/compute_v1/test_region_network_endpoint_groups.py index e8990c982..4b095591e 100644 --- a/tests/unit/gapic/compute_v1/test_region_network_endpoint_groups.py +++ b/tests/unit/gapic/compute_v1/test_region_network_endpoint_groups.py @@ -37,6 +37,7 @@ from google.cloud.compute_v1.services.region_network_endpoint_groups import ( RegionNetworkEndpointGroupsClient, ) +from google.cloud.compute_v1.services.region_network_endpoint_groups import pagers from google.cloud.compute_v1.services.region_network_endpoint_groups import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -172,7 +173,7 @@ def test_region_network_endpoint_groups_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -188,7 +189,7 @@ def test_region_network_endpoint_groups_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -204,7 +205,7 @@ def test_region_network_endpoint_groups_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -232,7 +233,7 @@ def test_region_network_endpoint_groups_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -276,29 +277,25 @@ def test_region_network_endpoint_groups_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -307,66 +304,53 @@ def test_region_network_endpoint_groups_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -391,7 +375,7 @@ def test_region_network_endpoint_groups_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -420,7 +404,7 @@ def test_region_network_endpoint_groups_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -469,6 +453,7 @@ def test_delete_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -521,6 +506,7 @@ def test_delete_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -536,7 +522,7 @@ def test_delete_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -599,6 +585,7 @@ def test_get_rest( # Wrap the value into a proper Response obj json_return_value = compute.NetworkEndpointGroup.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -652,6 +639,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.NetworkEndpointGroup.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -667,7 +655,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -735,6 +723,7 @@ def test_insert_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -787,6 +776,7 @@ def test_insert_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -806,14 +796,16 @@ def test_insert_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "region_value" in http_call[1] + str(body) assert compute.NetworkEndpointGroup.to_json( - network_endpoint_group_resource, including_default_value_fields=False + network_endpoint_group_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -862,16 +854,15 @@ def test_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.NetworkEndpointGroupList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.NetworkEndpointGroupList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [ compute.NetworkEndpointGroup(annotations={"key_value": "value_value"}) @@ -899,6 +890,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.NetworkEndpointGroupList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -912,7 +904,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -934,6 +926,59 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = RegionNetworkEndpointGroupsClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.NetworkEndpointGroupList( + items=[ + compute.NetworkEndpointGroup(), + compute.NetworkEndpointGroup(), + compute.NetworkEndpointGroup(), + ], + next_page_token="abc", + ), + compute.NetworkEndpointGroupList(items=[], next_page_token="def",), + compute.NetworkEndpointGroupList( + items=[compute.NetworkEndpointGroup(),], next_page_token="ghi", + ), + compute.NetworkEndpointGroupList( + items=[compute.NetworkEndpointGroup(), compute.NetworkEndpointGroup(),], + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.NetworkEndpointGroupList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.NetworkEndpointGroup) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_credentials_transport_error(): # It is an error to provide credentials and a transport instance. transport = transports.RegionNetworkEndpointGroupsRestTransport( @@ -1063,6 +1108,17 @@ def test_region_network_endpoint_groups_auth_adc(): ) +def test_region_network_endpoint_groups_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.RegionNetworkEndpointGroupsRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_region_network_endpoint_groups_host_no_port(): client = RegionNetworkEndpointGroupsClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_region_notification_endpoints.py b/tests/unit/gapic/compute_v1/test_region_notification_endpoints.py index 98dc8d1fa..940d8730e 100644 --- a/tests/unit/gapic/compute_v1/test_region_notification_endpoints.py +++ b/tests/unit/gapic/compute_v1/test_region_notification_endpoints.py @@ -37,6 +37,7 @@ from google.cloud.compute_v1.services.region_notification_endpoints import ( RegionNotificationEndpointsClient, ) +from google.cloud.compute_v1.services.region_notification_endpoints import pagers from google.cloud.compute_v1.services.region_notification_endpoints import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -172,7 +173,7 @@ def test_region_notification_endpoints_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -188,7 +189,7 @@ def test_region_notification_endpoints_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -204,7 +205,7 @@ def test_region_notification_endpoints_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -232,7 +233,7 @@ def test_region_notification_endpoints_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -276,29 +277,25 @@ def test_region_notification_endpoints_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -307,66 +304,53 @@ def test_region_notification_endpoints_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -391,7 +375,7 @@ def test_region_notification_endpoints_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -420,7 +404,7 @@ def test_region_notification_endpoints_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -469,6 +453,7 @@ def test_delete_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -521,6 +506,7 @@ def test_delete_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -536,7 +522,7 @@ def test_delete_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -590,6 +576,7 @@ def test_get_rest( # Wrap the value into a proper Response obj json_return_value = compute.NotificationEndpoint.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -627,6 +614,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.NotificationEndpoint.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -642,7 +630,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -710,6 +698,7 @@ def test_insert_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -762,6 +751,7 @@ def test_insert_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -781,14 +771,16 @@ def test_insert_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "region_value" in http_call[1] + str(body) assert compute.NotificationEndpoint.to_json( - notification_endpoint_resource, including_default_value_fields=False + notification_endpoint_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -839,16 +831,15 @@ def test_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.NotificationEndpointList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.NotificationEndpointList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [ compute.NotificationEndpoint(creation_timestamp="creation_timestamp_value") @@ -876,6 +867,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.NotificationEndpointList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -889,7 +881,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -911,6 +903,59 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = RegionNotificationEndpointsClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.NotificationEndpointList( + items=[ + compute.NotificationEndpoint(), + compute.NotificationEndpoint(), + compute.NotificationEndpoint(), + ], + next_page_token="abc", + ), + compute.NotificationEndpointList(items=[], next_page_token="def",), + compute.NotificationEndpointList( + items=[compute.NotificationEndpoint(),], next_page_token="ghi", + ), + compute.NotificationEndpointList( + items=[compute.NotificationEndpoint(), compute.NotificationEndpoint(),], + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.NotificationEndpointList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.NotificationEndpoint) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_credentials_transport_error(): # It is an error to provide credentials and a transport instance. transport = transports.RegionNotificationEndpointsRestTransport( @@ -1040,6 +1085,17 @@ def test_region_notification_endpoints_auth_adc(): ) +def test_region_notification_endpoints_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.RegionNotificationEndpointsRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_region_notification_endpoints_host_no_port(): client = RegionNotificationEndpointsClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_region_operations.py b/tests/unit/gapic/compute_v1/test_region_operations.py index 115b5c686..5ce7bdf38 100644 --- a/tests/unit/gapic/compute_v1/test_region_operations.py +++ b/tests/unit/gapic/compute_v1/test_region_operations.py @@ -35,6 +35,7 @@ from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.compute_v1.services.region_operations import RegionOperationsClient +from google.cloud.compute_v1.services.region_operations import pagers from google.cloud.compute_v1.services.region_operations import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -158,7 +159,7 @@ def test_region_operations_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -174,7 +175,7 @@ def test_region_operations_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -190,7 +191,7 @@ def test_region_operations_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -218,7 +219,7 @@ def test_region_operations_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -262,29 +263,25 @@ def test_region_operations_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -293,66 +290,53 @@ def test_region_operations_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -371,7 +355,7 @@ def test_region_operations_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -394,7 +378,7 @@ def test_region_operations_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -418,6 +402,7 @@ def test_delete_rest( # Wrap the value into a proper Response obj json_return_value = compute.DeleteRegionOperationResponse.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -443,6 +428,7 @@ def test_delete_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.DeleteRegionOperationResponse.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -456,7 +442,7 @@ def test_delete_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -521,6 +507,7 @@ def test_get_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -571,6 +558,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -584,7 +572,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -632,16 +620,15 @@ def test_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.OperationList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.OperationList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [ compute.Operation(client_operation_id="client_operation_id_value") @@ -667,6 +654,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.OperationList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -680,7 +668,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -700,6 +688,49 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = RegionOperationsClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.OperationList( + items=[compute.Operation(), compute.Operation(), compute.Operation(),], + next_page_token="abc", + ), + compute.OperationList(items=[], next_page_token="def",), + compute.OperationList(items=[compute.Operation(),], next_page_token="ghi",), + compute.OperationList(items=[compute.Operation(), compute.Operation(),],), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.OperationList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.Operation) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_wait_rest( transport: str = "rest", request_type=compute.WaitRegionOperationRequest ): @@ -742,6 +773,7 @@ def test_wait_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -792,6 +824,7 @@ def test_wait_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -805,7 +838,7 @@ def test_wait_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -955,6 +988,17 @@ def test_region_operations_auth_adc(): ) +def test_region_operations_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.RegionOperationsRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_region_operations_host_no_port(): client = RegionOperationsClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_region_ssl_certificates.py b/tests/unit/gapic/compute_v1/test_region_ssl_certificates.py index c0733bddf..0deb0805d 100644 --- a/tests/unit/gapic/compute_v1/test_region_ssl_certificates.py +++ b/tests/unit/gapic/compute_v1/test_region_ssl_certificates.py @@ -37,6 +37,7 @@ from google.cloud.compute_v1.services.region_ssl_certificates import ( RegionSslCertificatesClient, ) +from google.cloud.compute_v1.services.region_ssl_certificates import pagers from google.cloud.compute_v1.services.region_ssl_certificates import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -166,7 +167,7 @@ def test_region_ssl_certificates_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -182,7 +183,7 @@ def test_region_ssl_certificates_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -198,7 +199,7 @@ def test_region_ssl_certificates_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -226,7 +227,7 @@ def test_region_ssl_certificates_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -270,29 +271,25 @@ def test_region_ssl_certificates_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -301,66 +298,53 @@ def test_region_ssl_certificates_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -385,7 +369,7 @@ def test_region_ssl_certificates_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -414,7 +398,7 @@ def test_region_ssl_certificates_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -462,6 +446,7 @@ def test_delete_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -514,6 +499,7 @@ def test_delete_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -529,7 +515,7 @@ def test_delete_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -591,6 +577,7 @@ def test_get_rest( # Wrap the value into a proper Response obj json_return_value = compute.SslCertificate.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -636,6 +623,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.SslCertificate.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -651,7 +639,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -718,6 +706,7 @@ def test_insert_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -770,6 +759,7 @@ def test_insert_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -789,14 +779,16 @@ def test_insert_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "region_value" in http_call[1] + str(body) assert compute.SslCertificate.to_json( - ssl_certificate_resource, including_default_value_fields=False + ssl_certificate_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -843,16 +835,15 @@ def test_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.SslCertificateList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.SslCertificateList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [compute.SslCertificate(certificate="certificate_value")] assert response.kind == "kind_value" @@ -878,6 +869,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.SslCertificateList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -891,7 +883,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -913,6 +905,59 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = RegionSslCertificatesClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.SslCertificateList( + items=[ + compute.SslCertificate(), + compute.SslCertificate(), + compute.SslCertificate(), + ], + next_page_token="abc", + ), + compute.SslCertificateList(items=[], next_page_token="def",), + compute.SslCertificateList( + items=[compute.SslCertificate(),], next_page_token="ghi", + ), + compute.SslCertificateList( + items=[compute.SslCertificate(), compute.SslCertificate(),], + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.SslCertificateList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.SslCertificate) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_credentials_transport_error(): # It is an error to provide credentials and a transport instance. transport = transports.RegionSslCertificatesRestTransport( @@ -1042,6 +1087,17 @@ def test_region_ssl_certificates_auth_adc(): ) +def test_region_ssl_certificates_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.RegionSslCertificatesRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_region_ssl_certificates_host_no_port(): client = RegionSslCertificatesClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_region_target_http_proxies.py b/tests/unit/gapic/compute_v1/test_region_target_http_proxies.py index 8c6e988a5..d2235532d 100644 --- a/tests/unit/gapic/compute_v1/test_region_target_http_proxies.py +++ b/tests/unit/gapic/compute_v1/test_region_target_http_proxies.py @@ -37,6 +37,7 @@ from google.cloud.compute_v1.services.region_target_http_proxies import ( RegionTargetHttpProxiesClient, ) +from google.cloud.compute_v1.services.region_target_http_proxies import pagers from google.cloud.compute_v1.services.region_target_http_proxies import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -166,7 +167,7 @@ def test_region_target_http_proxies_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -182,7 +183,7 @@ def test_region_target_http_proxies_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -198,7 +199,7 @@ def test_region_target_http_proxies_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -226,7 +227,7 @@ def test_region_target_http_proxies_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -270,29 +271,25 @@ def test_region_target_http_proxies_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -301,66 +298,53 @@ def test_region_target_http_proxies_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -385,7 +369,7 @@ def test_region_target_http_proxies_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -414,7 +398,7 @@ def test_region_target_http_proxies_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -462,6 +446,7 @@ def test_delete_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -514,6 +499,7 @@ def test_delete_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -529,7 +515,7 @@ def test_delete_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -583,6 +569,7 @@ def test_get_rest( # Wrap the value into a proper Response obj json_return_value = compute.TargetHttpProxy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -621,6 +608,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.TargetHttpProxy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -636,7 +624,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -703,6 +691,7 @@ def test_insert_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -755,6 +744,7 @@ def test_insert_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -774,14 +764,16 @@ def test_insert_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "region_value" in http_call[1] + str(body) assert compute.TargetHttpProxy.to_json( - target_http_proxy_resource, including_default_value_fields=False + target_http_proxy_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -830,16 +822,15 @@ def test_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.TargetHttpProxyList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.TargetHttpProxyList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [ compute.TargetHttpProxy(creation_timestamp="creation_timestamp_value") @@ -867,6 +858,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.TargetHttpProxyList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -880,7 +872,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -902,6 +894,59 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = RegionTargetHttpProxiesClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.TargetHttpProxyList( + items=[ + compute.TargetHttpProxy(), + compute.TargetHttpProxy(), + compute.TargetHttpProxy(), + ], + next_page_token="abc", + ), + compute.TargetHttpProxyList(items=[], next_page_token="def",), + compute.TargetHttpProxyList( + items=[compute.TargetHttpProxy(),], next_page_token="ghi", + ), + compute.TargetHttpProxyList( + items=[compute.TargetHttpProxy(), compute.TargetHttpProxy(),], + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.TargetHttpProxyList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.TargetHttpProxy) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_set_url_map_rest( transport: str = "rest", request_type=compute.SetUrlMapRegionTargetHttpProxyRequest ): @@ -944,6 +989,7 @@ def test_set_url_map_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -996,6 +1042,7 @@ def test_set_url_map_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1014,7 +1061,7 @@ def test_set_url_map_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1023,7 +1070,9 @@ def test_set_url_map_rest_flattened(): assert "target_http_proxy_value" in http_call[1] + str(body) assert compute.UrlMapReference.to_json( - url_map_reference_resource, including_default_value_fields=False + url_map_reference_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1174,6 +1223,17 @@ def test_region_target_http_proxies_auth_adc(): ) +def test_region_target_http_proxies_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.RegionTargetHttpProxiesRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_region_target_http_proxies_host_no_port(): client = RegionTargetHttpProxiesClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_region_target_https_proxies.py b/tests/unit/gapic/compute_v1/test_region_target_https_proxies.py index ea1d126af..df3cab9b9 100644 --- a/tests/unit/gapic/compute_v1/test_region_target_https_proxies.py +++ b/tests/unit/gapic/compute_v1/test_region_target_https_proxies.py @@ -37,6 +37,7 @@ from google.cloud.compute_v1.services.region_target_https_proxies import ( RegionTargetHttpsProxiesClient, ) +from google.cloud.compute_v1.services.region_target_https_proxies import pagers from google.cloud.compute_v1.services.region_target_https_proxies import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -170,7 +171,7 @@ def test_region_target_https_proxies_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -186,7 +187,7 @@ def test_region_target_https_proxies_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -202,7 +203,7 @@ def test_region_target_https_proxies_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -230,7 +231,7 @@ def test_region_target_https_proxies_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -274,29 +275,25 @@ def test_region_target_https_proxies_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -305,66 +302,53 @@ def test_region_target_https_proxies_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -389,7 +373,7 @@ def test_region_target_https_proxies_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -418,7 +402,7 @@ def test_region_target_https_proxies_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -466,6 +450,7 @@ def test_delete_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -518,6 +503,7 @@ def test_delete_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -533,7 +519,7 @@ def test_delete_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -591,6 +577,7 @@ def test_get_rest( # Wrap the value into a proper Response obj json_return_value = compute.TargetHttpsProxy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -633,6 +620,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.TargetHttpsProxy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -648,7 +636,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -715,6 +703,7 @@ def test_insert_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -767,6 +756,7 @@ def test_insert_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -786,14 +776,16 @@ def test_insert_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "region_value" in http_call[1] + str(body) assert compute.TargetHttpsProxy.to_json( - target_https_proxy_resource, including_default_value_fields=False + target_https_proxy_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -844,16 +836,15 @@ def test_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.TargetHttpsProxyList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.TargetHttpsProxyList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [ compute.TargetHttpsProxy(authorization_policy="authorization_policy_value") @@ -881,6 +872,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.TargetHttpsProxyList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -894,7 +886,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -916,6 +908,59 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = RegionTargetHttpsProxiesClient( + credentials=credentials.AnonymousCredentials(), + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.TargetHttpsProxyList( + items=[ + compute.TargetHttpsProxy(), + compute.TargetHttpsProxy(), + compute.TargetHttpsProxy(), + ], + next_page_token="abc", + ), + compute.TargetHttpsProxyList(items=[], next_page_token="def",), + compute.TargetHttpsProxyList( + items=[compute.TargetHttpsProxy(),], next_page_token="ghi", + ), + compute.TargetHttpsProxyList( + items=[compute.TargetHttpsProxy(), compute.TargetHttpsProxy(),], + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.TargetHttpsProxyList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.TargetHttpsProxy) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_set_ssl_certificates_rest( transport: str = "rest", request_type=compute.SetSslCertificatesRegionTargetHttpsProxyRequest, @@ -959,6 +1004,7 @@ def test_set_ssl_certificates_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1011,6 +1057,7 @@ def test_set_ssl_certificates_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1031,7 +1078,7 @@ def test_set_ssl_certificates_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1042,6 +1089,7 @@ def test_set_ssl_certificates_rest_flattened(): assert compute.RegionTargetHttpsProxiesSetSslCertificatesRequest.to_json( region_target_https_proxies_set_ssl_certificates_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1106,6 +1154,7 @@ def test_set_url_map_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1158,6 +1207,7 @@ def test_set_url_map_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1176,7 +1226,7 @@ def test_set_url_map_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1185,7 +1235,9 @@ def test_set_url_map_rest_flattened(): assert "target_https_proxy_value" in http_call[1] + str(body) assert compute.UrlMapReference.to_json( - url_map_reference_resource, including_default_value_fields=False + url_map_reference_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1337,6 +1389,17 @@ def test_region_target_https_proxies_auth_adc(): ) +def test_region_target_https_proxies_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.RegionTargetHttpsProxiesRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_region_target_https_proxies_host_no_port(): client = RegionTargetHttpsProxiesClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_region_url_maps.py b/tests/unit/gapic/compute_v1/test_region_url_maps.py index 73ea91e93..41132a284 100644 --- a/tests/unit/gapic/compute_v1/test_region_url_maps.py +++ b/tests/unit/gapic/compute_v1/test_region_url_maps.py @@ -35,6 +35,7 @@ from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.compute_v1.services.region_url_maps import RegionUrlMapsClient +from google.cloud.compute_v1.services.region_url_maps import pagers from google.cloud.compute_v1.services.region_url_maps import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -157,7 +158,7 @@ def test_region_url_maps_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -173,7 +174,7 @@ def test_region_url_maps_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -189,7 +190,7 @@ def test_region_url_maps_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -217,7 +218,7 @@ def test_region_url_maps_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -251,29 +252,25 @@ def test_region_url_maps_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -282,66 +279,53 @@ def test_region_url_maps_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -360,7 +344,7 @@ def test_region_url_maps_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -383,7 +367,7 @@ def test_region_url_maps_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -431,6 +415,7 @@ def test_delete_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -481,6 +466,7 @@ def test_delete_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -494,7 +480,7 @@ def test_delete_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -563,6 +549,7 @@ def test_get_rest(transport: str = "rest", request_type=compute.GetRegionUrlMapR # Wrap the value into a proper Response obj json_return_value = compute.UrlMap.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -617,6 +604,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.UrlMap.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -630,7 +618,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -695,6 +683,7 @@ def test_insert_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -745,6 +734,7 @@ def test_insert_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -762,14 +752,16 @@ def test_insert_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "region_value" in http_call[1] + str(body) assert compute.UrlMap.to_json( - url_map_resource, including_default_value_fields=False + url_map_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -814,16 +806,15 @@ def test_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.UrlMapList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.UrlMapList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [ compute.UrlMap(creation_timestamp="creation_timestamp_value") @@ -849,6 +840,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.UrlMapList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -862,7 +854,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -882,6 +874,49 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = RegionUrlMapsClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.UrlMapList( + items=[compute.UrlMap(), compute.UrlMap(), compute.UrlMap(),], + next_page_token="abc", + ), + compute.UrlMapList(items=[], next_page_token="def",), + compute.UrlMapList(items=[compute.UrlMap(),], next_page_token="ghi",), + compute.UrlMapList(items=[compute.UrlMap(), compute.UrlMap(),],), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.UrlMapList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.UrlMap) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_patch_rest( transport: str = "rest", request_type=compute.PatchRegionUrlMapRequest ): @@ -924,6 +959,7 @@ def test_patch_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -974,6 +1010,7 @@ def test_patch_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -992,7 +1029,7 @@ def test_patch_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1001,7 +1038,9 @@ def test_patch_rest_flattened(): assert "url_map_value" in http_call[1] + str(body) assert compute.UrlMap.to_json( - url_map_resource, including_default_value_fields=False + url_map_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1064,6 +1103,7 @@ def test_update_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1114,6 +1154,7 @@ def test_update_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1132,7 +1173,7 @@ def test_update_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1141,7 +1182,9 @@ def test_update_rest_flattened(): assert "url_map_value" in http_call[1] + str(body) assert compute.UrlMap.to_json( - url_map_resource, including_default_value_fields=False + url_map_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1182,6 +1225,7 @@ def test_validate_rest( # Wrap the value into a proper Response obj json_return_value = compute.UrlMapsValidateResponse.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1210,6 +1254,7 @@ def test_validate_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.UrlMapsValidateResponse.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1230,7 +1275,7 @@ def test_validate_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1241,6 +1286,7 @@ def test_validate_rest_flattened(): assert compute.RegionUrlMapsValidateRequest.to_json( region_url_maps_validate_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1391,6 +1437,17 @@ def test_region_url_maps_auth_adc(): ) +def test_region_url_maps_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.RegionUrlMapsRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_region_url_maps_host_no_port(): client = RegionUrlMapsClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_regions.py b/tests/unit/gapic/compute_v1/test_regions.py index c91e7dacf..35598ef81 100644 --- a/tests/unit/gapic/compute_v1/test_regions.py +++ b/tests/unit/gapic/compute_v1/test_regions.py @@ -35,6 +35,7 @@ from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.compute_v1.services.regions import RegionsClient +from google.cloud.compute_v1.services.regions import pagers from google.cloud.compute_v1.services.regions import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -147,7 +148,7 @@ def test_regions_client_client_options(client_class, transport_class, transport_ credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -163,7 +164,7 @@ def test_regions_client_client_options(client_class, transport_class, transport_ credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -179,7 +180,7 @@ def test_regions_client_client_options(client_class, transport_class, transport_ credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -207,7 +208,7 @@ def test_regions_client_client_options(client_class, transport_class, transport_ credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -239,29 +240,25 @@ def test_regions_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -270,66 +267,53 @@ def test_regions_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -348,7 +332,7 @@ def test_regions_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -371,7 +355,7 @@ def test_regions_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -404,6 +388,7 @@ def test_get_rest(transport: str = "rest", request_type=compute.GetRegionRequest # Wrap the value into a proper Response obj json_return_value = compute.Region.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -439,6 +424,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Region.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -452,7 +438,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -493,16 +479,15 @@ def test_list_rest(transport: str = "rest", request_type=compute.ListRegionsRequ # Wrap the value into a proper Response obj json_return_value = compute.RegionList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.RegionList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [ compute.Region(creation_timestamp="creation_timestamp_value") @@ -528,6 +513,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.RegionList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -539,7 +525,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -555,6 +541,49 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = RegionsClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.RegionList( + items=[compute.Region(), compute.Region(), compute.Region(),], + next_page_token="abc", + ), + compute.RegionList(items=[], next_page_token="def",), + compute.RegionList(items=[compute.Region(),], next_page_token="ghi",), + compute.RegionList(items=[compute.Region(), compute.Region(),],), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.RegionList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.Region) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_credentials_transport_error(): # It is an error to provide credentials and a transport instance. transport = transports.RegionsRestTransport( @@ -682,6 +711,17 @@ def test_regions_auth_adc(): ) +def test_regions_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.RegionsRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_regions_host_no_port(): client = RegionsClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_reservations.py b/tests/unit/gapic/compute_v1/test_reservations.py index 6c4906930..7779697b6 100644 --- a/tests/unit/gapic/compute_v1/test_reservations.py +++ b/tests/unit/gapic/compute_v1/test_reservations.py @@ -35,6 +35,7 @@ from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.compute_v1.services.reservations import ReservationsClient +from google.cloud.compute_v1.services.reservations import pagers from google.cloud.compute_v1.services.reservations import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -152,7 +153,7 @@ def test_reservations_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -168,7 +169,7 @@ def test_reservations_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -184,7 +185,7 @@ def test_reservations_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -212,7 +213,7 @@ def test_reservations_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -244,29 +245,25 @@ def test_reservations_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -275,66 +272,53 @@ def test_reservations_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -353,7 +337,7 @@ def test_reservations_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -376,7 +360,7 @@ def test_reservations_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -412,16 +396,15 @@ def test_aggregated_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.ReservationAggregatedList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.aggregated_list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.ReservationAggregatedList) + assert isinstance(response, pagers.AggregatedListPager) assert response.id == "id_value" assert response.items == { "key_value": compute.ReservationsScopedList( @@ -450,6 +433,7 @@ def test_aggregated_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.ReservationAggregatedList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -461,7 +445,7 @@ def test_aggregated_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -477,6 +461,72 @@ def test_aggregated_list_rest_flattened_error(): ) +def test_aggregated_list_pager(): + client = ReservationsClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.ReservationAggregatedList( + items={ + "a": compute.ReservationsScopedList(), + "b": compute.ReservationsScopedList(), + "c": compute.ReservationsScopedList(), + }, + next_page_token="abc", + ), + compute.ReservationAggregatedList(items={}, next_page_token="def",), + compute.ReservationAggregatedList( + items={"g": compute.ReservationsScopedList(),}, next_page_token="ghi", + ), + compute.ReservationAggregatedList( + items={ + "h": compute.ReservationsScopedList(), + "i": compute.ReservationsScopedList(), + }, + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.ReservationAggregatedList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.aggregated_list(request={}) + + assert pager._metadata == metadata + + assert isinstance(pager.get("a"), compute.ReservationsScopedList) + assert pager.get("h") is None + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, tuple) for i in results) + for result in results: + assert isinstance(result, tuple) + assert tuple(type(t) for t in result) == ( + str, + compute.ReservationsScopedList, + ) + + assert pager.get("a") is None + assert isinstance(pager.get("h"), compute.ReservationsScopedList) + + pages = list(client.aggregated_list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_delete_rest( transport: str = "rest", request_type=compute.DeleteReservationRequest ): @@ -519,6 +569,7 @@ def test_delete_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -569,6 +620,7 @@ def test_delete_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -582,7 +634,7 @@ def test_delete_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -635,6 +687,7 @@ def test_get_rest(transport: str = "rest", request_type=compute.GetReservationRe # Wrap the value into a proper Response obj json_return_value = compute.Reservation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -674,6 +727,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Reservation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -687,7 +741,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -743,6 +797,7 @@ def test_get_iam_policy_rest( # Wrap the value into a proper Response obj json_return_value = compute.Policy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -781,6 +836,7 @@ def test_get_iam_policy_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Policy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -794,7 +850,7 @@ def test_get_iam_policy_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -859,6 +915,7 @@ def test_insert_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -909,6 +966,7 @@ def test_insert_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -926,14 +984,16 @@ def test_insert_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "zone_value" in http_call[1] + str(body) assert compute.Reservation.to_json( - reservation_resource, including_default_value_fields=False + reservation_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -976,16 +1036,15 @@ def test_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.ReservationList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.ReservationList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [compute.Reservation(commitment="commitment_value")] assert response.kind == "kind_value" @@ -1009,6 +1068,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.ReservationList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1022,7 +1082,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1042,6 +1102,57 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = ReservationsClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.ReservationList( + items=[ + compute.Reservation(), + compute.Reservation(), + compute.Reservation(), + ], + next_page_token="abc", + ), + compute.ReservationList(items=[], next_page_token="def",), + compute.ReservationList( + items=[compute.Reservation(),], next_page_token="ghi", + ), + compute.ReservationList( + items=[compute.Reservation(), compute.Reservation(),], + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.ReservationList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.Reservation) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_resize_rest( transport: str = "rest", request_type=compute.ResizeReservationRequest ): @@ -1084,6 +1195,7 @@ def test_resize_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1134,6 +1246,7 @@ def test_resize_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1154,7 +1267,7 @@ def test_resize_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1163,7 +1276,9 @@ def test_resize_rest_flattened(): assert "reservation_value" in http_call[1] + str(body) assert compute.ReservationsResizeRequest.to_json( - reservations_resize_request_resource, including_default_value_fields=False + reservations_resize_request_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1217,6 +1332,7 @@ def test_set_iam_policy_rest( # Wrap the value into a proper Response obj json_return_value = compute.Policy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1255,6 +1371,7 @@ def test_set_iam_policy_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Policy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1275,7 +1392,7 @@ def test_set_iam_policy_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1284,7 +1401,9 @@ def test_set_iam_policy_rest_flattened(): assert "resource_value" in http_call[1] + str(body) assert compute.ZoneSetPolicyRequest.to_json( - zone_set_policy_request_resource, including_default_value_fields=False + zone_set_policy_request_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1325,6 +1444,7 @@ def test_test_iam_permissions_rest( # Wrap the value into a proper Response obj json_return_value = compute.TestPermissionsResponse.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1351,6 +1471,7 @@ def test_test_iam_permissions_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.TestPermissionsResponse.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1371,7 +1492,7 @@ def test_test_iam_permissions_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1380,7 +1501,9 @@ def test_test_iam_permissions_rest_flattened(): assert "resource_value" in http_call[1] + str(body) assert compute.TestPermissionsRequest.to_json( - test_permissions_request_resource, including_default_value_fields=False + test_permissions_request_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1533,6 +1656,17 @@ def test_reservations_auth_adc(): ) +def test_reservations_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.ReservationsRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_reservations_host_no_port(): client = ReservationsClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_resource_policies.py b/tests/unit/gapic/compute_v1/test_resource_policies.py index c9973b437..7bc23a452 100644 --- a/tests/unit/gapic/compute_v1/test_resource_policies.py +++ b/tests/unit/gapic/compute_v1/test_resource_policies.py @@ -35,6 +35,7 @@ from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.compute_v1.services.resource_policies import ResourcePoliciesClient +from google.cloud.compute_v1.services.resource_policies import pagers from google.cloud.compute_v1.services.resource_policies import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -158,7 +159,7 @@ def test_resource_policies_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -174,7 +175,7 @@ def test_resource_policies_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -190,7 +191,7 @@ def test_resource_policies_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -218,7 +219,7 @@ def test_resource_policies_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -262,29 +263,25 @@ def test_resource_policies_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -293,66 +290,53 @@ def test_resource_policies_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -371,7 +355,7 @@ def test_resource_policies_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -394,7 +378,7 @@ def test_resource_policies_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -435,16 +419,15 @@ def test_aggregated_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.ResourcePolicyAggregatedList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.aggregated_list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.ResourcePolicyAggregatedList) + assert isinstance(response, pagers.AggregatedListPager) assert response.etag == "etag_value" assert response.id == "id_value" assert response.items == { @@ -476,6 +459,7 @@ def test_aggregated_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.ResourcePolicyAggregatedList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -487,7 +471,7 @@ def test_aggregated_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -503,6 +487,75 @@ def test_aggregated_list_rest_flattened_error(): ) +def test_aggregated_list_pager(): + client = ResourcePoliciesClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.ResourcePolicyAggregatedList( + items={ + "a": compute.ResourcePoliciesScopedList(), + "b": compute.ResourcePoliciesScopedList(), + "c": compute.ResourcePoliciesScopedList(), + }, + next_page_token="abc", + ), + compute.ResourcePolicyAggregatedList(items={}, next_page_token="def",), + compute.ResourcePolicyAggregatedList( + items={"g": compute.ResourcePoliciesScopedList(),}, + next_page_token="ghi", + ), + compute.ResourcePolicyAggregatedList( + items={ + "h": compute.ResourcePoliciesScopedList(), + "i": compute.ResourcePoliciesScopedList(), + }, + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple( + compute.ResourcePolicyAggregatedList.to_json(x) for x in response + ) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.aggregated_list(request={}) + + assert pager._metadata == metadata + + assert isinstance(pager.get("a"), compute.ResourcePoliciesScopedList) + assert pager.get("h") is None + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, tuple) for i in results) + for result in results: + assert isinstance(result, tuple) + assert tuple(type(t) for t in result) == ( + str, + compute.ResourcePoliciesScopedList, + ) + + assert pager.get("a") is None + assert isinstance(pager.get("h"), compute.ResourcePoliciesScopedList) + + pages = list(client.aggregated_list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_delete_rest( transport: str = "rest", request_type=compute.DeleteResourcePolicyRequest ): @@ -545,6 +598,7 @@ def test_delete_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -595,6 +649,7 @@ def test_delete_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -610,7 +665,7 @@ def test_delete_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -668,6 +723,7 @@ def test_get_rest( # Wrap the value into a proper Response obj json_return_value = compute.ResourcePolicy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -713,6 +769,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.ResourcePolicy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -728,7 +785,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -784,6 +841,7 @@ def test_get_iam_policy_rest( # Wrap the value into a proper Response obj json_return_value = compute.Policy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -822,6 +880,7 @@ def test_get_iam_policy_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Policy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -835,7 +894,7 @@ def test_get_iam_policy_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -900,6 +959,7 @@ def test_insert_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -950,6 +1010,7 @@ def test_insert_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -969,14 +1030,16 @@ def test_insert_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "region_value" in http_call[1] + str(body) assert compute.ResourcePolicy.to_json( - resource_policy_resource, including_default_value_fields=False + resource_policy_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1024,16 +1087,15 @@ def test_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.ResourcePolicyList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.ResourcePolicyList) + assert isinstance(response, pagers.ListPager) assert response.etag == "etag_value" assert response.id == "id_value" assert response.items == [ @@ -1060,6 +1122,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.ResourcePolicyList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1073,7 +1136,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1093,6 +1156,57 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = ResourcePoliciesClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.ResourcePolicyList( + items=[ + compute.ResourcePolicy(), + compute.ResourcePolicy(), + compute.ResourcePolicy(), + ], + next_page_token="abc", + ), + compute.ResourcePolicyList(items=[], next_page_token="def",), + compute.ResourcePolicyList( + items=[compute.ResourcePolicy(),], next_page_token="ghi", + ), + compute.ResourcePolicyList( + items=[compute.ResourcePolicy(), compute.ResourcePolicy(),], + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.ResourcePolicyList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.ResourcePolicy) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_set_iam_policy_rest( transport: str = "rest", request_type=compute.SetIamPolicyResourcePolicyRequest ): @@ -1126,6 +1240,7 @@ def test_set_iam_policy_rest( # Wrap the value into a proper Response obj json_return_value = compute.Policy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1164,6 +1279,7 @@ def test_set_iam_policy_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Policy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1184,7 +1300,7 @@ def test_set_iam_policy_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1193,7 +1309,9 @@ def test_set_iam_policy_rest_flattened(): assert "resource_value" in http_call[1] + str(body) assert compute.RegionSetPolicyRequest.to_json( - region_set_policy_request_resource, including_default_value_fields=False + region_set_policy_request_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1235,6 +1353,7 @@ def test_test_iam_permissions_rest( # Wrap the value into a proper Response obj json_return_value = compute.TestPermissionsResponse.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1261,6 +1380,7 @@ def test_test_iam_permissions_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.TestPermissionsResponse.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1281,7 +1401,7 @@ def test_test_iam_permissions_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1290,7 +1410,9 @@ def test_test_iam_permissions_rest_flattened(): assert "resource_value" in http_call[1] + str(body) assert compute.TestPermissionsRequest.to_json( - test_permissions_request_resource, including_default_value_fields=False + test_permissions_request_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1442,6 +1564,17 @@ def test_resource_policies_auth_adc(): ) +def test_resource_policies_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.ResourcePoliciesRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_resource_policies_host_no_port(): client = ResourcePoliciesClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_routers.py b/tests/unit/gapic/compute_v1/test_routers.py index 0d228b57f..00d4348ef 100644 --- a/tests/unit/gapic/compute_v1/test_routers.py +++ b/tests/unit/gapic/compute_v1/test_routers.py @@ -35,6 +35,7 @@ from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.compute_v1.services.routers import RoutersClient +from google.cloud.compute_v1.services.routers import pagers from google.cloud.compute_v1.services.routers import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -147,7 +148,7 @@ def test_routers_client_client_options(client_class, transport_class, transport_ credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -163,7 +164,7 @@ def test_routers_client_client_options(client_class, transport_class, transport_ credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -179,7 +180,7 @@ def test_routers_client_client_options(client_class, transport_class, transport_ credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -207,7 +208,7 @@ def test_routers_client_client_options(client_class, transport_class, transport_ credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -239,29 +240,25 @@ def test_routers_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -270,66 +267,53 @@ def test_routers_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) - - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -348,7 +332,7 @@ def test_routers_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -371,7 +355,7 @@ def test_routers_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -413,16 +397,15 @@ def test_aggregated_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.RouterAggregatedList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.aggregated_list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.RouterAggregatedList) + assert isinstance(response, pagers.AggregatedListPager) assert response.id == "id_value" assert response.items == { "key_value": compute.RoutersScopedList( @@ -457,6 +440,7 @@ def test_aggregated_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.RouterAggregatedList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -468,7 +452,7 @@ def test_aggregated_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -484,6 +468,69 @@ def test_aggregated_list_rest_flattened_error(): ) +def test_aggregated_list_pager(): + client = RoutersClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.RouterAggregatedList( + items={ + "a": compute.RoutersScopedList(), + "b": compute.RoutersScopedList(), + "c": compute.RoutersScopedList(), + }, + next_page_token="abc", + ), + compute.RouterAggregatedList(items={}, next_page_token="def",), + compute.RouterAggregatedList( + items={"g": compute.RoutersScopedList(),}, next_page_token="ghi", + ), + compute.RouterAggregatedList( + items={ + "h": compute.RoutersScopedList(), + "i": compute.RoutersScopedList(), + }, + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.RouterAggregatedList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.aggregated_list(request={}) + + assert pager._metadata == metadata + + assert isinstance(pager.get("a"), compute.RoutersScopedList) + assert pager.get("h") is None + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, tuple) for i in results) + for result in results: + assert isinstance(result, tuple) + assert tuple(type(t) for t in result) == (str, compute.RoutersScopedList) + + assert pager.get("a") is None + assert isinstance(pager.get("h"), compute.RoutersScopedList) + + pages = list(client.aggregated_list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_delete_rest(transport: str = "rest", request_type=compute.DeleteRouterRequest): client = RoutersClient( credentials=credentials.AnonymousCredentials(), transport=transport, @@ -524,6 +571,7 @@ def test_delete_rest(transport: str = "rest", request_type=compute.DeleteRouterR # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -574,6 +622,7 @@ def test_delete_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -587,7 +636,7 @@ def test_delete_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -645,6 +694,7 @@ def test_get_rest(transport: str = "rest", request_type=compute.GetRouterRequest # Wrap the value into a proper Response obj json_return_value = compute.Router.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -686,6 +736,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Router.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -699,7 +750,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -747,16 +798,15 @@ def test_get_nat_mapping_info_rest( # Wrap the value into a proper Response obj json_return_value = compute.VmEndpointNatMappingsList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.get_nat_mapping_info(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.VmEndpointNatMappingsList) + assert isinstance(response, pagers.GetNatMappingInfoPager) assert response.id == "id_value" assert response.kind == "kind_value" assert response.next_page_token == "next_page_token_value" @@ -782,6 +832,7 @@ def test_get_nat_mapping_info_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.VmEndpointNatMappingsList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -795,7 +846,7 @@ def test_get_nat_mapping_info_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -818,6 +869,60 @@ def test_get_nat_mapping_info_rest_flattened_error(): ) +def test_get_nat_mapping_info_pager(): + client = RoutersClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.VmEndpointNatMappingsList( + result=[ + compute.VmEndpointNatMappings(), + compute.VmEndpointNatMappings(), + compute.VmEndpointNatMappings(), + ], + next_page_token="abc", + ), + compute.VmEndpointNatMappingsList(result=[], next_page_token="def",), + compute.VmEndpointNatMappingsList( + result=[compute.VmEndpointNatMappings(),], next_page_token="ghi", + ), + compute.VmEndpointNatMappingsList( + result=[ + compute.VmEndpointNatMappings(), + compute.VmEndpointNatMappings(), + ], + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.VmEndpointNatMappingsList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.get_nat_mapping_info(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.VmEndpointNatMappings) for i in results) + + pages = list(client.get_nat_mapping_info(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_get_router_status_rest( transport: str = "rest", request_type=compute.GetRouterStatusRouterRequest ): @@ -843,6 +948,7 @@ def test_get_router_status_rest( # Wrap the value into a proper Response obj json_return_value = compute.RouterStatusResponse.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -872,6 +978,7 @@ def test_get_router_status_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.RouterStatusResponse.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -885,7 +992,7 @@ def test_get_router_status_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -948,6 +1055,7 @@ def test_insert_rest(transport: str = "rest", request_type=compute.InsertRouterR # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -998,6 +1106,7 @@ def test_insert_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1017,14 +1126,16 @@ def test_insert_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "region_value" in http_call[1] + str(body) assert compute.Router.to_json( - router_resource, including_default_value_fields=False + router_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1075,16 +1186,15 @@ def test_list_rest(transport: str = "rest", request_type=compute.ListRoutersRequ # Wrap the value into a proper Response obj json_return_value = compute.RouterList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.RouterList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [ compute.Router( @@ -1112,6 +1222,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.RouterList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1125,7 +1236,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1145,6 +1256,49 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = RoutersClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.RouterList( + items=[compute.Router(), compute.Router(), compute.Router(),], + next_page_token="abc", + ), + compute.RouterList(items=[], next_page_token="def",), + compute.RouterList(items=[compute.Router(),], next_page_token="ghi",), + compute.RouterList(items=[compute.Router(), compute.Router(),],), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.RouterList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.Router) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_patch_rest(transport: str = "rest", request_type=compute.PatchRouterRequest): client = RoutersClient( credentials=credentials.AnonymousCredentials(), transport=transport, @@ -1185,6 +1339,7 @@ def test_patch_rest(transport: str = "rest", request_type=compute.PatchRouterReq # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1235,6 +1390,7 @@ def test_patch_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1255,7 +1411,7 @@ def test_patch_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1264,7 +1420,9 @@ def test_patch_rest_flattened(): assert "router_value" in http_call[1] + str(body) assert compute.Router.to_json( - router_resource, including_default_value_fields=False + router_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1311,6 +1469,7 @@ def test_preview_rest( # Wrap the value into a proper Response obj json_return_value = compute.RoutersPreviewResponse.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1339,6 +1498,7 @@ def test_preview_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.RoutersPreviewResponse.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1359,7 +1519,7 @@ def test_preview_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1368,7 +1528,9 @@ def test_preview_rest_flattened(): assert "router_value" in http_call[1] + str(body) assert compute.Router.to_json( - router_resource, including_default_value_fields=False + router_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1431,6 +1593,7 @@ def test_update_rest(transport: str = "rest", request_type=compute.UpdateRouterR # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1481,6 +1644,7 @@ def test_update_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1501,7 +1665,7 @@ def test_update_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1510,7 +1674,9 @@ def test_update_rest_flattened(): assert "router_value" in http_call[1] + str(body) assert compute.Router.to_json( - router_resource, including_default_value_fields=False + router_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1666,6 +1832,17 @@ def test_routers_auth_adc(): ) +def test_routers_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.RoutersRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_routers_host_no_port(): client = RoutersClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_routes.py b/tests/unit/gapic/compute_v1/test_routes.py index a8de2e362..d079d9e0b 100644 --- a/tests/unit/gapic/compute_v1/test_routes.py +++ b/tests/unit/gapic/compute_v1/test_routes.py @@ -35,6 +35,7 @@ from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.compute_v1.services.routes import RoutesClient +from google.cloud.compute_v1.services.routes import pagers from google.cloud.compute_v1.services.routes import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -147,7 +148,7 @@ def test_routes_client_client_options(client_class, transport_class, transport_n credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -163,7 +164,7 @@ def test_routes_client_client_options(client_class, transport_class, transport_n credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -179,7 +180,7 @@ def test_routes_client_client_options(client_class, transport_class, transport_n credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -207,7 +208,7 @@ def test_routes_client_client_options(client_class, transport_class, transport_n credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -239,29 +240,25 @@ def test_routes_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -270,66 +267,53 @@ def test_routes_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -348,7 +332,7 @@ def test_routes_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -371,7 +355,7 @@ def test_routes_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -417,6 +401,7 @@ def test_delete_rest(transport: str = "rest", request_type=compute.DeleteRouteRe # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -467,6 +452,7 @@ def test_delete_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -480,7 +466,7 @@ def test_delete_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -533,6 +519,7 @@ def test_get_rest(transport: str = "rest", request_type=compute.GetRouteRequest) # Wrap the value into a proper Response obj json_return_value = compute.Route.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -578,6 +565,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Route.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -591,7 +579,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -649,6 +637,7 @@ def test_insert_rest(transport: str = "rest", request_type=compute.InsertRouteRe # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -699,6 +688,7 @@ def test_insert_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -714,12 +704,14 @@ def test_insert_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert compute.Route.to_json( - route_resource, including_default_value_fields=False + route_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -759,16 +751,15 @@ def test_list_rest(transport: str = "rest", request_type=compute.ListRoutesReque # Wrap the value into a proper Response obj json_return_value = compute.RouteList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.RouteList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [ compute.Route(creation_timestamp="creation_timestamp_value") @@ -794,6 +785,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.RouteList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -805,7 +797,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -821,6 +813,49 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = RoutesClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.RouteList( + items=[compute.Route(), compute.Route(), compute.Route(),], + next_page_token="abc", + ), + compute.RouteList(items=[], next_page_token="def",), + compute.RouteList(items=[compute.Route(),], next_page_token="ghi",), + compute.RouteList(items=[compute.Route(), compute.Route(),],), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.RouteList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.Route) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_credentials_transport_error(): # It is an error to provide credentials and a transport instance. transport = transports.RoutesRestTransport( @@ -948,6 +983,17 @@ def test_routes_auth_adc(): ) +def test_routes_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.RoutesRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_routes_host_no_port(): client = RoutesClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_security_policies.py b/tests/unit/gapic/compute_v1/test_security_policies.py index be8305a37..3d1f7b03a 100644 --- a/tests/unit/gapic/compute_v1/test_security_policies.py +++ b/tests/unit/gapic/compute_v1/test_security_policies.py @@ -35,6 +35,7 @@ from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.compute_v1.services.security_policies import SecurityPoliciesClient +from google.cloud.compute_v1.services.security_policies import pagers from google.cloud.compute_v1.services.security_policies import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -158,7 +159,7 @@ def test_security_policies_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -174,7 +175,7 @@ def test_security_policies_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -190,7 +191,7 @@ def test_security_policies_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -218,7 +219,7 @@ def test_security_policies_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -262,29 +263,25 @@ def test_security_policies_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -293,66 +290,53 @@ def test_security_policies_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -371,7 +355,7 @@ def test_security_policies_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -394,7 +378,7 @@ def test_security_policies_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -442,6 +426,7 @@ def test_add_rule_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -492,6 +477,7 @@ def test_add_rule_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -511,14 +497,16 @@ def test_add_rule_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "security_policy_value" in http_call[1] + str(body) assert compute.SecurityPolicyRule.to_json( - security_policy_rule_resource, including_default_value_fields=False + security_policy_rule_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -580,6 +568,7 @@ def test_delete_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -630,6 +619,7 @@ def test_delete_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -643,7 +633,7 @@ def test_delete_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -690,6 +680,7 @@ def test_get_rest( # Wrap the value into a proper Response obj json_return_value = compute.SecurityPolicy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -723,6 +714,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.SecurityPolicy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -736,7 +728,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -785,6 +777,7 @@ def test_get_rule_rest( # Wrap the value into a proper Response obj json_return_value = compute.SecurityPolicyRule.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -821,6 +814,7 @@ def test_get_rule_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.SecurityPolicyRule.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -834,7 +828,7 @@ def test_get_rule_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -896,6 +890,7 @@ def test_insert_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -946,6 +941,7 @@ def test_insert_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -963,12 +959,14 @@ def test_insert_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert compute.SecurityPolicy.to_json( - security_policy_resource, including_default_value_fields=False + security_policy_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1013,16 +1011,15 @@ def test_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.SecurityPolicyList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.SecurityPolicyList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [ compute.SecurityPolicy(creation_timestamp="creation_timestamp_value") @@ -1047,6 +1044,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.SecurityPolicyList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1058,7 +1056,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1074,6 +1072,57 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = SecurityPoliciesClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.SecurityPolicyList( + items=[ + compute.SecurityPolicy(), + compute.SecurityPolicy(), + compute.SecurityPolicy(), + ], + next_page_token="abc", + ), + compute.SecurityPolicyList(items=[], next_page_token="def",), + compute.SecurityPolicyList( + items=[compute.SecurityPolicy(),], next_page_token="ghi", + ), + compute.SecurityPolicyList( + items=[compute.SecurityPolicy(), compute.SecurityPolicy(),], + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.SecurityPolicyList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.SecurityPolicy) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_list_preconfigured_expression_sets_rest( transport: str = "rest", request_type=compute.ListPreconfiguredExpressionSetsSecurityPoliciesRequest, @@ -1103,6 +1152,7 @@ def test_list_preconfigured_expression_sets_rest( return_value ) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1137,6 +1187,7 @@ def test_list_preconfigured_expression_sets_rest_flattened(): return_value ) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1148,7 +1199,7 @@ def test_list_preconfigured_expression_sets_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1207,6 +1258,7 @@ def test_patch_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1257,6 +1309,7 @@ def test_patch_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1276,14 +1329,16 @@ def test_patch_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "security_policy_value" in http_call[1] + str(body) assert compute.SecurityPolicy.to_json( - security_policy_resource, including_default_value_fields=False + security_policy_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1345,6 +1400,7 @@ def test_patch_rule_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1395,6 +1451,7 @@ def test_patch_rule_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1414,14 +1471,16 @@ def test_patch_rule_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "security_policy_value" in http_call[1] + str(body) assert compute.SecurityPolicyRule.to_json( - security_policy_rule_resource, including_default_value_fields=False + security_policy_rule_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1483,6 +1542,7 @@ def test_remove_rule_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1533,6 +1593,7 @@ def test_remove_rule_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1546,7 +1607,7 @@ def test_remove_rule_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1699,6 +1760,17 @@ def test_security_policies_auth_adc(): ) +def test_security_policies_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.SecurityPoliciesRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_security_policies_host_no_port(): client = SecurityPoliciesClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_snapshots.py b/tests/unit/gapic/compute_v1/test_snapshots.py index 9e9171bce..92cee9796 100644 --- a/tests/unit/gapic/compute_v1/test_snapshots.py +++ b/tests/unit/gapic/compute_v1/test_snapshots.py @@ -35,6 +35,7 @@ from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.compute_v1.services.snapshots import SnapshotsClient +from google.cloud.compute_v1.services.snapshots import pagers from google.cloud.compute_v1.services.snapshots import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -148,7 +149,7 @@ def test_snapshots_client_client_options(client_class, transport_class, transpor credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -164,7 +165,7 @@ def test_snapshots_client_client_options(client_class, transport_class, transpor credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -180,7 +181,7 @@ def test_snapshots_client_client_options(client_class, transport_class, transpor credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -208,7 +209,7 @@ def test_snapshots_client_client_options(client_class, transport_class, transpor credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -240,29 +241,25 @@ def test_snapshots_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -271,66 +268,53 @@ def test_snapshots_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -349,7 +333,7 @@ def test_snapshots_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -372,7 +356,7 @@ def test_snapshots_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -420,6 +404,7 @@ def test_delete_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -470,6 +455,7 @@ def test_delete_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -483,7 +469,7 @@ def test_delete_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -546,6 +532,7 @@ def test_get_rest(transport: str = "rest", request_type=compute.GetSnapshotReque # Wrap the value into a proper Response obj json_return_value = compute.Snapshot.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -598,6 +585,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Snapshot.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -611,7 +599,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -664,6 +652,7 @@ def test_get_iam_policy_rest( # Wrap the value into a proper Response obj json_return_value = compute.Policy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -702,6 +691,7 @@ def test_get_iam_policy_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Policy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -715,7 +705,7 @@ def test_get_iam_policy_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -758,16 +748,15 @@ def test_list_rest(transport: str = "rest", request_type=compute.ListSnapshotsRe # Wrap the value into a proper Response obj json_return_value = compute.SnapshotList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.SnapshotList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [compute.Snapshot(auto_created=True)] assert response.kind == "kind_value" @@ -791,6 +780,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.SnapshotList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -802,7 +792,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -818,6 +808,49 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = SnapshotsClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.SnapshotList( + items=[compute.Snapshot(), compute.Snapshot(), compute.Snapshot(),], + next_page_token="abc", + ), + compute.SnapshotList(items=[], next_page_token="def",), + compute.SnapshotList(items=[compute.Snapshot(),], next_page_token="ghi",), + compute.SnapshotList(items=[compute.Snapshot(), compute.Snapshot(),],), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.SnapshotList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.Snapshot) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_set_iam_policy_rest( transport: str = "rest", request_type=compute.SetIamPolicySnapshotRequest ): @@ -851,6 +884,7 @@ def test_set_iam_policy_rest( # Wrap the value into a proper Response obj json_return_value = compute.Policy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -889,6 +923,7 @@ def test_set_iam_policy_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Policy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -908,14 +943,16 @@ def test_set_iam_policy_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "resource_value" in http_call[1] + str(body) assert compute.GlobalSetPolicyRequest.to_json( - global_set_policy_request_resource, including_default_value_fields=False + global_set_policy_request_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -977,6 +1014,7 @@ def test_set_labels_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1027,6 +1065,7 @@ def test_set_labels_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1046,14 +1085,16 @@ def test_set_labels_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "resource_value" in http_call[1] + str(body) assert compute.GlobalSetLabelsRequest.to_json( - global_set_labels_request_resource, including_default_value_fields=False + global_set_labels_request_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1093,6 +1134,7 @@ def test_test_iam_permissions_rest( # Wrap the value into a proper Response obj json_return_value = compute.TestPermissionsResponse.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1119,6 +1161,7 @@ def test_test_iam_permissions_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.TestPermissionsResponse.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1138,14 +1181,16 @@ def test_test_iam_permissions_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "resource_value" in http_call[1] + str(body) assert compute.TestPermissionsRequest.to_json( - test_permissions_request_resource, including_default_value_fields=False + test_permissions_request_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1295,6 +1340,17 @@ def test_snapshots_auth_adc(): ) +def test_snapshots_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.SnapshotsRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_snapshots_host_no_port(): client = SnapshotsClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_ssl_certificates.py b/tests/unit/gapic/compute_v1/test_ssl_certificates.py index a41999568..cf106e007 100644 --- a/tests/unit/gapic/compute_v1/test_ssl_certificates.py +++ b/tests/unit/gapic/compute_v1/test_ssl_certificates.py @@ -35,6 +35,7 @@ from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.compute_v1.services.ssl_certificates import SslCertificatesClient +from google.cloud.compute_v1.services.ssl_certificates import pagers from google.cloud.compute_v1.services.ssl_certificates import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -157,7 +158,7 @@ def test_ssl_certificates_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -173,7 +174,7 @@ def test_ssl_certificates_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -189,7 +190,7 @@ def test_ssl_certificates_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -217,7 +218,7 @@ def test_ssl_certificates_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -261,29 +262,25 @@ def test_ssl_certificates_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -292,66 +289,53 @@ def test_ssl_certificates_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -370,7 +354,7 @@ def test_ssl_certificates_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -393,7 +377,7 @@ def test_ssl_certificates_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -431,16 +415,15 @@ def test_aggregated_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.SslCertificateAggregatedList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.aggregated_list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.SslCertificateAggregatedList) + assert isinstance(response, pagers.AggregatedListPager) assert response.id == "id_value" assert response.items == { "key_value": compute.SslCertificatesScopedList( @@ -469,6 +452,7 @@ def test_aggregated_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.SslCertificateAggregatedList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -480,7 +464,7 @@ def test_aggregated_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -496,6 +480,75 @@ def test_aggregated_list_rest_flattened_error(): ) +def test_aggregated_list_pager(): + client = SslCertificatesClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.SslCertificateAggregatedList( + items={ + "a": compute.SslCertificatesScopedList(), + "b": compute.SslCertificatesScopedList(), + "c": compute.SslCertificatesScopedList(), + }, + next_page_token="abc", + ), + compute.SslCertificateAggregatedList(items={}, next_page_token="def",), + compute.SslCertificateAggregatedList( + items={"g": compute.SslCertificatesScopedList(),}, + next_page_token="ghi", + ), + compute.SslCertificateAggregatedList( + items={ + "h": compute.SslCertificatesScopedList(), + "i": compute.SslCertificatesScopedList(), + }, + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple( + compute.SslCertificateAggregatedList.to_json(x) for x in response + ) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.aggregated_list(request={}) + + assert pager._metadata == metadata + + assert isinstance(pager.get("a"), compute.SslCertificatesScopedList) + assert pager.get("h") is None + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, tuple) for i in results) + for result in results: + assert isinstance(result, tuple) + assert tuple(type(t) for t in result) == ( + str, + compute.SslCertificatesScopedList, + ) + + assert pager.get("a") is None + assert isinstance(pager.get("h"), compute.SslCertificatesScopedList) + + pages = list(client.aggregated_list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_delete_rest( transport: str = "rest", request_type=compute.DeleteSslCertificateRequest ): @@ -538,6 +591,7 @@ def test_delete_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -588,6 +642,7 @@ def test_delete_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -601,7 +656,7 @@ def test_delete_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -658,6 +713,7 @@ def test_get_rest( # Wrap the value into a proper Response obj json_return_value = compute.SslCertificate.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -701,6 +757,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.SslCertificate.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -714,7 +771,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -776,6 +833,7 @@ def test_insert_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -826,6 +884,7 @@ def test_insert_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -843,12 +902,14 @@ def test_insert_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert compute.SslCertificate.to_json( - ssl_certificate_resource, including_default_value_fields=False + ssl_certificate_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -892,16 +953,15 @@ def test_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.SslCertificateList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.SslCertificateList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [compute.SslCertificate(certificate="certificate_value")] assert response.kind == "kind_value" @@ -925,6 +985,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.SslCertificateList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -936,7 +997,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -952,6 +1013,57 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = SslCertificatesClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.SslCertificateList( + items=[ + compute.SslCertificate(), + compute.SslCertificate(), + compute.SslCertificate(), + ], + next_page_token="abc", + ), + compute.SslCertificateList(items=[], next_page_token="def",), + compute.SslCertificateList( + items=[compute.SslCertificate(),], next_page_token="ghi", + ), + compute.SslCertificateList( + items=[compute.SslCertificate(), compute.SslCertificate(),], + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.SslCertificateList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.SslCertificate) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_credentials_transport_error(): # It is an error to provide credentials and a transport instance. transport = transports.SslCertificatesRestTransport( @@ -1080,6 +1192,17 @@ def test_ssl_certificates_auth_adc(): ) +def test_ssl_certificates_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.SslCertificatesRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_ssl_certificates_host_no_port(): client = SslCertificatesClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_ssl_policies.py b/tests/unit/gapic/compute_v1/test_ssl_policies.py index c8e6e81e4..83e39611c 100644 --- a/tests/unit/gapic/compute_v1/test_ssl_policies.py +++ b/tests/unit/gapic/compute_v1/test_ssl_policies.py @@ -35,6 +35,7 @@ from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.compute_v1.services.ssl_policies import SslPoliciesClient +from google.cloud.compute_v1.services.ssl_policies import pagers from google.cloud.compute_v1.services.ssl_policies import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -152,7 +153,7 @@ def test_ssl_policies_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -168,7 +169,7 @@ def test_ssl_policies_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -184,7 +185,7 @@ def test_ssl_policies_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -212,7 +213,7 @@ def test_ssl_policies_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -244,29 +245,25 @@ def test_ssl_policies_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -275,66 +272,53 @@ def test_ssl_policies_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -353,7 +337,7 @@ def test_ssl_policies_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -376,7 +360,7 @@ def test_ssl_policies_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -424,6 +408,7 @@ def test_delete_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -474,6 +459,7 @@ def test_delete_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -487,7 +473,7 @@ def test_delete_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -536,6 +522,7 @@ def test_get_rest(transport: str = "rest", request_type=compute.GetSslPolicyRequ # Wrap the value into a proper Response obj json_return_value = compute.SslPolicy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -575,6 +562,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.SslPolicy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -588,7 +576,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -650,6 +638,7 @@ def test_insert_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -700,6 +689,7 @@ def test_insert_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -717,12 +707,14 @@ def test_insert_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert compute.SslPolicy.to_json( - ssl_policy_resource, including_default_value_fields=False + ssl_policy_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -766,16 +758,15 @@ def test_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.SslPoliciesList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.SslPoliciesList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [ compute.SslPolicy(creation_timestamp="creation_timestamp_value") @@ -801,6 +792,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.SslPoliciesList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -812,7 +804,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -828,6 +820,51 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = SslPoliciesClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.SslPoliciesList( + items=[compute.SslPolicy(), compute.SslPolicy(), compute.SslPolicy(),], + next_page_token="abc", + ), + compute.SslPoliciesList(items=[], next_page_token="def",), + compute.SslPoliciesList( + items=[compute.SslPolicy(),], next_page_token="ghi", + ), + compute.SslPoliciesList(items=[compute.SslPolicy(), compute.SslPolicy(),],), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.SslPoliciesList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.SslPolicy) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_list_available_features_rest( transport: str = "rest", request_type=compute.ListAvailableFeaturesSslPoliciesRequest, @@ -851,6 +888,7 @@ def test_list_available_features_rest( return_value ) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -879,6 +917,7 @@ def test_list_available_features_rest_flattened(): return_value ) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -890,7 +929,7 @@ def test_list_available_features_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -948,6 +987,7 @@ def test_patch_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -998,6 +1038,7 @@ def test_patch_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1017,14 +1058,16 @@ def test_patch_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "ssl_policy_value" in http_call[1] + str(body) assert compute.SslPolicy.to_json( - ssl_policy_resource, including_default_value_fields=False + ssl_policy_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1173,6 +1216,17 @@ def test_ssl_policies_auth_adc(): ) +def test_ssl_policies_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.SslPoliciesRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_ssl_policies_host_no_port(): client = SslPoliciesClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_subnetworks.py b/tests/unit/gapic/compute_v1/test_subnetworks.py index df3e2645d..128c6d849 100644 --- a/tests/unit/gapic/compute_v1/test_subnetworks.py +++ b/tests/unit/gapic/compute_v1/test_subnetworks.py @@ -35,6 +35,7 @@ from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.compute_v1.services.subnetworks import SubnetworksClient +from google.cloud.compute_v1.services.subnetworks import pagers from google.cloud.compute_v1.services.subnetworks import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -152,7 +153,7 @@ def test_subnetworks_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -168,7 +169,7 @@ def test_subnetworks_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -184,7 +185,7 @@ def test_subnetworks_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -212,7 +213,7 @@ def test_subnetworks_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -244,29 +245,25 @@ def test_subnetworks_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -275,66 +272,53 @@ def test_subnetworks_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -353,7 +337,7 @@ def test_subnetworks_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -376,7 +360,7 @@ def test_subnetworks_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -416,16 +400,15 @@ def test_aggregated_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.SubnetworkAggregatedList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.aggregated_list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.SubnetworkAggregatedList) + assert isinstance(response, pagers.AggregatedListPager) assert response.id == "id_value" assert response.items == { "key_value": compute.SubnetworksScopedList( @@ -456,6 +439,7 @@ def test_aggregated_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.SubnetworkAggregatedList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -467,7 +451,7 @@ def test_aggregated_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -483,6 +467,72 @@ def test_aggregated_list_rest_flattened_error(): ) +def test_aggregated_list_pager(): + client = SubnetworksClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.SubnetworkAggregatedList( + items={ + "a": compute.SubnetworksScopedList(), + "b": compute.SubnetworksScopedList(), + "c": compute.SubnetworksScopedList(), + }, + next_page_token="abc", + ), + compute.SubnetworkAggregatedList(items={}, next_page_token="def",), + compute.SubnetworkAggregatedList( + items={"g": compute.SubnetworksScopedList(),}, next_page_token="ghi", + ), + compute.SubnetworkAggregatedList( + items={ + "h": compute.SubnetworksScopedList(), + "i": compute.SubnetworksScopedList(), + }, + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.SubnetworkAggregatedList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.aggregated_list(request={}) + + assert pager._metadata == metadata + + assert isinstance(pager.get("a"), compute.SubnetworksScopedList) + assert pager.get("h") is None + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, tuple) for i in results) + for result in results: + assert isinstance(result, tuple) + assert tuple(type(t) for t in result) == ( + str, + compute.SubnetworksScopedList, + ) + + assert pager.get("a") is None + assert isinstance(pager.get("h"), compute.SubnetworksScopedList) + + pages = list(client.aggregated_list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_delete_rest( transport: str = "rest", request_type=compute.DeleteSubnetworkRequest ): @@ -525,6 +575,7 @@ def test_delete_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -575,6 +626,7 @@ def test_delete_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -590,7 +642,7 @@ def test_delete_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -655,6 +707,7 @@ def test_expand_ip_cidr_range_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -705,6 +758,7 @@ def test_expand_ip_cidr_range_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -725,7 +779,7 @@ def test_expand_ip_cidr_range_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -736,6 +790,7 @@ def test_expand_ip_cidr_range_rest_flattened(): assert compute.SubnetworksExpandIpCidrRangeRequest.to_json( subnetworks_expand_ip_cidr_range_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -797,6 +852,7 @@ def test_get_rest(transport: str = "rest", request_type=compute.GetSubnetworkReq # Wrap the value into a proper Response obj json_return_value = compute.Subnetwork.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -851,6 +907,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Subnetwork.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -866,7 +923,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -922,6 +979,7 @@ def test_get_iam_policy_rest( # Wrap the value into a proper Response obj json_return_value = compute.Policy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -960,6 +1018,7 @@ def test_get_iam_policy_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Policy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -973,7 +1032,7 @@ def test_get_iam_policy_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1038,6 +1097,7 @@ def test_insert_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1088,6 +1148,7 @@ def test_insert_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1107,14 +1168,16 @@ def test_insert_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "region_value" in http_call[1] + str(body) assert compute.Subnetwork.to_json( - subnetwork_resource, including_default_value_fields=False + subnetwork_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1159,16 +1222,15 @@ def test_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.SubnetworkList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.SubnetworkList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [ compute.Subnetwork(creation_timestamp="creation_timestamp_value") @@ -1194,6 +1256,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.SubnetworkList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1207,7 +1270,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1227,6 +1290,57 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = SubnetworksClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.SubnetworkList( + items=[ + compute.Subnetwork(), + compute.Subnetwork(), + compute.Subnetwork(), + ], + next_page_token="abc", + ), + compute.SubnetworkList(items=[], next_page_token="def",), + compute.SubnetworkList( + items=[compute.Subnetwork(),], next_page_token="ghi", + ), + compute.SubnetworkList( + items=[compute.Subnetwork(), compute.Subnetwork(),], + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.SubnetworkList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.Subnetwork) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_list_usable_rest( transport: str = "rest", request_type=compute.ListUsableSubnetworksRequest ): @@ -1254,16 +1368,15 @@ def test_list_usable_rest( return_value ) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list_usable(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.UsableSubnetworksAggregatedList) + assert isinstance(response, pagers.ListUsablePager) assert response.id == "id_value" assert response.items == [ compute.UsableSubnetwork(ip_cidr_range="ip_cidr_range_value") @@ -1291,6 +1404,7 @@ def test_list_usable_rest_flattened(): return_value ) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1302,7 +1416,7 @@ def test_list_usable_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1318,6 +1432,59 @@ def test_list_usable_rest_flattened_error(): ) +def test_list_usable_pager(): + client = SubnetworksClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.UsableSubnetworksAggregatedList( + items=[ + compute.UsableSubnetwork(), + compute.UsableSubnetwork(), + compute.UsableSubnetwork(), + ], + next_page_token="abc", + ), + compute.UsableSubnetworksAggregatedList(items=[], next_page_token="def",), + compute.UsableSubnetworksAggregatedList( + items=[compute.UsableSubnetwork(),], next_page_token="ghi", + ), + compute.UsableSubnetworksAggregatedList( + items=[compute.UsableSubnetwork(), compute.UsableSubnetwork(),], + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple( + compute.UsableSubnetworksAggregatedList.to_json(x) for x in response + ) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list_usable(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.UsableSubnetwork) for i in results) + + pages = list(client.list_usable(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_patch_rest( transport: str = "rest", request_type=compute.PatchSubnetworkRequest ): @@ -1360,6 +1527,7 @@ def test_patch_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1410,6 +1578,7 @@ def test_patch_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1430,7 +1599,7 @@ def test_patch_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1439,7 +1608,9 @@ def test_patch_rest_flattened(): assert "subnetwork_value" in http_call[1] + str(body) assert compute.Subnetwork.to_json( - subnetwork_resource, including_default_value_fields=False + subnetwork_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1493,6 +1664,7 @@ def test_set_iam_policy_rest( # Wrap the value into a proper Response obj json_return_value = compute.Policy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1531,6 +1703,7 @@ def test_set_iam_policy_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Policy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1551,7 +1724,7 @@ def test_set_iam_policy_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1560,7 +1733,9 @@ def test_set_iam_policy_rest_flattened(): assert "resource_value" in http_call[1] + str(body) assert compute.RegionSetPolicyRequest.to_json( - region_set_policy_request_resource, including_default_value_fields=False + region_set_policy_request_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1624,6 +1799,7 @@ def test_set_private_ip_google_access_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1674,6 +1850,7 @@ def test_set_private_ip_google_access_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1694,7 +1871,7 @@ def test_set_private_ip_google_access_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1705,6 +1882,7 @@ def test_set_private_ip_google_access_rest_flattened(): assert compute.SubnetworksSetPrivateIpGoogleAccessRequest.to_json( subnetworks_set_private_ip_google_access_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1745,6 +1923,7 @@ def test_test_iam_permissions_rest( # Wrap the value into a proper Response obj json_return_value = compute.TestPermissionsResponse.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1771,6 +1950,7 @@ def test_test_iam_permissions_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.TestPermissionsResponse.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1791,7 +1971,7 @@ def test_test_iam_permissions_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1800,7 +1980,9 @@ def test_test_iam_permissions_rest_flattened(): assert "resource_value" in http_call[1] + str(body) assert compute.TestPermissionsRequest.to_json( - test_permissions_request_resource, including_default_value_fields=False + test_permissions_request_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1956,6 +2138,17 @@ def test_subnetworks_auth_adc(): ) +def test_subnetworks_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.SubnetworksRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_subnetworks_host_no_port(): client = SubnetworksClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_target_grpc_proxies.py b/tests/unit/gapic/compute_v1/test_target_grpc_proxies.py index f2ca46eef..4b5bac2df 100644 --- a/tests/unit/gapic/compute_v1/test_target_grpc_proxies.py +++ b/tests/unit/gapic/compute_v1/test_target_grpc_proxies.py @@ -35,6 +35,7 @@ from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.compute_v1.services.target_grpc_proxies import TargetGrpcProxiesClient +from google.cloud.compute_v1.services.target_grpc_proxies import pagers from google.cloud.compute_v1.services.target_grpc_proxies import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -158,7 +159,7 @@ def test_target_grpc_proxies_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -174,7 +175,7 @@ def test_target_grpc_proxies_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -190,7 +191,7 @@ def test_target_grpc_proxies_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -218,7 +219,7 @@ def test_target_grpc_proxies_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -262,29 +263,25 @@ def test_target_grpc_proxies_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -293,66 +290,53 @@ def test_target_grpc_proxies_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -371,7 +355,7 @@ def test_target_grpc_proxies_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -394,7 +378,7 @@ def test_target_grpc_proxies_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -442,6 +426,7 @@ def test_delete_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -492,6 +477,7 @@ def test_delete_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -505,7 +491,7 @@ def test_delete_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -554,6 +540,7 @@ def test_get_rest( # Wrap the value into a proper Response obj json_return_value = compute.TargetGrpcProxy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -590,6 +577,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.TargetGrpcProxy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -603,7 +591,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -665,6 +653,7 @@ def test_insert_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -715,6 +704,7 @@ def test_insert_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -733,12 +723,14 @@ def test_insert_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert compute.TargetGrpcProxy.to_json( - target_grpc_proxy_resource, including_default_value_fields=False + target_grpc_proxy_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -784,16 +776,15 @@ def test_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.TargetGrpcProxyList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.TargetGrpcProxyList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [ compute.TargetGrpcProxy(creation_timestamp="creation_timestamp_value") @@ -819,6 +810,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.TargetGrpcProxyList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -830,7 +822,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -846,6 +838,57 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = TargetGrpcProxiesClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.TargetGrpcProxyList( + items=[ + compute.TargetGrpcProxy(), + compute.TargetGrpcProxy(), + compute.TargetGrpcProxy(), + ], + next_page_token="abc", + ), + compute.TargetGrpcProxyList(items=[], next_page_token="def",), + compute.TargetGrpcProxyList( + items=[compute.TargetGrpcProxy(),], next_page_token="ghi", + ), + compute.TargetGrpcProxyList( + items=[compute.TargetGrpcProxy(), compute.TargetGrpcProxy(),], + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.TargetGrpcProxyList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.TargetGrpcProxy) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_patch_rest( transport: str = "rest", request_type=compute.PatchTargetGrpcProxyRequest ): @@ -888,6 +931,7 @@ def test_patch_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -938,6 +982,7 @@ def test_patch_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -957,14 +1002,16 @@ def test_patch_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "target_grpc_proxy_value" in http_call[1] + str(body) assert compute.TargetGrpcProxy.to_json( - target_grpc_proxy_resource, including_default_value_fields=False + target_grpc_proxy_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1114,6 +1161,17 @@ def test_target_grpc_proxies_auth_adc(): ) +def test_target_grpc_proxies_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.TargetGrpcProxiesRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_target_grpc_proxies_host_no_port(): client = TargetGrpcProxiesClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_target_http_proxies.py b/tests/unit/gapic/compute_v1/test_target_http_proxies.py index ee6c93563..683d2de43 100644 --- a/tests/unit/gapic/compute_v1/test_target_http_proxies.py +++ b/tests/unit/gapic/compute_v1/test_target_http_proxies.py @@ -35,6 +35,7 @@ from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.compute_v1.services.target_http_proxies import TargetHttpProxiesClient +from google.cloud.compute_v1.services.target_http_proxies import pagers from google.cloud.compute_v1.services.target_http_proxies import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -158,7 +159,7 @@ def test_target_http_proxies_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -174,7 +175,7 @@ def test_target_http_proxies_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -190,7 +191,7 @@ def test_target_http_proxies_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -218,7 +219,7 @@ def test_target_http_proxies_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -262,29 +263,25 @@ def test_target_http_proxies_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -293,66 +290,53 @@ def test_target_http_proxies_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -371,7 +355,7 @@ def test_target_http_proxies_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -394,7 +378,7 @@ def test_target_http_proxies_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -433,16 +417,15 @@ def test_aggregated_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.TargetHttpProxyAggregatedList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.aggregated_list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.TargetHttpProxyAggregatedList) + assert isinstance(response, pagers.AggregatedListPager) assert response.id == "id_value" assert response.items == { "key_value": compute.TargetHttpProxiesScopedList( @@ -472,6 +455,7 @@ def test_aggregated_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.TargetHttpProxyAggregatedList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -483,7 +467,7 @@ def test_aggregated_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -499,6 +483,75 @@ def test_aggregated_list_rest_flattened_error(): ) +def test_aggregated_list_pager(): + client = TargetHttpProxiesClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.TargetHttpProxyAggregatedList( + items={ + "a": compute.TargetHttpProxiesScopedList(), + "b": compute.TargetHttpProxiesScopedList(), + "c": compute.TargetHttpProxiesScopedList(), + }, + next_page_token="abc", + ), + compute.TargetHttpProxyAggregatedList(items={}, next_page_token="def",), + compute.TargetHttpProxyAggregatedList( + items={"g": compute.TargetHttpProxiesScopedList(),}, + next_page_token="ghi", + ), + compute.TargetHttpProxyAggregatedList( + items={ + "h": compute.TargetHttpProxiesScopedList(), + "i": compute.TargetHttpProxiesScopedList(), + }, + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple( + compute.TargetHttpProxyAggregatedList.to_json(x) for x in response + ) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.aggregated_list(request={}) + + assert pager._metadata == metadata + + assert isinstance(pager.get("a"), compute.TargetHttpProxiesScopedList) + assert pager.get("h") is None + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, tuple) for i in results) + for result in results: + assert isinstance(result, tuple) + assert tuple(type(t) for t in result) == ( + str, + compute.TargetHttpProxiesScopedList, + ) + + assert pager.get("a") is None + assert isinstance(pager.get("h"), compute.TargetHttpProxiesScopedList) + + pages = list(client.aggregated_list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_delete_rest( transport: str = "rest", request_type=compute.DeleteTargetHttpProxyRequest ): @@ -541,6 +594,7 @@ def test_delete_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -591,6 +645,7 @@ def test_delete_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -604,7 +659,7 @@ def test_delete_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -653,6 +708,7 @@ def test_get_rest( # Wrap the value into a proper Response obj json_return_value = compute.TargetHttpProxy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -689,6 +745,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.TargetHttpProxy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -702,7 +759,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -764,6 +821,7 @@ def test_insert_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -814,6 +872,7 @@ def test_insert_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -832,12 +891,14 @@ def test_insert_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert compute.TargetHttpProxy.to_json( - target_http_proxy_resource, including_default_value_fields=False + target_http_proxy_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -883,16 +944,15 @@ def test_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.TargetHttpProxyList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.TargetHttpProxyList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [ compute.TargetHttpProxy(creation_timestamp="creation_timestamp_value") @@ -918,6 +978,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.TargetHttpProxyList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -929,7 +990,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -945,6 +1006,57 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = TargetHttpProxiesClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.TargetHttpProxyList( + items=[ + compute.TargetHttpProxy(), + compute.TargetHttpProxy(), + compute.TargetHttpProxy(), + ], + next_page_token="abc", + ), + compute.TargetHttpProxyList(items=[], next_page_token="def",), + compute.TargetHttpProxyList( + items=[compute.TargetHttpProxy(),], next_page_token="ghi", + ), + compute.TargetHttpProxyList( + items=[compute.TargetHttpProxy(), compute.TargetHttpProxy(),], + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.TargetHttpProxyList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.TargetHttpProxy) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_patch_rest( transport: str = "rest", request_type=compute.PatchTargetHttpProxyRequest ): @@ -987,6 +1099,7 @@ def test_patch_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1037,6 +1150,7 @@ def test_patch_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1056,14 +1170,16 @@ def test_patch_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "target_http_proxy_value" in http_call[1] + str(body) assert compute.TargetHttpProxy.to_json( - target_http_proxy_resource, including_default_value_fields=False + target_http_proxy_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1125,6 +1241,7 @@ def test_set_url_map_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1175,6 +1292,7 @@ def test_set_url_map_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1192,14 +1310,16 @@ def test_set_url_map_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "target_http_proxy_value" in http_call[1] + str(body) assert compute.UrlMapReference.to_json( - url_map_reference_resource, including_default_value_fields=False + url_map_reference_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1349,6 +1469,17 @@ def test_target_http_proxies_auth_adc(): ) +def test_target_http_proxies_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.TargetHttpProxiesRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_target_http_proxies_host_no_port(): client = TargetHttpProxiesClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_target_https_proxies.py b/tests/unit/gapic/compute_v1/test_target_https_proxies.py index d7cae867a..743a4c848 100644 --- a/tests/unit/gapic/compute_v1/test_target_https_proxies.py +++ b/tests/unit/gapic/compute_v1/test_target_https_proxies.py @@ -37,6 +37,7 @@ from google.cloud.compute_v1.services.target_https_proxies import ( TargetHttpsProxiesClient, ) +from google.cloud.compute_v1.services.target_https_proxies import pagers from google.cloud.compute_v1.services.target_https_proxies import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -160,7 +161,7 @@ def test_target_https_proxies_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -176,7 +177,7 @@ def test_target_https_proxies_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -192,7 +193,7 @@ def test_target_https_proxies_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -220,7 +221,7 @@ def test_target_https_proxies_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -264,29 +265,25 @@ def test_target_https_proxies_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -295,66 +292,53 @@ def test_target_https_proxies_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -373,7 +357,7 @@ def test_target_https_proxies_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -396,7 +380,7 @@ def test_target_https_proxies_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -437,16 +421,15 @@ def test_aggregated_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.TargetHttpsProxyAggregatedList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.aggregated_list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.TargetHttpsProxyAggregatedList) + assert isinstance(response, pagers.AggregatedListPager) assert response.id == "id_value" assert response.items == { "key_value": compute.TargetHttpsProxiesScopedList( @@ -479,6 +462,7 @@ def test_aggregated_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.TargetHttpsProxyAggregatedList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -490,7 +474,7 @@ def test_aggregated_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -506,6 +490,75 @@ def test_aggregated_list_rest_flattened_error(): ) +def test_aggregated_list_pager(): + client = TargetHttpsProxiesClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.TargetHttpsProxyAggregatedList( + items={ + "a": compute.TargetHttpsProxiesScopedList(), + "b": compute.TargetHttpsProxiesScopedList(), + "c": compute.TargetHttpsProxiesScopedList(), + }, + next_page_token="abc", + ), + compute.TargetHttpsProxyAggregatedList(items={}, next_page_token="def",), + compute.TargetHttpsProxyAggregatedList( + items={"g": compute.TargetHttpsProxiesScopedList(),}, + next_page_token="ghi", + ), + compute.TargetHttpsProxyAggregatedList( + items={ + "h": compute.TargetHttpsProxiesScopedList(), + "i": compute.TargetHttpsProxiesScopedList(), + }, + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple( + compute.TargetHttpsProxyAggregatedList.to_json(x) for x in response + ) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.aggregated_list(request={}) + + assert pager._metadata == metadata + + assert isinstance(pager.get("a"), compute.TargetHttpsProxiesScopedList) + assert pager.get("h") is None + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, tuple) for i in results) + for result in results: + assert isinstance(result, tuple) + assert tuple(type(t) for t in result) == ( + str, + compute.TargetHttpsProxiesScopedList, + ) + + assert pager.get("a") is None + assert isinstance(pager.get("h"), compute.TargetHttpsProxiesScopedList) + + pages = list(client.aggregated_list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_delete_rest( transport: str = "rest", request_type=compute.DeleteTargetHttpsProxyRequest ): @@ -548,6 +601,7 @@ def test_delete_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -598,6 +652,7 @@ def test_delete_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -611,7 +666,7 @@ def test_delete_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -664,6 +719,7 @@ def test_get_rest( # Wrap the value into a proper Response obj json_return_value = compute.TargetHttpsProxy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -704,6 +760,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.TargetHttpsProxy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -717,7 +774,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -779,6 +836,7 @@ def test_insert_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -829,6 +887,7 @@ def test_insert_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -847,12 +906,14 @@ def test_insert_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert compute.TargetHttpsProxy.to_json( - target_https_proxy_resource, including_default_value_fields=False + target_https_proxy_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -900,16 +961,15 @@ def test_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.TargetHttpsProxyList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.TargetHttpsProxyList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [ compute.TargetHttpsProxy(authorization_policy="authorization_policy_value") @@ -935,6 +995,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.TargetHttpsProxyList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -946,7 +1007,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -962,6 +1023,57 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = TargetHttpsProxiesClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.TargetHttpsProxyList( + items=[ + compute.TargetHttpsProxy(), + compute.TargetHttpsProxy(), + compute.TargetHttpsProxy(), + ], + next_page_token="abc", + ), + compute.TargetHttpsProxyList(items=[], next_page_token="def",), + compute.TargetHttpsProxyList( + items=[compute.TargetHttpsProxy(),], next_page_token="ghi", + ), + compute.TargetHttpsProxyList( + items=[compute.TargetHttpsProxy(), compute.TargetHttpsProxy(),], + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.TargetHttpsProxyList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.TargetHttpsProxy) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_set_quic_override_rest( transport: str = "rest", request_type=compute.SetQuicOverrideTargetHttpsProxyRequest ): @@ -1004,6 +1116,7 @@ def test_set_quic_override_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1054,6 +1167,7 @@ def test_set_quic_override_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1073,7 +1187,7 @@ def test_set_quic_override_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1082,6 +1196,7 @@ def test_set_quic_override_rest_flattened(): assert compute.TargetHttpsProxiesSetQuicOverrideRequest.to_json( target_https_proxies_set_quic_override_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1144,6 +1259,7 @@ def test_set_ssl_certificates_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1194,6 +1310,7 @@ def test_set_ssl_certificates_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1213,7 +1330,7 @@ def test_set_ssl_certificates_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1222,6 +1339,7 @@ def test_set_ssl_certificates_rest_flattened(): assert compute.TargetHttpsProxiesSetSslCertificatesRequest.to_json( target_https_proxies_set_ssl_certificates_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1283,6 +1401,7 @@ def test_set_ssl_policy_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1333,6 +1452,7 @@ def test_set_ssl_policy_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1352,14 +1472,16 @@ def test_set_ssl_policy_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "target_https_proxy_value" in http_call[1] + str(body) assert compute.SslPolicyReference.to_json( - ssl_policy_reference_resource, including_default_value_fields=False + ssl_policy_reference_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1421,6 +1543,7 @@ def test_set_url_map_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1471,6 +1594,7 @@ def test_set_url_map_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1488,14 +1612,16 @@ def test_set_url_map_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "target_https_proxy_value" in http_call[1] + str(body) assert compute.UrlMapReference.to_json( - url_map_reference_resource, including_default_value_fields=False + url_map_reference_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1647,6 +1773,17 @@ def test_target_https_proxies_auth_adc(): ) +def test_target_https_proxies_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.TargetHttpsProxiesRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_target_https_proxies_host_no_port(): client = TargetHttpsProxiesClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_target_instances.py b/tests/unit/gapic/compute_v1/test_target_instances.py index 4eb33b909..90696b241 100644 --- a/tests/unit/gapic/compute_v1/test_target_instances.py +++ b/tests/unit/gapic/compute_v1/test_target_instances.py @@ -35,6 +35,7 @@ from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.compute_v1.services.target_instances import TargetInstancesClient +from google.cloud.compute_v1.services.target_instances import pagers from google.cloud.compute_v1.services.target_instances import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -157,7 +158,7 @@ def test_target_instances_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -173,7 +174,7 @@ def test_target_instances_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -189,7 +190,7 @@ def test_target_instances_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -217,7 +218,7 @@ def test_target_instances_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -261,29 +262,25 @@ def test_target_instances_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -292,66 +289,53 @@ def test_target_instances_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -370,7 +354,7 @@ def test_target_instances_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -393,7 +377,7 @@ def test_target_instances_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -433,16 +417,15 @@ def test_aggregated_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.TargetInstanceAggregatedList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.aggregated_list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.TargetInstanceAggregatedList) + assert isinstance(response, pagers.AggregatedListPager) assert response.id == "id_value" assert response.items == { "key_value": compute.TargetInstancesScopedList( @@ -473,6 +456,7 @@ def test_aggregated_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.TargetInstanceAggregatedList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -484,7 +468,7 @@ def test_aggregated_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -500,6 +484,75 @@ def test_aggregated_list_rest_flattened_error(): ) +def test_aggregated_list_pager(): + client = TargetInstancesClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.TargetInstanceAggregatedList( + items={ + "a": compute.TargetInstancesScopedList(), + "b": compute.TargetInstancesScopedList(), + "c": compute.TargetInstancesScopedList(), + }, + next_page_token="abc", + ), + compute.TargetInstanceAggregatedList(items={}, next_page_token="def",), + compute.TargetInstanceAggregatedList( + items={"g": compute.TargetInstancesScopedList(),}, + next_page_token="ghi", + ), + compute.TargetInstanceAggregatedList( + items={ + "h": compute.TargetInstancesScopedList(), + "i": compute.TargetInstancesScopedList(), + }, + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple( + compute.TargetInstanceAggregatedList.to_json(x) for x in response + ) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.aggregated_list(request={}) + + assert pager._metadata == metadata + + assert isinstance(pager.get("a"), compute.TargetInstancesScopedList) + assert pager.get("h") is None + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, tuple) for i in results) + for result in results: + assert isinstance(result, tuple) + assert tuple(type(t) for t in result) == ( + str, + compute.TargetInstancesScopedList, + ) + + assert pager.get("a") is None + assert isinstance(pager.get("h"), compute.TargetInstancesScopedList) + + pages = list(client.aggregated_list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_delete_rest( transport: str = "rest", request_type=compute.DeleteTargetInstanceRequest ): @@ -542,6 +595,7 @@ def test_delete_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -592,6 +646,7 @@ def test_delete_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -607,7 +662,7 @@ def test_delete_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -658,6 +713,7 @@ def test_get_rest( # Wrap the value into a proper Response obj json_return_value = compute.TargetInstance.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -692,6 +748,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.TargetInstance.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -707,7 +764,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -772,6 +829,7 @@ def test_insert_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -822,6 +880,7 @@ def test_insert_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -841,14 +900,16 @@ def test_insert_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "zone_value" in http_call[1] + str(body) assert compute.TargetInstance.to_json( - target_instance_resource, including_default_value_fields=False + target_instance_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -895,16 +956,15 @@ def test_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.TargetInstanceList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.TargetInstanceList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [ compute.TargetInstance(creation_timestamp="creation_timestamp_value") @@ -930,6 +990,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.TargetInstanceList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -943,7 +1004,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -963,6 +1024,57 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = TargetInstancesClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.TargetInstanceList( + items=[ + compute.TargetInstance(), + compute.TargetInstance(), + compute.TargetInstance(), + ], + next_page_token="abc", + ), + compute.TargetInstanceList(items=[], next_page_token="def",), + compute.TargetInstanceList( + items=[compute.TargetInstance(),], next_page_token="ghi", + ), + compute.TargetInstanceList( + items=[compute.TargetInstance(), compute.TargetInstance(),], + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.TargetInstanceList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.TargetInstance) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_credentials_transport_error(): # It is an error to provide credentials and a transport instance. transport = transports.TargetInstancesRestTransport( @@ -1091,6 +1203,17 @@ def test_target_instances_auth_adc(): ) +def test_target_instances_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.TargetInstancesRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_target_instances_host_no_port(): client = TargetInstancesClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_target_pools.py b/tests/unit/gapic/compute_v1/test_target_pools.py index e2b9d3d83..d7cfacad7 100644 --- a/tests/unit/gapic/compute_v1/test_target_pools.py +++ b/tests/unit/gapic/compute_v1/test_target_pools.py @@ -35,6 +35,7 @@ from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.compute_v1.services.target_pools import TargetPoolsClient +from google.cloud.compute_v1.services.target_pools import pagers from google.cloud.compute_v1.services.target_pools import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -152,7 +153,7 @@ def test_target_pools_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -168,7 +169,7 @@ def test_target_pools_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -184,7 +185,7 @@ def test_target_pools_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -212,7 +213,7 @@ def test_target_pools_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -244,29 +245,25 @@ def test_target_pools_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -275,66 +272,53 @@ def test_target_pools_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -353,7 +337,7 @@ def test_target_pools_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -376,7 +360,7 @@ def test_target_pools_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -424,6 +408,7 @@ def test_add_health_check_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -474,6 +459,7 @@ def test_add_health_check_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -496,7 +482,7 @@ def test_add_health_check_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -507,6 +493,7 @@ def test_add_health_check_rest_flattened(): assert compute.TargetPoolsAddHealthCheckRequest.to_json( target_pools_add_health_check_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -571,6 +558,7 @@ def test_add_instance_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -621,6 +609,7 @@ def test_add_instance_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -641,7 +630,7 @@ def test_add_instance_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -652,6 +641,7 @@ def test_add_instance_rest_flattened(): assert compute.TargetPoolsAddInstanceRequest.to_json( target_pools_add_instance_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -702,16 +692,15 @@ def test_aggregated_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.TargetPoolAggregatedList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.aggregated_list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.TargetPoolAggregatedList) + assert isinstance(response, pagers.AggregatedListPager) assert response.id == "id_value" assert response.items == { "key_value": compute.TargetPoolsScopedList( @@ -740,6 +729,7 @@ def test_aggregated_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.TargetPoolAggregatedList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -751,7 +741,7 @@ def test_aggregated_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -767,6 +757,72 @@ def test_aggregated_list_rest_flattened_error(): ) +def test_aggregated_list_pager(): + client = TargetPoolsClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.TargetPoolAggregatedList( + items={ + "a": compute.TargetPoolsScopedList(), + "b": compute.TargetPoolsScopedList(), + "c": compute.TargetPoolsScopedList(), + }, + next_page_token="abc", + ), + compute.TargetPoolAggregatedList(items={}, next_page_token="def",), + compute.TargetPoolAggregatedList( + items={"g": compute.TargetPoolsScopedList(),}, next_page_token="ghi", + ), + compute.TargetPoolAggregatedList( + items={ + "h": compute.TargetPoolsScopedList(), + "i": compute.TargetPoolsScopedList(), + }, + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.TargetPoolAggregatedList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.aggregated_list(request={}) + + assert pager._metadata == metadata + + assert isinstance(pager.get("a"), compute.TargetPoolsScopedList) + assert pager.get("h") is None + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, tuple) for i in results) + for result in results: + assert isinstance(result, tuple) + assert tuple(type(t) for t in result) == ( + str, + compute.TargetPoolsScopedList, + ) + + assert pager.get("a") is None + assert isinstance(pager.get("h"), compute.TargetPoolsScopedList) + + pages = list(client.aggregated_list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_delete_rest( transport: str = "rest", request_type=compute.DeleteTargetPoolRequest ): @@ -809,6 +865,7 @@ def test_delete_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -859,6 +916,7 @@ def test_delete_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -874,7 +932,7 @@ def test_delete_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -926,6 +984,7 @@ def test_get_rest(transport: str = "rest", request_type=compute.GetTargetPoolReq # Wrap the value into a proper Response obj json_return_value = compute.TargetPool.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -963,6 +1022,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.TargetPool.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -978,7 +1038,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1024,6 +1084,7 @@ def test_get_health_rest( # Wrap the value into a proper Response obj json_return_value = compute.TargetPoolInstanceHealth.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1053,6 +1114,7 @@ def test_get_health_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.TargetPoolInstanceHealth.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1073,7 +1135,7 @@ def test_get_health_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1082,7 +1144,9 @@ def test_get_health_rest_flattened(): assert "target_pool_value" in http_call[1] + str(body) assert compute.InstanceReference.to_json( - instance_reference_resource, including_default_value_fields=False + instance_reference_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1145,6 +1209,7 @@ def test_insert_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1195,6 +1260,7 @@ def test_insert_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1212,14 +1278,16 @@ def test_insert_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "region_value" in http_call[1] + str(body) assert compute.TargetPool.to_json( - target_pool_resource, including_default_value_fields=False + target_pool_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1262,16 +1330,15 @@ def test_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.TargetPoolList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.TargetPoolList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [compute.TargetPool(backup_pool="backup_pool_value")] assert response.kind == "kind_value" @@ -1295,6 +1362,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.TargetPoolList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1308,7 +1376,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1328,6 +1396,57 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = TargetPoolsClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.TargetPoolList( + items=[ + compute.TargetPool(), + compute.TargetPool(), + compute.TargetPool(), + ], + next_page_token="abc", + ), + compute.TargetPoolList(items=[], next_page_token="def",), + compute.TargetPoolList( + items=[compute.TargetPool(),], next_page_token="ghi", + ), + compute.TargetPoolList( + items=[compute.TargetPool(), compute.TargetPool(),], + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.TargetPoolList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.TargetPool) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_remove_health_check_rest( transport: str = "rest", request_type=compute.RemoveHealthCheckTargetPoolRequest ): @@ -1370,6 +1489,7 @@ def test_remove_health_check_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1420,6 +1540,7 @@ def test_remove_health_check_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1442,7 +1563,7 @@ def test_remove_health_check_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1453,6 +1574,7 @@ def test_remove_health_check_rest_flattened(): assert compute.TargetPoolsRemoveHealthCheckRequest.to_json( target_pools_remove_health_check_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1517,6 +1639,7 @@ def test_remove_instance_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1567,6 +1690,7 @@ def test_remove_instance_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1587,7 +1711,7 @@ def test_remove_instance_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1598,6 +1722,7 @@ def test_remove_instance_rest_flattened(): assert compute.TargetPoolsRemoveInstanceRequest.to_json( target_pools_remove_instance_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1660,6 +1785,7 @@ def test_set_backup_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1710,6 +1836,7 @@ def test_set_backup_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1728,7 +1855,7 @@ def test_set_backup_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1737,7 +1864,9 @@ def test_set_backup_rest_flattened(): assert "target_pool_value" in http_call[1] + str(body) assert compute.TargetReference.to_json( - target_reference_resource, including_default_value_fields=False + target_reference_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1890,6 +2019,17 @@ def test_target_pools_auth_adc(): ) +def test_target_pools_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.TargetPoolsRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_target_pools_host_no_port(): client = TargetPoolsClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_target_ssl_proxies.py b/tests/unit/gapic/compute_v1/test_target_ssl_proxies.py index 37930fa47..716a64045 100644 --- a/tests/unit/gapic/compute_v1/test_target_ssl_proxies.py +++ b/tests/unit/gapic/compute_v1/test_target_ssl_proxies.py @@ -35,6 +35,7 @@ from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.compute_v1.services.target_ssl_proxies import TargetSslProxiesClient +from google.cloud.compute_v1.services.target_ssl_proxies import pagers from google.cloud.compute_v1.services.target_ssl_proxies import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -158,7 +159,7 @@ def test_target_ssl_proxies_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -174,7 +175,7 @@ def test_target_ssl_proxies_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -190,7 +191,7 @@ def test_target_ssl_proxies_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -218,7 +219,7 @@ def test_target_ssl_proxies_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -262,29 +263,25 @@ def test_target_ssl_proxies_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -293,66 +290,53 @@ def test_target_ssl_proxies_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -371,7 +355,7 @@ def test_target_ssl_proxies_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -394,7 +378,7 @@ def test_target_ssl_proxies_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -442,6 +426,7 @@ def test_delete_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -492,6 +477,7 @@ def test_delete_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -505,7 +491,7 @@ def test_delete_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -554,6 +540,7 @@ def test_get_rest( # Wrap the value into a proper Response obj json_return_value = compute.TargetSslProxy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -589,6 +576,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.TargetSslProxy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -602,7 +590,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -664,6 +652,7 @@ def test_insert_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -714,6 +703,7 @@ def test_insert_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -732,12 +722,14 @@ def test_insert_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert compute.TargetSslProxy.to_json( - target_ssl_proxy_resource, including_default_value_fields=False + target_ssl_proxy_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -783,16 +775,15 @@ def test_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.TargetSslProxyList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.TargetSslProxyList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [ compute.TargetSslProxy(creation_timestamp="creation_timestamp_value") @@ -818,6 +809,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.TargetSslProxyList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -829,7 +821,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -845,6 +837,57 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = TargetSslProxiesClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.TargetSslProxyList( + items=[ + compute.TargetSslProxy(), + compute.TargetSslProxy(), + compute.TargetSslProxy(), + ], + next_page_token="abc", + ), + compute.TargetSslProxyList(items=[], next_page_token="def",), + compute.TargetSslProxyList( + items=[compute.TargetSslProxy(),], next_page_token="ghi", + ), + compute.TargetSslProxyList( + items=[compute.TargetSslProxy(), compute.TargetSslProxy(),], + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.TargetSslProxyList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.TargetSslProxy) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_set_backend_service_rest( transport: str = "rest", request_type=compute.SetBackendServiceTargetSslProxyRequest ): @@ -887,6 +930,7 @@ def test_set_backend_service_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -937,6 +981,7 @@ def test_set_backend_service_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -956,7 +1001,7 @@ def test_set_backend_service_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -965,6 +1010,7 @@ def test_set_backend_service_rest_flattened(): assert compute.TargetSslProxiesSetBackendServiceRequest.to_json( target_ssl_proxies_set_backend_service_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1026,6 +1072,7 @@ def test_set_proxy_header_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1076,6 +1123,7 @@ def test_set_proxy_header_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1095,7 +1143,7 @@ def test_set_proxy_header_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1104,6 +1152,7 @@ def test_set_proxy_header_rest_flattened(): assert compute.TargetSslProxiesSetProxyHeaderRequest.to_json( target_ssl_proxies_set_proxy_header_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1166,6 +1215,7 @@ def test_set_ssl_certificates_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1216,6 +1266,7 @@ def test_set_ssl_certificates_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1235,7 +1286,7 @@ def test_set_ssl_certificates_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1244,6 +1295,7 @@ def test_set_ssl_certificates_rest_flattened(): assert compute.TargetSslProxiesSetSslCertificatesRequest.to_json( target_ssl_proxies_set_ssl_certificates_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1305,6 +1357,7 @@ def test_set_ssl_policy_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1355,6 +1408,7 @@ def test_set_ssl_policy_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1374,14 +1428,16 @@ def test_set_ssl_policy_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "target_ssl_proxy_value" in http_call[1] + str(body) assert compute.SslPolicyReference.to_json( - ssl_policy_reference_resource, including_default_value_fields=False + ssl_policy_reference_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1532,6 +1588,17 @@ def test_target_ssl_proxies_auth_adc(): ) +def test_target_ssl_proxies_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.TargetSslProxiesRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_target_ssl_proxies_host_no_port(): client = TargetSslProxiesClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_target_tcp_proxies.py b/tests/unit/gapic/compute_v1/test_target_tcp_proxies.py index 6e6c31952..e06fba947 100644 --- a/tests/unit/gapic/compute_v1/test_target_tcp_proxies.py +++ b/tests/unit/gapic/compute_v1/test_target_tcp_proxies.py @@ -35,6 +35,7 @@ from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.compute_v1.services.target_tcp_proxies import TargetTcpProxiesClient +from google.cloud.compute_v1.services.target_tcp_proxies import pagers from google.cloud.compute_v1.services.target_tcp_proxies import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -158,7 +159,7 @@ def test_target_tcp_proxies_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -174,7 +175,7 @@ def test_target_tcp_proxies_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -190,7 +191,7 @@ def test_target_tcp_proxies_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -218,7 +219,7 @@ def test_target_tcp_proxies_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -262,29 +263,25 @@ def test_target_tcp_proxies_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -293,66 +290,53 @@ def test_target_tcp_proxies_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -371,7 +355,7 @@ def test_target_tcp_proxies_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -394,7 +378,7 @@ def test_target_tcp_proxies_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -442,6 +426,7 @@ def test_delete_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -492,6 +477,7 @@ def test_delete_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -505,7 +491,7 @@ def test_delete_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -552,6 +538,7 @@ def test_get_rest( # Wrap the value into a proper Response obj json_return_value = compute.TargetTcpProxy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -585,6 +572,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.TargetTcpProxy.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -598,7 +586,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -660,6 +648,7 @@ def test_insert_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -710,6 +699,7 @@ def test_insert_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -728,12 +718,14 @@ def test_insert_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert compute.TargetTcpProxy.to_json( - target_tcp_proxy_resource, including_default_value_fields=False + target_tcp_proxy_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -779,16 +771,15 @@ def test_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.TargetTcpProxyList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.TargetTcpProxyList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [ compute.TargetTcpProxy(creation_timestamp="creation_timestamp_value") @@ -814,6 +805,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.TargetTcpProxyList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -825,7 +817,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -841,6 +833,57 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = TargetTcpProxiesClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.TargetTcpProxyList( + items=[ + compute.TargetTcpProxy(), + compute.TargetTcpProxy(), + compute.TargetTcpProxy(), + ], + next_page_token="abc", + ), + compute.TargetTcpProxyList(items=[], next_page_token="def",), + compute.TargetTcpProxyList( + items=[compute.TargetTcpProxy(),], next_page_token="ghi", + ), + compute.TargetTcpProxyList( + items=[compute.TargetTcpProxy(), compute.TargetTcpProxy(),], + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.TargetTcpProxyList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.TargetTcpProxy) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_set_backend_service_rest( transport: str = "rest", request_type=compute.SetBackendServiceTargetTcpProxyRequest ): @@ -883,6 +926,7 @@ def test_set_backend_service_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -933,6 +977,7 @@ def test_set_backend_service_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -952,7 +997,7 @@ def test_set_backend_service_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -961,6 +1006,7 @@ def test_set_backend_service_rest_flattened(): assert compute.TargetTcpProxiesSetBackendServiceRequest.to_json( target_tcp_proxies_set_backend_service_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1022,6 +1068,7 @@ def test_set_proxy_header_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1072,6 +1119,7 @@ def test_set_proxy_header_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1091,7 +1139,7 @@ def test_set_proxy_header_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1100,6 +1148,7 @@ def test_set_proxy_header_rest_flattened(): assert compute.TargetTcpProxiesSetProxyHeaderRequest.to_json( target_tcp_proxies_set_proxy_header_request_resource, including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1248,6 +1297,17 @@ def test_target_tcp_proxies_auth_adc(): ) +def test_target_tcp_proxies_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.TargetTcpProxiesRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_target_tcp_proxies_host_no_port(): client = TargetTcpProxiesClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_target_vpn_gateways.py b/tests/unit/gapic/compute_v1/test_target_vpn_gateways.py index 72f6b979d..8e696dc65 100644 --- a/tests/unit/gapic/compute_v1/test_target_vpn_gateways.py +++ b/tests/unit/gapic/compute_v1/test_target_vpn_gateways.py @@ -35,6 +35,7 @@ from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.compute_v1.services.target_vpn_gateways import TargetVpnGatewaysClient +from google.cloud.compute_v1.services.target_vpn_gateways import pagers from google.cloud.compute_v1.services.target_vpn_gateways import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -158,7 +159,7 @@ def test_target_vpn_gateways_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -174,7 +175,7 @@ def test_target_vpn_gateways_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -190,7 +191,7 @@ def test_target_vpn_gateways_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -218,7 +219,7 @@ def test_target_vpn_gateways_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -262,29 +263,25 @@ def test_target_vpn_gateways_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -293,66 +290,53 @@ def test_target_vpn_gateways_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -371,7 +355,7 @@ def test_target_vpn_gateways_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -394,7 +378,7 @@ def test_target_vpn_gateways_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -434,16 +418,15 @@ def test_aggregated_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.TargetVpnGatewayAggregatedList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.aggregated_list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.TargetVpnGatewayAggregatedList) + assert isinstance(response, pagers.AggregatedListPager) assert response.id == "id_value" assert response.items == { "key_value": compute.TargetVpnGatewaysScopedList( @@ -474,6 +457,7 @@ def test_aggregated_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.TargetVpnGatewayAggregatedList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -485,7 +469,7 @@ def test_aggregated_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -501,6 +485,75 @@ def test_aggregated_list_rest_flattened_error(): ) +def test_aggregated_list_pager(): + client = TargetVpnGatewaysClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.TargetVpnGatewayAggregatedList( + items={ + "a": compute.TargetVpnGatewaysScopedList(), + "b": compute.TargetVpnGatewaysScopedList(), + "c": compute.TargetVpnGatewaysScopedList(), + }, + next_page_token="abc", + ), + compute.TargetVpnGatewayAggregatedList(items={}, next_page_token="def",), + compute.TargetVpnGatewayAggregatedList( + items={"g": compute.TargetVpnGatewaysScopedList(),}, + next_page_token="ghi", + ), + compute.TargetVpnGatewayAggregatedList( + items={ + "h": compute.TargetVpnGatewaysScopedList(), + "i": compute.TargetVpnGatewaysScopedList(), + }, + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple( + compute.TargetVpnGatewayAggregatedList.to_json(x) for x in response + ) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.aggregated_list(request={}) + + assert pager._metadata == metadata + + assert isinstance(pager.get("a"), compute.TargetVpnGatewaysScopedList) + assert pager.get("h") is None + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, tuple) for i in results) + for result in results: + assert isinstance(result, tuple) + assert tuple(type(t) for t in result) == ( + str, + compute.TargetVpnGatewaysScopedList, + ) + + assert pager.get("a") is None + assert isinstance(pager.get("h"), compute.TargetVpnGatewaysScopedList) + + pages = list(client.aggregated_list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_delete_rest( transport: str = "rest", request_type=compute.DeleteTargetVpnGatewayRequest ): @@ -543,6 +596,7 @@ def test_delete_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -593,6 +647,7 @@ def test_delete_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -608,7 +663,7 @@ def test_delete_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -661,6 +716,7 @@ def test_get_rest( # Wrap the value into a proper Response obj json_return_value = compute.TargetVpnGateway.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -697,6 +753,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.TargetVpnGateway.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -712,7 +769,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -777,6 +834,7 @@ def test_insert_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -827,6 +885,7 @@ def test_insert_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -846,14 +905,16 @@ def test_insert_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "region_value" in http_call[1] + str(body) assert compute.TargetVpnGateway.to_json( - target_vpn_gateway_resource, including_default_value_fields=False + target_vpn_gateway_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -900,16 +961,15 @@ def test_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.TargetVpnGatewayList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.TargetVpnGatewayList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [ compute.TargetVpnGateway(creation_timestamp="creation_timestamp_value") @@ -935,6 +995,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.TargetVpnGatewayList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -948,7 +1009,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -968,6 +1029,57 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = TargetVpnGatewaysClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.TargetVpnGatewayList( + items=[ + compute.TargetVpnGateway(), + compute.TargetVpnGateway(), + compute.TargetVpnGateway(), + ], + next_page_token="abc", + ), + compute.TargetVpnGatewayList(items=[], next_page_token="def",), + compute.TargetVpnGatewayList( + items=[compute.TargetVpnGateway(),], next_page_token="ghi", + ), + compute.TargetVpnGatewayList( + items=[compute.TargetVpnGateway(), compute.TargetVpnGateway(),], + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.TargetVpnGatewayList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.TargetVpnGateway) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_credentials_transport_error(): # It is an error to provide credentials and a transport instance. transport = transports.TargetVpnGatewaysRestTransport( @@ -1098,6 +1210,17 @@ def test_target_vpn_gateways_auth_adc(): ) +def test_target_vpn_gateways_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.TargetVpnGatewaysRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_target_vpn_gateways_host_no_port(): client = TargetVpnGatewaysClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_url_maps.py b/tests/unit/gapic/compute_v1/test_url_maps.py index ecb7d2475..6a1d79757 100644 --- a/tests/unit/gapic/compute_v1/test_url_maps.py +++ b/tests/unit/gapic/compute_v1/test_url_maps.py @@ -35,6 +35,7 @@ from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.compute_v1.services.url_maps import UrlMapsClient +from google.cloud.compute_v1.services.url_maps import pagers from google.cloud.compute_v1.services.url_maps import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -147,7 +148,7 @@ def test_url_maps_client_client_options(client_class, transport_class, transport credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -163,7 +164,7 @@ def test_url_maps_client_client_options(client_class, transport_class, transport credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -179,7 +180,7 @@ def test_url_maps_client_client_options(client_class, transport_class, transport credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -207,7 +208,7 @@ def test_url_maps_client_client_options(client_class, transport_class, transport credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -239,29 +240,25 @@ def test_url_maps_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -270,66 +267,53 @@ def test_url_maps_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -348,7 +332,7 @@ def test_url_maps_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -371,7 +355,7 @@ def test_url_maps_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -409,16 +393,15 @@ def test_aggregated_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.UrlMapsAggregatedList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.aggregated_list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.UrlMapsAggregatedList) + assert isinstance(response, pagers.AggregatedListPager) assert response.id == "id_value" assert response.items == { "key_value": compute.UrlMapsScopedList( @@ -447,6 +430,7 @@ def test_aggregated_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.UrlMapsAggregatedList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -458,7 +442,7 @@ def test_aggregated_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -474,6 +458,69 @@ def test_aggregated_list_rest_flattened_error(): ) +def test_aggregated_list_pager(): + client = UrlMapsClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.UrlMapsAggregatedList( + items={ + "a": compute.UrlMapsScopedList(), + "b": compute.UrlMapsScopedList(), + "c": compute.UrlMapsScopedList(), + }, + next_page_token="abc", + ), + compute.UrlMapsAggregatedList(items={}, next_page_token="def",), + compute.UrlMapsAggregatedList( + items={"g": compute.UrlMapsScopedList(),}, next_page_token="ghi", + ), + compute.UrlMapsAggregatedList( + items={ + "h": compute.UrlMapsScopedList(), + "i": compute.UrlMapsScopedList(), + }, + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.UrlMapsAggregatedList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.aggregated_list(request={}) + + assert pager._metadata == metadata + + assert isinstance(pager.get("a"), compute.UrlMapsScopedList) + assert pager.get("h") is None + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, tuple) for i in results) + for result in results: + assert isinstance(result, tuple) + assert tuple(type(t) for t in result) == (str, compute.UrlMapsScopedList) + + assert pager.get("a") is None + assert isinstance(pager.get("h"), compute.UrlMapsScopedList) + + pages = list(client.aggregated_list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_delete_rest(transport: str = "rest", request_type=compute.DeleteUrlMapRequest): client = UrlMapsClient( credentials=credentials.AnonymousCredentials(), transport=transport, @@ -514,6 +561,7 @@ def test_delete_rest(transport: str = "rest", request_type=compute.DeleteUrlMapR # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -564,6 +612,7 @@ def test_delete_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -577,7 +626,7 @@ def test_delete_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -643,6 +692,7 @@ def test_get_rest(transport: str = "rest", request_type=compute.GetUrlMapRequest # Wrap the value into a proper Response obj json_return_value = compute.UrlMap.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -697,6 +747,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.UrlMap.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -710,7 +761,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -770,6 +821,7 @@ def test_insert_rest(transport: str = "rest", request_type=compute.InsertUrlMapR # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -820,6 +872,7 @@ def test_insert_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -835,12 +888,14 @@ def test_insert_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert compute.UrlMap.to_json( - url_map_resource, including_default_value_fields=False + url_map_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -901,6 +956,7 @@ def test_invalidate_cache_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -951,6 +1007,7 @@ def test_invalidate_cache_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -970,14 +1027,16 @@ def test_invalidate_cache_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "url_map_value" in http_call[1] + str(body) assert compute.CacheInvalidationRule.to_json( - cache_invalidation_rule_resource, including_default_value_fields=False + cache_invalidation_rule_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1020,16 +1079,15 @@ def test_list_rest(transport: str = "rest", request_type=compute.ListUrlMapsRequ # Wrap the value into a proper Response obj json_return_value = compute.UrlMapList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.UrlMapList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [ compute.UrlMap(creation_timestamp="creation_timestamp_value") @@ -1055,6 +1113,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.UrlMapList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1066,7 +1125,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1082,6 +1141,49 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = UrlMapsClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.UrlMapList( + items=[compute.UrlMap(), compute.UrlMap(), compute.UrlMap(),], + next_page_token="abc", + ), + compute.UrlMapList(items=[], next_page_token="def",), + compute.UrlMapList(items=[compute.UrlMap(),], next_page_token="ghi",), + compute.UrlMapList(items=[compute.UrlMap(), compute.UrlMap(),],), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.UrlMapList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.UrlMap) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_patch_rest(transport: str = "rest", request_type=compute.PatchUrlMapRequest): client = UrlMapsClient( credentials=credentials.AnonymousCredentials(), transport=transport, @@ -1122,6 +1224,7 @@ def test_patch_rest(transport: str = "rest", request_type=compute.PatchUrlMapReq # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1172,6 +1275,7 @@ def test_patch_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1189,14 +1293,16 @@ def test_patch_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "url_map_value" in http_call[1] + str(body) assert compute.UrlMap.to_json( - url_map_resource, including_default_value_fields=False + url_map_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1256,6 +1362,7 @@ def test_update_rest(transport: str = "rest", request_type=compute.UpdateUrlMapR # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1306,6 +1413,7 @@ def test_update_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1323,14 +1431,16 @@ def test_update_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "url_map_value" in http_call[1] + str(body) assert compute.UrlMap.to_json( - url_map_resource, including_default_value_fields=False + url_map_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1370,6 +1480,7 @@ def test_validate_rest( # Wrap the value into a proper Response obj json_return_value = compute.UrlMapsValidateResponse.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1398,6 +1509,7 @@ def test_validate_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.UrlMapsValidateResponse.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1417,14 +1529,16 @@ def test_validate_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "url_map_value" in http_call[1] + str(body) assert compute.UrlMapsValidateRequest.to_json( - url_maps_validate_request_resource, including_default_value_fields=False + url_maps_validate_request_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1576,6 +1690,17 @@ def test_url_maps_auth_adc(): ) +def test_url_maps_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.UrlMapsRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_url_maps_host_no_port(): client = UrlMapsClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_vpn_gateways.py b/tests/unit/gapic/compute_v1/test_vpn_gateways.py index 5f7a2c840..cbd4ba4bf 100644 --- a/tests/unit/gapic/compute_v1/test_vpn_gateways.py +++ b/tests/unit/gapic/compute_v1/test_vpn_gateways.py @@ -35,6 +35,7 @@ from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.compute_v1.services.vpn_gateways import VpnGatewaysClient +from google.cloud.compute_v1.services.vpn_gateways import pagers from google.cloud.compute_v1.services.vpn_gateways import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -152,7 +153,7 @@ def test_vpn_gateways_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -168,7 +169,7 @@ def test_vpn_gateways_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -184,7 +185,7 @@ def test_vpn_gateways_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -212,7 +213,7 @@ def test_vpn_gateways_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -244,29 +245,25 @@ def test_vpn_gateways_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -275,66 +272,53 @@ def test_vpn_gateways_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -353,7 +337,7 @@ def test_vpn_gateways_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -376,7 +360,7 @@ def test_vpn_gateways_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -416,16 +400,15 @@ def test_aggregated_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.VpnGatewayAggregatedList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.aggregated_list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.VpnGatewayAggregatedList) + assert isinstance(response, pagers.AggregatedListPager) assert response.id == "id_value" assert response.items == { "key_value": compute.VpnGatewaysScopedList( @@ -456,6 +439,7 @@ def test_aggregated_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.VpnGatewayAggregatedList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -467,7 +451,7 @@ def test_aggregated_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -483,6 +467,72 @@ def test_aggregated_list_rest_flattened_error(): ) +def test_aggregated_list_pager(): + client = VpnGatewaysClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.VpnGatewayAggregatedList( + items={ + "a": compute.VpnGatewaysScopedList(), + "b": compute.VpnGatewaysScopedList(), + "c": compute.VpnGatewaysScopedList(), + }, + next_page_token="abc", + ), + compute.VpnGatewayAggregatedList(items={}, next_page_token="def",), + compute.VpnGatewayAggregatedList( + items={"g": compute.VpnGatewaysScopedList(),}, next_page_token="ghi", + ), + compute.VpnGatewayAggregatedList( + items={ + "h": compute.VpnGatewaysScopedList(), + "i": compute.VpnGatewaysScopedList(), + }, + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.VpnGatewayAggregatedList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.aggregated_list(request={}) + + assert pager._metadata == metadata + + assert isinstance(pager.get("a"), compute.VpnGatewaysScopedList) + assert pager.get("h") is None + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, tuple) for i in results) + for result in results: + assert isinstance(result, tuple) + assert tuple(type(t) for t in result) == ( + str, + compute.VpnGatewaysScopedList, + ) + + assert pager.get("a") is None + assert isinstance(pager.get("h"), compute.VpnGatewaysScopedList) + + pages = list(client.aggregated_list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_delete_rest( transport: str = "rest", request_type=compute.DeleteVpnGatewayRequest ): @@ -525,6 +575,7 @@ def test_delete_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -575,6 +626,7 @@ def test_delete_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -590,7 +642,7 @@ def test_delete_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -641,6 +693,7 @@ def test_get_rest(transport: str = "rest", request_type=compute.GetVpnGatewayReq # Wrap the value into a proper Response obj json_return_value = compute.VpnGateway.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -677,6 +730,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.VpnGateway.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -692,7 +746,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -741,6 +795,7 @@ def test_get_status_rest( # Wrap the value into a proper Response obj json_return_value = compute.VpnGatewaysGetStatusResponse.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -773,6 +828,7 @@ def test_get_status_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.VpnGatewaysGetStatusResponse.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -788,7 +844,7 @@ def test_get_status_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -853,6 +909,7 @@ def test_insert_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -903,6 +960,7 @@ def test_insert_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -922,14 +980,16 @@ def test_insert_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "region_value" in http_call[1] + str(body) assert compute.VpnGateway.to_json( - vpn_gateway_resource, including_default_value_fields=False + vpn_gateway_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -974,16 +1034,15 @@ def test_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.VpnGatewayList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.VpnGatewayList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [ compute.VpnGateway(creation_timestamp="creation_timestamp_value") @@ -1009,6 +1068,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.VpnGatewayList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1022,7 +1082,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1042,6 +1102,57 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = VpnGatewaysClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.VpnGatewayList( + items=[ + compute.VpnGateway(), + compute.VpnGateway(), + compute.VpnGateway(), + ], + next_page_token="abc", + ), + compute.VpnGatewayList(items=[], next_page_token="def",), + compute.VpnGatewayList( + items=[compute.VpnGateway(),], next_page_token="ghi", + ), + compute.VpnGatewayList( + items=[compute.VpnGateway(), compute.VpnGateway(),], + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.VpnGatewayList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.VpnGateway) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_set_labels_rest( transport: str = "rest", request_type=compute.SetLabelsVpnGatewayRequest ): @@ -1084,6 +1195,7 @@ def test_set_labels_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1134,6 +1246,7 @@ def test_set_labels_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1154,7 +1267,7 @@ def test_set_labels_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1163,7 +1276,9 @@ def test_set_labels_rest_flattened(): assert "resource_value" in http_call[1] + str(body) assert compute.RegionSetLabelsRequest.to_json( - region_set_labels_request_resource, including_default_value_fields=False + region_set_labels_request_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1204,6 +1319,7 @@ def test_test_iam_permissions_rest( # Wrap the value into a proper Response obj json_return_value = compute.TestPermissionsResponse.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1230,6 +1346,7 @@ def test_test_iam_permissions_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.TestPermissionsResponse.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -1250,7 +1367,7 @@ def test_test_iam_permissions_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -1259,7 +1376,9 @@ def test_test_iam_permissions_rest_flattened(): assert "resource_value" in http_call[1] + str(body) assert compute.TestPermissionsRequest.to_json( - test_permissions_request_resource, including_default_value_fields=False + test_permissions_request_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -1411,6 +1530,17 @@ def test_vpn_gateways_auth_adc(): ) +def test_vpn_gateways_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.VpnGatewaysRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_vpn_gateways_host_no_port(): client = VpnGatewaysClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_vpn_tunnels.py b/tests/unit/gapic/compute_v1/test_vpn_tunnels.py index b049bc189..b3d1b1068 100644 --- a/tests/unit/gapic/compute_v1/test_vpn_tunnels.py +++ b/tests/unit/gapic/compute_v1/test_vpn_tunnels.py @@ -35,6 +35,7 @@ from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.compute_v1.services.vpn_tunnels import VpnTunnelsClient +from google.cloud.compute_v1.services.vpn_tunnels import pagers from google.cloud.compute_v1.services.vpn_tunnels import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -152,7 +153,7 @@ def test_vpn_tunnels_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -168,7 +169,7 @@ def test_vpn_tunnels_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -184,7 +185,7 @@ def test_vpn_tunnels_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -212,7 +213,7 @@ def test_vpn_tunnels_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -244,29 +245,25 @@ def test_vpn_tunnels_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -275,66 +272,53 @@ def test_vpn_tunnels_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -353,7 +337,7 @@ def test_vpn_tunnels_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -376,7 +360,7 @@ def test_vpn_tunnels_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -414,16 +398,15 @@ def test_aggregated_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.VpnTunnelAggregatedList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.aggregated_list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.VpnTunnelAggregatedList) + assert isinstance(response, pagers.AggregatedListPager) assert response.id == "id_value" assert response.items == { "key_value": compute.VpnTunnelsScopedList( @@ -454,6 +437,7 @@ def test_aggregated_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.VpnTunnelAggregatedList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -465,7 +449,7 @@ def test_aggregated_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -481,6 +465,69 @@ def test_aggregated_list_rest_flattened_error(): ) +def test_aggregated_list_pager(): + client = VpnTunnelsClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.VpnTunnelAggregatedList( + items={ + "a": compute.VpnTunnelsScopedList(), + "b": compute.VpnTunnelsScopedList(), + "c": compute.VpnTunnelsScopedList(), + }, + next_page_token="abc", + ), + compute.VpnTunnelAggregatedList(items={}, next_page_token="def",), + compute.VpnTunnelAggregatedList( + items={"g": compute.VpnTunnelsScopedList(),}, next_page_token="ghi", + ), + compute.VpnTunnelAggregatedList( + items={ + "h": compute.VpnTunnelsScopedList(), + "i": compute.VpnTunnelsScopedList(), + }, + ), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.VpnTunnelAggregatedList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.aggregated_list(request={}) + + assert pager._metadata == metadata + + assert isinstance(pager.get("a"), compute.VpnTunnelsScopedList) + assert pager.get("h") is None + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, tuple) for i in results) + for result in results: + assert isinstance(result, tuple) + assert tuple(type(t) for t in result) == (str, compute.VpnTunnelsScopedList) + + assert pager.get("a") is None + assert isinstance(pager.get("h"), compute.VpnTunnelsScopedList) + + pages = list(client.aggregated_list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_delete_rest( transport: str = "rest", request_type=compute.DeleteVpnTunnelRequest ): @@ -523,6 +570,7 @@ def test_delete_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -573,6 +621,7 @@ def test_delete_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -588,7 +637,7 @@ def test_delete_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -650,6 +699,7 @@ def test_get_rest(transport: str = "rest", request_type=compute.GetVpnTunnelRequ # Wrap the value into a proper Response obj json_return_value = compute.VpnTunnel.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -697,6 +747,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.VpnTunnel.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -712,7 +763,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -777,6 +828,7 @@ def test_insert_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -827,6 +879,7 @@ def test_insert_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -846,14 +899,16 @@ def test_insert_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) assert "region_value" in http_call[1] + str(body) assert compute.VpnTunnel.to_json( - vpn_tunnel_resource, including_default_value_fields=False + vpn_tunnel_resource, + including_default_value_fields=False, + use_integers_for_enums=False, ) in http_call[1] + str(body) @@ -896,16 +951,15 @@ def test_list_rest(transport: str = "rest", request_type=compute.ListVpnTunnelsR # Wrap the value into a proper Response obj json_return_value = compute.VpnTunnelList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.VpnTunnelList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [ compute.VpnTunnel(creation_timestamp="creation_timestamp_value") @@ -931,6 +985,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.VpnTunnelList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -944,7 +999,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -964,6 +1019,49 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = VpnTunnelsClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.VpnTunnelList( + items=[compute.VpnTunnel(), compute.VpnTunnel(), compute.VpnTunnel(),], + next_page_token="abc", + ), + compute.VpnTunnelList(items=[], next_page_token="def",), + compute.VpnTunnelList(items=[compute.VpnTunnel(),], next_page_token="ghi",), + compute.VpnTunnelList(items=[compute.VpnTunnel(), compute.VpnTunnel(),],), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.VpnTunnelList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.VpnTunnel) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_credentials_transport_error(): # It is an error to provide credentials and a transport instance. transport = transports.VpnTunnelsRestTransport( @@ -1092,6 +1190,17 @@ def test_vpn_tunnels_auth_adc(): ) +def test_vpn_tunnels_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.VpnTunnelsRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_vpn_tunnels_host_no_port(): client = VpnTunnelsClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_zone_operations.py b/tests/unit/gapic/compute_v1/test_zone_operations.py index b60441c59..5fccf058a 100644 --- a/tests/unit/gapic/compute_v1/test_zone_operations.py +++ b/tests/unit/gapic/compute_v1/test_zone_operations.py @@ -35,6 +35,7 @@ from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.compute_v1.services.zone_operations import ZoneOperationsClient +from google.cloud.compute_v1.services.zone_operations import pagers from google.cloud.compute_v1.services.zone_operations import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -157,7 +158,7 @@ def test_zone_operations_client_client_options( credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -173,7 +174,7 @@ def test_zone_operations_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -189,7 +190,7 @@ def test_zone_operations_client_client_options( credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -217,7 +218,7 @@ def test_zone_operations_client_client_options( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -251,29 +252,25 @@ def test_zone_operations_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -282,66 +279,53 @@ def test_zone_operations_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -360,7 +344,7 @@ def test_zone_operations_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -383,7 +367,7 @@ def test_zone_operations_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -407,6 +391,7 @@ def test_delete_rest( # Wrap the value into a proper Response obj json_return_value = compute.DeleteZoneOperationResponse.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -432,6 +417,7 @@ def test_delete_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.DeleteZoneOperationResponse.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -445,7 +431,7 @@ def test_delete_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -510,6 +496,7 @@ def test_get_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -560,6 +547,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -573,7 +561,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -621,16 +609,15 @@ def test_list_rest( # Wrap the value into a proper Response obj json_return_value = compute.OperationList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.OperationList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [ compute.Operation(client_operation_id="client_operation_id_value") @@ -656,6 +643,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.OperationList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -669,7 +657,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -689,6 +677,49 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = ZoneOperationsClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.OperationList( + items=[compute.Operation(), compute.Operation(), compute.Operation(),], + next_page_token="abc", + ), + compute.OperationList(items=[], next_page_token="def",), + compute.OperationList(items=[compute.Operation(),], next_page_token="ghi",), + compute.OperationList(items=[compute.Operation(), compute.Operation(),],), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.OperationList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.Operation) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_wait_rest( transport: str = "rest", request_type=compute.WaitZoneOperationRequest ): @@ -731,6 +762,7 @@ def test_wait_rest( # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -781,6 +813,7 @@ def test_wait_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Operation.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -794,7 +827,7 @@ def test_wait_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -944,6 +977,17 @@ def test_zone_operations_auth_adc(): ) +def test_zone_operations_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.ZoneOperationsRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_zone_operations_host_no_port(): client = ZoneOperationsClient( credentials=credentials.AnonymousCredentials(), diff --git a/tests/unit/gapic/compute_v1/test_zones.py b/tests/unit/gapic/compute_v1/test_zones.py index 35fc08fba..25d3ec804 100644 --- a/tests/unit/gapic/compute_v1/test_zones.py +++ b/tests/unit/gapic/compute_v1/test_zones.py @@ -35,6 +35,7 @@ from google.auth import credentials from google.auth.exceptions import MutualTLSChannelError from google.cloud.compute_v1.services.zones import ZonesClient +from google.cloud.compute_v1.services.zones import pagers from google.cloud.compute_v1.services.zones import transports from google.cloud.compute_v1.types import compute from google.oauth2 import service_account @@ -147,7 +148,7 @@ def test_zones_client_client_options(client_class, transport_class, transport_na credentials_file=None, host="squid.clam.whelk", scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -163,7 +164,7 @@ def test_zones_client_client_options(client_class, transport_class, transport_na credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -179,7 +180,7 @@ def test_zones_client_client_options(client_class, transport_class, transport_na credentials_file=None, host=client.DEFAULT_MTLS_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -207,7 +208,7 @@ def test_zones_client_client_options(client_class, transport_class, transport_na credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id="octopus", client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -239,29 +240,25 @@ def test_zones_client_mtls_env_auto( client_cert_source=client_cert_source_callback ) with mock.patch.object(transport_class, "__init__") as patched: - ssl_channel_creds = mock.Mock() - with mock.patch( - "grpc.ssl_channel_credentials", return_value=ssl_channel_creds - ): - patched.return_value = None - client = client_class(client_options=options) + patched.return_value = None + client = client_class(client_options=options) - if use_client_cert_env == "false": - expected_ssl_channel_creds = None - expected_host = client.DEFAULT_ENDPOINT - else: - expected_ssl_channel_creds = ssl_channel_creds - expected_host = client.DEFAULT_MTLS_ENDPOINT + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client.DEFAULT_ENDPOINT + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) # Check the case ADC client cert is provided. Whether client cert is used depends on # GOOGLE_API_USE_CLIENT_CERTIFICATE value. @@ -270,66 +267,53 @@ def test_zones_client_mtls_env_auto( ): with mock.patch.object(transport_class, "__init__") as patched: with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, ): with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.ssl_credentials", - new_callable=mock.PropertyMock, - ) as ssl_credentials_mock: - if use_client_cert_env == "false": - is_mtls_mock.return_value = False - ssl_credentials_mock.return_value = None - expected_host = client.DEFAULT_ENDPOINT - expected_ssl_channel_creds = None - else: - is_mtls_mock.return_value = True - ssl_credentials_mock.return_value = mock.Mock() - expected_host = client.DEFAULT_MTLS_ENDPOINT - expected_ssl_channel_creds = ( - ssl_credentials_mock.return_value - ) - - patched.return_value = None - client = client_class() - patched.assert_called_once_with( - credentials=None, - credentials_file=None, - host=expected_host, - scopes=None, - ssl_channel_credentials=expected_ssl_channel_creds, - quota_project_id=None, - client_info=transports.base.DEFAULT_CLIENT_INFO, - ) + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client.DEFAULT_ENDPOINT + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback - # Check the case client_cert_source and ADC client cert are not provided. - with mock.patch.dict( - os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} - ): - with mock.patch.object(transport_class, "__init__") as patched: - with mock.patch( - "google.auth.transport.grpc.SslCredentials.__init__", return_value=None - ): - with mock.patch( - "google.auth.transport.grpc.SslCredentials.is_mtls", - new_callable=mock.PropertyMock, - ) as is_mtls_mock: - is_mtls_mock.return_value = False patched.return_value = None client = client_class() patched.assert_called_once_with( credentials=None, credentials_file=None, - host=client.DEFAULT_ENDPOINT, + host=expected_host, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=expected_client_cert_source, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class() + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + ) + @pytest.mark.parametrize( "client_class,transport_class,transport_name", @@ -348,7 +332,7 @@ def test_zones_client_client_options_scopes( credentials_file=None, host=client.DEFAULT_ENDPOINT, scopes=["1", "2"], - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -371,7 +355,7 @@ def test_zones_client_client_options_credentials_file( credentials_file="credentials.json", host=client.DEFAULT_ENDPOINT, scopes=None, - ssl_channel_credentials=None, + client_cert_source_for_mtls=None, quota_project_id=None, client_info=transports.base.DEFAULT_CLIENT_INFO, ) @@ -404,6 +388,7 @@ def test_get_rest(transport: str = "rest", request_type=compute.GetZoneRequest): # Wrap the value into a proper Response obj json_return_value = compute.Zone.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -439,6 +424,7 @@ def test_get_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.Zone.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -452,7 +438,7 @@ def test_get_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -495,16 +481,15 @@ def test_list_rest(transport: str = "rest", request_type=compute.ListZonesReques # Wrap the value into a proper Response obj json_return_value = compute.ZoneList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value response = client.list(request) - assert response.raw_page is response - # Establish that the response is the type that we expect. - assert isinstance(response, compute.ZoneList) + assert isinstance(response, pagers.ListPager) assert response.id == "id_value" assert response.items == [ compute.Zone(available_cpu_platforms=["available_cpu_platforms_value"]) @@ -530,6 +515,7 @@ def test_list_rest_flattened(): # Wrap the value into a proper Response obj json_return_value = compute.ZoneList.to_json(return_value) response_value = Response() + response_value.status_code = 200 response_value._content = json_return_value.encode("UTF-8") req.return_value = response_value @@ -541,7 +527,7 @@ def test_list_rest_flattened(): # request object values. assert len(req.mock_calls) == 1 _, http_call, http_params = req.mock_calls[0] - body = http_params.get("json") + body = http_params.get("data") assert "project_value" in http_call[1] + str(body) @@ -557,6 +543,49 @@ def test_list_rest_flattened_error(): ) +def test_list_pager(): + client = ZonesClient(credentials=credentials.AnonymousCredentials(),) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Set the response as a series of pages + + response = ( + compute.ZoneList( + items=[compute.Zone(), compute.Zone(), compute.Zone(),], + next_page_token="abc", + ), + compute.ZoneList(items=[], next_page_token="def",), + compute.ZoneList(items=[compute.Zone(),], next_page_token="ghi",), + compute.ZoneList(items=[compute.Zone(), compute.Zone(),],), + ) + + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(compute.ZoneList.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + metadata = () + pager = client.list(request={}) + + assert pager._metadata == metadata + + results = list(pager) + assert len(results) == 6 + + assert all(isinstance(i, compute.Zone) for i in results) + + pages = list(client.list(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + def test_credentials_transport_error(): # It is an error to provide credentials and a transport instance. transport = transports.ZonesRestTransport( @@ -684,6 +713,17 @@ def test_zones_auth_adc(): ) +def test_zones_http_transport_client_cert_source_for_mtls(): + cred = credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.ZonesRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + def test_zones_host_no_port(): client = ZonesClient( credentials=credentials.AnonymousCredentials(),