Skip to content
/ ape Public
forked from ApeWorX/ape

Commit 12a281a

Browse files
authored
fix: unnecessarily was using explorer when providing proxy info manually (ApeWorX#2524)
1 parent a32e621 commit 12a281a

11 files changed

+190
-79
lines changed

.pre-commit-config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ repos:
1616
name: black
1717

1818
- repo: https://github.com/pycqa/flake8
19-
rev: 7.1.1
19+
rev: 7.1.2
2020
hooks:
2121
- id: flake8
2222
additional_dependencies: [flake8-breakpoint, flake8-print, flake8-pydantic, flake8-type-checking]

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ force_grid_wrap = 0
4949
include_trailing_comma = true
5050
multi_line_output = 3
5151
use_parentheses = true
52+
skip = ["version.py"]
5253

5354
[tool.mdformat]
5455
number = true

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
"types-toml", # Needed due to mypy typeshed
3232
"types-SQLAlchemy>=1.4.49", # Needed due to mypy typeshed
3333
"types-python-dateutil", # Needed due to mypy typeshed
34-
"flake8>=7.1.1,<8", # Style linter
34+
"flake8>=7.1.2,<8", # Style linter
3535
"flake8-breakpoint>=1.1.0,<2", # Detect breakpoints left in code
3636
"flake8-print>=4.0.1,<5", # Detect print statements left in code
3737
"flake8-pydantic", # For detecting issues with Pydantic models

src/ape/managers/_contractscache.py

+119-44
Original file line numberDiff line numberDiff line change
@@ -554,48 +554,39 @@ def get(
554554
return None
555555

556556
if contract_type := self.contract_types[address_key]:
557+
# The ContractType was previously cached.
557558
if default and default != contract_type:
558-
# Replacing contract type
559-
self.contract_types[address_key] = default
560-
return default
559+
# The given default ContractType is different than the cached one.
560+
# Merge the two and cache the merged result.
561+
combined_contract_type = _merge_contract_types(contract_type, default)
562+
self.contract_types[address_key] = combined_contract_type
563+
return combined_contract_type
561564

562565
return contract_type
563566

564-
else:
565-
# Contract is not cached yet. Check broader sources, such as an explorer.
566-
if not proxy_info and detect_proxy:
567-
# Proxy info not provided. Attempt to detect.
568-
if not (proxy_info := self.proxy_infos[address_key]):
569-
if proxy_info := self.provider.network.ecosystem.get_proxy_info(address_key):
570-
self.proxy_infos[address_key] = proxy_info
571-
572-
if proxy_info:
573-
# Contract is a proxy (either was detected or provided).
574-
implementation_contract_type = self.get(proxy_info.target, default=default)
575-
proxy_contract_type = (
576-
self._get_contract_type_from_explorer(address_key)
577-
if fetch_from_explorer
578-
else None
579-
)
580-
if proxy_contract_type is not None and implementation_contract_type is not None:
581-
combined_contract = _get_combined_contract_type(
582-
proxy_contract_type, proxy_info, implementation_contract_type
583-
)
584-
self.contract_types[address_key] = combined_contract
585-
return combined_contract
586-
587-
elif implementation_contract_type is not None:
588-
contract_type_to_cache = implementation_contract_type
589-
self.contract_types[address_key] = implementation_contract_type
590-
return contract_type_to_cache
591-
592-
elif proxy_contract_type is not None:
593-
self.contract_types[address_key] = proxy_contract_type
594-
return proxy_contract_type
595-
596-
# Also gets cached to disk for faster lookup next time.
597-
if fetch_from_explorer:
598-
contract_type = self._get_contract_type_from_explorer(address_key)
567+
# Contract is not cached yet. Check broader sources, such as an explorer.
568+
if not proxy_info and detect_proxy:
569+
# Proxy info not provided. Attempt to detect.
570+
if not (proxy_info := self.proxy_infos[address_key]):
571+
if proxy_info := self.provider.network.ecosystem.get_proxy_info(address_key):
572+
self.proxy_infos[address_key] = proxy_info
573+
574+
if proxy_info:
575+
if proxy_contract_type := self._get_proxy_contract_type(
576+
address_key,
577+
proxy_info,
578+
fetch_from_explorer=fetch_from_explorer,
579+
default=default,
580+
):
581+
# `proxy_contract_type` is one of the following:
582+
# 1. A ContractType with the combined proxy and implementation ABIs
583+
# 2. Implementation-only ABI ContractType (like forwarder proxies)
584+
# 3. Proxy only ABI (e.g. unverified implementation ContractType)
585+
return proxy_contract_type
586+
587+
# Also gets cached to disk for faster lookup next time.
588+
if fetch_from_explorer:
589+
contract_type = self._get_contract_type_from_explorer(address_key)
599590

600591
# Cache locally for faster in-session look-up.
601592
if contract_type:
@@ -606,6 +597,65 @@ def get(
606597

607598
return contract_type
608599

600+
def _get_proxy_contract_type(
601+
self,
602+
address: AddressType,
603+
proxy_info: ProxyInfoAPI,
604+
fetch_from_explorer: bool = True,
605+
default: Optional[ContractType] = None,
606+
) -> Optional[ContractType]:
607+
"""
608+
Combines the discoverable ABIs from the proxy contract and its implementation.
609+
"""
610+
implementation_contract_type = self._get_contract_type(
611+
proxy_info.target,
612+
fetch_from_explorer=fetch_from_explorer,
613+
default=default,
614+
)
615+
proxy_contract_type = self._get_contract_type(
616+
address, fetch_from_explorer=fetch_from_explorer
617+
)
618+
if proxy_contract_type is not None and implementation_contract_type is not None:
619+
combined_contract = _get_combined_contract_type(
620+
proxy_contract_type, proxy_info, implementation_contract_type
621+
)
622+
self.contract_types[address] = combined_contract
623+
return combined_contract
624+
625+
elif implementation_contract_type is not None:
626+
contract_type_to_cache = implementation_contract_type
627+
self.contract_types[address] = implementation_contract_type
628+
return contract_type_to_cache
629+
630+
elif proxy_contract_type is not None:
631+
# In this case, the implementation ContactType was not discovered.
632+
# However, we were able to discover the ContractType of the proxy.
633+
# Proceed with caching the proxy; the user can update the type later
634+
# when the implementation is discoverable.
635+
self.contract_types[address] = proxy_contract_type
636+
return proxy_contract_type
637+
638+
logger.warning(f"Unable to determine the ContractType for the proxy at '{address}'.")
639+
return None
640+
641+
def _get_contract_type(
642+
self,
643+
address: AddressType,
644+
fetch_from_explorer: bool = True,
645+
default: Optional[ContractType] = None,
646+
) -> Optional[ContractType]:
647+
"""
648+
Get the _exact_ ContractType for a given address. For proxy contracts, returns
649+
the proxy ABIs if there are any and not the implementation ABIs.
650+
"""
651+
if contract_type := self.contract_types[address]:
652+
return contract_type
653+
654+
elif fetch_from_explorer:
655+
return self._get_contract_type_from_explorer(address)
656+
657+
return default
658+
609659
@classmethod
610660
def get_container(cls, contract_type: ContractType) -> ContractContainer:
611661
"""
@@ -859,6 +909,16 @@ def _get_contract_type_from_explorer(self, address: AddressType) -> Optional[Con
859909

860910
if contract_type:
861911
# Cache contract so faster look-up next time.
912+
if not isinstance(contract_type, ContractType):
913+
explorer_name = self.provider.network.explorer.name
914+
wrong_type = type(contract_type)
915+
wrong_type_str = getattr(wrong_type, "__name__", f"{wrong_type}")
916+
logger.warning(
917+
f"Explorer '{explorer_name}' returned unexpected "
918+
f"type '{wrong_type_str}' ContractType."
919+
)
920+
return None
921+
862922
self.contract_types[address] = contract_type
863923

864924
return contract_type
@@ -869,16 +929,31 @@ def _get_combined_contract_type(
869929
proxy_info: ProxyInfoAPI,
870930
implementation_contract_type: ContractType,
871931
) -> ContractType:
872-
proxy_abis = [
873-
abi for abi in proxy_contract_type.abi if abi.type in ("error", "event", "function")
874-
]
932+
proxy_abis = _get_relevant_additive_abis(proxy_contract_type)
875933

876934
# Include "hidden" ABIs, such as Safe's `masterCopy()`.
877935
if proxy_info.abi and proxy_info.abi.signature not in [
878936
abi.signature for abi in implementation_contract_type.abi
879937
]:
880938
proxy_abis.append(proxy_info.abi)
881939

882-
combined_contract_type = implementation_contract_type.model_copy(deep=True)
883-
combined_contract_type.abi.extend(proxy_abis)
884-
return combined_contract_type
940+
return _merge_abis(implementation_contract_type, proxy_abis)
941+
942+
943+
def _get_relevant_additive_abis(contract_type: ContractType) -> list[ABI]:
944+
# Get ABIs you would want to add to a base contract as extra,
945+
# such as unique ABIs from proxies.
946+
return [abi for abi in contract_type.abi if abi.type in ("error", "event", "function")]
947+
948+
949+
def _merge_abis(base_contract: ContractType, extra_abis: list[ABI]) -> ContractType:
950+
contract_type = base_contract.model_copy(deep=True)
951+
contract_type.abi.extend(extra_abis)
952+
return contract_type
953+
954+
955+
def _merge_contract_types(
956+
base_contract_type: ContractType, additive_contract_type: ContractType
957+
) -> ContractType:
958+
relevant_abis = _get_relevant_additive_abis(additive_contract_type)
959+
return _merge_abis(base_contract_type, relevant_abis)

tests/functional/conftest.py

+7
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,13 @@ def dummy_live_network(chain):
472472
chain.provider.network.name = original_network
473473

474474

475+
@pytest.fixture
476+
def dummy_live_network_with_explorer(dummy_live_network, mock_explorer):
477+
dummy_live_network.__dict__["explorer"] = mock_explorer
478+
yield dummy_live_network
479+
dummy_live_network.__dict__.pop("explorer", None)
480+
481+
475482
@pytest.fixture(scope="session")
476483
def proxy_contract_container(get_contract_type):
477484
return ContractContainer(get_contract_type("proxy"))

tests/functional/test_accounts.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -314,19 +314,19 @@ def test_deploy_and_publish_live_network_no_explorer(owner, contract_container,
314314

315315

316316
@explorer_test
317-
def test_deploy_and_publish(owner, contract_container, dummy_live_network, mock_explorer):
318-
dummy_live_network.__dict__["explorer"] = mock_explorer
317+
def test_deploy_and_publish(
318+
owner, contract_container, dummy_live_network_with_explorer, mock_explorer
319+
):
319320
contract = owner.deploy(contract_container, 0, publish=True, required_confirmations=0)
320321
mock_explorer.publish_contract.assert_called_once_with(contract.address)
321-
dummy_live_network.__dict__["explorer"] = None
322322

323323

324324
@explorer_test
325-
def test_deploy_and_not_publish(owner, contract_container, dummy_live_network, mock_explorer):
326-
dummy_live_network.__dict__["explorer"] = mock_explorer
325+
def test_deploy_and_not_publish(
326+
owner, contract_container, dummy_live_network_with_explorer, mock_explorer
327+
):
327328
owner.deploy(contract_container, 0, publish=True, required_confirmations=0)
328329
assert not mock_explorer.call_count
329-
dummy_live_network.__dict__["explorer"] = None
330330

331331

332332
def test_deploy_proxy(owner, vyper_contract_instance, proxy_contract_container, chain):

tests/functional/test_contract.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -85,12 +85,13 @@ def test_Contract_at_unknown_address(networks_connected_to_tester, address):
8585

8686

8787
def test_Contract_specify_contract_type(
88-
solidity_contract_instance, vyper_contract_type, owner, networks_connected_to_tester
88+
vyper_contract_instance, solidity_contract_type, owner, networks_connected_to_tester
8989
):
90-
# Vyper contract type is very close to solidity's.
90+
# Solidity's contract type is very close to Vyper's.
9191
# This test purposely uses the other just to show we are able to specify it externally.
92-
contract = Contract(solidity_contract_instance.address, contract_type=vyper_contract_type)
93-
assert contract.address == solidity_contract_instance.address
94-
assert contract.contract_type == vyper_contract_type
95-
assert contract.setNumber(2, sender=owner)
96-
assert contract.myNumber() == 2
92+
contract = Contract(vyper_contract_instance.address, contract_type=solidity_contract_type)
93+
assert contract.address == vyper_contract_instance.address
94+
95+
abis = [abi.name for abi in contract.contract_type.abi if hasattr(abi, "name")]
96+
assert "setNumber" in abis # Shared ABI.
97+
assert "ACustomError" in abis # SolidityContract-defined ABI.

tests/functional/test_contract_container.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
ProjectError,
1111
)
1212
from ape_ethereum.ecosystem import ProxyType
13+
from tests.conftest import explorer_test
1314

1415

1516
def test_deploy(
@@ -55,18 +56,20 @@ def test_deploy_and_publish_live_network_no_explorer(owner, contract_container,
5556
contract_container.deploy(0, sender=owner, publish=True, required_confirmations=0)
5657

5758

58-
def test_deploy_and_publish(owner, contract_container, dummy_live_network, mock_explorer):
59-
dummy_live_network.__dict__["explorer"] = mock_explorer
59+
@explorer_test
60+
def test_deploy_and_publish(
61+
owner, contract_container, dummy_live_network_with_explorer, mock_explorer
62+
):
6063
contract = contract_container.deploy(0, sender=owner, publish=True, required_confirmations=0)
6164
mock_explorer.publish_contract.assert_called_once_with(contract.address)
62-
dummy_live_network.__dict__["explorer"] = None
6365

6466

65-
def test_deploy_and_not_publish(owner, contract_container, dummy_live_network, mock_explorer):
66-
dummy_live_network.__dict__["explorer"] = mock_explorer
67+
@explorer_test
68+
def test_deploy_and_not_publish(
69+
owner, contract_container, dummy_live_network_with_explorer, mock_explorer
70+
):
6771
contract_container.deploy(0, sender=owner, publish=False, required_confirmations=0)
6872
assert not mock_explorer.call_count
69-
dummy_live_network.__dict__["explorer"] = None
7073

7174

7275
def test_deploy_privately(owner, contract_container):

tests/functional/test_contract_instance.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -880,9 +880,9 @@ def test_value_to_non_payable_fallback_and_no_receive(
880880
break
881881

882882
new_contract_type = ContractType.model_validate(contract_type_data)
883-
contract = owner.chain_manager.contracts.instance_at(
884-
vyper_fallback_contract.address, contract_type=new_contract_type
885-
)
883+
contract = owner.chain_manager.contracts.instance_at(vyper_fallback_contract.address)
884+
contract.contract_type = new_contract_type # Setting to completely override instead of merge.
885+
886886
expected = (
887887
r"Contract's fallback is non-payable and there is no receive ABI\. Unable to send value\."
888888
)

0 commit comments

Comments
 (0)