Skip to content
/ ape Public
forked from ApeWorX/ape

Commit e69daba

Browse files
authored
fix: avoid attempting to cache None for a contract-type (ApeWorX#2517)
1 parent d504d18 commit e69daba

File tree

2 files changed

+65
-5
lines changed

2 files changed

+65
-5
lines changed

src/ape/managers/_contractscache.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -577,15 +577,21 @@ def get(
577577
if fetch_from_explorer
578578
else None
579579
)
580-
if proxy_contract_type:
581-
contract_type_to_cache = _get_combined_contract_type(
580+
if proxy_contract_type is not None and implementation_contract_type is not None:
581+
combined_contract = _get_combined_contract_type(
582582
proxy_contract_type, proxy_info, implementation_contract_type
583583
)
584-
else:
584+
self.contract_types[address_key] = combined_contract
585+
return combined_contract
586+
587+
elif implementation_contract_type is not None:
585588
contract_type_to_cache = implementation_contract_type
589+
self.contract_types[address_key] = implementation_contract_type
590+
return contract_type_to_cache
586591

587-
self.contract_types[address_key] = contract_type_to_cache
588-
return contract_type_to_cache
592+
elif proxy_contract_type is not None:
593+
self.contract_types[address_key] = proxy_contract_type
594+
return proxy_contract_type
589595

590596
# Also gets cached to disk for faster lookup next time.
591597
if fetch_from_explorer:

tests/functional/test_contracts_cache.py

+54
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,60 @@ def test_cache_non_checksum_address(chain, vyper_contract_instance):
474474
assert chain.contracts[vyper_contract_instance.address] == vyper_contract_instance.contract_type
475475

476476

477+
def test_get_when_proxy(chain, owner, minimal_proxy_container, vyper_contract_instance):
478+
placeholder = "0xBEbeBeBEbeBebeBeBEBEbebEBeBeBebeBeBebebe"
479+
if placeholder in chain.contracts:
480+
del chain.contracts[placeholder]
481+
482+
minimal_proxy = owner.deploy(minimal_proxy_container, sender=owner)
483+
chain.provider.network.__dict__["explorer"] = None # Ensure no explorer, messes up test.
484+
485+
actual = chain.contracts.get(minimal_proxy.address)
486+
assert actual == minimal_proxy.contract_type
487+
488+
489+
def test_get_when_proxy_but_implementation_missing(chain, owner, vyper_contract_container):
490+
"""
491+
Proxy is cached but implementation is missing.
492+
"""
493+
placeholder = vyper_contract_container.deploy(1001, sender=owner)
494+
assert chain.contracts[placeholder.address] # This must be cached!
495+
496+
proxy_container = _make_minimal_proxy(placeholder.address)
497+
minimal_proxy = owner.deploy(proxy_container, sender=owner)
498+
chain.provider.network.__dict__["explorer"] = None # Ensure no explorer, messes up test.
499+
500+
if minimal_proxy.address in chain.contracts:
501+
# Delete the proxy but make sure it does not delete the implementation!
502+
# (which it normally does here).
503+
del chain.contracts[minimal_proxy.address]
504+
chain.contracts[placeholder.address] = placeholder
505+
506+
actual = chain.contracts.get(minimal_proxy.address)
507+
assert actual == minimal_proxy.contract_type
508+
509+
510+
def test_get_pass_along_proxy_info(chain, owner, minimal_proxy_container, ethereum):
511+
placeholder = "0xBEbeBeBEbeBebeBeBEBEbebEBeBeBebeBeBebebe"
512+
if placeholder in chain.contracts:
513+
del chain.contracts[placeholder]
514+
515+
minimal_proxy = owner.deploy(minimal_proxy_container, sender=owner)
516+
chain.provider.network.__dict__["explorer"] = None # Ensure no explorer, messes up test.
517+
info = ethereum.get_proxy_info(minimal_proxy.address)
518+
assert info
519+
520+
# Ensure not already cached.
521+
if minimal_proxy.address in chain.contracts:
522+
del chain.contracts[minimal_proxy.address]
523+
524+
actual = chain.contracts.get(minimal_proxy.address, proxy_info=info)
525+
assert actual is None # It can't find the contact anymore.
526+
527+
# Ensure it does store 'None' (was a bug where it did).
528+
assert minimal_proxy.address not in chain.contracts.contract_types
529+
530+
477531
def test_get_creation_metadata(chain, vyper_contract_instance, owner):
478532
address = vyper_contract_instance.address
479533
creation = chain.contracts.get_creation_metadata(address)

0 commit comments

Comments
 (0)