@@ -58,6 +58,7 @@ struct bpf_stab {
58
58
struct bpf_map map ;
59
59
struct sock * * sock_map ;
60
60
struct bpf_sock_progs progs ;
61
+ raw_spinlock_t lock ;
61
62
};
62
63
63
64
struct bucket {
@@ -89,9 +90,9 @@ enum smap_psock_state {
89
90
90
91
struct smap_psock_map_entry {
91
92
struct list_head list ;
93
+ struct bpf_map * map ;
92
94
struct sock * * entry ;
93
95
struct htab_elem __rcu * hash_link ;
94
- struct bpf_htab __rcu * htab ;
95
96
};
96
97
97
98
struct smap_psock {
@@ -343,13 +344,18 @@ static void bpf_tcp_close(struct sock *sk, long timeout)
343
344
e = psock_map_pop (sk , psock );
344
345
while (e ) {
345
346
if (e -> entry ) {
346
- osk = cmpxchg (e -> entry , sk , NULL );
347
+ struct bpf_stab * stab = container_of (e -> map , struct bpf_stab , map );
348
+
349
+ raw_spin_lock_bh (& stab -> lock );
350
+ osk = * e -> entry ;
347
351
if (osk == sk ) {
352
+ * e -> entry = NULL ;
348
353
smap_release_sock (psock , sk );
349
354
}
355
+ raw_spin_unlock_bh (& stab -> lock );
350
356
} else {
351
357
struct htab_elem * link = rcu_dereference (e -> hash_link );
352
- struct bpf_htab * htab = rcu_dereference (e -> htab );
358
+ struct bpf_htab * htab = container_of (e -> map , struct bpf_htab , map );
353
359
struct hlist_head * head ;
354
360
struct htab_elem * l ;
355
361
struct bucket * b ;
@@ -370,6 +376,7 @@ static void bpf_tcp_close(struct sock *sk, long timeout)
370
376
}
371
377
raw_spin_unlock_bh (& b -> lock );
372
378
}
379
+ kfree (e );
373
380
e = psock_map_pop (sk , psock );
374
381
}
375
382
rcu_read_unlock ();
@@ -1641,6 +1648,7 @@ static struct bpf_map *sock_map_alloc(union bpf_attr *attr)
1641
1648
return ERR_PTR (- ENOMEM );
1642
1649
1643
1650
bpf_map_init_from_attr (& stab -> map , attr );
1651
+ raw_spin_lock_init (& stab -> lock );
1644
1652
1645
1653
/* make sure page count doesn't overflow */
1646
1654
cost = (u64 ) stab -> map .max_entries * sizeof (struct sock * );
@@ -1675,8 +1683,10 @@ static void smap_list_map_remove(struct smap_psock *psock,
1675
1683
1676
1684
spin_lock_bh (& psock -> maps_lock );
1677
1685
list_for_each_entry_safe (e , tmp , & psock -> maps , list ) {
1678
- if (e -> entry == entry )
1686
+ if (e -> entry == entry ) {
1679
1687
list_del (& e -> list );
1688
+ kfree (e );
1689
+ }
1680
1690
}
1681
1691
spin_unlock_bh (& psock -> maps_lock );
1682
1692
}
@@ -1690,8 +1700,10 @@ static void smap_list_hash_remove(struct smap_psock *psock,
1690
1700
list_for_each_entry_safe (e , tmp , & psock -> maps , list ) {
1691
1701
struct htab_elem * c = rcu_dereference (e -> hash_link );
1692
1702
1693
- if (c == hash_link )
1703
+ if (c == hash_link ) {
1694
1704
list_del (& e -> list );
1705
+ kfree (e );
1706
+ }
1695
1707
}
1696
1708
spin_unlock_bh (& psock -> maps_lock );
1697
1709
}
@@ -1711,14 +1723,15 @@ static void sock_map_free(struct bpf_map *map)
1711
1723
* and a grace period expire to ensure psock is really safe to remove.
1712
1724
*/
1713
1725
rcu_read_lock ();
1726
+ raw_spin_lock_bh (& stab -> lock );
1714
1727
for (i = 0 ; i < stab -> map .max_entries ; i ++ ) {
1715
1728
struct smap_psock * psock ;
1716
1729
struct sock * sock ;
1717
1730
1718
- sock = xchg ( & stab -> sock_map [i ], NULL ) ;
1731
+ sock = stab -> sock_map [i ];
1719
1732
if (!sock )
1720
1733
continue ;
1721
-
1734
+ stab -> sock_map [ i ] = NULL ;
1722
1735
psock = smap_psock_sk (sock );
1723
1736
/* This check handles a racing sock event that can get the
1724
1737
* sk_callback_lock before this case but after xchg happens
@@ -1730,6 +1743,7 @@ static void sock_map_free(struct bpf_map *map)
1730
1743
smap_release_sock (psock , sock );
1731
1744
}
1732
1745
}
1746
+ raw_spin_unlock_bh (& stab -> lock );
1733
1747
rcu_read_unlock ();
1734
1748
1735
1749
sock_map_remove_complete (stab );
@@ -1773,19 +1787,23 @@ static int sock_map_delete_elem(struct bpf_map *map, void *key)
1773
1787
if (k >= map -> max_entries )
1774
1788
return - EINVAL ;
1775
1789
1776
- sock = xchg (& stab -> sock_map [k ], NULL );
1790
+ raw_spin_lock_bh (& stab -> lock );
1791
+ sock = stab -> sock_map [k ];
1792
+ stab -> sock_map [k ] = NULL ;
1793
+ raw_spin_unlock_bh (& stab -> lock );
1777
1794
if (!sock )
1778
1795
return - EINVAL ;
1779
1796
1780
1797
psock = smap_psock_sk (sock );
1781
1798
if (!psock )
1782
- goto out ;
1783
-
1784
- if ( psock -> bpf_parse )
1799
+ return 0 ;
1800
+ if ( psock -> bpf_parse ) {
1801
+ write_lock_bh ( & sock -> sk_callback_lock );
1785
1802
smap_stop_sock (psock , sock );
1803
+ write_unlock_bh (& sock -> sk_callback_lock );
1804
+ }
1786
1805
smap_list_map_remove (psock , & stab -> sock_map [k ]);
1787
1806
smap_release_sock (psock , sock );
1788
- out :
1789
1807
return 0 ;
1790
1808
}
1791
1809
@@ -1821,11 +1839,9 @@ static int sock_map_delete_elem(struct bpf_map *map, void *key)
1821
1839
static int __sock_map_ctx_update_elem (struct bpf_map * map ,
1822
1840
struct bpf_sock_progs * progs ,
1823
1841
struct sock * sock ,
1824
- struct sock * * map_link ,
1825
1842
void * key )
1826
1843
{
1827
1844
struct bpf_prog * verdict , * parse , * tx_msg ;
1828
- struct smap_psock_map_entry * e = NULL ;
1829
1845
struct smap_psock * psock ;
1830
1846
bool new = false;
1831
1847
int err = 0 ;
@@ -1898,14 +1914,6 @@ static int __sock_map_ctx_update_elem(struct bpf_map *map,
1898
1914
new = true;
1899
1915
}
1900
1916
1901
- if (map_link ) {
1902
- e = kzalloc (sizeof (* e ), GFP_ATOMIC | __GFP_NOWARN );
1903
- if (!e ) {
1904
- err = - ENOMEM ;
1905
- goto out_free ;
1906
- }
1907
- }
1908
-
1909
1917
/* 3. At this point we have a reference to a valid psock that is
1910
1918
* running. Attach any BPF programs needed.
1911
1919
*/
@@ -1927,17 +1935,6 @@ static int __sock_map_ctx_update_elem(struct bpf_map *map,
1927
1935
write_unlock_bh (& sock -> sk_callback_lock );
1928
1936
}
1929
1937
1930
- /* 4. Place psock in sockmap for use and stop any programs on
1931
- * the old sock assuming its not the same sock we are replacing
1932
- * it with. Because we can only have a single set of programs if
1933
- * old_sock has a strp we can stop it.
1934
- */
1935
- if (map_link ) {
1936
- e -> entry = map_link ;
1937
- spin_lock_bh (& psock -> maps_lock );
1938
- list_add_tail (& e -> list , & psock -> maps );
1939
- spin_unlock_bh (& psock -> maps_lock );
1940
- }
1941
1938
return err ;
1942
1939
out_free :
1943
1940
smap_release_sock (psock , sock );
@@ -1948,7 +1945,6 @@ static int __sock_map_ctx_update_elem(struct bpf_map *map,
1948
1945
}
1949
1946
if (tx_msg )
1950
1947
bpf_prog_put (tx_msg );
1951
- kfree (e );
1952
1948
return err ;
1953
1949
}
1954
1950
@@ -1958,36 +1954,57 @@ static int sock_map_ctx_update_elem(struct bpf_sock_ops_kern *skops,
1958
1954
{
1959
1955
struct bpf_stab * stab = container_of (map , struct bpf_stab , map );
1960
1956
struct bpf_sock_progs * progs = & stab -> progs ;
1961
- struct sock * osock , * sock ;
1957
+ struct sock * osock , * sock = skops -> sk ;
1958
+ struct smap_psock_map_entry * e ;
1959
+ struct smap_psock * psock ;
1962
1960
u32 i = * (u32 * )key ;
1963
1961
int err ;
1964
1962
1965
1963
if (unlikely (flags > BPF_EXIST ))
1966
1964
return - EINVAL ;
1967
-
1968
1965
if (unlikely (i >= stab -> map .max_entries ))
1969
1966
return - E2BIG ;
1970
1967
1971
- sock = READ_ONCE (stab -> sock_map [i ]);
1972
- if (flags == BPF_EXIST && !sock )
1973
- return - ENOENT ;
1974
- else if (flags == BPF_NOEXIST && sock )
1975
- return - EEXIST ;
1968
+ e = kzalloc (sizeof (* e ), GFP_ATOMIC | __GFP_NOWARN );
1969
+ if (!e )
1970
+ return - ENOMEM ;
1976
1971
1977
- sock = skops -> sk ;
1978
- err = __sock_map_ctx_update_elem (map , progs , sock , & stab -> sock_map [i ],
1979
- key );
1972
+ err = __sock_map_ctx_update_elem (map , progs , sock , key );
1980
1973
if (err )
1981
1974
goto out ;
1982
1975
1983
- osock = xchg (& stab -> sock_map [i ], sock );
1984
- if (osock ) {
1985
- struct smap_psock * opsock = smap_psock_sk (osock );
1976
+ /* psock guaranteed to be present. */
1977
+ psock = smap_psock_sk (sock );
1978
+ raw_spin_lock_bh (& stab -> lock );
1979
+ osock = stab -> sock_map [i ];
1980
+ if (osock && flags == BPF_NOEXIST ) {
1981
+ err = - EEXIST ;
1982
+ goto out_unlock ;
1983
+ }
1984
+ if (!osock && flags == BPF_EXIST ) {
1985
+ err = - ENOENT ;
1986
+ goto out_unlock ;
1987
+ }
1986
1988
1987
- smap_list_map_remove (opsock , & stab -> sock_map [i ]);
1988
- smap_release_sock (opsock , osock );
1989
+ e -> entry = & stab -> sock_map [i ];
1990
+ e -> map = map ;
1991
+ spin_lock_bh (& psock -> maps_lock );
1992
+ list_add_tail (& e -> list , & psock -> maps );
1993
+ spin_unlock_bh (& psock -> maps_lock );
1994
+
1995
+ stab -> sock_map [i ] = sock ;
1996
+ if (osock ) {
1997
+ psock = smap_psock_sk (osock );
1998
+ smap_list_map_remove (psock , & stab -> sock_map [i ]);
1999
+ smap_release_sock (psock , osock );
1989
2000
}
2001
+ raw_spin_unlock_bh (& stab -> lock );
2002
+ return 0 ;
2003
+ out_unlock :
2004
+ smap_release_sock (psock , sock );
2005
+ raw_spin_unlock_bh (& stab -> lock );
1990
2006
out :
2007
+ kfree (e );
1991
2008
return err ;
1992
2009
}
1993
2010
@@ -2350,7 +2367,7 @@ static int sock_hash_ctx_update_elem(struct bpf_sock_ops_kern *skops,
2350
2367
b = __select_bucket (htab , hash );
2351
2368
head = & b -> head ;
2352
2369
2353
- err = __sock_map_ctx_update_elem (map , progs , sock , NULL , key );
2370
+ err = __sock_map_ctx_update_elem (map , progs , sock , key );
2354
2371
if (err )
2355
2372
goto err ;
2356
2373
@@ -2376,8 +2393,7 @@ static int sock_hash_ctx_update_elem(struct bpf_sock_ops_kern *skops,
2376
2393
}
2377
2394
2378
2395
rcu_assign_pointer (e -> hash_link , l_new );
2379
- rcu_assign_pointer (e -> htab ,
2380
- container_of (map , struct bpf_htab , map ));
2396
+ e -> map = map ;
2381
2397
spin_lock_bh (& psock -> maps_lock );
2382
2398
list_add_tail (& e -> list , & psock -> maps );
2383
2399
spin_unlock_bh (& psock -> maps_lock );
0 commit comments